Compare commits
138 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4b149b8a52 | ||
|
|
ad0c06851d | ||
|
|
49359b1416 | ||
|
|
911d950c15 | ||
|
|
b2a3ac60e0 | ||
|
|
dea092e1bb | ||
|
|
4154dd5268 | ||
|
|
db12937495 | ||
|
|
81e1456631 | ||
|
|
58e78d8700 | ||
|
|
5e80139afa | ||
|
|
60ebeaa582 | ||
|
|
758052dcff | ||
|
|
1628eacf2b | ||
|
|
2bea057fb1 | ||
|
|
9e54f16e56 | ||
|
|
96e6400bd8 | ||
|
|
6c7b72f130 | ||
|
|
027ebfc332 | ||
|
|
c2466ab401 | ||
|
|
7828d35e06 | ||
|
|
6b07e62f00 | ||
|
|
0d2005ddcb | ||
|
|
dfa75e682e | ||
|
|
22ecb5e989 | ||
|
|
2ab69f8561 | ||
|
|
95342cc94d | ||
|
|
f6194b3e19 | ||
|
|
6bb376a336 | ||
|
|
cd7a9ccbdf | ||
|
|
953af52d0e | ||
|
|
e6e98d4ed1 | ||
|
|
ca5f5e3383 | ||
|
|
d0fc7f37ff | ||
|
|
18d717e996 | ||
|
|
f482559e15 | ||
|
|
6e8b0b022a | ||
|
|
746fb7b181 | ||
|
|
caf283bed2 | ||
|
|
520c06175e | ||
|
|
065e43c5a9 | ||
|
|
c8b88dadc3 | ||
|
|
015f2de6c6 | ||
|
|
f36bfb3781 | ||
|
|
ef659cd72d | ||
|
|
728edd1453 | ||
|
|
498c0a0e94 | ||
|
|
e5975fa5d0 | ||
|
|
731a188a76 | ||
|
|
fe2104822e | ||
|
|
664415111a | ||
|
|
acd18ff694 | ||
|
|
da5affd613 | ||
|
|
a79d923dc1 | ||
|
|
c72f6aa2f9 | ||
|
|
4f24cebf11 | ||
|
|
e0739a786c | ||
|
|
64576da7dc | ||
|
|
4a55bd63a3 | ||
|
|
a78b903f5a | ||
|
|
c7b2c82700 | ||
|
|
50b865b23b | ||
|
|
6f5dd58b54 | ||
|
|
0ceee8545e | ||
|
|
62aea06e0d | ||
|
|
24f1cc637e | ||
|
|
8b6cca5d4d | ||
|
|
c9700f760e | ||
|
|
6f509e71ce | ||
|
|
f5a86953c6 | ||
|
|
246d2a6752 | ||
|
|
36ab7069cf | ||
|
|
a4c91cb8c3 | ||
|
|
a7ba0f9bd8 | ||
|
|
f3fb4ecbeb | ||
|
|
5c35702caf | ||
|
|
7280b182bd | ||
|
|
06b2491c1f | ||
|
|
b8265783f3 | ||
|
|
63066c50ba | ||
|
|
ddf9b5fe25 | ||
|
|
c3b66cccfc | ||
|
|
896f0d92e5 | ||
|
|
2ccaeb23f2 | ||
|
|
04c939d4c2 | ||
|
|
71c94c3b5a | ||
|
|
d71891ac4e | ||
|
|
3492941aec | ||
|
|
81e8d7e73d | ||
|
|
f0b04d53af | ||
|
|
35af7daf90 | ||
|
|
5fab15a11e | ||
|
|
ab913575e1 | ||
|
|
82cb6386a6 | ||
|
|
2d05035c1d | ||
|
|
15d747eb28 | ||
|
|
3d6fa6b791 | ||
|
|
3ea1874638 | ||
|
|
e1657d5ad8 | ||
|
|
83fa51fd4a | ||
|
|
db868c53c6 | ||
|
|
68f1865a1e | ||
|
|
5b1e2852ea | ||
|
|
d0a88d1fd1 | ||
|
|
e85788f79f | ||
|
|
25d42ee2a6 | ||
|
|
e41ceafaef | ||
|
|
43fa69db7d | ||
|
|
29309e5cfd | ||
|
|
cea97afe25 | ||
|
|
b43fa8ace2 | ||
|
|
742ce4c9c8 | ||
|
|
6ea9edf3d1 | ||
|
|
25b8f1723e | ||
|
|
73d10f364c | ||
|
|
2310c8cdfd | ||
|
|
2f7124959d | ||
|
|
2104ae38ec | ||
|
|
2055320058 | ||
|
|
11da0d57a8 | ||
|
|
acfda1e9a9 | ||
|
|
3c24a8c522 | ||
|
|
ec111f9ce6 | ||
|
|
520a4d60fb | ||
|
|
6e645835dc | ||
|
|
fcda8f0f96 | ||
|
|
d6db6af964 | ||
|
|
88cf4e0abc | ||
|
|
f138417486 | ||
|
|
de47d9ee43 | ||
|
|
406b25cda0 | ||
|
|
bd702734c2 | ||
|
|
5594655fba | ||
|
|
ebd307cab4 | ||
|
|
6e3cdebbfb | ||
|
|
a6a336b66e | ||
|
|
9901dc7f51 | ||
|
|
ac64d9505e |
@@ -1,15 +1,22 @@
|
||||
# Common settings
|
||||
PROJECT_NAME=App
|
||||
PROJECT_NAME=Syndarix
|
||||
VERSION=1.0.0
|
||||
|
||||
# Database settings
|
||||
POSTGRES_USER=postgres
|
||||
POSTGRES_PASSWORD=postgres
|
||||
POSTGRES_DB=app
|
||||
POSTGRES_DB=syndarix
|
||||
POSTGRES_HOST=db
|
||||
POSTGRES_PORT=5432
|
||||
DATABASE_URL=postgresql://${POSTGRES_USER}:${POSTGRES_PASSWORD}@${POSTGRES_HOST}:${POSTGRES_PORT}/${POSTGRES_DB}
|
||||
|
||||
# Redis settings (cache, pub/sub, Celery broker)
|
||||
REDIS_URL=redis://redis:6379/0
|
||||
|
||||
# Celery settings (optional - defaults to REDIS_URL if not set)
|
||||
# CELERY_BROKER_URL=redis://redis:6379/0
|
||||
# CELERY_RESULT_BACKEND=redis://redis:6379/0
|
||||
|
||||
# Backend settings
|
||||
BACKEND_PORT=8000
|
||||
# CRITICAL: Generate a secure SECRET_KEY for production!
|
||||
|
||||
460
.gitea/workflows/ci.yaml
Normal file
460
.gitea/workflows/ci.yaml
Normal file
@@ -0,0 +1,460 @@
|
||||
# Syndarix CI/CD Pipeline
|
||||
# Gitea Actions workflow for continuous integration and deployment
|
||||
#
|
||||
# Pipeline Structure:
|
||||
# - lint: Fast feedback (linting and type checking)
|
||||
# - test: Run test suites (depends on lint)
|
||||
# - build: Build Docker images (depends on test)
|
||||
# - deploy: Deploy to production (depends on build, only on main)
|
||||
|
||||
name: CI/CD Pipeline
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- dev
|
||||
- 'feature/**'
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- dev
|
||||
|
||||
env:
|
||||
PYTHON_VERSION: "3.12"
|
||||
NODE_VERSION: "20"
|
||||
UV_VERSION: "0.4.x"
|
||||
|
||||
jobs:
|
||||
# ===========================================================================
|
||||
# LINT JOB - Fast feedback first
|
||||
# ===========================================================================
|
||||
lint:
|
||||
name: Lint & Type Check
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
component: [backend, frontend]
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
# ----- Backend Linting -----
|
||||
- name: Set up Python
|
||||
if: matrix.component == 'backend'
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ env.PYTHON_VERSION }}
|
||||
|
||||
- name: Install uv
|
||||
if: matrix.component == 'backend'
|
||||
uses: astral-sh/setup-uv@v4
|
||||
with:
|
||||
version: ${{ env.UV_VERSION }}
|
||||
|
||||
- name: Cache uv dependencies
|
||||
if: matrix.component == 'backend'
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/.cache/uv
|
||||
backend/.venv
|
||||
key: uv-${{ runner.os }}-${{ hashFiles('backend/uv.lock') }}
|
||||
restore-keys: |
|
||||
uv-${{ runner.os }}-
|
||||
|
||||
- name: Install backend dependencies
|
||||
if: matrix.component == 'backend'
|
||||
working-directory: backend
|
||||
run: uv sync --extra dev --frozen
|
||||
|
||||
- name: Run ruff linting
|
||||
if: matrix.component == 'backend'
|
||||
working-directory: backend
|
||||
run: uv run ruff check app
|
||||
|
||||
- name: Run ruff format check
|
||||
if: matrix.component == 'backend'
|
||||
working-directory: backend
|
||||
run: uv run ruff format --check app
|
||||
|
||||
- name: Run mypy type checking
|
||||
if: matrix.component == 'backend'
|
||||
working-directory: backend
|
||||
run: uv run mypy app
|
||||
|
||||
# ----- Frontend Linting -----
|
||||
- name: Set up Node.js
|
||||
if: matrix.component == 'frontend'
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: ${{ env.NODE_VERSION }}
|
||||
|
||||
- name: Cache npm dependencies
|
||||
if: matrix.component == 'frontend'
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/.npm
|
||||
frontend/node_modules
|
||||
key: npm-${{ runner.os }}-${{ hashFiles('frontend/package-lock.json') }}
|
||||
restore-keys: |
|
||||
npm-${{ runner.os }}-
|
||||
|
||||
- name: Install frontend dependencies
|
||||
if: matrix.component == 'frontend'
|
||||
working-directory: frontend
|
||||
run: npm ci
|
||||
|
||||
- name: Run ESLint
|
||||
if: matrix.component == 'frontend'
|
||||
working-directory: frontend
|
||||
run: npm run lint
|
||||
|
||||
- name: Run TypeScript type check
|
||||
if: matrix.component == 'frontend'
|
||||
working-directory: frontend
|
||||
run: npm run type-check
|
||||
|
||||
- name: Run Prettier format check
|
||||
if: matrix.component == 'frontend'
|
||||
working-directory: frontend
|
||||
run: npm run format:check
|
||||
|
||||
# ===========================================================================
|
||||
# TEST JOB - Run test suites
|
||||
# ===========================================================================
|
||||
test:
|
||||
name: Test
|
||||
runs-on: ubuntu-latest
|
||||
needs: lint
|
||||
strategy:
|
||||
matrix:
|
||||
component: [backend, frontend]
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
# ----- Backend Tests -----
|
||||
- name: Set up Python
|
||||
if: matrix.component == 'backend'
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ env.PYTHON_VERSION }}
|
||||
|
||||
- name: Install uv
|
||||
if: matrix.component == 'backend'
|
||||
uses: astral-sh/setup-uv@v4
|
||||
with:
|
||||
version: ${{ env.UV_VERSION }}
|
||||
|
||||
- name: Cache uv dependencies
|
||||
if: matrix.component == 'backend'
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/.cache/uv
|
||||
backend/.venv
|
||||
key: uv-${{ runner.os }}-${{ hashFiles('backend/uv.lock') }}
|
||||
restore-keys: |
|
||||
uv-${{ runner.os }}-
|
||||
|
||||
- name: Install backend dependencies
|
||||
if: matrix.component == 'backend'
|
||||
working-directory: backend
|
||||
run: uv sync --extra dev --frozen
|
||||
|
||||
- name: Run pytest with coverage
|
||||
if: matrix.component == 'backend'
|
||||
working-directory: backend
|
||||
env:
|
||||
IS_TEST: "True"
|
||||
run: |
|
||||
uv run pytest --cov=app --cov-report=xml --cov-report=term-missing --cov-fail-under=90
|
||||
|
||||
- name: Upload backend coverage report
|
||||
if: matrix.component == 'backend'
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: backend-coverage
|
||||
path: backend/coverage.xml
|
||||
retention-days: 7
|
||||
|
||||
# ----- Frontend Tests -----
|
||||
- name: Set up Node.js
|
||||
if: matrix.component == 'frontend'
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: ${{ env.NODE_VERSION }}
|
||||
|
||||
- name: Cache npm dependencies
|
||||
if: matrix.component == 'frontend'
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/.npm
|
||||
frontend/node_modules
|
||||
key: npm-${{ runner.os }}-${{ hashFiles('frontend/package-lock.json') }}
|
||||
restore-keys: |
|
||||
npm-${{ runner.os }}-
|
||||
|
||||
- name: Install frontend dependencies
|
||||
if: matrix.component == 'frontend'
|
||||
working-directory: frontend
|
||||
run: npm ci
|
||||
|
||||
- name: Run Jest unit tests
|
||||
if: matrix.component == 'frontend'
|
||||
working-directory: frontend
|
||||
run: npm test -- --coverage --passWithNoTests
|
||||
|
||||
- name: Upload frontend coverage report
|
||||
if: matrix.component == 'frontend'
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: frontend-coverage
|
||||
path: frontend/coverage/
|
||||
retention-days: 7
|
||||
|
||||
# ===========================================================================
|
||||
# BUILD JOB - Build Docker images
|
||||
# ===========================================================================
|
||||
build:
|
||||
name: Build
|
||||
runs-on: ubuntu-latest
|
||||
needs: test
|
||||
strategy:
|
||||
matrix:
|
||||
component: [backend, frontend]
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Cache Docker layers
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: /tmp/.buildx-cache
|
||||
key: docker-${{ matrix.component }}-${{ github.sha }}
|
||||
restore-keys: |
|
||||
docker-${{ matrix.component }}-
|
||||
|
||||
- name: Build backend Docker image
|
||||
if: matrix.component == 'backend'
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
target: production
|
||||
push: false
|
||||
tags: syndarix-backend:${{ github.sha }}
|
||||
cache-from: type=local,src=/tmp/.buildx-cache
|
||||
cache-to: type=local,dest=/tmp/.buildx-cache-new,mode=max
|
||||
|
||||
- name: Build frontend Docker image
|
||||
if: matrix.component == 'frontend'
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: ./frontend
|
||||
file: ./frontend/Dockerfile
|
||||
target: runner
|
||||
push: false
|
||||
tags: syndarix-frontend:${{ github.sha }}
|
||||
build-args: |
|
||||
NEXT_PUBLIC_API_URL=http://localhost:8000
|
||||
cache-from: type=local,src=/tmp/.buildx-cache
|
||||
cache-to: type=local,dest=/tmp/.buildx-cache-new,mode=max
|
||||
|
||||
# Prevent cache from growing indefinitely
|
||||
- name: Move cache
|
||||
run: |
|
||||
rm -rf /tmp/.buildx-cache
|
||||
mv /tmp/.buildx-cache-new /tmp/.buildx-cache
|
||||
|
||||
# ===========================================================================
|
||||
# DEPLOY JOB - Deploy to production (only on main branch)
|
||||
# ===========================================================================
|
||||
deploy:
|
||||
name: Deploy
|
||||
runs-on: ubuntu-latest
|
||||
needs: build
|
||||
if: github.ref == 'refs/heads/main' && github.event_name == 'push'
|
||||
environment: production
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Deploy notification
|
||||
run: |
|
||||
echo "Deployment to production would happen here"
|
||||
echo "Branch: ${{ github.ref }}"
|
||||
echo "Commit: ${{ github.sha }}"
|
||||
echo "Actor: ${{ github.actor }}"
|
||||
|
||||
# TODO: Add actual deployment steps when infrastructure is ready
|
||||
# Options:
|
||||
# - SSH to production server and run docker-compose pull && docker-compose up -d
|
||||
# - Use Kubernetes deployment
|
||||
# - Use cloud provider deployment (AWS ECS, GCP Cloud Run, etc.)
|
||||
# - Trigger webhook to deployment orchestrator
|
||||
|
||||
# ===========================================================================
|
||||
# SECURITY SCAN JOB - Run on main and dev branches
|
||||
# ===========================================================================
|
||||
security:
|
||||
name: Security Scan
|
||||
runs-on: ubuntu-latest
|
||||
needs: lint
|
||||
if: github.ref == 'refs/heads/main' || github.ref == 'refs/heads/dev'
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ env.PYTHON_VERSION }}
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v4
|
||||
with:
|
||||
version: ${{ env.UV_VERSION }}
|
||||
|
||||
- name: Install backend dependencies
|
||||
working-directory: backend
|
||||
run: uv sync --extra dev --frozen
|
||||
|
||||
- name: Run Bandit security scan (via ruff)
|
||||
working-directory: backend
|
||||
run: |
|
||||
# Ruff includes flake8-bandit (S rules) for security scanning
|
||||
# Run with explicit security rules only
|
||||
uv run ruff check app --select=S --ignore=S101,S104,S105,S106,S603,S607
|
||||
|
||||
- name: Run pip-audit for dependency vulnerabilities
|
||||
working-directory: backend
|
||||
run: |
|
||||
# pip-audit checks for known vulnerabilities in Python dependencies
|
||||
uv run pip-audit --require-hashes --disable-pip -r <(uv pip compile pyproject.toml) || true
|
||||
# Note: Using || true temporarily while setting up proper remediation
|
||||
|
||||
- name: Check for secrets in code
|
||||
run: |
|
||||
# Basic check for common secret patterns
|
||||
# In production, use tools like gitleaks or trufflehog
|
||||
echo "Checking for potential hardcoded secrets..."
|
||||
! grep -rn --include="*.py" --include="*.ts" --include="*.tsx" --include="*.js" \
|
||||
-E "(api_key|apikey|secret_key|secretkey|password|passwd|token)\s*=\s*['\"][^'\"]{8,}['\"]" \
|
||||
backend/app frontend/src || echo "No obvious secrets found"
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: ${{ env.NODE_VERSION }}
|
||||
|
||||
- name: Install frontend dependencies
|
||||
working-directory: frontend
|
||||
run: npm ci
|
||||
|
||||
- name: Run npm audit
|
||||
working-directory: frontend
|
||||
run: |
|
||||
npm audit --audit-level=high || true
|
||||
# Note: Using || true to not fail on moderate vulnerabilities
|
||||
# In production, consider stricter settings
|
||||
|
||||
# ===========================================================================
|
||||
# E2E TEST JOB - Run end-to-end tests with Playwright
|
||||
# ===========================================================================
|
||||
e2e-tests:
|
||||
name: E2E Tests
|
||||
runs-on: ubuntu-latest
|
||||
needs: [lint, test]
|
||||
if: github.ref == 'refs/heads/main' || github.ref == 'refs/heads/dev' || github.event_name == 'pull_request'
|
||||
services:
|
||||
postgres:
|
||||
image: pgvector/pgvector:pg17
|
||||
env:
|
||||
POSTGRES_USER: postgres
|
||||
POSTGRES_PASSWORD: postgres
|
||||
POSTGRES_DB: syndarix_test
|
||||
ports:
|
||||
- 5432:5432
|
||||
options: >-
|
||||
--health-cmd "pg_isready -U postgres"
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
--health-retries 5
|
||||
redis:
|
||||
image: redis:7-alpine
|
||||
ports:
|
||||
- 6379:6379
|
||||
options: >-
|
||||
--health-cmd "redis-cli ping"
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
--health-retries 5
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ env.PYTHON_VERSION }}
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v4
|
||||
with:
|
||||
version: ${{ env.UV_VERSION }}
|
||||
|
||||
- name: Install backend dependencies
|
||||
working-directory: backend
|
||||
run: uv sync --extra dev --frozen
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: ${{ env.NODE_VERSION }}
|
||||
|
||||
- name: Install frontend dependencies
|
||||
working-directory: frontend
|
||||
run: npm ci
|
||||
|
||||
- name: Install Playwright browsers
|
||||
working-directory: frontend
|
||||
run: npx playwright install --with-deps chromium
|
||||
|
||||
- name: Start backend server
|
||||
working-directory: backend
|
||||
env:
|
||||
DATABASE_URL: postgresql://postgres:postgres@localhost:5432/syndarix_test
|
||||
REDIS_URL: redis://localhost:6379/0
|
||||
SECRET_KEY: test-secret-key-for-e2e-tests-only
|
||||
ENVIRONMENT: test
|
||||
IS_TEST: "True"
|
||||
run: |
|
||||
# Run migrations
|
||||
uv run python -c "from app.database import create_tables; import asyncio; asyncio.run(create_tables())" || true
|
||||
# Start backend in background
|
||||
uv run uvicorn app.main:app --host 0.0.0.0 --port 8000 &
|
||||
# Wait for backend to be ready
|
||||
sleep 10
|
||||
|
||||
- name: Run Playwright E2E tests
|
||||
working-directory: frontend
|
||||
env:
|
||||
NEXT_PUBLIC_API_URL: http://localhost:8000
|
||||
run: |
|
||||
npm run build
|
||||
npm run test:e2e -- --project=chromium
|
||||
|
||||
- name: Upload Playwright report
|
||||
uses: actions/upload-artifact@v4
|
||||
if: always()
|
||||
with:
|
||||
name: playwright-report
|
||||
path: frontend/playwright-report/
|
||||
retention-days: 7
|
||||
61
.githooks/pre-commit
Executable file
61
.githooks/pre-commit
Executable file
@@ -0,0 +1,61 @@
|
||||
#!/bin/bash
|
||||
# Pre-commit hook to enforce validation before commits on protected branches
|
||||
# Install: git config core.hooksPath .githooks
|
||||
|
||||
set -e
|
||||
|
||||
# Get the current branch name
|
||||
BRANCH=$(git rev-parse --abbrev-ref HEAD)
|
||||
|
||||
# Protected branches that require validation
|
||||
PROTECTED_BRANCHES="main dev"
|
||||
|
||||
# Check if we're on a protected branch
|
||||
is_protected() {
|
||||
for branch in $PROTECTED_BRANCHES; do
|
||||
if [ "$BRANCH" = "$branch" ]; then
|
||||
return 0
|
||||
fi
|
||||
done
|
||||
return 1
|
||||
}
|
||||
|
||||
if is_protected; then
|
||||
echo "🔒 Committing to protected branch '$BRANCH' - running validation..."
|
||||
|
||||
# Check if we have backend changes
|
||||
if git diff --cached --name-only | grep -q "^backend/"; then
|
||||
echo "📦 Backend changes detected - running make validate..."
|
||||
cd backend
|
||||
if ! make validate; then
|
||||
echo ""
|
||||
echo "❌ Backend validation failed!"
|
||||
echo " Please fix the issues and try again."
|
||||
echo " Run 'cd backend && make validate' to see errors."
|
||||
exit 1
|
||||
fi
|
||||
cd ..
|
||||
echo "✅ Backend validation passed!"
|
||||
fi
|
||||
|
||||
# Check if we have frontend changes
|
||||
if git diff --cached --name-only | grep -q "^frontend/"; then
|
||||
echo "🎨 Frontend changes detected - running npm run validate..."
|
||||
cd frontend
|
||||
if ! npm run validate 2>/dev/null; then
|
||||
echo ""
|
||||
echo "❌ Frontend validation failed!"
|
||||
echo " Please fix the issues and try again."
|
||||
echo " Run 'cd frontend && npm run validate' to see errors."
|
||||
exit 1
|
||||
fi
|
||||
cd ..
|
||||
echo "✅ Frontend validation passed!"
|
||||
fi
|
||||
|
||||
echo "🎉 All validations passed! Proceeding with commit..."
|
||||
else
|
||||
echo "📝 Committing to feature branch '$BRANCH' - skipping validation (run manually if needed)"
|
||||
fi
|
||||
|
||||
exit 0
|
||||
347
CLAUDE.md
347
CLAUDE.md
@@ -1,243 +1,204 @@
|
||||
# CLAUDE.md
|
||||
|
||||
Claude Code context for FastAPI + Next.js Full-Stack Template.
|
||||
Claude Code context for **Syndarix** - AI-Powered Software Consulting Agency.
|
||||
|
||||
**See [AGENTS.md](./AGENTS.md) for project context, architecture, and development commands.**
|
||||
**Built on PragmaStack.** See [AGENTS.md](./AGENTS.md) for base template context.
|
||||
|
||||
---
|
||||
|
||||
## Syndarix Project Context
|
||||
|
||||
### Vision
|
||||
|
||||
Syndarix is an autonomous platform that orchestrates specialized AI agents to deliver complete software solutions with minimal human intervention. It acts as a virtual consulting agency with AI agents playing roles like Product Owner, Architect, Engineers, QA, etc.
|
||||
|
||||
### Repository
|
||||
|
||||
- **URL:** https://gitea.pragmazest.com/cardosofelipe/syndarix
|
||||
- **Issue Tracker:** Gitea Issues (primary)
|
||||
- **CI/CD:** Gitea Actions
|
||||
|
||||
### Core Concepts
|
||||
|
||||
**Agent Types & Instances:**
|
||||
- Agent Type = Template (base model, failover, expertise, personality)
|
||||
- Agent Instance = Spawned from type, assigned to project
|
||||
- Multiple instances of same type can work together
|
||||
|
||||
**Project Workflow:**
|
||||
1. Requirements discovery with Product Owner agent
|
||||
2. Architecture spike (PO + BA + Architect brainstorm)
|
||||
3. Implementation planning and backlog creation
|
||||
4. Autonomous sprint execution with checkpoints
|
||||
5. Demo and client feedback
|
||||
|
||||
**Autonomy Levels:**
|
||||
- `FULL_CONTROL`: Approve every action
|
||||
- `MILESTONE`: Approve sprint boundaries
|
||||
- `AUTONOMOUS`: Only major decisions
|
||||
|
||||
**MCP-First Architecture:**
|
||||
All integrations via Model Context Protocol servers with explicit scoping:
|
||||
```python
|
||||
# All tools take project_id for scoping
|
||||
search_knowledge(project_id="proj-123", query="auth flow")
|
||||
create_issue(project_id="proj-123", title="Add login")
|
||||
```
|
||||
|
||||
### Directory Structure
|
||||
|
||||
```
|
||||
docs/
|
||||
├── development/ # Workflow and coding standards
|
||||
├── requirements/ # Requirements documents
|
||||
├── architecture/ # Architecture documentation
|
||||
├── adrs/ # Architecture Decision Records
|
||||
└── spikes/ # Spike research documents
|
||||
```
|
||||
|
||||
### Current Phase
|
||||
|
||||
**Backlog Population** - Creating detailed issues for Phase 0-1 implementation.
|
||||
|
||||
---
|
||||
|
||||
## Development Standards
|
||||
|
||||
**CRITICAL: These rules are mandatory. See linked docs for full details.**
|
||||
|
||||
### Quick Reference
|
||||
|
||||
| Topic | Documentation |
|
||||
|-------|---------------|
|
||||
| **Workflow & Branching** | [docs/development/WORKFLOW.md](./docs/development/WORKFLOW.md) |
|
||||
| **Coding Standards** | [docs/development/CODING_STANDARDS.md](./docs/development/CODING_STANDARDS.md) |
|
||||
| **Design System** | [frontend/docs/design-system/](./frontend/docs/design-system/) |
|
||||
| **Backend E2E Testing** | [backend/docs/E2E_TESTING.md](./backend/docs/E2E_TESTING.md) |
|
||||
| **Demo Mode** | [frontend/docs/DEMO_MODE.md](./frontend/docs/DEMO_MODE.md) |
|
||||
|
||||
### Essential Rules Summary
|
||||
|
||||
1. **Issue-Driven Development**: Every piece of work MUST have an issue first
|
||||
2. **Branch per Feature**: `feature/<issue-number>-<description>`, single branch for design+implementation
|
||||
3. **Testing Required**: All code must be tested, aim for >90% coverage
|
||||
4. **Code Review**: Must pass multi-agent review before merge
|
||||
5. **No Direct Commits**: Never commit directly to `main` or `dev`
|
||||
6. **Stack Verification**: ALWAYS run the full stack before considering work done (see below)
|
||||
|
||||
### CRITICAL: Stack Verification Before Merge
|
||||
|
||||
**This is NON-NEGOTIABLE. A feature with 100% test coverage that crashes on startup is WORTHLESS.**
|
||||
|
||||
Before considering ANY issue complete:
|
||||
|
||||
```bash
|
||||
# 1. Start the dev stack
|
||||
make dev
|
||||
|
||||
# 2. Wait for backend to be healthy, check logs
|
||||
docker compose -f docker-compose.dev.yml logs backend --tail=100
|
||||
|
||||
# 3. Start frontend
|
||||
cd frontend && npm run dev
|
||||
|
||||
# 4. Verify both are running without errors
|
||||
```
|
||||
|
||||
**The issue is NOT done if:**
|
||||
- Backend crashes on startup (import errors, missing dependencies)
|
||||
- Frontend fails to compile or render
|
||||
- Health checks fail
|
||||
- Any error appears in logs
|
||||
|
||||
**Why this matters:**
|
||||
- Tests run in isolation and may pass despite broken imports
|
||||
- Docker builds cache layers and may hide dependency issues
|
||||
- A single `ModuleNotFoundError` renders all test coverage meaningless
|
||||
|
||||
### Common Commands
|
||||
|
||||
```bash
|
||||
# Backend
|
||||
IS_TEST=True uv run pytest # Run tests
|
||||
uv run ruff check src/ # Lint
|
||||
uv run mypy src/ # Type check
|
||||
python migrate.py auto "message" # Database migration
|
||||
|
||||
# Frontend
|
||||
npm test # Unit tests
|
||||
npm run lint # Lint
|
||||
npm run type-check # Type check
|
||||
npm run generate:api # Regenerate API client
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Claude Code-Specific Guidance
|
||||
|
||||
### Critical User Preferences
|
||||
|
||||
#### File Operations - NEVER Use Heredoc/Cat Append
|
||||
**ALWAYS use Read/Write/Edit tools instead of `cat >> file << EOF` commands.**
|
||||
**File Operations:**
|
||||
- ALWAYS use Read/Write/Edit tools instead of `cat >> file << EOF`
|
||||
- Never use heredoc - it triggers manual approval dialogs
|
||||
|
||||
This triggers manual approval dialogs and disrupts workflow.
|
||||
|
||||
```bash
|
||||
# WRONG ❌
|
||||
cat >> file.txt << EOF
|
||||
content
|
||||
EOF
|
||||
|
||||
# CORRECT ✅ - Use Read, then Write tools
|
||||
```
|
||||
|
||||
#### Work Style
|
||||
**Work Style:**
|
||||
- User prefers autonomous operation without frequent interruptions
|
||||
- Ask for batch permissions upfront for long work sessions
|
||||
- Work independently, document decisions clearly
|
||||
- Only use emojis if the user explicitly requests it
|
||||
|
||||
### When Working with This Stack
|
||||
|
||||
**Dependency Management:**
|
||||
- Backend uses **uv** (modern Python package manager), not pip
|
||||
- Always use `uv run` prefix: `IS_TEST=True uv run pytest`
|
||||
- Or use Makefile commands: `make test`, `make install-dev`
|
||||
- Add dependencies: `uv add <package>` or `uv add --dev <package>`
|
||||
|
||||
**Database Migrations:**
|
||||
- Use the `migrate.py` helper script, not Alembic directly
|
||||
- Generate + apply: `python migrate.py auto "message"`
|
||||
- Never commit migrations without testing them first
|
||||
- Check current state: `python migrate.py current`
|
||||
|
||||
**Frontend API Client Generation:**
|
||||
- Run `npm run generate:api` after backend schema changes
|
||||
- Client is auto-generated from OpenAPI spec
|
||||
- Located in `frontend/src/lib/api/generated/`
|
||||
- NEVER manually edit generated files
|
||||
|
||||
**Testing Commands:**
|
||||
- Backend unit/integration: `IS_TEST=True uv run pytest` (always prefix with `IS_TEST=True`)
|
||||
- Backend E2E (requires Docker): `make test-e2e`
|
||||
- Frontend unit: `npm test`
|
||||
- Frontend E2E: `npm run test:e2e`
|
||||
- Use `make test` or `make test-cov` in backend for convenience
|
||||
|
||||
**Backend E2E Testing (requires Docker):**
|
||||
- Install deps: `make install-e2e`
|
||||
- Run all E2E tests: `make test-e2e`
|
||||
- Run schema tests only: `make test-e2e-schema`
|
||||
- Run all tests: `make test-all` (unit + E2E)
|
||||
- Uses Testcontainers (real PostgreSQL) + Schemathesis (OpenAPI contract testing)
|
||||
- Markers: `@pytest.mark.e2e`, `@pytest.mark.postgres`, `@pytest.mark.schemathesis`
|
||||
- See: `backend/docs/E2E_TESTING.md` for complete guide
|
||||
|
||||
### 🔴 CRITICAL: Auth Store Dependency Injection Pattern
|
||||
### Critical Pattern: Auth Store DI
|
||||
|
||||
**ALWAYS use `useAuth()` from `AuthContext`, NEVER import `useAuthStore` directly!**
|
||||
|
||||
```typescript
|
||||
// ❌ WRONG - Bypasses dependency injection
|
||||
// ❌ WRONG
|
||||
import { useAuthStore } from '@/lib/stores/authStore';
|
||||
const { user, isAuthenticated } = useAuthStore();
|
||||
|
||||
// ✅ CORRECT - Uses dependency injection
|
||||
// ✅ CORRECT
|
||||
import { useAuth } from '@/lib/auth/AuthContext';
|
||||
const { user, isAuthenticated } = useAuth();
|
||||
```
|
||||
|
||||
**Why This Matters:**
|
||||
- E2E tests inject mock stores via `window.__TEST_AUTH_STORE__`
|
||||
- Unit tests inject via `<AuthProvider store={mockStore}>`
|
||||
- Direct `useAuthStore` imports bypass this injection → **tests fail**
|
||||
- ESLint will catch violations (added Nov 2025)
|
||||
|
||||
**Exceptions:**
|
||||
1. `AuthContext.tsx` - DI boundary, legitimately needs real store
|
||||
2. `client.ts` - Non-React context, uses dynamic import + `__TEST_AUTH_STORE__` check
|
||||
|
||||
### E2E Test Best Practices
|
||||
|
||||
When writing or fixing Playwright tests:
|
||||
|
||||
**Navigation Pattern:**
|
||||
```typescript
|
||||
// ✅ CORRECT - Use Promise.all for Next.js Link clicks
|
||||
await Promise.all([
|
||||
page.waitForURL('/target', { timeout: 10000 }),
|
||||
link.click()
|
||||
]);
|
||||
```
|
||||
|
||||
**Selectors:**
|
||||
- Use ID-based selectors for validation errors: `#email-error`
|
||||
- Error IDs use dashes not underscores: `#new-password-error`
|
||||
- Target `.border-destructive[role="alert"]` to avoid Next.js route announcer conflicts
|
||||
- Avoid generic `[role="alert"]` which matches multiple elements
|
||||
|
||||
**URL Assertions:**
|
||||
```typescript
|
||||
// ✅ Use regex to handle query params
|
||||
await expect(page).toHaveURL(/\/auth\/login/);
|
||||
|
||||
// ❌ Don't use exact strings (fails with query params)
|
||||
await expect(page).toHaveURL('/auth/login');
|
||||
```
|
||||
|
||||
**Configuration:**
|
||||
- Uses 12 workers in non-CI mode (`playwright.config.ts`)
|
||||
- Reduces to 2 workers in CI for stability
|
||||
- Tests are designed to be non-flaky with proper waits
|
||||
|
||||
### Important Implementation Details
|
||||
|
||||
**Authentication Testing:**
|
||||
- Backend fixtures in `tests/conftest.py`:
|
||||
- `async_test_db`: Fresh SQLite per test
|
||||
- `async_test_user` / `async_test_superuser`: Pre-created users
|
||||
- `user_token` / `superuser_token`: Access tokens for API calls
|
||||
- Always use `@pytest.mark.asyncio` for async tests
|
||||
- Use `@pytest_asyncio.fixture` for async fixtures
|
||||
|
||||
**Database Testing:**
|
||||
```python
|
||||
# Mock database exceptions correctly
|
||||
from unittest.mock import patch, AsyncMock
|
||||
|
||||
async def mock_commit():
|
||||
raise OperationalError("Connection lost", {}, Exception())
|
||||
|
||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
|
||||
with pytest.raises(OperationalError):
|
||||
await crud_method(session, obj_in=data)
|
||||
mock_rollback.assert_called_once()
|
||||
```
|
||||
|
||||
**Frontend Component Development:**
|
||||
- Follow design system docs in `frontend/docs/design-system/`
|
||||
- Read `08-ai-guidelines.md` for AI code generation rules
|
||||
- Use parent-controlled spacing (see `04-spacing-philosophy.md`)
|
||||
- WCAG AA compliance required (see `07-accessibility.md`)
|
||||
|
||||
**Security Considerations:**
|
||||
- Backend has comprehensive security tests (JWT attacks, session hijacking)
|
||||
- Never skip security headers in production
|
||||
- Rate limiting is configured in route decorators: `@limiter.limit("10/minute")`
|
||||
- Session revocation is database-backed, not just JWT expiry
|
||||
|
||||
### Common Workflows Guidance
|
||||
|
||||
**When Adding a New Feature:**
|
||||
1. Start with backend schema and CRUD
|
||||
2. Implement API route with proper authorization
|
||||
3. Write backend tests (aim for >90% coverage)
|
||||
4. Generate frontend API client: `npm run generate:api`
|
||||
5. Implement frontend components
|
||||
6. Write frontend unit tests
|
||||
7. Add E2E tests for critical flows
|
||||
8. Update relevant documentation
|
||||
|
||||
**When Fixing Tests:**
|
||||
- Backend: Check test database isolation and async fixture usage
|
||||
- Frontend unit: Verify mocking of `useAuth()` not `useAuthStore`
|
||||
- E2E: Use `Promise.all()` pattern and regex URL assertions
|
||||
|
||||
**When Debugging:**
|
||||
- Backend: Check `IS_TEST=True` environment variable is set
|
||||
- Frontend: Run `npm run type-check` first
|
||||
- E2E: Use `npm run test:e2e:debug` for step-by-step debugging
|
||||
- Check logs: Backend has detailed error logging
|
||||
|
||||
**Demo Mode (Frontend-Only Showcase):**
|
||||
- Enable: `echo "NEXT_PUBLIC_DEMO_MODE=true" > frontend/.env.local`
|
||||
- Uses MSW (Mock Service Worker) to intercept API calls in browser
|
||||
- Zero backend required - perfect for Vercel deployments
|
||||
- **Fully Automated**: MSW handlers auto-generated from OpenAPI spec
|
||||
- Run `npm run generate:api` → updates both API client AND MSW handlers
|
||||
- No manual synchronization needed!
|
||||
- Demo credentials (any password ≥8 chars works):
|
||||
- User: `demo@example.com` / `DemoPass123`
|
||||
- Admin: `admin@example.com` / `AdminPass123`
|
||||
- **Safe**: MSW never runs during tests (Jest or Playwright)
|
||||
- **Coverage**: Mock files excluded from linting and coverage
|
||||
- **Documentation**: `frontend/docs/DEMO_MODE.md` for complete guide
|
||||
See [CODING_STANDARDS.md](./docs/development/CODING_STANDARDS.md#auth-store-dependency-injection) for details.
|
||||
|
||||
### Tool Usage Preferences
|
||||
|
||||
**Prefer specialized tools over bash:**
|
||||
- Use Read/Write/Edit tools for file operations
|
||||
- Never use `cat`, `echo >`, or heredoc for file manipulation
|
||||
- Use Task tool with `subagent_type=Explore` for codebase exploration
|
||||
- Use Grep tool for code search, not bash `grep`
|
||||
|
||||
**When to use parallel tool calls:**
|
||||
- Independent git commands: `git status`, `git diff`, `git log`
|
||||
**Parallel tool calls for:**
|
||||
- Independent git commands
|
||||
- Reading multiple unrelated files
|
||||
- Running multiple test suites simultaneously
|
||||
- Running multiple test suites
|
||||
- Independent validation steps
|
||||
|
||||
## Custom Skills
|
||||
---
|
||||
|
||||
No Claude Code Skills installed yet. To create one, invoke the built-in "skill-creator" skill.
|
||||
## Key Extensions (from PragmaStack base)
|
||||
|
||||
**Potential skill ideas for this project:**
|
||||
- API endpoint generator workflow (schema → CRUD → route → tests → frontend client)
|
||||
- Component generator with design system compliance
|
||||
- Database migration troubleshooting helper
|
||||
- Test coverage analyzer and improvement suggester
|
||||
- E2E test generator for new features
|
||||
- Celery + Redis for agent job queue
|
||||
- WebSocket/SSE for real-time updates
|
||||
- pgvector for RAG knowledge base
|
||||
- MCP server integration layer
|
||||
|
||||
---
|
||||
|
||||
## Additional Resources
|
||||
|
||||
**Comprehensive Documentation:**
|
||||
**Documentation:**
|
||||
- [AGENTS.md](./AGENTS.md) - Framework-agnostic AI assistant context
|
||||
- [README.md](./README.md) - User-facing project overview
|
||||
- `backend/docs/` - Backend architecture, coding standards, common pitfalls
|
||||
- `frontend/docs/design-system/` - Complete design system guide
|
||||
- [docs/development/](./docs/development/) - Development workflow and standards
|
||||
- [backend/docs/](./backend/docs/) - Backend architecture and guides
|
||||
- [frontend/docs/design-system/](./frontend/docs/design-system/) - Complete design system
|
||||
|
||||
**API Documentation (when running):**
|
||||
- Swagger UI: http://localhost:8000/docs
|
||||
- ReDoc: http://localhost:8000/redoc
|
||||
- OpenAPI JSON: http://localhost:8000/api/v1/openapi.json
|
||||
|
||||
**Testing Documentation:**
|
||||
- Backend tests: `backend/tests/` (97% coverage)
|
||||
- Frontend E2E: `frontend/e2e/README.md`
|
||||
- Design system: `frontend/docs/design-system/08-ai-guidelines.md`
|
||||
|
||||
---
|
||||
|
||||
**For project architecture, development commands, and general context, see [AGENTS.md](./AGENTS.md).**
|
||||
|
||||
92
Makefile
92
Makefile
@@ -1,18 +1,31 @@
|
||||
.PHONY: help dev dev-full prod down logs logs-dev clean clean-slate drop-db reset-db push-images deploy
|
||||
.PHONY: test test-backend test-mcp test-frontend test-all test-cov test-integration validate validate-all
|
||||
|
||||
VERSION ?= latest
|
||||
REGISTRY ?= ghcr.io/cardosofelipe/pragma-stack
|
||||
|
||||
# Default target
|
||||
help:
|
||||
@echo "FastAPI + Next.js Full-Stack Template"
|
||||
@echo "Syndarix - AI-Powered Software Consulting Agency"
|
||||
@echo ""
|
||||
@echo "Development:"
|
||||
@echo " make dev - Start backend + db (frontend runs separately)"
|
||||
@echo " make dev - Start backend + db + MCP servers (frontend runs separately)"
|
||||
@echo " make dev-full - Start all services including frontend"
|
||||
@echo " make down - Stop all services"
|
||||
@echo " make logs-dev - Follow dev container logs"
|
||||
@echo ""
|
||||
@echo "Testing:"
|
||||
@echo " make test - Run all tests (backend + MCP servers)"
|
||||
@echo " make test-backend - Run backend tests only"
|
||||
@echo " make test-mcp - Run MCP server tests only"
|
||||
@echo " make test-frontend - Run frontend tests only"
|
||||
@echo " make test-cov - Run all tests with coverage reports"
|
||||
@echo " make test-integration - Run MCP integration tests (requires running stack)"
|
||||
@echo ""
|
||||
@echo "Validation:"
|
||||
@echo " make validate - Validate backend + MCP servers (lint, type-check, test)"
|
||||
@echo " make validate-all - Validate everything including frontend"
|
||||
@echo ""
|
||||
@echo "Database:"
|
||||
@echo " make drop-db - Drop and recreate empty database"
|
||||
@echo " make reset-db - Drop database and apply all migrations"
|
||||
@@ -28,8 +41,10 @@ help:
|
||||
@echo " make clean-slate - Stop containers AND delete volumes (DATA LOSS!)"
|
||||
@echo ""
|
||||
@echo "Subdirectory commands:"
|
||||
@echo " cd backend && make help - Backend-specific commands"
|
||||
@echo " cd frontend && npm run - Frontend-specific commands"
|
||||
@echo " cd backend && make help - Backend-specific commands"
|
||||
@echo " cd mcp-servers/llm-gateway && make - LLM Gateway commands"
|
||||
@echo " cd mcp-servers/knowledge-base && make - Knowledge Base commands"
|
||||
@echo " cd frontend && npm run - Frontend-specific commands"
|
||||
|
||||
# ============================================================================
|
||||
# Development
|
||||
@@ -99,3 +114,72 @@ clean:
|
||||
# WARNING! THIS REMOVES CONTAINERS AND VOLUMES AS WELL - DO NOT USE THIS UNLESS YOU WANT TO START OVER WITH DATA AND ALL
|
||||
clean-slate:
|
||||
docker compose -f docker-compose.dev.yml down -v --remove-orphans
|
||||
|
||||
# ============================================================================
|
||||
# Testing
|
||||
# ============================================================================
|
||||
|
||||
test: test-backend test-mcp
|
||||
@echo ""
|
||||
@echo "All tests passed!"
|
||||
|
||||
test-backend:
|
||||
@echo "Running backend tests..."
|
||||
@cd backend && IS_TEST=True uv run pytest tests/ -v
|
||||
|
||||
test-mcp:
|
||||
@echo "Running MCP server tests..."
|
||||
@echo ""
|
||||
@echo "=== LLM Gateway ==="
|
||||
@cd mcp-servers/llm-gateway && uv run pytest tests/ -v
|
||||
@echo ""
|
||||
@echo "=== Knowledge Base ==="
|
||||
@cd mcp-servers/knowledge-base && uv run pytest tests/ -v
|
||||
|
||||
test-frontend:
|
||||
@echo "Running frontend tests..."
|
||||
@cd frontend && npm test
|
||||
|
||||
test-all: test test-frontend
|
||||
@echo ""
|
||||
@echo "All tests (backend + MCP + frontend) passed!"
|
||||
|
||||
test-cov:
|
||||
@echo "Running all tests with coverage..."
|
||||
@echo ""
|
||||
@echo "=== Backend Coverage ==="
|
||||
@cd backend && IS_TEST=True uv run pytest tests/ -v --cov=app --cov-report=term-missing
|
||||
@echo ""
|
||||
@echo "=== LLM Gateway Coverage ==="
|
||||
@cd mcp-servers/llm-gateway && uv run pytest tests/ -v --cov=. --cov-report=term-missing
|
||||
@echo ""
|
||||
@echo "=== Knowledge Base Coverage ==="
|
||||
@cd mcp-servers/knowledge-base && uv run pytest tests/ -v --cov=. --cov-report=term-missing
|
||||
|
||||
test-integration:
|
||||
@echo "Running MCP integration tests..."
|
||||
@echo "Note: Requires running stack (make dev first)"
|
||||
@cd backend && RUN_INTEGRATION_TESTS=true IS_TEST=True uv run pytest tests/integration/ -v
|
||||
|
||||
# ============================================================================
|
||||
# Validation (lint + type-check + test)
|
||||
# ============================================================================
|
||||
|
||||
validate:
|
||||
@echo "Validating backend..."
|
||||
@cd backend && make validate
|
||||
@echo ""
|
||||
@echo "Validating LLM Gateway..."
|
||||
@cd mcp-servers/llm-gateway && make validate
|
||||
@echo ""
|
||||
@echo "Validating Knowledge Base..."
|
||||
@cd mcp-servers/knowledge-base && make validate
|
||||
@echo ""
|
||||
@echo "All validations passed!"
|
||||
|
||||
validate-all: validate
|
||||
@echo ""
|
||||
@echo "Validating frontend..."
|
||||
@cd frontend && npm run validate
|
||||
@echo ""
|
||||
@echo "Full validation passed!"
|
||||
|
||||
724
README.md
724
README.md
@@ -1,659 +1,175 @@
|
||||
# <img src="frontend/public/logo.svg" alt="PragmaStack" width="32" height="32" style="vertical-align: middle" /> PragmaStack
|
||||
# Syndarix
|
||||
|
||||
> **The Pragmatic Full-Stack Template. Production-ready, security-first, and opinionated.**
|
||||
> **Your AI-Powered Software Consulting Agency**
|
||||
>
|
||||
> An autonomous platform that orchestrates specialized AI agents to deliver complete software solutions with minimal human intervention.
|
||||
|
||||
[](./backend/tests)
|
||||
[](./frontend/tests)
|
||||
[](./frontend/e2e)
|
||||
[](https://gitea.pragmazest.com/cardosofelipe/fast-next-template)
|
||||
[](./LICENSE)
|
||||
[](./CONTRIBUTING.md)
|
||||
|
||||

|
||||
|
||||
---
|
||||
|
||||
## Why PragmaStack?
|
||||
## Vision
|
||||
|
||||
Building a modern full-stack application often leads to "analysis paralysis" or "boilerplate fatigue". You spend weeks setting up authentication, testing, and linting before writing a single line of business logic.
|
||||
Syndarix transforms the software development lifecycle by providing a **virtual consulting team** of AI agents that collaboratively plan, design, implement, test, and deliver complete software solutions.
|
||||
|
||||
**PragmaStack cuts through the noise.**
|
||||
**The Problem:** Even with AI coding assistants, developers spend as much time managing AI as doing the work themselves. Context switching, babysitting, and knowledge fragmentation limit productivity.
|
||||
|
||||
We provide a **pragmatic**, opinionated foundation that prioritizes:
|
||||
- **Speed**: Ship features, not config files.
|
||||
- **Robustness**: Security and testing are not optional.
|
||||
- **Clarity**: Code that is easy to read and maintain.
|
||||
|
||||
Whether you're building a SaaS, an internal tool, or a side project, PragmaStack gives you a solid starting point without the bloat.
|
||||
**The Solution:** A structured, autonomous agency where specialized AI agents handle different roles (Product Owner, Architect, Engineers, QA, etc.) with proper workflows, reviews, and quality gates.
|
||||
|
||||
---
|
||||
|
||||
## ✨ Features
|
||||
## Key Features
|
||||
|
||||
### 🔐 **Authentication & Security**
|
||||
- JWT-based authentication with access + refresh tokens
|
||||
- **OAuth/Social Login** (Google, GitHub) with PKCE support
|
||||
- **OAuth 2.0 Authorization Server** (MCP-ready) for third-party integrations
|
||||
- Session management with device tracking and revocation
|
||||
- Password reset flow (email integration ready)
|
||||
- Secure password hashing (bcrypt)
|
||||
- CSRF protection, rate limiting, and security headers
|
||||
- Comprehensive security tests (JWT algorithm attacks, session hijacking, privilege escalation)
|
||||
### Multi-Agent Orchestration
|
||||
- Configurable agent **types** with base model, failover, expertise, and personality
|
||||
- Spawn multiple **instances** from the same type (e.g., Dave, Ellis, Kate as Software Developers)
|
||||
- Agent-to-agent communication and collaboration
|
||||
- Per-instance customization with domain-specific knowledge
|
||||
|
||||
### 🔌 **OAuth Provider Mode (MCP Integration)**
|
||||
Full OAuth 2.0 Authorization Server for Model Context Protocol (MCP) and third-party clients:
|
||||
- **RFC 7636**: Authorization Code Flow with PKCE (S256 only)
|
||||
- **RFC 8414**: Server metadata discovery at `/.well-known/oauth-authorization-server`
|
||||
- **RFC 7662**: Token introspection endpoint
|
||||
- **RFC 7009**: Token revocation endpoint
|
||||
- **JWT access tokens**: Self-contained, configurable lifetime
|
||||
- **Opaque refresh tokens**: Secure rotation, database-backed revocation
|
||||
- **Consent management**: Users can review and revoke app permissions
|
||||
- **Client management**: Admin endpoints for registering OAuth clients
|
||||
- **Scopes**: `openid`, `profile`, `email`, `read:users`, `write:users`, `admin`
|
||||
### Complete SDLC Support
|
||||
- **Requirements Discovery** → **Architecture Spike** → **Implementation Planning**
|
||||
- **Sprint Management** with automated ceremonies
|
||||
- **Issue Tracking** with Epic/Story/Task hierarchy
|
||||
- **Git Integration** with proper branch/PR workflows
|
||||
- **CI/CD Pipelines** with automated testing
|
||||
|
||||
### 👥 **Multi-Tenancy & Organizations**
|
||||
- Full organization system with role-based access control (Owner, Admin, Member)
|
||||
- Invite/remove members, manage permissions
|
||||
- Organization-scoped data access
|
||||
- User can belong to multiple organizations
|
||||
### Configurable Autonomy
|
||||
- From `FULL_CONTROL` (approve everything) to `AUTONOMOUS` (only major milestones)
|
||||
- Client can intervene at any point
|
||||
- Transparent progress visibility
|
||||
|
||||
### 🛠️ **Admin Panel**
|
||||
- Complete user management (CRUD, activate/deactivate, bulk operations)
|
||||
- Organization management (create, edit, delete, member management)
|
||||
- Session monitoring across all users
|
||||
- Real-time statistics dashboard
|
||||
- Admin-only routes with proper authorization
|
||||
### MCP-First Architecture
|
||||
- All integrations via **Model Context Protocol (MCP)** servers
|
||||
- Unified Knowledge Base with project/agent scoping
|
||||
- Git providers (Gitea, GitHub, GitLab) via MCP
|
||||
- Extensible through custom MCP tools
|
||||
|
||||
### 🎨 **Modern Frontend**
|
||||
- Next.js 16 with App Router and React 19
|
||||
- **PragmaStack Design System** built on shadcn/ui + TailwindCSS
|
||||
- Pre-configured theme with dark mode support (coming soon)
|
||||
- Responsive, accessible components (WCAG AA compliant)
|
||||
- Rich marketing landing page with animated components
|
||||
- Live component showcase and documentation at `/dev`
|
||||
|
||||
### 🌍 **Internationalization (i18n)**
|
||||
- Built-in multi-language support with next-intl v4
|
||||
- Locale-based routing (`/en/*`, `/it/*`)
|
||||
- Seamless language switching with LocaleSwitcher component
|
||||
- SEO-friendly URLs and metadata per locale
|
||||
- Translation files for English and Italian (easily extensible)
|
||||
- Type-safe translations throughout the app
|
||||
|
||||
### 🎯 **Content & UX Features**
|
||||
- **Toast notifications** with Sonner for elegant user feedback
|
||||
- **Smooth animations** powered by Framer Motion
|
||||
- **Markdown rendering** with syntax highlighting (GitHub Flavored Markdown)
|
||||
- **Charts and visualizations** ready with Recharts
|
||||
- **SEO optimization** with dynamic sitemap and robots.txt generation
|
||||
- **Session tracking UI** with device information and revocation controls
|
||||
|
||||
### 🧪 **Comprehensive Testing**
|
||||
- **Backend Testing**: ~97% unit test coverage
|
||||
- Unit, integration, and security tests
|
||||
- Async database testing with SQLAlchemy
|
||||
- API endpoint testing with fixtures
|
||||
- Security vulnerability tests (JWT attacks, session hijacking, privilege escalation)
|
||||
- **Frontend Unit Tests**: ~97% coverage with Jest
|
||||
- Component testing
|
||||
- Hook testing
|
||||
- Utility function testing
|
||||
- **End-to-End Tests**: Playwright with zero flaky tests
|
||||
- Complete user flows (auth, navigation, settings)
|
||||
- Parallel execution for speed
|
||||
- Visual regression testing ready
|
||||
|
||||
### 📚 **Developer Experience**
|
||||
- Auto-generated TypeScript API client from OpenAPI spec
|
||||
- Interactive API documentation (Swagger + ReDoc)
|
||||
- Database migrations with Alembic helper script
|
||||
- Hot reload in development for both frontend and backend
|
||||
- Comprehensive code documentation and design system docs
|
||||
- Live component playground at `/dev` with code examples
|
||||
- Docker support for easy deployment
|
||||
- VSCode workspace settings included
|
||||
|
||||
### 📊 **Ready for Production**
|
||||
- Docker + docker-compose setup
|
||||
- Environment-based configuration
|
||||
- Database connection pooling
|
||||
- Error handling and logging
|
||||
- Health check endpoints
|
||||
- Production security headers
|
||||
- Rate limiting on sensitive endpoints
|
||||
- SEO optimization with dynamic sitemaps and robots.txt
|
||||
- Multi-language SEO with locale-specific metadata
|
||||
- Performance monitoring and bundle analysis
|
||||
### Project Complexity Wizard
|
||||
- **Script** → Minimal process, no repo needed
|
||||
- **Simple** → Single sprint, basic backlog
|
||||
- **Medium/Complex** → Full AGILE workflow with multiple sprints
|
||||
|
||||
---
|
||||
|
||||
## 📸 Screenshots
|
||||
## Technology Stack
|
||||
|
||||
<details>
|
||||
<summary>Click to view screenshots</summary>
|
||||
Built on [PragmaStack](https://gitea.pragmazest.com/cardosofelipe/fast-next-template):
|
||||
|
||||
### Landing Page
|
||||

|
||||
| Component | Technology |
|
||||
|-----------|------------|
|
||||
| Backend | FastAPI 0.115+ (Python 3.11+) |
|
||||
| Frontend | Next.js 16 (React 19) |
|
||||
| Database | PostgreSQL 15+ with pgvector |
|
||||
| ORM | SQLAlchemy 2.0 |
|
||||
| State Management | Zustand + TanStack Query |
|
||||
| UI | shadcn/ui + Tailwind 4 |
|
||||
| Auth | JWT dual-token + OAuth 2.0 |
|
||||
| Testing | pytest + Jest + Playwright |
|
||||
|
||||
|
||||
|
||||
### Authentication
|
||||

|
||||
|
||||
|
||||
|
||||
### Admin Dashboard
|
||||

|
||||
|
||||
|
||||
|
||||
### Design System
|
||||

|
||||
|
||||
</details>
|
||||
### Syndarix Extensions
|
||||
| Component | Technology |
|
||||
|-----------|------------|
|
||||
| Task Queue | Celery + Redis |
|
||||
| Real-time | FastAPI WebSocket / SSE |
|
||||
| Vector DB | pgvector (PostgreSQL extension) |
|
||||
| MCP SDK | Anthropic MCP SDK |
|
||||
|
||||
---
|
||||
|
||||
## 🎭 Demo Mode
|
||||
## Project Status
|
||||
|
||||
**Try the frontend without a backend!** Perfect for:
|
||||
- **Free deployment** on Vercel (no backend costs)
|
||||
- **Portfolio showcasing** with live demos
|
||||
- **Client presentations** without infrastructure setup
|
||||
**Phase:** Architecture & Planning
|
||||
|
||||
See [docs/requirements/](./docs/requirements/) for the comprehensive requirements document.
|
||||
|
||||
### Current Milestones
|
||||
- [x] Fork PragmaStack as foundation
|
||||
- [x] Create requirements document
|
||||
- [ ] Execute architecture spikes
|
||||
- [ ] Create ADRs for key decisions
|
||||
- [ ] Begin MVP implementation
|
||||
|
||||
---
|
||||
|
||||
## Documentation
|
||||
|
||||
- [Requirements Document](./docs/requirements/SYNDARIX_REQUIREMENTS.md)
|
||||
- [Architecture Decisions](./docs/adrs/) (coming soon)
|
||||
- [Spike Research](./docs/spikes/) (coming soon)
|
||||
- [Architecture Overview](./docs/architecture/) (coming soon)
|
||||
|
||||
---
|
||||
|
||||
## Getting Started
|
||||
|
||||
### Prerequisites
|
||||
- Docker & Docker Compose
|
||||
- Node.js 20+
|
||||
- Python 3.11+
|
||||
- PostgreSQL 15+ (or use Docker)
|
||||
|
||||
### Quick Start
|
||||
|
||||
```bash
|
||||
cd frontend
|
||||
echo "NEXT_PUBLIC_DEMO_MODE=true" > .env.local
|
||||
npm run dev
|
||||
```
|
||||
|
||||
**Demo Credentials:**
|
||||
- Regular user: `demo@example.com` / `DemoPass123`
|
||||
- Admin user: `admin@example.com` / `AdminPass123`
|
||||
|
||||
Demo mode uses [Mock Service Worker (MSW)](https://mswjs.io/) to intercept API calls in the browser. Your code remains unchanged - the same components work with both real and mocked backends.
|
||||
|
||||
**Key Features:**
|
||||
- ✅ Zero backend required
|
||||
- ✅ All features functional (auth, admin, stats)
|
||||
- ✅ Realistic network delays and errors
|
||||
- ✅ Does NOT interfere with tests (97%+ coverage maintained)
|
||||
- ✅ One-line toggle: `NEXT_PUBLIC_DEMO_MODE=true`
|
||||
|
||||
📖 **[Complete Demo Mode Documentation](./frontend/docs/DEMO_MODE.md)**
|
||||
|
||||
---
|
||||
|
||||
## 🚀 Tech Stack
|
||||
|
||||
### Backend
|
||||
- **[FastAPI](https://fastapi.tiangolo.com/)** - Modern async Python web framework
|
||||
- **[SQLAlchemy 2.0](https://www.sqlalchemy.org/)** - Powerful ORM with async support
|
||||
- **[PostgreSQL](https://www.postgresql.org/)** - Robust relational database
|
||||
- **[Alembic](https://alembic.sqlalchemy.org/)** - Database migrations
|
||||
- **[Pydantic v2](https://docs.pydantic.dev/)** - Data validation with type hints
|
||||
- **[pytest](https://pytest.org/)** - Testing framework with async support
|
||||
|
||||
### Frontend
|
||||
- **[Next.js 16](https://nextjs.org/)** - React framework with App Router
|
||||
- **[React 19](https://react.dev/)** - UI library
|
||||
- **[TypeScript](https://www.typescriptlang.org/)** - Type-safe JavaScript
|
||||
- **[TailwindCSS](https://tailwindcss.com/)** - Utility-first CSS framework
|
||||
- **[shadcn/ui](https://ui.shadcn.com/)** - Beautiful, accessible component library
|
||||
- **[next-intl](https://next-intl.dev/)** - Internationalization (i18n) with type safety
|
||||
- **[TanStack Query](https://tanstack.com/query)** - Powerful data fetching/caching
|
||||
- **[Zustand](https://zustand-demo.pmnd.rs/)** - Lightweight state management
|
||||
- **[Framer Motion](https://www.framer.com/motion/)** - Production-ready animation library
|
||||
- **[Sonner](https://sonner.emilkowal.ski/)** - Beautiful toast notifications
|
||||
- **[Recharts](https://recharts.org/)** - Composable charting library
|
||||
- **[React Markdown](https://github.com/remarkjs/react-markdown)** - Markdown rendering with GFM support
|
||||
- **[Playwright](https://playwright.dev/)** - End-to-end testing
|
||||
|
||||
### DevOps
|
||||
- **[Docker](https://www.docker.com/)** - Containerization
|
||||
- **[docker-compose](https://docs.docker.com/compose/)** - Multi-container orchestration
|
||||
- **GitHub Actions** (coming soon) - CI/CD pipelines
|
||||
|
||||
---
|
||||
|
||||
## 📋 Prerequisites
|
||||
|
||||
- **Docker & Docker Compose** (recommended) - [Install Docker](https://docs.docker.com/get-docker/)
|
||||
- **OR manually:**
|
||||
- Python 3.12+
|
||||
- Node.js 18+ (Node 20+ recommended)
|
||||
- PostgreSQL 15+
|
||||
|
||||
---
|
||||
|
||||
## 🏃 Quick Start (Docker)
|
||||
|
||||
The fastest way to get started is with Docker:
|
||||
|
||||
```bash
|
||||
# Clone the repository
|
||||
git clone https://github.com/cardosofelipe/pragma-stack.git
|
||||
cd fast-next-template
|
||||
git clone https://gitea.pragmazest.com/cardosofelipe/syndarix.git
|
||||
cd syndarix
|
||||
|
||||
# Copy environment file
|
||||
# Copy environment template
|
||||
cp .env.template .env
|
||||
|
||||
# Start all services (backend, frontend, database)
|
||||
docker-compose up
|
||||
# Start development environment
|
||||
docker-compose -f docker-compose.dev.yml up -d
|
||||
|
||||
# In another terminal, run database migrations
|
||||
docker-compose exec backend alembic upgrade head
|
||||
# Run database migrations
|
||||
make migrate
|
||||
|
||||
# Create first superuser (optional)
|
||||
docker-compose exec backend python -c "from app.init_db import init_db; import asyncio; asyncio.run(init_db())"
|
||||
```
|
||||
|
||||
**That's it! 🎉**
|
||||
|
||||
- Frontend: http://localhost:3000
|
||||
- Backend API: http://localhost:8000
|
||||
- API Docs: http://localhost:8000/docs
|
||||
|
||||
Default superuser credentials:
|
||||
- Email: `admin@example.com`
|
||||
- Password: `admin123`
|
||||
|
||||
**⚠️ Change these immediately in production!**
|
||||
|
||||
---
|
||||
|
||||
## 🛠️ Manual Setup (Development)
|
||||
|
||||
### Backend Setup
|
||||
|
||||
```bash
|
||||
cd backend
|
||||
|
||||
# Create virtual environment
|
||||
python -m venv .venv
|
||||
source .venv/bin/activate # On Windows: .venv\Scripts\activate
|
||||
|
||||
# Install dependencies
|
||||
pip install -r requirements.txt
|
||||
|
||||
# Setup environment
|
||||
cp .env.example .env
|
||||
# Edit .env with your database credentials
|
||||
|
||||
# Run migrations
|
||||
alembic upgrade head
|
||||
|
||||
# Initialize database with first superuser
|
||||
python -c "from app.init_db import init_db; import asyncio; asyncio.run(init_db())"
|
||||
|
||||
# Start development server
|
||||
uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
|
||||
```
|
||||
|
||||
### Frontend Setup
|
||||
|
||||
```bash
|
||||
cd frontend
|
||||
|
||||
# Install dependencies
|
||||
npm install
|
||||
|
||||
# Setup environment
|
||||
cp .env.local.example .env.local
|
||||
# Edit .env.local with your backend URL
|
||||
|
||||
# Generate API client
|
||||
npm run generate:api
|
||||
|
||||
# Start development server
|
||||
npm run dev
|
||||
```
|
||||
|
||||
Visit http://localhost:3000 to see your app!
|
||||
|
||||
---
|
||||
|
||||
## 📂 Project Structure
|
||||
|
||||
```
|
||||
├── backend/ # FastAPI backend
|
||||
│ ├── app/
|
||||
│ │ ├── api/ # API routes and dependencies
|
||||
│ │ ├── core/ # Core functionality (auth, config, database)
|
||||
│ │ ├── crud/ # Database operations
|
||||
│ │ ├── models/ # SQLAlchemy models
|
||||
│ │ ├── schemas/ # Pydantic schemas
|
||||
│ │ ├── services/ # Business logic
|
||||
│ │ └── utils/ # Utilities
|
||||
│ ├── tests/ # Backend tests (97% coverage)
|
||||
│ ├── alembic/ # Database migrations
|
||||
│ └── docs/ # Backend documentation
|
||||
│
|
||||
├── frontend/ # Next.js frontend
|
||||
│ ├── src/
|
||||
│ │ ├── app/ # Next.js App Router pages
|
||||
│ │ ├── components/ # React components
|
||||
│ │ ├── lib/ # Libraries and utilities
|
||||
│ │ │ ├── api/ # API client (auto-generated)
|
||||
│ │ │ └── stores/ # Zustand stores
|
||||
│ │ └── hooks/ # Custom React hooks
|
||||
│ ├── e2e/ # Playwright E2E tests
|
||||
│ ├── tests/ # Unit tests (Jest)
|
||||
│ └── docs/ # Frontend documentation
|
||||
│ └── design-system/ # Comprehensive design system docs
|
||||
│
|
||||
├── docker-compose.yml # Docker orchestration
|
||||
├── docker-compose.dev.yml # Development with hot reload
|
||||
└── README.md # You are here!
|
||||
# Start the development servers
|
||||
make dev
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🧪 Testing
|
||||
## Architecture Overview
|
||||
|
||||
This template takes testing seriously with comprehensive coverage across all layers:
|
||||
|
||||
### Backend Unit & Integration Tests
|
||||
|
||||
**High coverage (~97%)** across all critical paths including security-focused tests.
|
||||
|
||||
```bash
|
||||
cd backend
|
||||
|
||||
# Run all tests
|
||||
IS_TEST=True pytest
|
||||
|
||||
# Run with coverage report
|
||||
IS_TEST=True pytest --cov=app --cov-report=term-missing
|
||||
|
||||
# Run specific test file
|
||||
IS_TEST=True pytest tests/api/test_auth.py -v
|
||||
|
||||
# Generate HTML coverage report
|
||||
IS_TEST=True pytest --cov=app --cov-report=html
|
||||
open htmlcov/index.html
|
||||
```
|
||||
|
||||
**Test types:**
|
||||
- **Unit tests**: CRUD operations, utilities, business logic
|
||||
- **Integration tests**: API endpoints with database
|
||||
- **Security tests**: JWT algorithm attacks, session hijacking, privilege escalation
|
||||
- **Error handling tests**: Database failures, validation errors
|
||||
|
||||
### Frontend Unit Tests
|
||||
|
||||
**High coverage (~97%)** with Jest and React Testing Library.
|
||||
|
||||
```bash
|
||||
cd frontend
|
||||
|
||||
# Run unit tests
|
||||
npm test
|
||||
|
||||
# Run with coverage
|
||||
npm run test:coverage
|
||||
|
||||
# Watch mode
|
||||
npm run test:watch
|
||||
```
|
||||
|
||||
**Test types:**
|
||||
- Component rendering and interactions
|
||||
- Custom hooks behavior
|
||||
- State management
|
||||
- Utility functions
|
||||
- API integration mocks
|
||||
|
||||
### End-to-End Tests
|
||||
|
||||
**Zero flaky tests** with Playwright covering complete user journeys.
|
||||
|
||||
```bash
|
||||
cd frontend
|
||||
|
||||
# Run E2E tests
|
||||
npm run test:e2e
|
||||
|
||||
# Run E2E tests in UI mode (recommended for development)
|
||||
npm run test:e2e:ui
|
||||
|
||||
# Run specific test file
|
||||
npx playwright test auth-login.spec.ts
|
||||
|
||||
# Generate test report
|
||||
npx playwright show-report
|
||||
```
|
||||
|
||||
**Test coverage:**
|
||||
- Complete authentication flows
|
||||
- Navigation and routing
|
||||
- Form submissions and validation
|
||||
- Settings and profile management
|
||||
- Session management
|
||||
- Admin panel workflows (in progress)
|
||||
|
||||
---
|
||||
|
||||
## 🤖 AI-Friendly Documentation
|
||||
|
||||
This project includes comprehensive documentation designed for AI coding assistants:
|
||||
|
||||
- **[AGENTS.md](./AGENTS.md)** - Framework-agnostic AI assistant context for PragmaStack
|
||||
- **[CLAUDE.md](./CLAUDE.md)** - Claude Code-specific guidance
|
||||
|
||||
These files provide AI assistants with the **PragmaStack** architecture, patterns, and best practices.
|
||||
|
||||
---
|
||||
|
||||
## 🗄️ Database Migrations
|
||||
|
||||
The template uses Alembic for database migrations:
|
||||
|
||||
```bash
|
||||
cd backend
|
||||
|
||||
# Generate migration from model changes
|
||||
python migrate.py generate "description of changes"
|
||||
|
||||
# Apply migrations
|
||||
python migrate.py apply
|
||||
|
||||
# Or do both in one command
|
||||
python migrate.py auto "description"
|
||||
|
||||
# View migration history
|
||||
python migrate.py list
|
||||
|
||||
# Check current revision
|
||||
python migrate.py current
|
||||
+====================================================================+
|
||||
| SYNDARIX CORE |
|
||||
+====================================================================+
|
||||
| +------------------+ +------------------+ +------------------+ |
|
||||
| | Agent Orchestrator| | Project Manager | | Workflow Engine | |
|
||||
| +------------------+ +------------------+ +------------------+ |
|
||||
+====================================================================+
|
||||
|
|
||||
v
|
||||
+====================================================================+
|
||||
| MCP ORCHESTRATION LAYER |
|
||||
| All integrations via unified MCP servers with project scoping |
|
||||
+====================================================================+
|
||||
|
|
||||
+------------------------+------------------------+
|
||||
| | |
|
||||
+----v----+ +----v----+ +----v----+ +----v----+ +----v----+
|
||||
| LLM | | Git | |Knowledge| | File | | Code |
|
||||
| Providers| | MCP | |Base MCP | |Sys. MCP | |Analysis |
|
||||
+---------+ +---------+ +---------+ +---------+ +---------+
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 📖 Documentation
|
||||
## Contributing
|
||||
|
||||
### AI Assistant Documentation
|
||||
|
||||
- **[AGENTS.md](./AGENTS.md)** - Framework-agnostic AI coding assistant context
|
||||
- **[CLAUDE.md](./CLAUDE.md)** - Claude Code-specific guidance and preferences
|
||||
|
||||
### Backend Documentation
|
||||
|
||||
- **[ARCHITECTURE.md](./backend/docs/ARCHITECTURE.md)** - System architecture and design patterns
|
||||
- **[CODING_STANDARDS.md](./backend/docs/CODING_STANDARDS.md)** - Code quality standards
|
||||
- **[COMMON_PITFALLS.md](./backend/docs/COMMON_PITFALLS.md)** - Common mistakes to avoid
|
||||
- **[FEATURE_EXAMPLE.md](./backend/docs/FEATURE_EXAMPLE.md)** - Step-by-step feature guide
|
||||
|
||||
### Frontend Documentation
|
||||
|
||||
- **[PragmaStack Design System](./frontend/docs/design-system/)** - Complete design system guide
|
||||
- Quick start, foundations (colors, typography, spacing)
|
||||
- Component library guide
|
||||
- Layout patterns, spacing philosophy
|
||||
- Forms, accessibility, AI guidelines
|
||||
- **[E2E Testing Guide](./frontend/e2e/README.md)** - E2E testing setup and best practices
|
||||
|
||||
### API Documentation
|
||||
|
||||
When the backend is running:
|
||||
- **Swagger UI**: http://localhost:8000/docs
|
||||
- **ReDoc**: http://localhost:8000/redoc
|
||||
- **OpenAPI JSON**: http://localhost:8000/api/v1/openapi.json
|
||||
See [CONTRIBUTING.md](./CONTRIBUTING.md) for guidelines.
|
||||
|
||||
---
|
||||
|
||||
## 🚢 Deployment
|
||||
## License
|
||||
|
||||
### Docker Production Deployment
|
||||
|
||||
```bash
|
||||
# Build and start all services
|
||||
docker-compose up -d
|
||||
|
||||
# Run migrations
|
||||
docker-compose exec backend alembic upgrade head
|
||||
|
||||
# View logs
|
||||
docker-compose logs -f
|
||||
|
||||
# Stop services
|
||||
docker-compose down
|
||||
```
|
||||
|
||||
### Production Checklist
|
||||
|
||||
- [ ] Change default superuser credentials
|
||||
- [ ] Set strong `SECRET_KEY` in backend `.env`
|
||||
- [ ] Configure production database (PostgreSQL)
|
||||
- [ ] Set `ENVIRONMENT=production` in backend
|
||||
- [ ] Configure CORS origins for your domain
|
||||
- [ ] Setup SSL/TLS certificates
|
||||
- [ ] Configure email service for password resets
|
||||
- [ ] Setup monitoring and logging
|
||||
- [ ] Configure backup strategy
|
||||
- [ ] Review and adjust rate limits
|
||||
- [ ] Test security headers
|
||||
MIT License - see [LICENSE](./LICENSE) for details.
|
||||
|
||||
---
|
||||
|
||||
## 🛣️ Roadmap & Status
|
||||
## Acknowledgments
|
||||
|
||||
### ✅ Completed
|
||||
- [x] Authentication system (JWT, refresh tokens, session management, OAuth)
|
||||
- [x] User management (CRUD, profile, password change)
|
||||
- [x] Organization system with RBAC (Owner, Admin, Member)
|
||||
- [x] Admin panel (users, organizations, sessions, statistics)
|
||||
- [x] **Internationalization (i18n)** with next-intl (English + Italian)
|
||||
- [x] Backend testing infrastructure (~97% coverage)
|
||||
- [x] Frontend unit testing infrastructure (~97% coverage)
|
||||
- [x] Frontend E2E testing (Playwright, zero flaky tests)
|
||||
- [x] Design system documentation
|
||||
- [x] **Marketing landing page** with animated components
|
||||
- [x] **`/dev` documentation portal** with live component examples
|
||||
- [x] **Toast notifications** system (Sonner)
|
||||
- [x] **Charts and visualizations** (Recharts)
|
||||
- [x] **Animation system** (Framer Motion)
|
||||
- [x] **Markdown rendering** with syntax highlighting
|
||||
- [x] **SEO optimization** (sitemap, robots.txt, locale-aware metadata)
|
||||
- [x] Database migrations with helper script
|
||||
- [x] Docker deployment
|
||||
- [x] API documentation (OpenAPI/Swagger)
|
||||
|
||||
### 🚧 In Progress
|
||||
- [ ] Email integration (templates ready, SMTP pending)
|
||||
|
||||
### 🔮 Planned
|
||||
- [ ] GitHub Actions CI/CD pipelines
|
||||
- [ ] Dynamic test coverage badges from CI
|
||||
- [ ] E2E test coverage reporting
|
||||
- [ ] OAuth token encryption at rest (security hardening)
|
||||
- [ ] Additional languages (Spanish, French, German, etc.)
|
||||
- [ ] SSO/SAML authentication
|
||||
- [ ] Real-time notifications with WebSockets
|
||||
- [ ] Webhook system
|
||||
- [ ] File upload/storage (S3-compatible)
|
||||
- [ ] Audit logging system
|
||||
- [ ] API versioning example
|
||||
|
||||
|
||||
---
|
||||
|
||||
## 🤝 Contributing
|
||||
|
||||
Contributions are welcome! Whether you're fixing bugs, improving documentation, or proposing new features, we'd love your help.
|
||||
|
||||
### How to Contribute
|
||||
|
||||
1. **Fork the repository**
|
||||
2. **Create a feature branch** (`git checkout -b feature/amazing-feature`)
|
||||
3. **Make your changes**
|
||||
- Follow existing code style
|
||||
- Add tests for new features
|
||||
- Update documentation as needed
|
||||
4. **Run tests** to ensure everything works
|
||||
5. **Commit your changes** (`git commit -m 'Add amazing feature'`)
|
||||
6. **Push to your branch** (`git push origin feature/amazing-feature`)
|
||||
7. **Open a Pull Request**
|
||||
|
||||
### Development Guidelines
|
||||
|
||||
- Write tests for new features (aim for >90% coverage)
|
||||
- Follow the existing architecture patterns
|
||||
- Update documentation when adding features
|
||||
- Keep commits atomic and well-described
|
||||
- Be respectful and constructive in discussions
|
||||
|
||||
### Reporting Issues
|
||||
|
||||
Found a bug? Have a suggestion? [Open an issue](https://github.com/cardosofelipe/pragma-stack/issues)!
|
||||
|
||||
Please include:
|
||||
- Clear description of the issue/suggestion
|
||||
- Steps to reproduce (for bugs)
|
||||
- Expected vs. actual behavior
|
||||
- Environment details (OS, Python/Node version, etc.)
|
||||
|
||||
---
|
||||
|
||||
## 📄 License
|
||||
|
||||
This project is licensed under the **MIT License** - see the [LICENSE](./LICENSE) file for details.
|
||||
|
||||
**TL;DR**: You can use this template for any purpose, commercial or non-commercial. Attribution is appreciated but not required!
|
||||
|
||||
---
|
||||
|
||||
## 🙏 Acknowledgments
|
||||
|
||||
This template is built on the shoulders of giants:
|
||||
|
||||
- [FastAPI](https://fastapi.tiangolo.com/) by Sebastián Ramírez
|
||||
- [Next.js](https://nextjs.org/) by Vercel
|
||||
- [shadcn/ui](https://ui.shadcn.com/) by shadcn
|
||||
- [TanStack Query](https://tanstack.com/query) by Tanner Linsley
|
||||
- [Playwright](https://playwright.dev/) by Microsoft
|
||||
- And countless other open-source projects that make modern development possible
|
||||
|
||||
---
|
||||
|
||||
## 💬 Questions?
|
||||
|
||||
- **Documentation**: Check the `/docs` folders in backend and frontend
|
||||
- **Issues**: [GitHub Issues](https://github.com/cardosofelipe/pragma-stack/issues)
|
||||
- **Discussions**: [GitHub Discussions](https://github.com/cardosofelipe/pragma-stack/discussions)
|
||||
|
||||
---
|
||||
|
||||
## ⭐ Star This Repo
|
||||
|
||||
If this template saves you time, consider giving it a star! It helps others discover the project and motivates continued development.
|
||||
|
||||
**Happy coding! 🚀**
|
||||
|
||||
---
|
||||
|
||||
<div align="center">
|
||||
Made with ❤️ by a developer who got tired of rebuilding the same boilerplate
|
||||
</div>
|
||||
- Built on [PragmaStack](https://gitea.pragmazest.com/cardosofelipe/fast-next-template)
|
||||
- Powered by Claude and the Anthropic API
|
||||
|
||||
@@ -7,7 +7,10 @@ ENV PYTHONDONTWRITEBYTECODE=1 \
|
||||
PYTHONPATH=/app \
|
||||
UV_COMPILE_BYTECODE=1 \
|
||||
UV_LINK_MODE=copy \
|
||||
UV_NO_CACHE=1
|
||||
UV_NO_CACHE=1 \
|
||||
UV_PROJECT_ENVIRONMENT=/opt/venv \
|
||||
VIRTUAL_ENV=/opt/venv \
|
||||
PATH="/opt/venv/bin:$PATH"
|
||||
|
||||
# Install system dependencies and uv
|
||||
RUN apt-get update && \
|
||||
@@ -20,7 +23,7 @@ RUN apt-get update && \
|
||||
# Copy dependency files
|
||||
COPY pyproject.toml uv.lock ./
|
||||
|
||||
# Install dependencies using uv (development mode with dev dependencies)
|
||||
# Install dependencies using uv into /opt/venv (outside /app to survive bind mounts)
|
||||
RUN uv sync --extra dev --frozen
|
||||
|
||||
# Copy application code
|
||||
@@ -45,7 +48,10 @@ ENV PYTHONDONTWRITEBYTECODE=1 \
|
||||
PYTHONPATH=/app \
|
||||
UV_COMPILE_BYTECODE=1 \
|
||||
UV_LINK_MODE=copy \
|
||||
UV_NO_CACHE=1
|
||||
UV_NO_CACHE=1 \
|
||||
UV_PROJECT_ENVIRONMENT=/opt/venv \
|
||||
VIRTUAL_ENV=/opt/venv \
|
||||
PATH="/opt/venv/bin:$PATH"
|
||||
|
||||
# Install system dependencies and uv
|
||||
RUN apt-get update && \
|
||||
@@ -58,7 +64,7 @@ RUN apt-get update && \
|
||||
# Copy dependency files
|
||||
COPY pyproject.toml uv.lock ./
|
||||
|
||||
# Install only production dependencies using uv (no dev dependencies)
|
||||
# Install only production dependencies using uv into /opt/venv
|
||||
RUN uv sync --frozen --no-dev
|
||||
|
||||
# Copy application code
|
||||
@@ -67,7 +73,7 @@ COPY entrypoint.sh /usr/local/bin/
|
||||
RUN chmod +x /usr/local/bin/entrypoint.sh
|
||||
|
||||
# Set ownership to non-root user
|
||||
RUN chown -R appuser:appuser /app
|
||||
RUN chown -R appuser:appuser /app /opt/venv
|
||||
|
||||
# Switch to non-root user
|
||||
USER appuser
|
||||
@@ -77,4 +83,4 @@ HEALTHCHECK --interval=30s --timeout=10s --start-period=40s --retries=3 \
|
||||
CMD curl -f http://localhost:8000/health || exit 1
|
||||
|
||||
ENTRYPOINT ["/usr/local/bin/entrypoint.sh"]
|
||||
CMD ["uv", "run", "uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
.PHONY: help lint lint-fix format format-check type-check test test-cov validate clean install-dev sync check-docker install-e2e test-e2e test-e2e-schema test-all
|
||||
.PHONY: help lint lint-fix format format-check type-check test test-cov validate clean install-dev sync check-docker install-e2e test-e2e test-e2e-schema test-all test-integration
|
||||
|
||||
# Default target
|
||||
help:
|
||||
@@ -22,6 +22,7 @@ help:
|
||||
@echo " make test-cov - Run pytest with coverage report"
|
||||
@echo " make test-e2e - Run E2E tests (PostgreSQL, requires Docker)"
|
||||
@echo " make test-e2e-schema - Run Schemathesis API schema tests"
|
||||
@echo " make test-integration - Run MCP integration tests (requires running stack)"
|
||||
@echo " make test-all - Run all tests (unit + E2E)"
|
||||
@echo " make check-docker - Check if Docker is available"
|
||||
@echo ""
|
||||
@@ -82,6 +83,15 @@ test-cov:
|
||||
@IS_TEST=True PYTHONPATH=. uv run pytest --cov=app --cov-report=term-missing --cov-report=html -n 16
|
||||
@echo "📊 Coverage report generated in htmlcov/index.html"
|
||||
|
||||
# ============================================================================
|
||||
# Integration Testing (requires running stack: make dev)
|
||||
# ============================================================================
|
||||
|
||||
test-integration:
|
||||
@echo "🧪 Running MCP integration tests..."
|
||||
@echo "Note: Requires running stack (make dev from project root)"
|
||||
@RUN_INTEGRATION_TESTS=true IS_TEST=True PYTHONPATH=. uv run pytest tests/integration/ -v
|
||||
|
||||
# ============================================================================
|
||||
# E2E Testing (requires Docker)
|
||||
# ============================================================================
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# PragmaStack Backend API
|
||||
# Syndarix Backend API
|
||||
|
||||
> The pragmatic, production-ready FastAPI backend for PragmaStack.
|
||||
> The pragmatic, production-ready FastAPI backend for Syndarix.
|
||||
|
||||
## Overview
|
||||
|
||||
|
||||
@@ -40,6 +40,7 @@ def include_object(object, name, type_, reflected, compare_to):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
# This line sets up loggers basically.
|
||||
if config.config_file_name is not None:
|
||||
|
||||
@@ -1,262 +1,446 @@
|
||||
"""initial models
|
||||
|
||||
Revision ID: 0001
|
||||
Revises:
|
||||
Revises:
|
||||
Create Date: 2025-11-27 09:08:09.464506
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '0001'
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
revision: str = "0001"
|
||||
down_revision: str | None = None
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('oauth_states',
|
||||
sa.Column('state', sa.String(length=255), nullable=False),
|
||||
sa.Column('code_verifier', sa.String(length=128), nullable=True),
|
||||
sa.Column('nonce', sa.String(length=255), nullable=True),
|
||||
sa.Column('provider', sa.String(length=50), nullable=False),
|
||||
sa.Column('redirect_uri', sa.String(length=500), nullable=True),
|
||||
sa.Column('user_id', sa.UUID(), nullable=True),
|
||||
sa.Column('expires_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
op.create_table(
|
||||
"oauth_states",
|
||||
sa.Column("state", sa.String(length=255), nullable=False),
|
||||
sa.Column("code_verifier", sa.String(length=128), nullable=True),
|
||||
sa.Column("nonce", sa.String(length=255), nullable=True),
|
||||
sa.Column("provider", sa.String(length=50), nullable=False),
|
||||
sa.Column("redirect_uri", sa.String(length=500), nullable=True),
|
||||
sa.Column("user_id", sa.UUID(), nullable=True),
|
||||
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(op.f('ix_oauth_states_state'), 'oauth_states', ['state'], unique=True)
|
||||
op.create_table('organizations',
|
||||
sa.Column('name', sa.String(length=255), nullable=False),
|
||||
sa.Column('slug', sa.String(length=255), nullable=False),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||
sa.Column('settings', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
op.create_index(
|
||||
op.f("ix_oauth_states_state"), "oauth_states", ["state"], unique=True
|
||||
)
|
||||
op.create_index(op.f('ix_organizations_is_active'), 'organizations', ['is_active'], unique=False)
|
||||
op.create_index(op.f('ix_organizations_name'), 'organizations', ['name'], unique=False)
|
||||
op.create_index('ix_organizations_name_active', 'organizations', ['name', 'is_active'], unique=False)
|
||||
op.create_index(op.f('ix_organizations_slug'), 'organizations', ['slug'], unique=True)
|
||||
op.create_index('ix_organizations_slug_active', 'organizations', ['slug', 'is_active'], unique=False)
|
||||
op.create_table('users',
|
||||
sa.Column('email', sa.String(length=255), nullable=False),
|
||||
sa.Column('password_hash', sa.String(length=255), nullable=True),
|
||||
sa.Column('first_name', sa.String(length=100), nullable=False),
|
||||
sa.Column('last_name', sa.String(length=100), nullable=True),
|
||||
sa.Column('phone_number', sa.String(length=20), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||
sa.Column('is_superuser', sa.Boolean(), nullable=False),
|
||||
sa.Column('preferences', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column('locale', sa.String(length=10), nullable=True),
|
||||
sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
op.create_table(
|
||||
"organizations",
|
||||
sa.Column("name", sa.String(length=255), nullable=False),
|
||||
sa.Column("slug", sa.String(length=255), nullable=False),
|
||||
sa.Column("description", sa.Text(), nullable=True),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=False),
|
||||
sa.Column("settings", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(op.f('ix_users_deleted_at'), 'users', ['deleted_at'], unique=False)
|
||||
op.create_index(op.f('ix_users_email'), 'users', ['email'], unique=True)
|
||||
op.create_index(op.f('ix_users_is_active'), 'users', ['is_active'], unique=False)
|
||||
op.create_index(op.f('ix_users_is_superuser'), 'users', ['is_superuser'], unique=False)
|
||||
op.create_index(op.f('ix_users_locale'), 'users', ['locale'], unique=False)
|
||||
op.create_table('oauth_accounts',
|
||||
sa.Column('user_id', sa.UUID(), nullable=False),
|
||||
sa.Column('provider', sa.String(length=50), nullable=False),
|
||||
sa.Column('provider_user_id', sa.String(length=255), nullable=False),
|
||||
sa.Column('provider_email', sa.String(length=255), nullable=True),
|
||||
sa.Column('access_token_encrypted', sa.String(length=2048), nullable=True),
|
||||
sa.Column('refresh_token_encrypted', sa.String(length=2048), nullable=True),
|
||||
sa.Column('token_expires_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('provider', 'provider_user_id', name='uq_oauth_provider_user')
|
||||
op.create_index(
|
||||
op.f("ix_organizations_is_active"), "organizations", ["is_active"], unique=False
|
||||
)
|
||||
op.create_index(op.f('ix_oauth_accounts_provider'), 'oauth_accounts', ['provider'], unique=False)
|
||||
op.create_index(op.f('ix_oauth_accounts_provider_email'), 'oauth_accounts', ['provider_email'], unique=False)
|
||||
op.create_index(op.f('ix_oauth_accounts_user_id'), 'oauth_accounts', ['user_id'], unique=False)
|
||||
op.create_index('ix_oauth_accounts_user_provider', 'oauth_accounts', ['user_id', 'provider'], unique=False)
|
||||
op.create_table('oauth_clients',
|
||||
sa.Column('client_id', sa.String(length=64), nullable=False),
|
||||
sa.Column('client_secret_hash', sa.String(length=255), nullable=True),
|
||||
sa.Column('client_name', sa.String(length=255), nullable=False),
|
||||
sa.Column('client_description', sa.String(length=1000), nullable=True),
|
||||
sa.Column('client_type', sa.String(length=20), nullable=False),
|
||||
sa.Column('redirect_uris', postgresql.JSONB(astext_type=sa.Text()), nullable=False),
|
||||
sa.Column('allowed_scopes', postgresql.JSONB(astext_type=sa.Text()), nullable=False),
|
||||
sa.Column('access_token_lifetime', sa.String(length=10), nullable=False),
|
||||
sa.Column('refresh_token_lifetime', sa.String(length=10), nullable=False),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||
sa.Column('owner_user_id', sa.UUID(), nullable=True),
|
||||
sa.Column('mcp_server_url', sa.String(length=2048), nullable=True),
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(['owner_user_id'], ['users.id'], ondelete='SET NULL'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
op.create_index(
|
||||
op.f("ix_organizations_name"), "organizations", ["name"], unique=False
|
||||
)
|
||||
op.create_index(op.f('ix_oauth_clients_client_id'), 'oauth_clients', ['client_id'], unique=True)
|
||||
op.create_index(op.f('ix_oauth_clients_is_active'), 'oauth_clients', ['is_active'], unique=False)
|
||||
op.create_table('user_organizations',
|
||||
sa.Column('user_id', sa.UUID(), nullable=False),
|
||||
sa.Column('organization_id', sa.UUID(), nullable=False),
|
||||
sa.Column('role', sa.Enum('OWNER', 'ADMIN', 'MEMBER', 'GUEST', name='organizationrole'), nullable=False),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||
sa.Column('custom_permissions', sa.String(length=500), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('user_id', 'organization_id')
|
||||
op.create_index(
|
||||
"ix_organizations_name_active",
|
||||
"organizations",
|
||||
["name", "is_active"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index('ix_user_org_org_active', 'user_organizations', ['organization_id', 'is_active'], unique=False)
|
||||
op.create_index('ix_user_org_role', 'user_organizations', ['role'], unique=False)
|
||||
op.create_index('ix_user_org_user_active', 'user_organizations', ['user_id', 'is_active'], unique=False)
|
||||
op.create_index(op.f('ix_user_organizations_is_active'), 'user_organizations', ['is_active'], unique=False)
|
||||
op.create_table('user_sessions',
|
||||
sa.Column('user_id', sa.UUID(), nullable=False),
|
||||
sa.Column('refresh_token_jti', sa.String(length=255), nullable=False),
|
||||
sa.Column('device_name', sa.String(length=255), nullable=True),
|
||||
sa.Column('device_id', sa.String(length=255), nullable=True),
|
||||
sa.Column('ip_address', sa.String(length=45), nullable=True),
|
||||
sa.Column('user_agent', sa.String(length=500), nullable=True),
|
||||
sa.Column('last_used_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('expires_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||
sa.Column('location_city', sa.String(length=100), nullable=True),
|
||||
sa.Column('location_country', sa.String(length=100), nullable=True),
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
op.create_index(
|
||||
op.f("ix_organizations_slug"), "organizations", ["slug"], unique=True
|
||||
)
|
||||
op.create_index(op.f('ix_user_sessions_is_active'), 'user_sessions', ['is_active'], unique=False)
|
||||
op.create_index('ix_user_sessions_jti_active', 'user_sessions', ['refresh_token_jti', 'is_active'], unique=False)
|
||||
op.create_index(op.f('ix_user_sessions_refresh_token_jti'), 'user_sessions', ['refresh_token_jti'], unique=True)
|
||||
op.create_index('ix_user_sessions_user_active', 'user_sessions', ['user_id', 'is_active'], unique=False)
|
||||
op.create_index(op.f('ix_user_sessions_user_id'), 'user_sessions', ['user_id'], unique=False)
|
||||
op.create_table('oauth_authorization_codes',
|
||||
sa.Column('code', sa.String(length=128), nullable=False),
|
||||
sa.Column('client_id', sa.String(length=64), nullable=False),
|
||||
sa.Column('user_id', sa.UUID(), nullable=False),
|
||||
sa.Column('redirect_uri', sa.String(length=2048), nullable=False),
|
||||
sa.Column('scope', sa.String(length=1000), nullable=False),
|
||||
sa.Column('code_challenge', sa.String(length=128), nullable=True),
|
||||
sa.Column('code_challenge_method', sa.String(length=10), nullable=True),
|
||||
sa.Column('state', sa.String(length=256), nullable=True),
|
||||
sa.Column('nonce', sa.String(length=256), nullable=True),
|
||||
sa.Column('expires_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('used', sa.Boolean(), nullable=False),
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(['client_id'], ['oauth_clients.client_id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
op.create_index(
|
||||
"ix_organizations_slug_active",
|
||||
"organizations",
|
||||
["slug", "is_active"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index('ix_oauth_authorization_codes_client_user', 'oauth_authorization_codes', ['client_id', 'user_id'], unique=False)
|
||||
op.create_index(op.f('ix_oauth_authorization_codes_code'), 'oauth_authorization_codes', ['code'], unique=True)
|
||||
op.create_index('ix_oauth_authorization_codes_expires_at', 'oauth_authorization_codes', ['expires_at'], unique=False)
|
||||
op.create_table('oauth_consents',
|
||||
sa.Column('user_id', sa.UUID(), nullable=False),
|
||||
sa.Column('client_id', sa.String(length=64), nullable=False),
|
||||
sa.Column('granted_scopes', sa.String(length=1000), nullable=False),
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(['client_id'], ['oauth_clients.client_id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
op.create_table(
|
||||
"users",
|
||||
sa.Column("email", sa.String(length=255), nullable=False),
|
||||
sa.Column("password_hash", sa.String(length=255), nullable=True),
|
||||
sa.Column("first_name", sa.String(length=100), nullable=False),
|
||||
sa.Column("last_name", sa.String(length=100), nullable=True),
|
||||
sa.Column("phone_number", sa.String(length=20), nullable=True),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=False),
|
||||
sa.Column("is_superuser", sa.Boolean(), nullable=False),
|
||||
sa.Column(
|
||||
"preferences", postgresql.JSONB(astext_type=sa.Text()), nullable=True
|
||||
),
|
||||
sa.Column("locale", sa.String(length=10), nullable=True),
|
||||
sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index('ix_oauth_consents_user_client', 'oauth_consents', ['user_id', 'client_id'], unique=True)
|
||||
op.create_table('oauth_provider_refresh_tokens',
|
||||
sa.Column('token_hash', sa.String(length=64), nullable=False),
|
||||
sa.Column('jti', sa.String(length=64), nullable=False),
|
||||
sa.Column('client_id', sa.String(length=64), nullable=False),
|
||||
sa.Column('user_id', sa.UUID(), nullable=False),
|
||||
sa.Column('scope', sa.String(length=1000), nullable=False),
|
||||
sa.Column('expires_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('revoked', sa.Boolean(), nullable=False),
|
||||
sa.Column('last_used_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('device_info', sa.String(length=500), nullable=True),
|
||||
sa.Column('ip_address', sa.String(length=45), nullable=True),
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(['client_id'], ['oauth_clients.client_id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
op.create_index(op.f("ix_users_deleted_at"), "users", ["deleted_at"], unique=False)
|
||||
op.create_index(op.f("ix_users_email"), "users", ["email"], unique=True)
|
||||
op.create_index(op.f("ix_users_is_active"), "users", ["is_active"], unique=False)
|
||||
op.create_index(
|
||||
op.f("ix_users_is_superuser"), "users", ["is_superuser"], unique=False
|
||||
)
|
||||
op.create_index(op.f("ix_users_locale"), "users", ["locale"], unique=False)
|
||||
op.create_table(
|
||||
"oauth_accounts",
|
||||
sa.Column("user_id", sa.UUID(), nullable=False),
|
||||
sa.Column("provider", sa.String(length=50), nullable=False),
|
||||
sa.Column("provider_user_id", sa.String(length=255), nullable=False),
|
||||
sa.Column("provider_email", sa.String(length=255), nullable=True),
|
||||
sa.Column("access_token_encrypted", sa.String(length=2048), nullable=True),
|
||||
sa.Column("refresh_token_encrypted", sa.String(length=2048), nullable=True),
|
||||
sa.Column("token_expires_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint(
|
||||
"provider", "provider_user_id", name="uq_oauth_provider_user"
|
||||
),
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_oauth_accounts_provider"), "oauth_accounts", ["provider"], unique=False
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_oauth_accounts_provider_email"),
|
||||
"oauth_accounts",
|
||||
["provider_email"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_oauth_accounts_user_id"), "oauth_accounts", ["user_id"], unique=False
|
||||
)
|
||||
op.create_index(
|
||||
"ix_oauth_accounts_user_provider",
|
||||
"oauth_accounts",
|
||||
["user_id", "provider"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_table(
|
||||
"oauth_clients",
|
||||
sa.Column("client_id", sa.String(length=64), nullable=False),
|
||||
sa.Column("client_secret_hash", sa.String(length=255), nullable=True),
|
||||
sa.Column("client_name", sa.String(length=255), nullable=False),
|
||||
sa.Column("client_description", sa.String(length=1000), nullable=True),
|
||||
sa.Column("client_type", sa.String(length=20), nullable=False),
|
||||
sa.Column(
|
||||
"redirect_uris", postgresql.JSONB(astext_type=sa.Text()), nullable=False
|
||||
),
|
||||
sa.Column(
|
||||
"allowed_scopes", postgresql.JSONB(astext_type=sa.Text()), nullable=False
|
||||
),
|
||||
sa.Column("access_token_lifetime", sa.String(length=10), nullable=False),
|
||||
sa.Column("refresh_token_lifetime", sa.String(length=10), nullable=False),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=False),
|
||||
sa.Column("owner_user_id", sa.UUID(), nullable=True),
|
||||
sa.Column("mcp_server_url", sa.String(length=2048), nullable=True),
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(["owner_user_id"], ["users.id"], ondelete="SET NULL"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_oauth_clients_client_id"), "oauth_clients", ["client_id"], unique=True
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_oauth_clients_is_active"), "oauth_clients", ["is_active"], unique=False
|
||||
)
|
||||
op.create_table(
|
||||
"user_organizations",
|
||||
sa.Column("user_id", sa.UUID(), nullable=False),
|
||||
sa.Column("organization_id", sa.UUID(), nullable=False),
|
||||
sa.Column(
|
||||
"role",
|
||||
sa.Enum("OWNER", "ADMIN", "MEMBER", "GUEST", name="organizationrole"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=False),
|
||||
sa.Column("custom_permissions", sa.String(length=500), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["organization_id"], ["organizations.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("user_id", "organization_id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_user_org_org_active",
|
||||
"user_organizations",
|
||||
["organization_id", "is_active"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index("ix_user_org_role", "user_organizations", ["role"], unique=False)
|
||||
op.create_index(
|
||||
"ix_user_org_user_active",
|
||||
"user_organizations",
|
||||
["user_id", "is_active"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_user_organizations_is_active"),
|
||||
"user_organizations",
|
||||
["is_active"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_table(
|
||||
"user_sessions",
|
||||
sa.Column("user_id", sa.UUID(), nullable=False),
|
||||
sa.Column("refresh_token_jti", sa.String(length=255), nullable=False),
|
||||
sa.Column("device_name", sa.String(length=255), nullable=True),
|
||||
sa.Column("device_id", sa.String(length=255), nullable=True),
|
||||
sa.Column("ip_address", sa.String(length=45), nullable=True),
|
||||
sa.Column("user_agent", sa.String(length=500), nullable=True),
|
||||
sa.Column("last_used_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=False),
|
||||
sa.Column("location_city", sa.String(length=100), nullable=True),
|
||||
sa.Column("location_country", sa.String(length=100), nullable=True),
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_user_sessions_is_active"), "user_sessions", ["is_active"], unique=False
|
||||
)
|
||||
op.create_index(
|
||||
"ix_user_sessions_jti_active",
|
||||
"user_sessions",
|
||||
["refresh_token_jti", "is_active"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_user_sessions_refresh_token_jti"),
|
||||
"user_sessions",
|
||||
["refresh_token_jti"],
|
||||
unique=True,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_user_sessions_user_active",
|
||||
"user_sessions",
|
||||
["user_id", "is_active"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_user_sessions_user_id"), "user_sessions", ["user_id"], unique=False
|
||||
)
|
||||
op.create_table(
|
||||
"oauth_authorization_codes",
|
||||
sa.Column("code", sa.String(length=128), nullable=False),
|
||||
sa.Column("client_id", sa.String(length=64), nullable=False),
|
||||
sa.Column("user_id", sa.UUID(), nullable=False),
|
||||
sa.Column("redirect_uri", sa.String(length=2048), nullable=False),
|
||||
sa.Column("scope", sa.String(length=1000), nullable=False),
|
||||
sa.Column("code_challenge", sa.String(length=128), nullable=True),
|
||||
sa.Column("code_challenge_method", sa.String(length=10), nullable=True),
|
||||
sa.Column("state", sa.String(length=256), nullable=True),
|
||||
sa.Column("nonce", sa.String(length=256), nullable=True),
|
||||
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("used", sa.Boolean(), nullable=False),
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["client_id"], ["oauth_clients.client_id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_oauth_authorization_codes_client_user",
|
||||
"oauth_authorization_codes",
|
||||
["client_id", "user_id"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_oauth_authorization_codes_code"),
|
||||
"oauth_authorization_codes",
|
||||
["code"],
|
||||
unique=True,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_oauth_authorization_codes_expires_at",
|
||||
"oauth_authorization_codes",
|
||||
["expires_at"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_table(
|
||||
"oauth_consents",
|
||||
sa.Column("user_id", sa.UUID(), nullable=False),
|
||||
sa.Column("client_id", sa.String(length=64), nullable=False),
|
||||
sa.Column("granted_scopes", sa.String(length=1000), nullable=False),
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["client_id"], ["oauth_clients.client_id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_oauth_consents_user_client",
|
||||
"oauth_consents",
|
||||
["user_id", "client_id"],
|
||||
unique=True,
|
||||
)
|
||||
op.create_table(
|
||||
"oauth_provider_refresh_tokens",
|
||||
sa.Column("token_hash", sa.String(length=64), nullable=False),
|
||||
sa.Column("jti", sa.String(length=64), nullable=False),
|
||||
sa.Column("client_id", sa.String(length=64), nullable=False),
|
||||
sa.Column("user_id", sa.UUID(), nullable=False),
|
||||
sa.Column("scope", sa.String(length=1000), nullable=False),
|
||||
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("revoked", sa.Boolean(), nullable=False),
|
||||
sa.Column("last_used_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("device_info", sa.String(length=500), nullable=True),
|
||||
sa.Column("ip_address", sa.String(length=45), nullable=True),
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["client_id"], ["oauth_clients.client_id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_oauth_provider_refresh_tokens_client_user",
|
||||
"oauth_provider_refresh_tokens",
|
||||
["client_id", "user_id"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_oauth_provider_refresh_tokens_expires_at",
|
||||
"oauth_provider_refresh_tokens",
|
||||
["expires_at"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_oauth_provider_refresh_tokens_jti"),
|
||||
"oauth_provider_refresh_tokens",
|
||||
["jti"],
|
||||
unique=True,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_oauth_provider_refresh_tokens_revoked"),
|
||||
"oauth_provider_refresh_tokens",
|
||||
["revoked"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_oauth_provider_refresh_tokens_token_hash"),
|
||||
"oauth_provider_refresh_tokens",
|
||||
["token_hash"],
|
||||
unique=True,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_oauth_provider_refresh_tokens_user_revoked",
|
||||
"oauth_provider_refresh_tokens",
|
||||
["user_id", "revoked"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index('ix_oauth_provider_refresh_tokens_client_user', 'oauth_provider_refresh_tokens', ['client_id', 'user_id'], unique=False)
|
||||
op.create_index('ix_oauth_provider_refresh_tokens_expires_at', 'oauth_provider_refresh_tokens', ['expires_at'], unique=False)
|
||||
op.create_index(op.f('ix_oauth_provider_refresh_tokens_jti'), 'oauth_provider_refresh_tokens', ['jti'], unique=True)
|
||||
op.create_index(op.f('ix_oauth_provider_refresh_tokens_revoked'), 'oauth_provider_refresh_tokens', ['revoked'], unique=False)
|
||||
op.create_index(op.f('ix_oauth_provider_refresh_tokens_token_hash'), 'oauth_provider_refresh_tokens', ['token_hash'], unique=True)
|
||||
op.create_index('ix_oauth_provider_refresh_tokens_user_revoked', 'oauth_provider_refresh_tokens', ['user_id', 'revoked'], unique=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index('ix_oauth_provider_refresh_tokens_user_revoked', table_name='oauth_provider_refresh_tokens')
|
||||
op.drop_index(op.f('ix_oauth_provider_refresh_tokens_token_hash'), table_name='oauth_provider_refresh_tokens')
|
||||
op.drop_index(op.f('ix_oauth_provider_refresh_tokens_revoked'), table_name='oauth_provider_refresh_tokens')
|
||||
op.drop_index(op.f('ix_oauth_provider_refresh_tokens_jti'), table_name='oauth_provider_refresh_tokens')
|
||||
op.drop_index('ix_oauth_provider_refresh_tokens_expires_at', table_name='oauth_provider_refresh_tokens')
|
||||
op.drop_index('ix_oauth_provider_refresh_tokens_client_user', table_name='oauth_provider_refresh_tokens')
|
||||
op.drop_table('oauth_provider_refresh_tokens')
|
||||
op.drop_index('ix_oauth_consents_user_client', table_name='oauth_consents')
|
||||
op.drop_table('oauth_consents')
|
||||
op.drop_index('ix_oauth_authorization_codes_expires_at', table_name='oauth_authorization_codes')
|
||||
op.drop_index(op.f('ix_oauth_authorization_codes_code'), table_name='oauth_authorization_codes')
|
||||
op.drop_index('ix_oauth_authorization_codes_client_user', table_name='oauth_authorization_codes')
|
||||
op.drop_table('oauth_authorization_codes')
|
||||
op.drop_index(op.f('ix_user_sessions_user_id'), table_name='user_sessions')
|
||||
op.drop_index('ix_user_sessions_user_active', table_name='user_sessions')
|
||||
op.drop_index(op.f('ix_user_sessions_refresh_token_jti'), table_name='user_sessions')
|
||||
op.drop_index('ix_user_sessions_jti_active', table_name='user_sessions')
|
||||
op.drop_index(op.f('ix_user_sessions_is_active'), table_name='user_sessions')
|
||||
op.drop_table('user_sessions')
|
||||
op.drop_index(op.f('ix_user_organizations_is_active'), table_name='user_organizations')
|
||||
op.drop_index('ix_user_org_user_active', table_name='user_organizations')
|
||||
op.drop_index('ix_user_org_role', table_name='user_organizations')
|
||||
op.drop_index('ix_user_org_org_active', table_name='user_organizations')
|
||||
op.drop_table('user_organizations')
|
||||
op.drop_index(op.f('ix_oauth_clients_is_active'), table_name='oauth_clients')
|
||||
op.drop_index(op.f('ix_oauth_clients_client_id'), table_name='oauth_clients')
|
||||
op.drop_table('oauth_clients')
|
||||
op.drop_index('ix_oauth_accounts_user_provider', table_name='oauth_accounts')
|
||||
op.drop_index(op.f('ix_oauth_accounts_user_id'), table_name='oauth_accounts')
|
||||
op.drop_index(op.f('ix_oauth_accounts_provider_email'), table_name='oauth_accounts')
|
||||
op.drop_index(op.f('ix_oauth_accounts_provider'), table_name='oauth_accounts')
|
||||
op.drop_table('oauth_accounts')
|
||||
op.drop_index(op.f('ix_users_locale'), table_name='users')
|
||||
op.drop_index(op.f('ix_users_is_superuser'), table_name='users')
|
||||
op.drop_index(op.f('ix_users_is_active'), table_name='users')
|
||||
op.drop_index(op.f('ix_users_email'), table_name='users')
|
||||
op.drop_index(op.f('ix_users_deleted_at'), table_name='users')
|
||||
op.drop_table('users')
|
||||
op.drop_index('ix_organizations_slug_active', table_name='organizations')
|
||||
op.drop_index(op.f('ix_organizations_slug'), table_name='organizations')
|
||||
op.drop_index('ix_organizations_name_active', table_name='organizations')
|
||||
op.drop_index(op.f('ix_organizations_name'), table_name='organizations')
|
||||
op.drop_index(op.f('ix_organizations_is_active'), table_name='organizations')
|
||||
op.drop_table('organizations')
|
||||
op.drop_index(op.f('ix_oauth_states_state'), table_name='oauth_states')
|
||||
op.drop_table('oauth_states')
|
||||
op.drop_index(
|
||||
"ix_oauth_provider_refresh_tokens_user_revoked",
|
||||
table_name="oauth_provider_refresh_tokens",
|
||||
)
|
||||
op.drop_index(
|
||||
op.f("ix_oauth_provider_refresh_tokens_token_hash"),
|
||||
table_name="oauth_provider_refresh_tokens",
|
||||
)
|
||||
op.drop_index(
|
||||
op.f("ix_oauth_provider_refresh_tokens_revoked"),
|
||||
table_name="oauth_provider_refresh_tokens",
|
||||
)
|
||||
op.drop_index(
|
||||
op.f("ix_oauth_provider_refresh_tokens_jti"),
|
||||
table_name="oauth_provider_refresh_tokens",
|
||||
)
|
||||
op.drop_index(
|
||||
"ix_oauth_provider_refresh_tokens_expires_at",
|
||||
table_name="oauth_provider_refresh_tokens",
|
||||
)
|
||||
op.drop_index(
|
||||
"ix_oauth_provider_refresh_tokens_client_user",
|
||||
table_name="oauth_provider_refresh_tokens",
|
||||
)
|
||||
op.drop_table("oauth_provider_refresh_tokens")
|
||||
op.drop_index("ix_oauth_consents_user_client", table_name="oauth_consents")
|
||||
op.drop_table("oauth_consents")
|
||||
op.drop_index(
|
||||
"ix_oauth_authorization_codes_expires_at",
|
||||
table_name="oauth_authorization_codes",
|
||||
)
|
||||
op.drop_index(
|
||||
op.f("ix_oauth_authorization_codes_code"),
|
||||
table_name="oauth_authorization_codes",
|
||||
)
|
||||
op.drop_index(
|
||||
"ix_oauth_authorization_codes_client_user",
|
||||
table_name="oauth_authorization_codes",
|
||||
)
|
||||
op.drop_table("oauth_authorization_codes")
|
||||
op.drop_index(op.f("ix_user_sessions_user_id"), table_name="user_sessions")
|
||||
op.drop_index("ix_user_sessions_user_active", table_name="user_sessions")
|
||||
op.drop_index(
|
||||
op.f("ix_user_sessions_refresh_token_jti"), table_name="user_sessions"
|
||||
)
|
||||
op.drop_index("ix_user_sessions_jti_active", table_name="user_sessions")
|
||||
op.drop_index(op.f("ix_user_sessions_is_active"), table_name="user_sessions")
|
||||
op.drop_table("user_sessions")
|
||||
op.drop_index(
|
||||
op.f("ix_user_organizations_is_active"), table_name="user_organizations"
|
||||
)
|
||||
op.drop_index("ix_user_org_user_active", table_name="user_organizations")
|
||||
op.drop_index("ix_user_org_role", table_name="user_organizations")
|
||||
op.drop_index("ix_user_org_org_active", table_name="user_organizations")
|
||||
op.drop_table("user_organizations")
|
||||
op.drop_index(op.f("ix_oauth_clients_is_active"), table_name="oauth_clients")
|
||||
op.drop_index(op.f("ix_oauth_clients_client_id"), table_name="oauth_clients")
|
||||
op.drop_table("oauth_clients")
|
||||
op.drop_index("ix_oauth_accounts_user_provider", table_name="oauth_accounts")
|
||||
op.drop_index(op.f("ix_oauth_accounts_user_id"), table_name="oauth_accounts")
|
||||
op.drop_index(op.f("ix_oauth_accounts_provider_email"), table_name="oauth_accounts")
|
||||
op.drop_index(op.f("ix_oauth_accounts_provider"), table_name="oauth_accounts")
|
||||
op.drop_table("oauth_accounts")
|
||||
op.drop_index(op.f("ix_users_locale"), table_name="users")
|
||||
op.drop_index(op.f("ix_users_is_superuser"), table_name="users")
|
||||
op.drop_index(op.f("ix_users_is_active"), table_name="users")
|
||||
op.drop_index(op.f("ix_users_email"), table_name="users")
|
||||
op.drop_index(op.f("ix_users_deleted_at"), table_name="users")
|
||||
op.drop_table("users")
|
||||
op.drop_index("ix_organizations_slug_active", table_name="organizations")
|
||||
op.drop_index(op.f("ix_organizations_slug"), table_name="organizations")
|
||||
op.drop_index("ix_organizations_name_active", table_name="organizations")
|
||||
op.drop_index(op.f("ix_organizations_name"), table_name="organizations")
|
||||
op.drop_index(op.f("ix_organizations_is_active"), table_name="organizations")
|
||||
op.drop_table("organizations")
|
||||
op.drop_index(op.f("ix_oauth_states_state"), table_name="oauth_states")
|
||||
op.drop_table("oauth_states")
|
||||
# ### end Alembic commands ###
|
||||
|
||||
@@ -114,8 +114,13 @@ def upgrade() -> None:
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop indexes in reverse order
|
||||
op.drop_index("ix_perf_oauth_auth_codes_expires", table_name="oauth_authorization_codes")
|
||||
op.drop_index("ix_perf_oauth_refresh_tokens_expires", table_name="oauth_provider_refresh_tokens")
|
||||
op.drop_index(
|
||||
"ix_perf_oauth_auth_codes_expires", table_name="oauth_authorization_codes"
|
||||
)
|
||||
op.drop_index(
|
||||
"ix_perf_oauth_refresh_tokens_expires",
|
||||
table_name="oauth_provider_refresh_tokens",
|
||||
)
|
||||
op.drop_index("ix_perf_user_sessions_expires", table_name="user_sessions")
|
||||
op.drop_index("ix_perf_organizations_slug_lower", table_name="organizations")
|
||||
op.drop_index("ix_perf_users_active", table_name="users")
|
||||
|
||||
@@ -0,0 +1,66 @@
|
||||
"""Enable pgvector extension
|
||||
|
||||
Revision ID: 0003
|
||||
Revises: 0002
|
||||
Create Date: 2025-12-30
|
||||
|
||||
This migration enables the pgvector extension for PostgreSQL, which provides
|
||||
vector similarity search capabilities required for the RAG (Retrieval-Augmented
|
||||
Generation) knowledge base system.
|
||||
|
||||
Vector Dimension Reference (per ADR-008 and SPIKE-006):
|
||||
---------------------------------------------------------
|
||||
The dimension size depends on the embedding model used:
|
||||
|
||||
| Model | Dimensions | Use Case |
|
||||
|----------------------------|------------|------------------------------|
|
||||
| text-embedding-3-small | 1536 | General docs, conversations |
|
||||
| text-embedding-3-large | 256-3072 | High accuracy (configurable) |
|
||||
| voyage-code-3 | 1024 | Code files (Python, JS, etc) |
|
||||
| voyage-3-large | 1024 | High quality general purpose |
|
||||
| nomic-embed-text (Ollama) | 768 | Local/fallback embedding |
|
||||
|
||||
Recommended defaults for Syndarix:
|
||||
- Documentation/conversations: 1536 (text-embedding-3-small)
|
||||
- Code files: 1024 (voyage-code-3)
|
||||
|
||||
Prerequisites:
|
||||
--------------
|
||||
This migration requires PostgreSQL with the pgvector extension installed.
|
||||
The Docker Compose configuration uses `pgvector/pgvector:pg17` which includes
|
||||
the extension pre-installed.
|
||||
|
||||
References:
|
||||
-----------
|
||||
- ADR-008: Knowledge Base and RAG Architecture
|
||||
- SPIKE-006: Knowledge Base with pgvector for RAG System
|
||||
- https://github.com/pgvector/pgvector
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "0003"
|
||||
down_revision: str | None = "0002"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Enable the pgvector extension.
|
||||
|
||||
The CREATE EXTENSION IF NOT EXISTS statement is idempotent - it will
|
||||
succeed whether the extension already exists or not.
|
||||
"""
|
||||
op.execute("CREATE EXTENSION IF NOT EXISTS vector")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Drop the pgvector extension.
|
||||
|
||||
Note: This will fail if any tables with vector columns exist.
|
||||
Future migrations that create vector columns should be downgraded first.
|
||||
"""
|
||||
op.execute("DROP EXTENSION IF EXISTS vector")
|
||||
507
backend/app/alembic/versions/0004_add_syndarix_models.py
Normal file
507
backend/app/alembic/versions/0004_add_syndarix_models.py
Normal file
@@ -0,0 +1,507 @@
|
||||
"""Add Syndarix models
|
||||
|
||||
Revision ID: 0004
|
||||
Revises: 0003
|
||||
Create Date: 2025-12-31
|
||||
|
||||
This migration creates the core Syndarix domain tables:
|
||||
- projects: Client engagement projects
|
||||
- agent_types: Agent template configurations
|
||||
- agent_instances: Spawned agent instances assigned to projects
|
||||
- sprints: Sprint containers for issues
|
||||
- issues: Work items (epics, stories, tasks, bugs)
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "0004"
|
||||
down_revision: str | None = "0003"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Create Syndarix domain tables."""
|
||||
|
||||
# =========================================================================
|
||||
# Create projects table
|
||||
# Note: ENUM types are created automatically by sa.Enum() during table creation
|
||||
# =========================================================================
|
||||
op.create_table(
|
||||
"projects",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("name", sa.String(255), nullable=False),
|
||||
sa.Column("slug", sa.String(255), nullable=False),
|
||||
sa.Column("description", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"autonomy_level",
|
||||
sa.Enum(
|
||||
"full_control",
|
||||
"milestone",
|
||||
"autonomous",
|
||||
name="autonomy_level",
|
||||
),
|
||||
nullable=False,
|
||||
server_default="milestone",
|
||||
),
|
||||
sa.Column(
|
||||
"status",
|
||||
sa.Enum(
|
||||
"active",
|
||||
"paused",
|
||||
"completed",
|
||||
"archived",
|
||||
name="project_status",
|
||||
),
|
||||
nullable=False,
|
||||
server_default="active",
|
||||
),
|
||||
sa.Column(
|
||||
"complexity",
|
||||
sa.Enum(
|
||||
"script",
|
||||
"simple",
|
||||
"medium",
|
||||
"complex",
|
||||
name="project_complexity",
|
||||
),
|
||||
nullable=False,
|
||||
server_default="medium",
|
||||
),
|
||||
sa.Column(
|
||||
"client_mode",
|
||||
sa.Enum("technical", "auto", name="client_mode"),
|
||||
nullable=False,
|
||||
server_default="auto",
|
||||
),
|
||||
sa.Column(
|
||||
"settings",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=False,
|
||||
server_default="{}",
|
||||
),
|
||||
sa.Column("owner_id", postgresql.UUID(as_uuid=True), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(["owner_id"], ["users.id"], ondelete="SET NULL"),
|
||||
sa.UniqueConstraint("slug"),
|
||||
)
|
||||
# Single column indexes
|
||||
op.create_index("ix_projects_name", "projects", ["name"])
|
||||
op.create_index("ix_projects_slug", "projects", ["slug"])
|
||||
op.create_index("ix_projects_status", "projects", ["status"])
|
||||
op.create_index("ix_projects_autonomy_level", "projects", ["autonomy_level"])
|
||||
op.create_index("ix_projects_complexity", "projects", ["complexity"])
|
||||
op.create_index("ix_projects_client_mode", "projects", ["client_mode"])
|
||||
op.create_index("ix_projects_owner_id", "projects", ["owner_id"])
|
||||
# Composite indexes
|
||||
op.create_index("ix_projects_slug_status", "projects", ["slug", "status"])
|
||||
op.create_index("ix_projects_owner_status", "projects", ["owner_id", "status"])
|
||||
op.create_index(
|
||||
"ix_projects_autonomy_status", "projects", ["autonomy_level", "status"]
|
||||
)
|
||||
op.create_index(
|
||||
"ix_projects_complexity_status", "projects", ["complexity", "status"]
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Create agent_types table
|
||||
# =========================================================================
|
||||
op.create_table(
|
||||
"agent_types",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("name", sa.String(255), nullable=False),
|
||||
sa.Column("slug", sa.String(255), nullable=False),
|
||||
sa.Column("description", sa.Text(), nullable=True),
|
||||
# Areas of expertise (e.g., ["python", "fastapi", "databases"])
|
||||
sa.Column(
|
||||
"expertise",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=False,
|
||||
server_default="[]",
|
||||
),
|
||||
# System prompt defining personality and behavior (required)
|
||||
sa.Column("personality_prompt", sa.Text(), nullable=False),
|
||||
# LLM model configuration
|
||||
sa.Column("primary_model", sa.String(100), nullable=False),
|
||||
sa.Column(
|
||||
"fallback_models",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=False,
|
||||
server_default="[]",
|
||||
),
|
||||
# Model parameters (temperature, max_tokens, etc.)
|
||||
sa.Column(
|
||||
"model_params",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=False,
|
||||
server_default="{}",
|
||||
),
|
||||
# MCP servers this agent can connect to
|
||||
sa.Column(
|
||||
"mcp_servers",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=False,
|
||||
server_default="[]",
|
||||
),
|
||||
# Tool permissions configuration
|
||||
sa.Column(
|
||||
"tool_permissions",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=False,
|
||||
server_default="{}",
|
||||
),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=False, server_default="true"),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("slug"),
|
||||
)
|
||||
# Single column indexes
|
||||
op.create_index("ix_agent_types_name", "agent_types", ["name"])
|
||||
op.create_index("ix_agent_types_slug", "agent_types", ["slug"])
|
||||
op.create_index("ix_agent_types_is_active", "agent_types", ["is_active"])
|
||||
# Composite indexes
|
||||
op.create_index("ix_agent_types_slug_active", "agent_types", ["slug", "is_active"])
|
||||
op.create_index("ix_agent_types_name_active", "agent_types", ["name", "is_active"])
|
||||
|
||||
# =========================================================================
|
||||
# Create agent_instances table
|
||||
# =========================================================================
|
||||
op.create_table(
|
||||
"agent_instances",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("agent_type_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("project_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("name", sa.String(100), nullable=False),
|
||||
sa.Column(
|
||||
"status",
|
||||
sa.Enum(
|
||||
"idle",
|
||||
"working",
|
||||
"waiting",
|
||||
"paused",
|
||||
"terminated",
|
||||
name="agent_status",
|
||||
),
|
||||
nullable=False,
|
||||
server_default="idle",
|
||||
),
|
||||
sa.Column("current_task", sa.Text(), nullable=True),
|
||||
# Short-term memory (conversation context, recent decisions)
|
||||
sa.Column(
|
||||
"short_term_memory",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=False,
|
||||
server_default="{}",
|
||||
),
|
||||
# Reference to long-term memory in vector store
|
||||
sa.Column("long_term_memory_ref", sa.String(500), nullable=True),
|
||||
# Session ID for active MCP connections
|
||||
sa.Column("session_id", sa.String(255), nullable=True),
|
||||
# Activity tracking
|
||||
sa.Column("last_activity_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("terminated_at", sa.DateTime(timezone=True), nullable=True),
|
||||
# Usage metrics
|
||||
sa.Column("tasks_completed", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.Column("tokens_used", sa.BigInteger(), nullable=False, server_default="0"),
|
||||
sa.Column(
|
||||
"cost_incurred",
|
||||
sa.Numeric(precision=10, scale=4),
|
||||
nullable=False,
|
||||
server_default="0",
|
||||
),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["agent_type_id"], ["agent_types.id"], ondelete="RESTRICT"
|
||||
),
|
||||
sa.ForeignKeyConstraint(["project_id"], ["projects.id"], ondelete="CASCADE"),
|
||||
)
|
||||
# Single column indexes
|
||||
op.create_index("ix_agent_instances_name", "agent_instances", ["name"])
|
||||
op.create_index("ix_agent_instances_status", "agent_instances", ["status"])
|
||||
op.create_index(
|
||||
"ix_agent_instances_agent_type_id", "agent_instances", ["agent_type_id"]
|
||||
)
|
||||
op.create_index("ix_agent_instances_project_id", "agent_instances", ["project_id"])
|
||||
op.create_index("ix_agent_instances_session_id", "agent_instances", ["session_id"])
|
||||
op.create_index(
|
||||
"ix_agent_instances_last_activity_at", "agent_instances", ["last_activity_at"]
|
||||
)
|
||||
op.create_index(
|
||||
"ix_agent_instances_terminated_at", "agent_instances", ["terminated_at"]
|
||||
)
|
||||
# Composite indexes
|
||||
op.create_index(
|
||||
"ix_agent_instances_project_status",
|
||||
"agent_instances",
|
||||
["project_id", "status"],
|
||||
)
|
||||
op.create_index(
|
||||
"ix_agent_instances_type_status",
|
||||
"agent_instances",
|
||||
["agent_type_id", "status"],
|
||||
)
|
||||
op.create_index(
|
||||
"ix_agent_instances_project_type",
|
||||
"agent_instances",
|
||||
["project_id", "agent_type_id"],
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Create sprints table (before issues for FK reference)
|
||||
# =========================================================================
|
||||
op.create_table(
|
||||
"sprints",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("project_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("name", sa.String(255), nullable=False),
|
||||
sa.Column("number", sa.Integer(), nullable=False),
|
||||
sa.Column("goal", sa.Text(), nullable=True),
|
||||
sa.Column("start_date", sa.Date(), nullable=False),
|
||||
sa.Column("end_date", sa.Date(), nullable=False),
|
||||
sa.Column(
|
||||
"status",
|
||||
sa.Enum(
|
||||
"planned",
|
||||
"active",
|
||||
"in_review",
|
||||
"completed",
|
||||
"cancelled",
|
||||
name="sprint_status",
|
||||
),
|
||||
nullable=False,
|
||||
server_default="planned",
|
||||
),
|
||||
sa.Column("planned_points", sa.Integer(), nullable=True),
|
||||
sa.Column("velocity", sa.Integer(), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(["project_id"], ["projects.id"], ondelete="CASCADE"),
|
||||
sa.UniqueConstraint("project_id", "number", name="uq_sprint_project_number"),
|
||||
)
|
||||
# Single column indexes
|
||||
op.create_index("ix_sprints_project_id", "sprints", ["project_id"])
|
||||
op.create_index("ix_sprints_status", "sprints", ["status"])
|
||||
op.create_index("ix_sprints_start_date", "sprints", ["start_date"])
|
||||
op.create_index("ix_sprints_end_date", "sprints", ["end_date"])
|
||||
# Composite indexes
|
||||
op.create_index("ix_sprints_project_status", "sprints", ["project_id", "status"])
|
||||
op.create_index("ix_sprints_project_number", "sprints", ["project_id", "number"])
|
||||
op.create_index("ix_sprints_date_range", "sprints", ["start_date", "end_date"])
|
||||
|
||||
# =========================================================================
|
||||
# Create issues table
|
||||
# =========================================================================
|
||||
op.create_table(
|
||||
"issues",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("project_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
# Parent issue for hierarchy (Epic -> Story -> Task)
|
||||
sa.Column("parent_id", postgresql.UUID(as_uuid=True), nullable=True),
|
||||
# Issue type (epic, story, task, bug)
|
||||
sa.Column(
|
||||
"type",
|
||||
sa.Enum(
|
||||
"epic",
|
||||
"story",
|
||||
"task",
|
||||
"bug",
|
||||
name="issue_type",
|
||||
),
|
||||
nullable=False,
|
||||
server_default="task",
|
||||
),
|
||||
# Reporter (who created this issue)
|
||||
sa.Column("reporter_id", postgresql.UUID(as_uuid=True), nullable=True),
|
||||
# Issue content
|
||||
sa.Column("title", sa.String(500), nullable=False),
|
||||
sa.Column("body", sa.Text(), nullable=False, server_default=""),
|
||||
# Status and priority
|
||||
sa.Column(
|
||||
"status",
|
||||
sa.Enum(
|
||||
"open",
|
||||
"in_progress",
|
||||
"in_review",
|
||||
"blocked",
|
||||
"closed",
|
||||
name="issue_status",
|
||||
),
|
||||
nullable=False,
|
||||
server_default="open",
|
||||
),
|
||||
sa.Column(
|
||||
"priority",
|
||||
sa.Enum(
|
||||
"low",
|
||||
"medium",
|
||||
"high",
|
||||
"critical",
|
||||
name="issue_priority",
|
||||
),
|
||||
nullable=False,
|
||||
server_default="medium",
|
||||
),
|
||||
# Labels for categorization
|
||||
sa.Column(
|
||||
"labels",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=False,
|
||||
server_default="[]",
|
||||
),
|
||||
# Assignment - agent or human (mutually exclusive)
|
||||
sa.Column("assigned_agent_id", postgresql.UUID(as_uuid=True), nullable=True),
|
||||
sa.Column("human_assignee", sa.String(255), nullable=True),
|
||||
# Sprint association
|
||||
sa.Column("sprint_id", postgresql.UUID(as_uuid=True), nullable=True),
|
||||
# Estimation
|
||||
sa.Column("story_points", sa.Integer(), nullable=True),
|
||||
sa.Column("due_date", sa.Date(), nullable=True),
|
||||
# External tracker integration (String for flexibility)
|
||||
sa.Column("external_tracker_type", sa.String(50), nullable=True),
|
||||
sa.Column("external_issue_id", sa.String(255), nullable=True),
|
||||
sa.Column("remote_url", sa.String(1000), nullable=True),
|
||||
sa.Column("external_issue_number", sa.Integer(), nullable=True),
|
||||
# Sync status
|
||||
sa.Column(
|
||||
"sync_status",
|
||||
sa.Enum(
|
||||
"synced",
|
||||
"pending",
|
||||
"conflict",
|
||||
"error",
|
||||
name="sync_status",
|
||||
),
|
||||
nullable=False,
|
||||
server_default="synced",
|
||||
),
|
||||
sa.Column("last_synced_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("external_updated_at", sa.DateTime(timezone=True), nullable=True),
|
||||
# Lifecycle
|
||||
sa.Column("closed_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(["project_id"], ["projects.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(["parent_id"], ["issues.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(["sprint_id"], ["sprints.id"], ondelete="SET NULL"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["assigned_agent_id"], ["agent_instances.id"], ondelete="SET NULL"
|
||||
),
|
||||
)
|
||||
# Single column indexes
|
||||
op.create_index("ix_issues_project_id", "issues", ["project_id"])
|
||||
op.create_index("ix_issues_parent_id", "issues", ["parent_id"])
|
||||
op.create_index("ix_issues_type", "issues", ["type"])
|
||||
op.create_index("ix_issues_reporter_id", "issues", ["reporter_id"])
|
||||
op.create_index("ix_issues_status", "issues", ["status"])
|
||||
op.create_index("ix_issues_priority", "issues", ["priority"])
|
||||
op.create_index("ix_issues_assigned_agent_id", "issues", ["assigned_agent_id"])
|
||||
op.create_index("ix_issues_human_assignee", "issues", ["human_assignee"])
|
||||
op.create_index("ix_issues_sprint_id", "issues", ["sprint_id"])
|
||||
op.create_index("ix_issues_due_date", "issues", ["due_date"])
|
||||
op.create_index(
|
||||
"ix_issues_external_tracker_type", "issues", ["external_tracker_type"]
|
||||
)
|
||||
op.create_index("ix_issues_sync_status", "issues", ["sync_status"])
|
||||
op.create_index("ix_issues_closed_at", "issues", ["closed_at"])
|
||||
# Composite indexes
|
||||
op.create_index("ix_issues_project_status", "issues", ["project_id", "status"])
|
||||
op.create_index("ix_issues_project_priority", "issues", ["project_id", "priority"])
|
||||
op.create_index("ix_issues_project_sprint", "issues", ["project_id", "sprint_id"])
|
||||
op.create_index("ix_issues_project_type", "issues", ["project_id", "type"])
|
||||
op.create_index(
|
||||
"ix_issues_project_agent", "issues", ["project_id", "assigned_agent_id"]
|
||||
)
|
||||
op.create_index(
|
||||
"ix_issues_project_status_priority",
|
||||
"issues",
|
||||
["project_id", "status", "priority"],
|
||||
)
|
||||
op.create_index(
|
||||
"ix_issues_external_tracker_id",
|
||||
"issues",
|
||||
["external_tracker_type", "external_issue_id"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Drop Syndarix domain tables."""
|
||||
# Drop tables in reverse order (respecting FK constraints)
|
||||
op.drop_table("issues")
|
||||
op.drop_table("sprints")
|
||||
op.drop_table("agent_instances")
|
||||
op.drop_table("agent_types")
|
||||
op.drop_table("projects")
|
||||
|
||||
# Drop ENUM types
|
||||
op.execute("DROP TYPE IF EXISTS sprint_status")
|
||||
op.execute("DROP TYPE IF EXISTS sync_status")
|
||||
op.execute("DROP TYPE IF EXISTS issue_priority")
|
||||
op.execute("DROP TYPE IF EXISTS issue_status")
|
||||
op.execute("DROP TYPE IF EXISTS issue_type")
|
||||
op.execute("DROP TYPE IF EXISTS agent_status")
|
||||
op.execute("DROP TYPE IF EXISTS client_mode")
|
||||
op.execute("DROP TYPE IF EXISTS project_complexity")
|
||||
op.execute("DROP TYPE IF EXISTS project_status")
|
||||
op.execute("DROP TYPE IF EXISTS autonomy_level")
|
||||
@@ -151,3 +151,83 @@ async def get_optional_current_user(
|
||||
return user
|
||||
except (TokenExpiredError, TokenInvalidError):
|
||||
return None
|
||||
|
||||
|
||||
async def get_current_user_sse(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
authorization: str | None = Header(None),
|
||||
token: str | None = None, # Query parameter - passed directly from route
|
||||
) -> User:
|
||||
"""
|
||||
Get the current authenticated user for SSE endpoints.
|
||||
|
||||
SSE (Server-Sent Events) via EventSource API doesn't support custom headers,
|
||||
so this dependency accepts tokens from either:
|
||||
1. Authorization header (preferred, for non-EventSource clients)
|
||||
2. Query parameter 'token' (fallback for EventSource compatibility)
|
||||
|
||||
Security note: Query parameter tokens appear in server logs and browser history.
|
||||
Consider implementing short-lived SSE-specific tokens for production if this
|
||||
is a concern. The current approach is acceptable for internal/trusted networks.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
authorization: Authorization header (Bearer token)
|
||||
token: Query parameter token (fallback for EventSource)
|
||||
|
||||
Returns:
|
||||
User: The authenticated user
|
||||
|
||||
Raises:
|
||||
HTTPException: If authentication fails
|
||||
"""
|
||||
# Try Authorization header first (preferred)
|
||||
auth_token = None
|
||||
if authorization:
|
||||
scheme, param = get_authorization_scheme_param(authorization)
|
||||
if scheme.lower() == "bearer" and param:
|
||||
auth_token = param
|
||||
|
||||
# Fall back to query parameter if no header token
|
||||
if not auth_token and token:
|
||||
auth_token = token
|
||||
|
||||
if not auth_token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Not authenticated",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
try:
|
||||
# Decode token and get user ID
|
||||
token_data = get_token_data(auth_token)
|
||||
|
||||
# Get user from database
|
||||
result = await db.execute(select(User).where(User.id == token_data.user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
|
||||
)
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user"
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
except TokenExpiredError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Token expired",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
except TokenInvalidError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
36
backend/app/api/dependencies/event_bus.py
Normal file
36
backend/app/api/dependencies/event_bus.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""
|
||||
Event bus dependency for FastAPI routes.
|
||||
|
||||
This module provides the FastAPI dependency for injecting the EventBus
|
||||
into route handlers. The event bus is a singleton that maintains
|
||||
Redis pub/sub connections for real-time event streaming.
|
||||
"""
|
||||
|
||||
from app.services.event_bus import (
|
||||
EventBus,
|
||||
get_connected_event_bus as _get_connected_event_bus,
|
||||
)
|
||||
|
||||
|
||||
async def get_event_bus() -> EventBus:
|
||||
"""
|
||||
FastAPI dependency that provides a connected EventBus instance.
|
||||
|
||||
The EventBus is a singleton that maintains Redis pub/sub connections.
|
||||
It's lazily initialized and connected on first access, and should be
|
||||
closed during application shutdown via close_event_bus().
|
||||
|
||||
Usage:
|
||||
@router.get("/events/stream")
|
||||
async def stream_events(
|
||||
event_bus: EventBus = Depends(get_event_bus)
|
||||
):
|
||||
...
|
||||
|
||||
Returns:
|
||||
EventBus: The global connected event bus instance
|
||||
|
||||
Raises:
|
||||
EventBusConnectionError: If connection to Redis fails
|
||||
"""
|
||||
return await _get_connected_event_bus()
|
||||
@@ -2,11 +2,19 @@ from fastapi import APIRouter
|
||||
|
||||
from app.api.routes import (
|
||||
admin,
|
||||
agent_types,
|
||||
agents,
|
||||
auth,
|
||||
context,
|
||||
events,
|
||||
issues,
|
||||
mcp,
|
||||
oauth,
|
||||
oauth_provider,
|
||||
organizations,
|
||||
projects,
|
||||
sessions,
|
||||
sprints,
|
||||
users,
|
||||
)
|
||||
|
||||
@@ -22,3 +30,25 @@ api_router.include_router(admin.router, prefix="/admin", tags=["Admin"])
|
||||
api_router.include_router(
|
||||
organizations.router, prefix="/organizations", tags=["Organizations"]
|
||||
)
|
||||
# SSE events router - no prefix, routes define full paths
|
||||
api_router.include_router(events.router, tags=["Events"])
|
||||
|
||||
# MCP (Model Context Protocol) router
|
||||
api_router.include_router(mcp.router, prefix="/mcp", tags=["MCP"])
|
||||
|
||||
# Context Management Engine router
|
||||
api_router.include_router(context.router, prefix="/context", tags=["Context"])
|
||||
|
||||
# Syndarix domain routers
|
||||
api_router.include_router(projects.router, prefix="/projects", tags=["Projects"])
|
||||
api_router.include_router(
|
||||
agent_types.router, prefix="/agent-types", tags=["Agent Types"]
|
||||
)
|
||||
# Issues router - routes include /projects/{project_id}/issues paths
|
||||
api_router.include_router(issues.router, tags=["Issues"])
|
||||
# Agents router - routes include /projects/{project_id}/agents paths
|
||||
api_router.include_router(agents.router, tags=["Agents"])
|
||||
# Sprints router - routes need prefix as they use /projects/{project_id}/sprints paths
|
||||
api_router.include_router(
|
||||
sprints.router, prefix="/projects/{project_id}/sprints", tags=["Sprints"]
|
||||
)
|
||||
|
||||
462
backend/app/api/routes/agent_types.py
Normal file
462
backend/app/api/routes/agent_types.py
Normal file
@@ -0,0 +1,462 @@
|
||||
# app/api/routes/agent_types.py
|
||||
"""
|
||||
AgentType configuration API endpoints.
|
||||
|
||||
Provides CRUD operations for managing AI agent type templates.
|
||||
Agent types define the base configuration (model, personality, expertise)
|
||||
from which agent instances are spawned for projects.
|
||||
|
||||
Authorization:
|
||||
- Read endpoints: Any authenticated user
|
||||
- Write endpoints (create, update, delete): Superusers only
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Request, status
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.dependencies.auth import get_current_user
|
||||
from app.api.dependencies.permissions import require_superuser
|
||||
from app.core.database import get_db
|
||||
from app.core.exceptions import (
|
||||
DuplicateError,
|
||||
ErrorCode,
|
||||
NotFoundError,
|
||||
)
|
||||
from app.crud.syndarix.agent_type import agent_type as agent_type_crud
|
||||
from app.models.user import User
|
||||
from app.schemas.common import (
|
||||
MessageResponse,
|
||||
PaginatedResponse,
|
||||
PaginationParams,
|
||||
create_pagination_meta,
|
||||
)
|
||||
from app.schemas.syndarix import (
|
||||
AgentTypeCreate,
|
||||
AgentTypeResponse,
|
||||
AgentTypeUpdate,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Initialize limiter for this router
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
# Use higher rate limits in test environment
|
||||
IS_TEST = os.getenv("IS_TEST", "False") == "True"
|
||||
RATE_MULTIPLIER = 100 if IS_TEST else 1
|
||||
|
||||
|
||||
def _build_agent_type_response(
|
||||
agent_type: Any,
|
||||
instance_count: int = 0,
|
||||
) -> AgentTypeResponse:
|
||||
"""
|
||||
Build an AgentTypeResponse from a database model.
|
||||
|
||||
Args:
|
||||
agent_type: AgentType model instance
|
||||
instance_count: Number of agent instances for this type
|
||||
|
||||
Returns:
|
||||
AgentTypeResponse schema
|
||||
"""
|
||||
return AgentTypeResponse(
|
||||
id=agent_type.id,
|
||||
name=agent_type.name,
|
||||
slug=agent_type.slug,
|
||||
description=agent_type.description,
|
||||
expertise=agent_type.expertise,
|
||||
personality_prompt=agent_type.personality_prompt,
|
||||
primary_model=agent_type.primary_model,
|
||||
fallback_models=agent_type.fallback_models,
|
||||
model_params=agent_type.model_params,
|
||||
mcp_servers=agent_type.mcp_servers,
|
||||
tool_permissions=agent_type.tool_permissions,
|
||||
is_active=agent_type.is_active,
|
||||
created_at=agent_type.created_at,
|
||||
updated_at=agent_type.updated_at,
|
||||
instance_count=instance_count,
|
||||
)
|
||||
|
||||
|
||||
# ===== Write Endpoints (Admin Only) =====
|
||||
|
||||
|
||||
@router.post(
|
||||
"",
|
||||
response_model=AgentTypeResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Create Agent Type",
|
||||
description="Create a new agent type configuration (admin only)",
|
||||
operation_id="create_agent_type",
|
||||
)
|
||||
@limiter.limit(f"{20 * RATE_MULTIPLIER}/minute")
|
||||
async def create_agent_type(
|
||||
request: Request,
|
||||
agent_type_in: AgentTypeCreate,
|
||||
admin: User = Depends(require_superuser),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Create a new agent type configuration.
|
||||
|
||||
Agent types define templates for AI agents including:
|
||||
- Model configuration (primary model, fallback models, parameters)
|
||||
- Personality and expertise areas
|
||||
- MCP server integrations and tool permissions
|
||||
|
||||
Requires superuser privileges.
|
||||
|
||||
Args:
|
||||
request: FastAPI request object
|
||||
agent_type_in: Agent type creation data
|
||||
admin: Authenticated superuser
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
The created agent type configuration
|
||||
|
||||
Raises:
|
||||
DuplicateError: If slug already exists
|
||||
"""
|
||||
try:
|
||||
agent_type = await agent_type_crud.create(db, obj_in=agent_type_in)
|
||||
logger.info(
|
||||
f"Admin {admin.email} created agent type: {agent_type.name} "
|
||||
f"(slug: {agent_type.slug})"
|
||||
)
|
||||
return _build_agent_type_response(agent_type, instance_count=0)
|
||||
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to create agent type: {e!s}")
|
||||
raise DuplicateError(
|
||||
message=str(e),
|
||||
error_code=ErrorCode.ALREADY_EXISTS,
|
||||
field="slug",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating agent type: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/{agent_type_id}",
|
||||
response_model=AgentTypeResponse,
|
||||
summary="Update Agent Type",
|
||||
description="Update an existing agent type configuration (admin only)",
|
||||
operation_id="update_agent_type",
|
||||
)
|
||||
@limiter.limit(f"{30 * RATE_MULTIPLIER}/minute")
|
||||
async def update_agent_type(
|
||||
request: Request,
|
||||
agent_type_id: UUID,
|
||||
agent_type_in: AgentTypeUpdate,
|
||||
admin: User = Depends(require_superuser),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Update an existing agent type configuration.
|
||||
|
||||
Partial updates are supported - only provided fields will be updated.
|
||||
|
||||
Requires superuser privileges.
|
||||
|
||||
Args:
|
||||
request: FastAPI request object
|
||||
agent_type_id: UUID of the agent type to update
|
||||
agent_type_in: Agent type update data
|
||||
admin: Authenticated superuser
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
The updated agent type configuration
|
||||
|
||||
Raises:
|
||||
NotFoundError: If agent type not found
|
||||
DuplicateError: If new slug already exists
|
||||
"""
|
||||
try:
|
||||
# Verify agent type exists
|
||||
result = await agent_type_crud.get_with_instance_count(
|
||||
db, agent_type_id=agent_type_id
|
||||
)
|
||||
if not result:
|
||||
raise NotFoundError(
|
||||
message=f"Agent type {agent_type_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
existing_type = result["agent_type"]
|
||||
instance_count = result["instance_count"]
|
||||
|
||||
# Perform update
|
||||
updated_type = await agent_type_crud.update(
|
||||
db, db_obj=existing_type, obj_in=agent_type_in
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Admin {admin.email} updated agent type: {updated_type.name} "
|
||||
f"(id: {agent_type_id})"
|
||||
)
|
||||
|
||||
return _build_agent_type_response(updated_type, instance_count=instance_count)
|
||||
|
||||
except NotFoundError:
|
||||
raise
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to update agent type {agent_type_id}: {e!s}")
|
||||
raise DuplicateError(
|
||||
message=str(e),
|
||||
error_code=ErrorCode.ALREADY_EXISTS,
|
||||
field="slug",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating agent type {agent_type_id}: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/{agent_type_id}",
|
||||
response_model=MessageResponse,
|
||||
summary="Deactivate Agent Type",
|
||||
description="Deactivate an agent type (soft delete, admin only)",
|
||||
operation_id="deactivate_agent_type",
|
||||
)
|
||||
@limiter.limit(f"{10 * RATE_MULTIPLIER}/minute")
|
||||
async def deactivate_agent_type(
|
||||
request: Request,
|
||||
agent_type_id: UUID,
|
||||
admin: User = Depends(require_superuser),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Deactivate an agent type (soft delete).
|
||||
|
||||
This sets is_active=False rather than deleting the record,
|
||||
preserving referential integrity with existing agent instances.
|
||||
|
||||
Requires superuser privileges.
|
||||
|
||||
Args:
|
||||
request: FastAPI request object
|
||||
agent_type_id: UUID of the agent type to deactivate
|
||||
admin: Authenticated superuser
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Success message
|
||||
|
||||
Raises:
|
||||
NotFoundError: If agent type not found
|
||||
"""
|
||||
try:
|
||||
deactivated = await agent_type_crud.deactivate(db, agent_type_id=agent_type_id)
|
||||
|
||||
if not deactivated:
|
||||
raise NotFoundError(
|
||||
message=f"Agent type {agent_type_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Admin {admin.email} deactivated agent type: {deactivated.name} "
|
||||
f"(id: {agent_type_id})"
|
||||
)
|
||||
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message=f"Agent type '{deactivated.name}' has been deactivated",
|
||||
)
|
||||
|
||||
except NotFoundError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error deactivating agent type {agent_type_id}: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
# ===== Read Endpoints (Authenticated Users) =====
|
||||
|
||||
|
||||
@router.get(
|
||||
"",
|
||||
response_model=PaginatedResponse[AgentTypeResponse],
|
||||
summary="List Agent Types",
|
||||
description="Get paginated list of active agent types",
|
||||
operation_id="list_agent_types",
|
||||
)
|
||||
@limiter.limit(f"{60 * RATE_MULTIPLIER}/minute")
|
||||
async def list_agent_types(
|
||||
request: Request,
|
||||
pagination: PaginationParams = Depends(),
|
||||
is_active: bool = Query(True, description="Filter by active status"),
|
||||
search: str | None = Query(None, description="Search by name, slug, description"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
List all agent types with pagination and filtering.
|
||||
|
||||
By default, returns only active agent types. Set is_active=false
|
||||
to include deactivated types (useful for admin views).
|
||||
|
||||
Args:
|
||||
request: FastAPI request object
|
||||
pagination: Pagination parameters (page, limit)
|
||||
is_active: Filter by active status (default: True)
|
||||
search: Optional search term for name, slug, description
|
||||
current_user: Authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Paginated list of agent types with instance counts
|
||||
"""
|
||||
try:
|
||||
# Get agent types with instance counts
|
||||
results, total = await agent_type_crud.get_multi_with_instance_counts(
|
||||
db,
|
||||
skip=pagination.offset,
|
||||
limit=pagination.limit,
|
||||
is_active=is_active,
|
||||
search=search,
|
||||
)
|
||||
|
||||
# Build response objects
|
||||
agent_types_response = [
|
||||
_build_agent_type_response(
|
||||
item["agent_type"],
|
||||
instance_count=item["instance_count"],
|
||||
)
|
||||
for item in results
|
||||
]
|
||||
|
||||
pagination_meta = create_pagination_meta(
|
||||
total=total,
|
||||
page=pagination.page,
|
||||
limit=pagination.limit,
|
||||
items_count=len(agent_types_response),
|
||||
)
|
||||
|
||||
return PaginatedResponse(data=agent_types_response, pagination=pagination_meta)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing agent types: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{agent_type_id}",
|
||||
response_model=AgentTypeResponse,
|
||||
summary="Get Agent Type",
|
||||
description="Get agent type details by ID",
|
||||
operation_id="get_agent_type",
|
||||
)
|
||||
@limiter.limit(f"{100 * RATE_MULTIPLIER}/minute")
|
||||
async def get_agent_type(
|
||||
request: Request,
|
||||
agent_type_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Get detailed information about a specific agent type.
|
||||
|
||||
Args:
|
||||
request: FastAPI request object
|
||||
agent_type_id: UUID of the agent type
|
||||
current_user: Authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Agent type details with instance count
|
||||
|
||||
Raises:
|
||||
NotFoundError: If agent type not found
|
||||
"""
|
||||
try:
|
||||
result = await agent_type_crud.get_with_instance_count(
|
||||
db, agent_type_id=agent_type_id
|
||||
)
|
||||
|
||||
if not result:
|
||||
raise NotFoundError(
|
||||
message=f"Agent type {agent_type_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
return _build_agent_type_response(
|
||||
result["agent_type"],
|
||||
instance_count=result["instance_count"],
|
||||
)
|
||||
|
||||
except NotFoundError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting agent type {agent_type_id}: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@router.get(
|
||||
"/slug/{slug}",
|
||||
response_model=AgentTypeResponse,
|
||||
summary="Get Agent Type by Slug",
|
||||
description="Get agent type details by slug",
|
||||
operation_id="get_agent_type_by_slug",
|
||||
)
|
||||
@limiter.limit(f"{100 * RATE_MULTIPLIER}/minute")
|
||||
async def get_agent_type_by_slug(
|
||||
request: Request,
|
||||
slug: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Get detailed information about an agent type by its slug.
|
||||
|
||||
Slugs are human-readable identifiers like "product-owner" or "backend-engineer".
|
||||
Useful for referencing agent types in configuration files or APIs.
|
||||
|
||||
Args:
|
||||
request: FastAPI request object
|
||||
slug: Slug identifier of the agent type
|
||||
current_user: Authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Agent type details with instance count
|
||||
|
||||
Raises:
|
||||
NotFoundError: If agent type not found
|
||||
"""
|
||||
try:
|
||||
agent_type = await agent_type_crud.get_by_slug(db, slug=slug)
|
||||
|
||||
if not agent_type:
|
||||
raise NotFoundError(
|
||||
message=f"Agent type with slug '{slug}' not found",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
# Get instance count separately
|
||||
result = await agent_type_crud.get_with_instance_count(
|
||||
db, agent_type_id=agent_type.id
|
||||
)
|
||||
instance_count = result["instance_count"] if result else 0
|
||||
|
||||
return _build_agent_type_response(agent_type, instance_count=instance_count)
|
||||
|
||||
except NotFoundError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting agent type by slug '{slug}': {e!s}", exc_info=True)
|
||||
raise
|
||||
984
backend/app/api/routes/agents.py
Normal file
984
backend/app/api/routes/agents.py
Normal file
@@ -0,0 +1,984 @@
|
||||
# app/api/routes/agents.py
|
||||
"""
|
||||
Agent Instance management endpoints for Syndarix projects.
|
||||
|
||||
These endpoints allow project owners and superusers to manage AI agent instances
|
||||
within their projects, including spawning, pausing, resuming, and terminating agents.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Request, status
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.dependencies.auth import get_current_user
|
||||
from app.core.database import get_db
|
||||
from app.core.exceptions import (
|
||||
AuthorizationError,
|
||||
NotFoundError,
|
||||
ValidationException,
|
||||
)
|
||||
from app.crud.syndarix.agent_instance import agent_instance as agent_instance_crud
|
||||
from app.crud.syndarix.agent_type import agent_type as agent_type_crud
|
||||
from app.crud.syndarix.project import project as project_crud
|
||||
from app.models.syndarix import AgentInstance, Project
|
||||
from app.models.syndarix.enums import AgentStatus
|
||||
from app.models.user import User
|
||||
from app.schemas.common import (
|
||||
MessageResponse,
|
||||
PaginatedResponse,
|
||||
PaginationParams,
|
||||
create_pagination_meta,
|
||||
)
|
||||
from app.schemas.errors import ErrorCode
|
||||
from app.schemas.syndarix.agent_instance import (
|
||||
AgentInstanceCreate,
|
||||
AgentInstanceMetrics,
|
||||
AgentInstanceResponse,
|
||||
AgentInstanceUpdate,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Initialize limiter for this router
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
# Use higher rate limits in test environment
|
||||
IS_TEST = os.getenv("IS_TEST", "False") == "True"
|
||||
RATE_MULTIPLIER = 100 if IS_TEST else 1
|
||||
|
||||
|
||||
# Valid status transitions for agent lifecycle management
|
||||
VALID_STATUS_TRANSITIONS: dict[AgentStatus, set[AgentStatus]] = {
|
||||
AgentStatus.IDLE: {AgentStatus.WORKING, AgentStatus.PAUSED, AgentStatus.TERMINATED},
|
||||
AgentStatus.WORKING: {
|
||||
AgentStatus.IDLE,
|
||||
AgentStatus.WAITING,
|
||||
AgentStatus.PAUSED,
|
||||
AgentStatus.TERMINATED,
|
||||
},
|
||||
AgentStatus.WAITING: {
|
||||
AgentStatus.IDLE,
|
||||
AgentStatus.WORKING,
|
||||
AgentStatus.PAUSED,
|
||||
AgentStatus.TERMINATED,
|
||||
},
|
||||
AgentStatus.PAUSED: {AgentStatus.IDLE, AgentStatus.TERMINATED},
|
||||
AgentStatus.TERMINATED: set(), # Terminal state, no transitions allowed
|
||||
}
|
||||
|
||||
|
||||
async def verify_project_access(
|
||||
db: AsyncSession,
|
||||
project_id: UUID,
|
||||
user: User,
|
||||
) -> Project:
|
||||
"""
|
||||
Verify user has access to a project.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
project_id: UUID of the project to verify
|
||||
user: Current authenticated user
|
||||
|
||||
Returns:
|
||||
Project: The project if access is granted
|
||||
|
||||
Raises:
|
||||
NotFoundError: If the project does not exist
|
||||
AuthorizationError: If the user does not have access to the project
|
||||
"""
|
||||
project = await project_crud.get(db, id=project_id)
|
||||
if not project:
|
||||
raise NotFoundError(
|
||||
message=f"Project {project_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
if not user.is_superuser and project.owner_id != user.id:
|
||||
raise AuthorizationError(
|
||||
message="You do not have access to this project",
|
||||
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||
)
|
||||
return project
|
||||
|
||||
|
||||
def validate_status_transition(
|
||||
current_status: AgentStatus,
|
||||
target_status: AgentStatus,
|
||||
) -> None:
|
||||
"""
|
||||
Validate that a status transition is allowed.
|
||||
|
||||
Args:
|
||||
current_status: The agent's current status
|
||||
target_status: The desired target status
|
||||
|
||||
Raises:
|
||||
ValidationException: If the transition is not allowed
|
||||
"""
|
||||
valid_targets = VALID_STATUS_TRANSITIONS.get(current_status, set())
|
||||
if target_status not in valid_targets:
|
||||
raise ValidationException(
|
||||
message=f"Cannot transition from {current_status.value} to {target_status.value}",
|
||||
error_code=ErrorCode.VALIDATION_ERROR,
|
||||
field="status",
|
||||
)
|
||||
|
||||
|
||||
def build_agent_response(
|
||||
agent: AgentInstance,
|
||||
agent_type_name: str | None = None,
|
||||
agent_type_slug: str | None = None,
|
||||
project_name: str | None = None,
|
||||
project_slug: str | None = None,
|
||||
assigned_issues_count: int = 0,
|
||||
) -> AgentInstanceResponse:
|
||||
"""
|
||||
Build an AgentInstanceResponse from an AgentInstance model.
|
||||
|
||||
Args:
|
||||
agent: The agent instance model
|
||||
agent_type_name: Name of the agent type
|
||||
agent_type_slug: Slug of the agent type
|
||||
project_name: Name of the project
|
||||
project_slug: Slug of the project
|
||||
assigned_issues_count: Number of issues assigned to this agent
|
||||
|
||||
Returns:
|
||||
AgentInstanceResponse: The response schema
|
||||
"""
|
||||
return AgentInstanceResponse(
|
||||
id=agent.id,
|
||||
agent_type_id=agent.agent_type_id,
|
||||
project_id=agent.project_id,
|
||||
name=agent.name,
|
||||
status=agent.status,
|
||||
current_task=agent.current_task,
|
||||
short_term_memory=agent.short_term_memory or {},
|
||||
long_term_memory_ref=agent.long_term_memory_ref,
|
||||
session_id=agent.session_id,
|
||||
last_activity_at=agent.last_activity_at,
|
||||
terminated_at=agent.terminated_at,
|
||||
tasks_completed=agent.tasks_completed,
|
||||
tokens_used=agent.tokens_used,
|
||||
cost_incurred=agent.cost_incurred,
|
||||
created_at=agent.created_at,
|
||||
updated_at=agent.updated_at,
|
||||
agent_type_name=agent_type_name,
|
||||
agent_type_slug=agent_type_slug,
|
||||
project_name=project_name,
|
||||
project_slug=project_slug,
|
||||
assigned_issues_count=assigned_issues_count,
|
||||
)
|
||||
|
||||
|
||||
# ===== Agent Instance Management Endpoints =====
|
||||
|
||||
|
||||
@router.post(
|
||||
"/projects/{project_id}/agents",
|
||||
response_model=AgentInstanceResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Spawn Agent Instance",
|
||||
description="Spawn a new agent instance in a project. Requires project ownership or superuser.",
|
||||
operation_id="spawn_agent",
|
||||
)
|
||||
@limiter.limit(f"{20 * RATE_MULTIPLIER}/minute")
|
||||
async def spawn_agent(
|
||||
request: Request,
|
||||
project_id: UUID,
|
||||
agent_in: AgentInstanceCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Spawn a new agent instance in a project.
|
||||
|
||||
Creates a new agent instance from an agent type template and assigns it
|
||||
to the specified project. The agent starts in IDLE status by default.
|
||||
|
||||
Args:
|
||||
request: FastAPI request object (for rate limiting)
|
||||
project_id: UUID of the project to spawn the agent in
|
||||
agent_in: Agent instance creation data
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
AgentInstanceResponse: The newly created agent instance
|
||||
|
||||
Raises:
|
||||
NotFoundError: If the project is not found
|
||||
AuthorizationError: If the user lacks access to the project
|
||||
ValidationException: If the agent creation data is invalid
|
||||
"""
|
||||
try:
|
||||
# Verify project access
|
||||
project = await verify_project_access(db, project_id, current_user)
|
||||
|
||||
# Ensure the agent is being created for the correct project
|
||||
if agent_in.project_id != project_id:
|
||||
raise ValidationException(
|
||||
message="Agent project_id must match the URL project_id",
|
||||
error_code=ErrorCode.VALIDATION_ERROR,
|
||||
field="project_id",
|
||||
)
|
||||
|
||||
# Validate that the agent type exists and is active
|
||||
agent_type = await agent_type_crud.get(db, id=agent_in.agent_type_id)
|
||||
if not agent_type:
|
||||
raise NotFoundError(
|
||||
message=f"Agent type {agent_in.agent_type_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
if not agent_type.is_active:
|
||||
raise ValidationException(
|
||||
message=f"Agent type '{agent_type.name}' is inactive and cannot be used",
|
||||
error_code=ErrorCode.VALIDATION_ERROR,
|
||||
field="agent_type_id",
|
||||
)
|
||||
|
||||
# Create the agent instance
|
||||
agent = await agent_instance_crud.create(db, obj_in=agent_in)
|
||||
|
||||
logger.info(
|
||||
f"User {current_user.email} spawned agent '{agent.name}' "
|
||||
f"(id={agent.id}) in project {project.slug}"
|
||||
)
|
||||
|
||||
# Get agent details for response
|
||||
details = await agent_instance_crud.get_with_details(db, instance_id=agent.id)
|
||||
if details:
|
||||
return build_agent_response(
|
||||
agent=details["instance"],
|
||||
agent_type_name=details.get("agent_type_name"),
|
||||
agent_type_slug=details.get("agent_type_slug"),
|
||||
project_name=details.get("project_name"),
|
||||
project_slug=details.get("project_slug"),
|
||||
assigned_issues_count=details.get("assigned_issues_count", 0),
|
||||
)
|
||||
|
||||
return build_agent_response(agent)
|
||||
|
||||
except (NotFoundError, AuthorizationError, ValidationException):
|
||||
raise
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to spawn agent: {e!s}")
|
||||
raise ValidationException(
|
||||
message=str(e),
|
||||
error_code=ErrorCode.VALIDATION_ERROR,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error spawning agent: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@router.get(
|
||||
"/projects/{project_id}/agents",
|
||||
response_model=PaginatedResponse[AgentInstanceResponse],
|
||||
summary="List Project Agents",
|
||||
description="List all agent instances in a project with optional filtering.",
|
||||
operation_id="list_project_agents",
|
||||
)
|
||||
@limiter.limit(f"{60 * RATE_MULTIPLIER}/minute")
|
||||
async def list_project_agents(
|
||||
request: Request,
|
||||
project_id: UUID,
|
||||
pagination: PaginationParams = Depends(),
|
||||
status_filter: AgentStatus | None = Query(
|
||||
None, alias="status", description="Filter by agent status"
|
||||
),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
List all agent instances in a project.
|
||||
|
||||
Returns a paginated list of agents with optional status filtering.
|
||||
Results are ordered by creation date (newest first).
|
||||
|
||||
Args:
|
||||
request: FastAPI request object (for rate limiting)
|
||||
project_id: UUID of the project
|
||||
pagination: Pagination parameters
|
||||
status_filter: Optional filter by agent status
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
PaginatedResponse[AgentInstanceResponse]: Paginated list of agents
|
||||
|
||||
Raises:
|
||||
NotFoundError: If the project is not found
|
||||
AuthorizationError: If the user lacks access to the project
|
||||
"""
|
||||
try:
|
||||
# Verify project access
|
||||
project = await verify_project_access(db, project_id, current_user)
|
||||
|
||||
# Get agents for the project
|
||||
agents, total = await agent_instance_crud.get_by_project(
|
||||
db,
|
||||
project_id=project_id,
|
||||
status=status_filter,
|
||||
skip=pagination.offset,
|
||||
limit=pagination.limit,
|
||||
)
|
||||
|
||||
# Build response objects
|
||||
agent_responses = []
|
||||
for agent in agents:
|
||||
# Get details for each agent (could be optimized with bulk query)
|
||||
details = await agent_instance_crud.get_with_details(
|
||||
db, instance_id=agent.id
|
||||
)
|
||||
if details:
|
||||
agent_responses.append(
|
||||
build_agent_response(
|
||||
agent=details["instance"],
|
||||
agent_type_name=details.get("agent_type_name"),
|
||||
agent_type_slug=details.get("agent_type_slug"),
|
||||
project_name=details.get("project_name"),
|
||||
project_slug=details.get("project_slug"),
|
||||
assigned_issues_count=details.get("assigned_issues_count", 0),
|
||||
)
|
||||
)
|
||||
else:
|
||||
agent_responses.append(build_agent_response(agent))
|
||||
|
||||
pagination_meta = create_pagination_meta(
|
||||
total=total,
|
||||
page=pagination.page,
|
||||
limit=pagination.limit,
|
||||
items_count=len(agent_responses),
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"User {current_user.email} listed {len(agent_responses)} agents "
|
||||
f"in project {project.slug}"
|
||||
)
|
||||
|
||||
return PaginatedResponse(data=agent_responses, pagination=pagination_meta)
|
||||
|
||||
except (NotFoundError, AuthorizationError):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing project agents: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
# ===== Project Agent Metrics Endpoint =====
|
||||
# NOTE: This endpoint MUST be defined before /{agent_id} routes
|
||||
# to prevent FastAPI from trying to parse "metrics" as a UUID
|
||||
|
||||
|
||||
@router.get(
|
||||
"/projects/{project_id}/agents/metrics",
|
||||
response_model=AgentInstanceMetrics,
|
||||
summary="Get Project Agent Metrics",
|
||||
description="Get aggregated usage metrics for all agents in a project.",
|
||||
operation_id="get_project_agent_metrics",
|
||||
)
|
||||
@limiter.limit(f"{60 * RATE_MULTIPLIER}/minute")
|
||||
async def get_project_agent_metrics(
|
||||
request: Request,
|
||||
project_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Get aggregated usage metrics for all agents in a project.
|
||||
|
||||
Returns aggregated metrics across all agents including total
|
||||
tasks completed, tokens used, and cost incurred.
|
||||
|
||||
Args:
|
||||
request: FastAPI request object (for rate limiting)
|
||||
project_id: UUID of the project
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
AgentInstanceMetrics: Aggregated project agent metrics
|
||||
|
||||
Raises:
|
||||
NotFoundError: If the project is not found
|
||||
AuthorizationError: If the user lacks access to the project
|
||||
"""
|
||||
try:
|
||||
# Verify project access
|
||||
project = await verify_project_access(db, project_id, current_user)
|
||||
|
||||
# Get aggregated metrics for the project
|
||||
metrics = await agent_instance_crud.get_project_metrics(
|
||||
db, project_id=project_id
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"User {current_user.email} retrieved project metrics for {project.slug}"
|
||||
)
|
||||
|
||||
return AgentInstanceMetrics(
|
||||
total_instances=metrics["total_instances"],
|
||||
active_instances=metrics["active_instances"],
|
||||
idle_instances=metrics["idle_instances"],
|
||||
total_tasks_completed=metrics["total_tasks_completed"],
|
||||
total_tokens_used=metrics["total_tokens_used"],
|
||||
total_cost_incurred=metrics["total_cost_incurred"],
|
||||
)
|
||||
|
||||
except (NotFoundError, AuthorizationError):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting project agent metrics: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@router.get(
|
||||
"/projects/{project_id}/agents/{agent_id}",
|
||||
response_model=AgentInstanceResponse,
|
||||
summary="Get Agent Details",
|
||||
description="Get detailed information about a specific agent instance.",
|
||||
operation_id="get_agent",
|
||||
)
|
||||
@limiter.limit(f"{60 * RATE_MULTIPLIER}/minute")
|
||||
async def get_agent(
|
||||
request: Request,
|
||||
project_id: UUID,
|
||||
agent_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Get detailed information about a specific agent instance.
|
||||
|
||||
Returns full agent details including related entity information
|
||||
(agent type name, project name) and assigned issues count.
|
||||
|
||||
Args:
|
||||
request: FastAPI request object (for rate limiting)
|
||||
project_id: UUID of the project
|
||||
agent_id: UUID of the agent instance
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
AgentInstanceResponse: The agent instance details
|
||||
|
||||
Raises:
|
||||
NotFoundError: If the project or agent is not found
|
||||
AuthorizationError: If the user lacks access to the project
|
||||
"""
|
||||
try:
|
||||
# Verify project access
|
||||
await verify_project_access(db, project_id, current_user)
|
||||
|
||||
# Get agent with full details
|
||||
details = await agent_instance_crud.get_with_details(db, instance_id=agent_id)
|
||||
|
||||
if not details:
|
||||
raise NotFoundError(
|
||||
message=f"Agent {agent_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
agent = details["instance"]
|
||||
|
||||
# Verify agent belongs to the specified project
|
||||
if agent.project_id != project_id:
|
||||
raise NotFoundError(
|
||||
message=f"Agent {agent_id} not found in project {project_id}",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"User {current_user.email} retrieved agent {agent.name} (id={agent_id})"
|
||||
)
|
||||
|
||||
return build_agent_response(
|
||||
agent=agent,
|
||||
agent_type_name=details.get("agent_type_name"),
|
||||
agent_type_slug=details.get("agent_type_slug"),
|
||||
project_name=details.get("project_name"),
|
||||
project_slug=details.get("project_slug"),
|
||||
assigned_issues_count=details.get("assigned_issues_count", 0),
|
||||
)
|
||||
|
||||
except (NotFoundError, AuthorizationError):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting agent details: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/projects/{project_id}/agents/{agent_id}",
|
||||
response_model=AgentInstanceResponse,
|
||||
summary="Update Agent",
|
||||
description="Update an agent instance's configuration and state.",
|
||||
operation_id="update_agent",
|
||||
)
|
||||
@limiter.limit(f"{30 * RATE_MULTIPLIER}/minute")
|
||||
async def update_agent(
|
||||
request: Request,
|
||||
project_id: UUID,
|
||||
agent_id: UUID,
|
||||
agent_in: AgentInstanceUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Update an agent instance's configuration and state.
|
||||
|
||||
Allows updating agent status, current task, memory, and other
|
||||
configurable fields. Status transitions are validated according
|
||||
to the agent lifecycle state machine.
|
||||
|
||||
Args:
|
||||
request: FastAPI request object (for rate limiting)
|
||||
project_id: UUID of the project
|
||||
agent_id: UUID of the agent instance
|
||||
agent_in: Agent update data
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
AgentInstanceResponse: The updated agent instance
|
||||
|
||||
Raises:
|
||||
NotFoundError: If the project or agent is not found
|
||||
AuthorizationError: If the user lacks access to the project
|
||||
ValidationException: If the status transition is invalid
|
||||
"""
|
||||
try:
|
||||
# Verify project access
|
||||
await verify_project_access(db, project_id, current_user)
|
||||
|
||||
# Get current agent
|
||||
agent = await agent_instance_crud.get(db, id=agent_id)
|
||||
if not agent:
|
||||
raise NotFoundError(
|
||||
message=f"Agent {agent_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
# Verify agent belongs to the specified project
|
||||
if agent.project_id != project_id:
|
||||
raise NotFoundError(
|
||||
message=f"Agent {agent_id} not found in project {project_id}",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
# Validate status transition if status is being changed
|
||||
if agent_in.status is not None and agent_in.status != agent.status:
|
||||
validate_status_transition(agent.status, agent_in.status)
|
||||
|
||||
# Update the agent
|
||||
updated_agent = await agent_instance_crud.update(
|
||||
db, db_obj=agent, obj_in=agent_in
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"User {current_user.email} updated agent {updated_agent.name} "
|
||||
f"(id={agent_id})"
|
||||
)
|
||||
|
||||
# Get updated details
|
||||
details = await agent_instance_crud.get_with_details(
|
||||
db, instance_id=updated_agent.id
|
||||
)
|
||||
if details:
|
||||
return build_agent_response(
|
||||
agent=details["instance"],
|
||||
agent_type_name=details.get("agent_type_name"),
|
||||
agent_type_slug=details.get("agent_type_slug"),
|
||||
project_name=details.get("project_name"),
|
||||
project_slug=details.get("project_slug"),
|
||||
assigned_issues_count=details.get("assigned_issues_count", 0),
|
||||
)
|
||||
|
||||
return build_agent_response(updated_agent)
|
||||
|
||||
except (NotFoundError, AuthorizationError, ValidationException):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating agent: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@router.post(
|
||||
"/projects/{project_id}/agents/{agent_id}/pause",
|
||||
response_model=AgentInstanceResponse,
|
||||
summary="Pause Agent",
|
||||
description="Pause an agent instance, temporarily stopping its work.",
|
||||
operation_id="pause_agent",
|
||||
)
|
||||
@limiter.limit(f"{20 * RATE_MULTIPLIER}/minute")
|
||||
async def pause_agent(
|
||||
request: Request,
|
||||
project_id: UUID,
|
||||
agent_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Pause an agent instance.
|
||||
|
||||
Transitions the agent to PAUSED status, temporarily stopping
|
||||
its work. The agent can be resumed later with the resume endpoint.
|
||||
|
||||
Args:
|
||||
request: FastAPI request object (for rate limiting)
|
||||
project_id: UUID of the project
|
||||
agent_id: UUID of the agent instance
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
AgentInstanceResponse: The paused agent instance
|
||||
|
||||
Raises:
|
||||
NotFoundError: If the project or agent is not found
|
||||
AuthorizationError: If the user lacks access to the project
|
||||
ValidationException: If the agent cannot be paused from its current state
|
||||
"""
|
||||
try:
|
||||
# Verify project access
|
||||
await verify_project_access(db, project_id, current_user)
|
||||
|
||||
# Get current agent
|
||||
agent = await agent_instance_crud.get(db, id=agent_id)
|
||||
if not agent:
|
||||
raise NotFoundError(
|
||||
message=f"Agent {agent_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
# Verify agent belongs to the specified project
|
||||
if agent.project_id != project_id:
|
||||
raise NotFoundError(
|
||||
message=f"Agent {agent_id} not found in project {project_id}",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
# Validate the transition to PAUSED
|
||||
validate_status_transition(agent.status, AgentStatus.PAUSED)
|
||||
|
||||
# Update status to PAUSED
|
||||
paused_agent = await agent_instance_crud.update_status(
|
||||
db,
|
||||
instance_id=agent_id,
|
||||
status=AgentStatus.PAUSED,
|
||||
)
|
||||
|
||||
if not paused_agent:
|
||||
raise NotFoundError(
|
||||
message=f"Agent {agent_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"User {current_user.email} paused agent {paused_agent.name} "
|
||||
f"(id={agent_id})"
|
||||
)
|
||||
|
||||
# Get updated details
|
||||
details = await agent_instance_crud.get_with_details(
|
||||
db, instance_id=paused_agent.id
|
||||
)
|
||||
if details:
|
||||
return build_agent_response(
|
||||
agent=details["instance"],
|
||||
agent_type_name=details.get("agent_type_name"),
|
||||
agent_type_slug=details.get("agent_type_slug"),
|
||||
project_name=details.get("project_name"),
|
||||
project_slug=details.get("project_slug"),
|
||||
assigned_issues_count=details.get("assigned_issues_count", 0),
|
||||
)
|
||||
|
||||
return build_agent_response(paused_agent)
|
||||
|
||||
except (NotFoundError, AuthorizationError, ValidationException):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error pausing agent: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@router.post(
|
||||
"/projects/{project_id}/agents/{agent_id}/resume",
|
||||
response_model=AgentInstanceResponse,
|
||||
summary="Resume Agent",
|
||||
description="Resume a paused agent instance.",
|
||||
operation_id="resume_agent",
|
||||
)
|
||||
@limiter.limit(f"{20 * RATE_MULTIPLIER}/minute")
|
||||
async def resume_agent(
|
||||
request: Request,
|
||||
project_id: UUID,
|
||||
agent_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Resume a paused agent instance.
|
||||
|
||||
Transitions the agent from PAUSED back to IDLE status,
|
||||
allowing it to accept new work.
|
||||
|
||||
Args:
|
||||
request: FastAPI request object (for rate limiting)
|
||||
project_id: UUID of the project
|
||||
agent_id: UUID of the agent instance
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
AgentInstanceResponse: The resumed agent instance
|
||||
|
||||
Raises:
|
||||
NotFoundError: If the project or agent is not found
|
||||
AuthorizationError: If the user lacks access to the project
|
||||
ValidationException: If the agent cannot be resumed from its current state
|
||||
"""
|
||||
try:
|
||||
# Verify project access
|
||||
await verify_project_access(db, project_id, current_user)
|
||||
|
||||
# Get current agent
|
||||
agent = await agent_instance_crud.get(db, id=agent_id)
|
||||
if not agent:
|
||||
raise NotFoundError(
|
||||
message=f"Agent {agent_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
# Verify agent belongs to the specified project
|
||||
if agent.project_id != project_id:
|
||||
raise NotFoundError(
|
||||
message=f"Agent {agent_id} not found in project {project_id}",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
# Validate the transition to IDLE (resume)
|
||||
validate_status_transition(agent.status, AgentStatus.IDLE)
|
||||
|
||||
# Update status to IDLE
|
||||
resumed_agent = await agent_instance_crud.update_status(
|
||||
db,
|
||||
instance_id=agent_id,
|
||||
status=AgentStatus.IDLE,
|
||||
)
|
||||
|
||||
if not resumed_agent:
|
||||
raise NotFoundError(
|
||||
message=f"Agent {agent_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"User {current_user.email} resumed agent {resumed_agent.name} "
|
||||
f"(id={agent_id})"
|
||||
)
|
||||
|
||||
# Get updated details
|
||||
details = await agent_instance_crud.get_with_details(
|
||||
db, instance_id=resumed_agent.id
|
||||
)
|
||||
if details:
|
||||
return build_agent_response(
|
||||
agent=details["instance"],
|
||||
agent_type_name=details.get("agent_type_name"),
|
||||
agent_type_slug=details.get("agent_type_slug"),
|
||||
project_name=details.get("project_name"),
|
||||
project_slug=details.get("project_slug"),
|
||||
assigned_issues_count=details.get("assigned_issues_count", 0),
|
||||
)
|
||||
|
||||
return build_agent_response(resumed_agent)
|
||||
|
||||
except (NotFoundError, AuthorizationError, ValidationException):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error resuming agent: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/projects/{project_id}/agents/{agent_id}",
|
||||
response_model=MessageResponse,
|
||||
summary="Terminate Agent",
|
||||
description="Terminate an agent instance, permanently stopping it.",
|
||||
operation_id="terminate_agent",
|
||||
)
|
||||
@limiter.limit(f"{10 * RATE_MULTIPLIER}/minute")
|
||||
async def terminate_agent(
|
||||
request: Request,
|
||||
project_id: UUID,
|
||||
agent_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Terminate an agent instance.
|
||||
|
||||
Permanently terminates the agent, setting its status to TERMINATED.
|
||||
This action cannot be undone - a new agent must be spawned if needed.
|
||||
The agent's session and current task are cleared.
|
||||
|
||||
Args:
|
||||
request: FastAPI request object (for rate limiting)
|
||||
project_id: UUID of the project
|
||||
agent_id: UUID of the agent instance
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
MessageResponse: Confirmation message
|
||||
|
||||
Raises:
|
||||
NotFoundError: If the project or agent is not found
|
||||
AuthorizationError: If the user lacks access to the project
|
||||
ValidationException: If the agent is already terminated
|
||||
"""
|
||||
try:
|
||||
# Verify project access
|
||||
await verify_project_access(db, project_id, current_user)
|
||||
|
||||
# Get current agent
|
||||
agent = await agent_instance_crud.get(db, id=agent_id)
|
||||
if not agent:
|
||||
raise NotFoundError(
|
||||
message=f"Agent {agent_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
# Verify agent belongs to the specified project
|
||||
if agent.project_id != project_id:
|
||||
raise NotFoundError(
|
||||
message=f"Agent {agent_id} not found in project {project_id}",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
# Check if already terminated
|
||||
if agent.status == AgentStatus.TERMINATED:
|
||||
raise ValidationException(
|
||||
message="Agent is already terminated",
|
||||
error_code=ErrorCode.VALIDATION_ERROR,
|
||||
field="status",
|
||||
)
|
||||
|
||||
# Validate the transition to TERMINATED
|
||||
validate_status_transition(agent.status, AgentStatus.TERMINATED)
|
||||
|
||||
agent_name = agent.name
|
||||
|
||||
# Terminate the agent
|
||||
terminated_agent = await agent_instance_crud.terminate(db, instance_id=agent_id)
|
||||
|
||||
if not terminated_agent:
|
||||
raise NotFoundError(
|
||||
message=f"Agent {agent_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"User {current_user.email} terminated agent {agent_name} (id={agent_id})"
|
||||
)
|
||||
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message=f"Agent '{agent_name}' has been terminated",
|
||||
)
|
||||
|
||||
except (NotFoundError, AuthorizationError, ValidationException):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error terminating agent: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@router.get(
|
||||
"/projects/{project_id}/agents/{agent_id}/metrics",
|
||||
response_model=AgentInstanceMetrics,
|
||||
summary="Get Agent Metrics",
|
||||
description="Get usage metrics for a specific agent instance.",
|
||||
operation_id="get_agent_metrics",
|
||||
)
|
||||
@limiter.limit(f"{60 * RATE_MULTIPLIER}/minute")
|
||||
async def get_agent_metrics(
|
||||
request: Request,
|
||||
project_id: UUID,
|
||||
agent_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Get usage metrics for a specific agent instance.
|
||||
|
||||
Returns metrics including tasks completed, tokens used,
|
||||
and cost incurred for the specified agent.
|
||||
|
||||
Args:
|
||||
request: FastAPI request object (for rate limiting)
|
||||
project_id: UUID of the project
|
||||
agent_id: UUID of the agent instance
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
AgentInstanceMetrics: Agent usage metrics
|
||||
|
||||
Raises:
|
||||
NotFoundError: If the project or agent is not found
|
||||
AuthorizationError: If the user lacks access to the project
|
||||
"""
|
||||
try:
|
||||
# Verify project access
|
||||
await verify_project_access(db, project_id, current_user)
|
||||
|
||||
# Get agent
|
||||
agent = await agent_instance_crud.get(db, id=agent_id)
|
||||
if not agent:
|
||||
raise NotFoundError(
|
||||
message=f"Agent {agent_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
# Verify agent belongs to the specified project
|
||||
if agent.project_id != project_id:
|
||||
raise NotFoundError(
|
||||
message=f"Agent {agent_id} not found in project {project_id}",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
# Calculate metrics for this single agent
|
||||
# For a single agent, we report its individual metrics
|
||||
is_active = agent.status == AgentStatus.WORKING
|
||||
is_idle = agent.status == AgentStatus.IDLE
|
||||
|
||||
logger.debug(
|
||||
f"User {current_user.email} retrieved metrics for agent {agent.name} "
|
||||
f"(id={agent_id})"
|
||||
)
|
||||
|
||||
return AgentInstanceMetrics(
|
||||
total_instances=1,
|
||||
active_instances=1 if is_active else 0,
|
||||
idle_instances=1 if is_idle else 0,
|
||||
total_tasks_completed=agent.tasks_completed,
|
||||
total_tokens_used=agent.tokens_used,
|
||||
total_cost_incurred=agent.cost_incurred,
|
||||
)
|
||||
|
||||
except (NotFoundError, AuthorizationError):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting agent metrics: {e!s}", exc_info=True)
|
||||
raise
|
||||
411
backend/app/api/routes/context.py
Normal file
411
backend/app/api/routes/context.py
Normal file
@@ -0,0 +1,411 @@
|
||||
"""
|
||||
Context Management API Endpoints.
|
||||
|
||||
Provides REST endpoints for context assembly and optimization
|
||||
for LLM requests using the ContextEngine.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Annotated, Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.api.dependencies.permissions import require_superuser
|
||||
from app.models.user import User
|
||||
from app.services.context import (
|
||||
AssemblyTimeoutError,
|
||||
BudgetExceededError,
|
||||
ContextEngine,
|
||||
ContextSettings,
|
||||
create_context_engine,
|
||||
get_context_settings,
|
||||
)
|
||||
from app.services.mcp import MCPClientManager, get_mcp_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Singleton Engine Management
|
||||
# ============================================================================
|
||||
|
||||
_context_engine: ContextEngine | None = None
|
||||
|
||||
|
||||
def _get_or_create_engine(
|
||||
mcp: MCPClientManager,
|
||||
settings: ContextSettings | None = None,
|
||||
) -> ContextEngine:
|
||||
"""Get or create the singleton ContextEngine."""
|
||||
global _context_engine
|
||||
if _context_engine is None:
|
||||
_context_engine = create_context_engine(
|
||||
mcp_manager=mcp,
|
||||
redis=None, # Optional: add Redis caching later
|
||||
settings=settings or get_context_settings(),
|
||||
)
|
||||
logger.info("ContextEngine initialized")
|
||||
else:
|
||||
# Ensure MCP manager is up to date
|
||||
_context_engine.set_mcp_manager(mcp)
|
||||
return _context_engine
|
||||
|
||||
|
||||
async def get_context_engine(
|
||||
mcp: MCPClientManager = Depends(get_mcp_client),
|
||||
) -> ContextEngine:
|
||||
"""FastAPI dependency to get the ContextEngine."""
|
||||
return _get_or_create_engine(mcp)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Request/Response Schemas
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class ConversationTurn(BaseModel):
|
||||
"""A single conversation turn."""
|
||||
|
||||
role: str = Field(..., description="Role: 'user' or 'assistant'")
|
||||
content: str = Field(..., description="Message content")
|
||||
|
||||
|
||||
class ToolResult(BaseModel):
|
||||
"""A tool execution result."""
|
||||
|
||||
tool_name: str = Field(..., description="Name of the tool")
|
||||
content: str | dict[str, Any] = Field(..., description="Tool result content")
|
||||
status: str = Field(default="success", description="Execution status")
|
||||
|
||||
|
||||
class AssembleContextRequest(BaseModel):
|
||||
"""Request to assemble context for an LLM request."""
|
||||
|
||||
project_id: str = Field(..., description="Project identifier")
|
||||
agent_id: str = Field(..., description="Agent identifier")
|
||||
query: str = Field(..., description="User's query or current request")
|
||||
model: str = Field(
|
||||
default="claude-3-sonnet",
|
||||
description="Target model name",
|
||||
)
|
||||
max_tokens: int | None = Field(
|
||||
None,
|
||||
description="Maximum context tokens (uses model default if None)",
|
||||
)
|
||||
system_prompt: str | None = Field(
|
||||
None,
|
||||
description="System prompt/instructions",
|
||||
)
|
||||
task_description: str | None = Field(
|
||||
None,
|
||||
description="Current task description",
|
||||
)
|
||||
knowledge_query: str | None = Field(
|
||||
None,
|
||||
description="Query for knowledge base search",
|
||||
)
|
||||
knowledge_limit: int = Field(
|
||||
default=10,
|
||||
ge=1,
|
||||
le=50,
|
||||
description="Max number of knowledge results",
|
||||
)
|
||||
conversation_history: list[ConversationTurn] | None = Field(
|
||||
None,
|
||||
description="Previous conversation turns",
|
||||
)
|
||||
tool_results: list[ToolResult] | None = Field(
|
||||
None,
|
||||
description="Tool execution results to include",
|
||||
)
|
||||
compress: bool = Field(
|
||||
default=True,
|
||||
description="Whether to apply compression",
|
||||
)
|
||||
use_cache: bool = Field(
|
||||
default=True,
|
||||
description="Whether to use caching",
|
||||
)
|
||||
|
||||
|
||||
class AssembledContextResponse(BaseModel):
|
||||
"""Response containing assembled context."""
|
||||
|
||||
content: str = Field(..., description="Assembled context content")
|
||||
total_tokens: int = Field(..., description="Total token count")
|
||||
context_count: int = Field(..., description="Number of context items included")
|
||||
compressed: bool = Field(..., description="Whether compression was applied")
|
||||
budget_used_percent: float = Field(
|
||||
...,
|
||||
description="Percentage of token budget used",
|
||||
)
|
||||
metadata: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Additional metadata",
|
||||
)
|
||||
|
||||
|
||||
class TokenCountRequest(BaseModel):
|
||||
"""Request to count tokens in content."""
|
||||
|
||||
content: str = Field(..., description="Content to count tokens in")
|
||||
model: str | None = Field(
|
||||
None,
|
||||
description="Model for model-specific tokenization",
|
||||
)
|
||||
|
||||
|
||||
class TokenCountResponse(BaseModel):
|
||||
"""Response containing token count."""
|
||||
|
||||
token_count: int = Field(..., description="Number of tokens")
|
||||
model: str | None = Field(None, description="Model used for counting")
|
||||
|
||||
|
||||
class BudgetInfoResponse(BaseModel):
|
||||
"""Response containing budget information for a model."""
|
||||
|
||||
model: str = Field(..., description="Model name")
|
||||
total_tokens: int = Field(..., description="Total token budget")
|
||||
system_tokens: int = Field(..., description="Tokens reserved for system")
|
||||
knowledge_tokens: int = Field(..., description="Tokens for knowledge")
|
||||
conversation_tokens: int = Field(..., description="Tokens for conversation")
|
||||
tool_tokens: int = Field(..., description="Tokens for tool results")
|
||||
response_reserve: int = Field(..., description="Tokens reserved for response")
|
||||
|
||||
|
||||
class ContextEngineStatsResponse(BaseModel):
|
||||
"""Response containing engine statistics."""
|
||||
|
||||
cache: dict[str, Any] = Field(..., description="Cache statistics")
|
||||
settings: dict[str, Any] = Field(..., description="Current settings")
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
"""Health check response."""
|
||||
|
||||
status: str = Field(..., description="Health status")
|
||||
mcp_connected: bool = Field(..., description="Whether MCP is connected")
|
||||
cache_enabled: bool = Field(..., description="Whether caching is enabled")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Endpoints
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.get(
|
||||
"/health",
|
||||
response_model=HealthResponse,
|
||||
summary="Context Engine Health",
|
||||
description="Check health status of the context engine.",
|
||||
)
|
||||
async def health_check(
|
||||
engine: ContextEngine = Depends(get_context_engine),
|
||||
) -> HealthResponse:
|
||||
"""Check context engine health."""
|
||||
stats = await engine.get_stats()
|
||||
return HealthResponse(
|
||||
status="healthy",
|
||||
mcp_connected=engine._mcp is not None,
|
||||
cache_enabled=stats.get("settings", {}).get("cache_enabled", False),
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/assemble",
|
||||
response_model=AssembledContextResponse,
|
||||
summary="Assemble Context",
|
||||
description="Assemble optimized context for an LLM request.",
|
||||
)
|
||||
async def assemble_context(
|
||||
request: AssembleContextRequest,
|
||||
current_user: User = Depends(require_superuser),
|
||||
engine: ContextEngine = Depends(get_context_engine),
|
||||
) -> AssembledContextResponse:
|
||||
"""
|
||||
Assemble optimized context for an LLM request.
|
||||
|
||||
This endpoint gathers context from various sources, scores and ranks them,
|
||||
compresses if needed, and formats for the target model.
|
||||
"""
|
||||
logger.info(
|
||||
"Context assembly for project=%s agent=%s by user=%s",
|
||||
request.project_id,
|
||||
request.agent_id,
|
||||
current_user.id,
|
||||
)
|
||||
|
||||
# Convert conversation history to dict format
|
||||
conversation_history = None
|
||||
if request.conversation_history:
|
||||
conversation_history = [
|
||||
{"role": turn.role, "content": turn.content}
|
||||
for turn in request.conversation_history
|
||||
]
|
||||
|
||||
# Convert tool results to dict format
|
||||
tool_results = None
|
||||
if request.tool_results:
|
||||
tool_results = [
|
||||
{
|
||||
"tool_name": tr.tool_name,
|
||||
"content": tr.content,
|
||||
"status": tr.status,
|
||||
}
|
||||
for tr in request.tool_results
|
||||
]
|
||||
|
||||
try:
|
||||
result = await engine.assemble_context(
|
||||
project_id=request.project_id,
|
||||
agent_id=request.agent_id,
|
||||
query=request.query,
|
||||
model=request.model,
|
||||
max_tokens=request.max_tokens,
|
||||
system_prompt=request.system_prompt,
|
||||
task_description=request.task_description,
|
||||
knowledge_query=request.knowledge_query,
|
||||
knowledge_limit=request.knowledge_limit,
|
||||
conversation_history=conversation_history,
|
||||
tool_results=tool_results,
|
||||
compress=request.compress,
|
||||
use_cache=request.use_cache,
|
||||
)
|
||||
|
||||
# Calculate budget usage percentage
|
||||
budget = await engine.get_budget_for_model(request.model, request.max_tokens)
|
||||
budget_used_percent = (result.total_tokens / budget.total) * 100
|
||||
|
||||
# Check if compression was applied (from metadata if available)
|
||||
was_compressed = result.metadata.get("compressed_contexts", 0) > 0
|
||||
|
||||
return AssembledContextResponse(
|
||||
content=result.content,
|
||||
total_tokens=result.total_tokens,
|
||||
context_count=result.context_count,
|
||||
compressed=was_compressed,
|
||||
budget_used_percent=round(budget_used_percent, 2),
|
||||
metadata={
|
||||
"model": request.model,
|
||||
"query": request.query,
|
||||
"knowledge_included": bool(request.knowledge_query),
|
||||
"conversation_turns": len(request.conversation_history or []),
|
||||
"excluded_count": result.excluded_count,
|
||||
"assembly_time_ms": result.assembly_time_ms,
|
||||
},
|
||||
)
|
||||
|
||||
except AssemblyTimeoutError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_504_GATEWAY_TIMEOUT,
|
||||
detail=f"Context assembly timed out: {e}",
|
||||
) from e
|
||||
except BudgetExceededError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
|
||||
detail=f"Token budget exceeded: {e}",
|
||||
) from e
|
||||
except Exception as e:
|
||||
logger.exception("Context assembly failed")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Context assembly failed: {e}",
|
||||
) from e
|
||||
|
||||
|
||||
@router.post(
|
||||
"/count-tokens",
|
||||
response_model=TokenCountResponse,
|
||||
summary="Count Tokens",
|
||||
description="Count tokens in content using the LLM Gateway.",
|
||||
)
|
||||
async def count_tokens(
|
||||
request: TokenCountRequest,
|
||||
engine: ContextEngine = Depends(get_context_engine),
|
||||
) -> TokenCountResponse:
|
||||
"""Count tokens in content."""
|
||||
try:
|
||||
count = await engine.count_tokens(
|
||||
content=request.content,
|
||||
model=request.model,
|
||||
)
|
||||
return TokenCountResponse(
|
||||
token_count=count,
|
||||
model=request.model,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Token counting failed: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Token counting failed: {e}",
|
||||
) from e
|
||||
|
||||
|
||||
@router.get(
|
||||
"/budget/{model}",
|
||||
response_model=BudgetInfoResponse,
|
||||
summary="Get Token Budget",
|
||||
description="Get token budget allocation for a specific model.",
|
||||
)
|
||||
async def get_budget(
|
||||
model: str,
|
||||
max_tokens: Annotated[int | None, Query(description="Custom max tokens")] = None,
|
||||
engine: ContextEngine = Depends(get_context_engine),
|
||||
) -> BudgetInfoResponse:
|
||||
"""Get token budget information for a model."""
|
||||
budget = await engine.get_budget_for_model(model, max_tokens)
|
||||
return BudgetInfoResponse(
|
||||
model=model,
|
||||
total_tokens=budget.total,
|
||||
system_tokens=budget.system,
|
||||
knowledge_tokens=budget.knowledge,
|
||||
conversation_tokens=budget.conversation,
|
||||
tool_tokens=budget.tools,
|
||||
response_reserve=budget.response_reserve,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/stats",
|
||||
response_model=ContextEngineStatsResponse,
|
||||
summary="Engine Statistics",
|
||||
description="Get context engine statistics and configuration.",
|
||||
)
|
||||
async def get_stats(
|
||||
current_user: User = Depends(require_superuser),
|
||||
engine: ContextEngine = Depends(get_context_engine),
|
||||
) -> ContextEngineStatsResponse:
|
||||
"""Get engine statistics."""
|
||||
stats = await engine.get_stats()
|
||||
return ContextEngineStatsResponse(
|
||||
cache=stats.get("cache", {}),
|
||||
settings=stats.get("settings", {}),
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/cache/invalidate",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
summary="Invalidate Cache (Admin Only)",
|
||||
description="Invalidate context cache entries.",
|
||||
)
|
||||
async def invalidate_cache(
|
||||
project_id: Annotated[
|
||||
str | None, Query(description="Project to invalidate")
|
||||
] = None,
|
||||
pattern: Annotated[str | None, Query(description="Pattern to match")] = None,
|
||||
current_user: User = Depends(require_superuser),
|
||||
engine: ContextEngine = Depends(get_context_engine),
|
||||
) -> None:
|
||||
"""Invalidate cache entries."""
|
||||
logger.info(
|
||||
"Cache invalidation by user %s: project=%s pattern=%s",
|
||||
current_user.id,
|
||||
project_id,
|
||||
pattern,
|
||||
)
|
||||
await engine.invalidate_cache(project_id=project_id, pattern=pattern)
|
||||
316
backend/app/api/routes/events.py
Normal file
316
backend/app/api/routes/events.py
Normal file
@@ -0,0 +1,316 @@
|
||||
"""
|
||||
SSE endpoint for real-time project event streaming.
|
||||
|
||||
This module provides Server-Sent Events (SSE) endpoints for streaming
|
||||
project events to connected clients. Events are scoped to projects,
|
||||
with authorization checks to ensure clients only receive events
|
||||
for projects they have access to.
|
||||
|
||||
Features:
|
||||
- Real-time event streaming via SSE
|
||||
- Project-scoped authorization
|
||||
- Automatic reconnection support (Last-Event-ID)
|
||||
- Keepalive messages every 30 seconds
|
||||
- Graceful connection cleanup
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, Query, Request
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
||||
from app.api.dependencies.auth import get_current_user, get_current_user_sse
|
||||
from app.api.dependencies.event_bus import get_event_bus
|
||||
from app.core.database import get_db
|
||||
from app.core.exceptions import AuthorizationError
|
||||
from app.models.user import User
|
||||
from app.schemas.errors import ErrorCode
|
||||
from app.schemas.events import EventType
|
||||
from app.services.event_bus import EventBus
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
# Keepalive interval in seconds
|
||||
KEEPALIVE_INTERVAL = 30
|
||||
|
||||
|
||||
async def check_project_access(
|
||||
project_id: UUID,
|
||||
user: User,
|
||||
db: "AsyncSession",
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a user has access to a project's events.
|
||||
|
||||
Authorization rules:
|
||||
- Superusers can access all projects
|
||||
- Project owners can access their own projects
|
||||
|
||||
Args:
|
||||
project_id: The project to check access for
|
||||
user: The authenticated user
|
||||
db: Database session for project lookup
|
||||
|
||||
Returns:
|
||||
bool: True if user has access, False otherwise
|
||||
"""
|
||||
# Superusers can access all projects
|
||||
if user.is_superuser:
|
||||
logger.debug(
|
||||
f"Project access granted for superuser {user.id} on project {project_id}"
|
||||
)
|
||||
return True
|
||||
|
||||
# Check if user owns the project
|
||||
from app.crud.syndarix import project as project_crud
|
||||
|
||||
project = await project_crud.get(db, id=project_id)
|
||||
if not project:
|
||||
logger.debug(f"Project {project_id} not found for access check")
|
||||
return False
|
||||
|
||||
has_access = bool(project.owner_id == user.id)
|
||||
logger.debug(
|
||||
f"Project access {'granted' if has_access else 'denied'} "
|
||||
f"for user {user.id} on project {project_id} (owner: {project.owner_id})"
|
||||
)
|
||||
return has_access
|
||||
|
||||
|
||||
async def event_generator(
|
||||
project_id: UUID,
|
||||
event_bus: EventBus,
|
||||
last_event_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
Generate SSE events for a project.
|
||||
|
||||
This async generator yields SSE-formatted events from the event bus,
|
||||
including keepalive comments to maintain the connection.
|
||||
|
||||
Args:
|
||||
project_id: The project to stream events for
|
||||
event_bus: The EventBus instance
|
||||
last_event_id: Optional last received event ID for reconnection
|
||||
|
||||
Yields:
|
||||
dict: SSE event data with 'event', 'data', and optional 'id' fields
|
||||
"""
|
||||
try:
|
||||
async for event_data in event_bus.subscribe_sse(
|
||||
project_id=project_id,
|
||||
last_event_id=last_event_id,
|
||||
keepalive_interval=KEEPALIVE_INTERVAL,
|
||||
):
|
||||
if event_data == "":
|
||||
# Keepalive - yield SSE comment
|
||||
yield {"comment": "keepalive"}
|
||||
else:
|
||||
# Parse event to extract type and id
|
||||
try:
|
||||
event_dict = json.loads(event_data)
|
||||
event_type = event_dict.get("type", "message")
|
||||
event_id = event_dict.get("id")
|
||||
|
||||
yield {
|
||||
"event": event_type,
|
||||
"data": event_data,
|
||||
"id": event_id,
|
||||
}
|
||||
except json.JSONDecodeError:
|
||||
# If we can't parse, send as generic message
|
||||
yield {
|
||||
"event": "message",
|
||||
"data": event_data,
|
||||
}
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"Event stream cancelled for project {project_id}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error in event stream for project {project_id}: {e}")
|
||||
raise
|
||||
|
||||
|
||||
@router.get(
|
||||
"/projects/{project_id}/events/stream",
|
||||
summary="Stream Project Events",
|
||||
description="""
|
||||
Stream real-time events for a project via Server-Sent Events (SSE).
|
||||
|
||||
**Authentication**: Required (Bearer token OR query parameter)
|
||||
**Authorization**: Must have access to the project
|
||||
|
||||
**Authentication Methods**:
|
||||
- Bearer token in Authorization header (preferred)
|
||||
- Query parameter `token` (for EventSource compatibility)
|
||||
|
||||
Note: EventSource API doesn't support custom headers, so the query parameter
|
||||
option is provided for browser-based SSE clients.
|
||||
|
||||
**SSE Event Format**:
|
||||
```
|
||||
event: agent.status_changed
|
||||
id: 550e8400-e29b-41d4-a716-446655440000
|
||||
data: {"id": "...", "type": "agent.status_changed", "project_id": "...", ...}
|
||||
|
||||
: keepalive
|
||||
|
||||
event: issue.created
|
||||
id: 550e8400-e29b-41d4-a716-446655440001
|
||||
data: {...}
|
||||
```
|
||||
|
||||
**Reconnection**: Include the `Last-Event-ID` header with the last received
|
||||
event ID to resume from where you left off.
|
||||
|
||||
**Keepalive**: The server sends a comment (`: keepalive`) every 30 seconds
|
||||
to keep the connection alive.
|
||||
|
||||
**Rate Limit**: 10 connections/minute per IP
|
||||
""",
|
||||
response_class=EventSourceResponse,
|
||||
responses={
|
||||
200: {
|
||||
"description": "SSE stream established",
|
||||
"content": {"text/event-stream": {}},
|
||||
},
|
||||
401: {"description": "Not authenticated"},
|
||||
403: {"description": "Not authorized to access this project"},
|
||||
404: {"description": "Project not found"},
|
||||
},
|
||||
operation_id="stream_project_events",
|
||||
)
|
||||
@limiter.limit("10/minute")
|
||||
async def stream_project_events(
|
||||
request: Request,
|
||||
project_id: UUID,
|
||||
db: "AsyncSession" = Depends(get_db),
|
||||
event_bus: EventBus = Depends(get_event_bus),
|
||||
token: str | None = Query(
|
||||
None, description="Auth token (for EventSource compatibility)"
|
||||
),
|
||||
authorization: str | None = Header(None, alias="Authorization"),
|
||||
last_event_id: str | None = Header(None, alias="Last-Event-ID"),
|
||||
):
|
||||
"""
|
||||
Stream real-time events for a project via SSE.
|
||||
|
||||
This endpoint establishes a persistent SSE connection that streams
|
||||
project events to the client in real-time. The connection includes:
|
||||
|
||||
- Event streaming: All project events (agent updates, issues, etc.)
|
||||
- Keepalive: Comment every 30 seconds to maintain connection
|
||||
- Reconnection: Use Last-Event-ID header to resume after disconnect
|
||||
|
||||
The connection is automatically cleaned up when the client disconnects.
|
||||
"""
|
||||
# Authenticate user (supports both header and query param tokens)
|
||||
current_user = await get_current_user_sse(
|
||||
db=db, authorization=authorization, token=token
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"SSE connection request for project {project_id} "
|
||||
f"by user {current_user.id} "
|
||||
f"(last_event_id={last_event_id})"
|
||||
)
|
||||
|
||||
# Check project access
|
||||
has_access = await check_project_access(project_id, current_user, db)
|
||||
if not has_access:
|
||||
raise AuthorizationError(
|
||||
message=f"You don't have access to project {project_id}",
|
||||
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||
)
|
||||
|
||||
# Return SSE response
|
||||
return EventSourceResponse(
|
||||
event_generator(
|
||||
project_id=project_id,
|
||||
event_bus=event_bus,
|
||||
last_event_id=last_event_id,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no", # Disable nginx buffering
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/projects/{project_id}/events/test",
|
||||
summary="Send Test Event (Development Only)",
|
||||
description="""
|
||||
Send a test event to a project's event stream. This endpoint is
|
||||
intended for development and testing purposes.
|
||||
|
||||
**Authentication**: Required (Bearer token)
|
||||
**Authorization**: Must have access to the project
|
||||
|
||||
**Note**: This endpoint should be disabled or restricted in production.
|
||||
""",
|
||||
response_model=dict,
|
||||
responses={
|
||||
200: {"description": "Test event sent"},
|
||||
401: {"description": "Not authenticated"},
|
||||
403: {"description": "Not authorized to access this project"},
|
||||
},
|
||||
operation_id="send_test_event",
|
||||
)
|
||||
async def send_test_event(
|
||||
project_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
event_bus: EventBus = Depends(get_event_bus),
|
||||
db: "AsyncSession" = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Send a test event to the project's event stream.
|
||||
|
||||
This is useful for testing SSE connections during development.
|
||||
"""
|
||||
# Check project access
|
||||
has_access = await check_project_access(project_id, current_user, db)
|
||||
if not has_access:
|
||||
raise AuthorizationError(
|
||||
message=f"You don't have access to project {project_id}",
|
||||
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||
)
|
||||
|
||||
# Create and publish test event using the Event schema
|
||||
event = EventBus.create_event(
|
||||
event_type=EventType.AGENT_MESSAGE,
|
||||
project_id=project_id,
|
||||
actor_type="user",
|
||||
actor_id=current_user.id,
|
||||
payload={
|
||||
"message": "Test event from SSE endpoint",
|
||||
"message_type": "info",
|
||||
},
|
||||
)
|
||||
|
||||
channel = event_bus.get_project_channel(project_id)
|
||||
await event_bus.publish(channel, event)
|
||||
|
||||
logger.info(f"Test event sent to project {project_id}: {event.id}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"event_id": event.id,
|
||||
"event_type": event.type.value,
|
||||
"message": "Test event sent successfully",
|
||||
}
|
||||
968
backend/app/api/routes/issues.py
Normal file
968
backend/app/api/routes/issues.py
Normal file
@@ -0,0 +1,968 @@
|
||||
# app/api/routes/issues.py
|
||||
"""
|
||||
Issue CRUD API endpoints for Syndarix projects.
|
||||
|
||||
Provides endpoints for managing issues within projects, including:
|
||||
- Create, read, update, delete operations
|
||||
- Filtering by status, priority, labels, sprint, assigned agent
|
||||
- Search across title and body
|
||||
- Assignment to agents
|
||||
- External issue tracker sync triggers
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Request, status
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.dependencies.auth import get_current_user
|
||||
from app.core.database import get_db
|
||||
from app.core.exceptions import (
|
||||
AuthorizationError,
|
||||
NotFoundError,
|
||||
ValidationException,
|
||||
)
|
||||
from app.crud.syndarix.agent_instance import agent_instance as agent_instance_crud
|
||||
from app.crud.syndarix.issue import issue as issue_crud
|
||||
from app.crud.syndarix.project import project as project_crud
|
||||
from app.crud.syndarix.sprint import sprint as sprint_crud
|
||||
from app.models.syndarix.enums import (
|
||||
AgentStatus,
|
||||
IssuePriority,
|
||||
IssueStatus,
|
||||
SprintStatus,
|
||||
SyncStatus,
|
||||
)
|
||||
from app.models.user import User
|
||||
from app.schemas.common import (
|
||||
MessageResponse,
|
||||
PaginatedResponse,
|
||||
PaginationParams,
|
||||
SortOrder,
|
||||
create_pagination_meta,
|
||||
)
|
||||
from app.schemas.errors import ErrorCode
|
||||
from app.schemas.syndarix.issue import (
|
||||
IssueAssign,
|
||||
IssueCreate,
|
||||
IssueResponse,
|
||||
IssueStats,
|
||||
IssueUpdate,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Initialize limiter for this router
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
# Use higher rate limits in test environment
|
||||
IS_TEST = os.getenv("IS_TEST", "False") == "True"
|
||||
RATE_MULTIPLIER = 100 if IS_TEST else 1
|
||||
|
||||
|
||||
async def verify_project_ownership(
|
||||
db: AsyncSession,
|
||||
project_id: UUID,
|
||||
user: User,
|
||||
) -> None:
|
||||
"""
|
||||
Verify that the user owns the project or is a superuser.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
project_id: Project UUID to verify
|
||||
user: Current authenticated user
|
||||
|
||||
Raises:
|
||||
NotFoundError: If project does not exist
|
||||
AuthorizationError: If user does not own the project
|
||||
"""
|
||||
project = await project_crud.get(db, id=project_id)
|
||||
if not project:
|
||||
raise NotFoundError(
|
||||
message=f"Project {project_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
if not user.is_superuser and project.owner_id != user.id:
|
||||
raise AuthorizationError(
|
||||
message="You do not have access to this project",
|
||||
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||
)
|
||||
|
||||
|
||||
def _build_issue_response(
|
||||
issue: Any,
|
||||
project_name: str | None = None,
|
||||
project_slug: str | None = None,
|
||||
sprint_name: str | None = None,
|
||||
assigned_agent_type_name: str | None = None,
|
||||
) -> IssueResponse:
|
||||
"""
|
||||
Build an IssueResponse from an Issue model instance.
|
||||
|
||||
Args:
|
||||
issue: Issue model instance
|
||||
project_name: Optional project name from relationship
|
||||
project_slug: Optional project slug from relationship
|
||||
sprint_name: Optional sprint name from relationship
|
||||
assigned_agent_type_name: Optional agent type name from relationship
|
||||
|
||||
Returns:
|
||||
IssueResponse schema instance
|
||||
"""
|
||||
return IssueResponse(
|
||||
id=issue.id,
|
||||
project_id=issue.project_id,
|
||||
title=issue.title,
|
||||
body=issue.body,
|
||||
status=issue.status,
|
||||
priority=issue.priority,
|
||||
labels=issue.labels or [],
|
||||
assigned_agent_id=issue.assigned_agent_id,
|
||||
human_assignee=issue.human_assignee,
|
||||
sprint_id=issue.sprint_id,
|
||||
story_points=issue.story_points,
|
||||
external_tracker_type=issue.external_tracker_type,
|
||||
external_issue_id=issue.external_issue_id,
|
||||
remote_url=issue.remote_url,
|
||||
external_issue_number=issue.external_issue_number,
|
||||
sync_status=issue.sync_status,
|
||||
last_synced_at=issue.last_synced_at,
|
||||
external_updated_at=issue.external_updated_at,
|
||||
closed_at=issue.closed_at,
|
||||
created_at=issue.created_at,
|
||||
updated_at=issue.updated_at,
|
||||
project_name=project_name,
|
||||
project_slug=project_slug,
|
||||
sprint_name=sprint_name,
|
||||
assigned_agent_type_name=assigned_agent_type_name,
|
||||
)
|
||||
|
||||
|
||||
# ===== Issue CRUD Endpoints =====
|
||||
|
||||
|
||||
@router.post(
|
||||
"/projects/{project_id}/issues",
|
||||
response_model=IssueResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Create Issue",
|
||||
description="Create a new issue in a project",
|
||||
operation_id="create_issue",
|
||||
)
|
||||
@limiter.limit(f"{60 * RATE_MULTIPLIER}/minute")
|
||||
async def create_issue(
|
||||
request: Request,
|
||||
project_id: UUID,
|
||||
issue_in: IssueCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Create a new issue within a project.
|
||||
|
||||
The user must own the project or be a superuser.
|
||||
The project_id in the path takes precedence over any project_id in the body.
|
||||
|
||||
Args:
|
||||
request: FastAPI request object (for rate limiting)
|
||||
project_id: UUID of the project to create the issue in
|
||||
issue_in: Issue creation data
|
||||
current_user: Authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Created issue with full details
|
||||
|
||||
Raises:
|
||||
NotFoundError: If project not found
|
||||
AuthorizationError: If user lacks access
|
||||
ValidationException: If assigned agent not in project
|
||||
"""
|
||||
# Verify project access
|
||||
await verify_project_ownership(db, project_id, current_user)
|
||||
|
||||
# Override project_id from path
|
||||
issue_in.project_id = project_id
|
||||
|
||||
# Validate assigned agent if provided
|
||||
if issue_in.assigned_agent_id:
|
||||
agent = await agent_instance_crud.get(db, id=issue_in.assigned_agent_id)
|
||||
if not agent:
|
||||
raise NotFoundError(
|
||||
message=f"Agent instance {issue_in.assigned_agent_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
if agent.project_id != project_id:
|
||||
raise ValidationException(
|
||||
message="Agent instance does not belong to this project",
|
||||
error_code=ErrorCode.VALIDATION_ERROR,
|
||||
field="assigned_agent_id",
|
||||
)
|
||||
if agent.status == AgentStatus.TERMINATED:
|
||||
raise ValidationException(
|
||||
message="Cannot assign issue to a terminated agent",
|
||||
error_code=ErrorCode.VALIDATION_ERROR,
|
||||
field="assigned_agent_id",
|
||||
)
|
||||
|
||||
# Validate sprint if provided (IDOR prevention)
|
||||
if issue_in.sprint_id:
|
||||
sprint = await sprint_crud.get(db, id=issue_in.sprint_id)
|
||||
if not sprint:
|
||||
raise NotFoundError(
|
||||
message=f"Sprint {issue_in.sprint_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
if sprint.project_id != project_id:
|
||||
raise ValidationException(
|
||||
message="Sprint does not belong to this project",
|
||||
error_code=ErrorCode.VALIDATION_ERROR,
|
||||
field="sprint_id",
|
||||
)
|
||||
|
||||
try:
|
||||
issue = await issue_crud.create(db, obj_in=issue_in)
|
||||
logger.info(
|
||||
f"User {current_user.email} created issue '{issue.title}' "
|
||||
f"in project {project_id}"
|
||||
)
|
||||
|
||||
# Get project details for response
|
||||
project = await project_crud.get(db, id=project_id)
|
||||
|
||||
return _build_issue_response(
|
||||
issue,
|
||||
project_name=project.name if project else None,
|
||||
project_slug=project.slug if project else None,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to create issue: {e!s}")
|
||||
raise ValidationException(
|
||||
message=str(e),
|
||||
error_code=ErrorCode.VALIDATION_ERROR,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating issue: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@router.get(
|
||||
"/projects/{project_id}/issues",
|
||||
response_model=PaginatedResponse[IssueResponse],
|
||||
summary="List Issues",
|
||||
description="Get paginated list of issues in a project with filtering",
|
||||
operation_id="list_issues",
|
||||
)
|
||||
@limiter.limit(f"{120 * RATE_MULTIPLIER}/minute")
|
||||
async def list_issues(
|
||||
request: Request,
|
||||
project_id: UUID,
|
||||
pagination: PaginationParams = Depends(),
|
||||
status_filter: IssueStatus | None = Query(
|
||||
None, alias="status", description="Filter by issue status"
|
||||
),
|
||||
priority: IssuePriority | None = Query(None, description="Filter by priority"),
|
||||
labels: list[str] | None = Query(
|
||||
None, description="Filter by labels (comma-separated)"
|
||||
),
|
||||
sprint_id: UUID | None = Query(None, description="Filter by sprint ID"),
|
||||
assigned_agent_id: UUID | None = Query(
|
||||
None, description="Filter by assigned agent ID"
|
||||
),
|
||||
sync_status: SyncStatus | None = Query(None, description="Filter by sync status"),
|
||||
search: str | None = Query(
|
||||
None, min_length=1, max_length=100, description="Search in title and body"
|
||||
),
|
||||
sort_by: str = Query(
|
||||
"created_at",
|
||||
description="Field to sort by (created_at, updated_at, priority, status, title)",
|
||||
),
|
||||
sort_order: SortOrder = Query(SortOrder.DESC, description="Sort order"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
List issues in a project with comprehensive filtering options.
|
||||
|
||||
Supports filtering by:
|
||||
- status: Issue status (open, in_progress, in_review, blocked, closed)
|
||||
- priority: Issue priority (low, medium, high, critical)
|
||||
- labels: Match issues containing any of the provided labels
|
||||
- sprint_id: Issues in a specific sprint
|
||||
- assigned_agent_id: Issues assigned to a specific agent
|
||||
- sync_status: External tracker sync status
|
||||
- search: Full-text search in title and body
|
||||
|
||||
Args:
|
||||
request: FastAPI request object
|
||||
project_id: Project UUID
|
||||
pagination: Pagination parameters
|
||||
status_filter: Optional status filter
|
||||
priority: Optional priority filter
|
||||
labels: Optional labels filter
|
||||
sprint_id: Optional sprint filter
|
||||
assigned_agent_id: Optional agent assignment filter
|
||||
sync_status: Optional sync status filter
|
||||
search: Optional search query
|
||||
sort_by: Field to sort by
|
||||
sort_order: Sort direction
|
||||
current_user: Authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Paginated list of issues matching filters
|
||||
"""
|
||||
# Verify project access
|
||||
await verify_project_ownership(db, project_id, current_user)
|
||||
|
||||
try:
|
||||
# Get filtered issues
|
||||
issues, total = await issue_crud.get_by_project(
|
||||
db,
|
||||
project_id=project_id,
|
||||
status=status_filter,
|
||||
priority=priority,
|
||||
sprint_id=sprint_id,
|
||||
assigned_agent_id=assigned_agent_id,
|
||||
labels=labels,
|
||||
search=search,
|
||||
skip=pagination.offset,
|
||||
limit=pagination.limit,
|
||||
sort_by=sort_by,
|
||||
sort_order=sort_order.value,
|
||||
)
|
||||
|
||||
# Build response objects
|
||||
issue_responses = [_build_issue_response(issue) for issue in issues]
|
||||
|
||||
pagination_meta = create_pagination_meta(
|
||||
total=total,
|
||||
page=pagination.page,
|
||||
limit=pagination.limit,
|
||||
items_count=len(issue_responses),
|
||||
)
|
||||
|
||||
return PaginatedResponse(data=issue_responses, pagination=pagination_meta)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error listing issues for project {project_id}: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
# ===== Issue Statistics Endpoint =====
|
||||
# NOTE: This endpoint MUST be defined before /{issue_id} routes
|
||||
# to prevent FastAPI from trying to parse "stats" as a UUID
|
||||
|
||||
|
||||
@router.get(
|
||||
"/projects/{project_id}/issues/stats",
|
||||
response_model=IssueStats,
|
||||
summary="Get Issue Statistics",
|
||||
description="Get aggregated issue statistics for a project",
|
||||
operation_id="get_issue_stats",
|
||||
)
|
||||
@limiter.limit(f"{60 * RATE_MULTIPLIER}/minute")
|
||||
async def get_issue_stats(
|
||||
request: Request,
|
||||
project_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Get aggregated statistics for issues in a project.
|
||||
|
||||
Returns counts by status and priority, along with story point totals.
|
||||
|
||||
Args:
|
||||
request: FastAPI request object
|
||||
project_id: Project UUID
|
||||
current_user: Authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Issue statistics including counts by status/priority and story points
|
||||
|
||||
Raises:
|
||||
NotFoundError: If project not found
|
||||
AuthorizationError: If user lacks access
|
||||
"""
|
||||
# Verify project access
|
||||
await verify_project_ownership(db, project_id, current_user)
|
||||
|
||||
try:
|
||||
stats = await issue_crud.get_project_stats(db, project_id=project_id)
|
||||
return IssueStats(**stats)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting issue stats for project {project_id}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@router.get(
|
||||
"/projects/{project_id}/issues/{issue_id}",
|
||||
response_model=IssueResponse,
|
||||
summary="Get Issue",
|
||||
description="Get detailed information about a specific issue",
|
||||
operation_id="get_issue",
|
||||
)
|
||||
@limiter.limit(f"{120 * RATE_MULTIPLIER}/minute")
|
||||
async def get_issue(
|
||||
request: Request,
|
||||
project_id: UUID,
|
||||
issue_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Get detailed information about a specific issue.
|
||||
|
||||
Returns the issue with expanded relationship data including
|
||||
project name, sprint name, and assigned agent type name.
|
||||
|
||||
Args:
|
||||
request: FastAPI request object
|
||||
project_id: Project UUID
|
||||
issue_id: Issue UUID
|
||||
current_user: Authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Issue details with relationship data
|
||||
|
||||
Raises:
|
||||
NotFoundError: If project or issue not found
|
||||
AuthorizationError: If user lacks access
|
||||
"""
|
||||
# Verify project access
|
||||
await verify_project_ownership(db, project_id, current_user)
|
||||
|
||||
# Get issue with details
|
||||
issue_data = await issue_crud.get_with_details(db, issue_id=issue_id)
|
||||
|
||||
if not issue_data:
|
||||
raise NotFoundError(
|
||||
message=f"Issue {issue_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
issue = issue_data["issue"]
|
||||
|
||||
# Verify issue belongs to the project
|
||||
if issue.project_id != project_id:
|
||||
raise NotFoundError(
|
||||
message=f"Issue {issue_id} not found in project {project_id}",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
return _build_issue_response(
|
||||
issue,
|
||||
project_name=issue_data.get("project_name"),
|
||||
project_slug=issue_data.get("project_slug"),
|
||||
sprint_name=issue_data.get("sprint_name"),
|
||||
assigned_agent_type_name=issue_data.get("assigned_agent_type_name"),
|
||||
)
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/projects/{project_id}/issues/{issue_id}",
|
||||
response_model=IssueResponse,
|
||||
summary="Update Issue",
|
||||
description="Update an existing issue",
|
||||
operation_id="update_issue",
|
||||
)
|
||||
@limiter.limit(f"{60 * RATE_MULTIPLIER}/minute")
|
||||
async def update_issue(
|
||||
request: Request,
|
||||
project_id: UUID,
|
||||
issue_id: UUID,
|
||||
issue_in: IssueUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Update an existing issue.
|
||||
|
||||
All fields are optional - only provided fields will be updated.
|
||||
Validates that assigned agent belongs to the same project.
|
||||
|
||||
Args:
|
||||
request: FastAPI request object
|
||||
project_id: Project UUID
|
||||
issue_id: Issue UUID
|
||||
issue_in: Fields to update
|
||||
current_user: Authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Updated issue details
|
||||
|
||||
Raises:
|
||||
NotFoundError: If project or issue not found
|
||||
AuthorizationError: If user lacks access
|
||||
ValidationException: If validation fails
|
||||
"""
|
||||
# Verify project access
|
||||
await verify_project_ownership(db, project_id, current_user)
|
||||
|
||||
# Get existing issue
|
||||
issue = await issue_crud.get(db, id=issue_id)
|
||||
if not issue:
|
||||
raise NotFoundError(
|
||||
message=f"Issue {issue_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
# Verify issue belongs to the project
|
||||
if issue.project_id != project_id:
|
||||
raise NotFoundError(
|
||||
message=f"Issue {issue_id} not found in project {project_id}",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
# Validate assigned agent if being updated
|
||||
if issue_in.assigned_agent_id is not None:
|
||||
agent = await agent_instance_crud.get(db, id=issue_in.assigned_agent_id)
|
||||
if not agent:
|
||||
raise NotFoundError(
|
||||
message=f"Agent instance {issue_in.assigned_agent_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
if agent.project_id != project_id:
|
||||
raise ValidationException(
|
||||
message="Agent instance does not belong to this project",
|
||||
error_code=ErrorCode.VALIDATION_ERROR,
|
||||
field="assigned_agent_id",
|
||||
)
|
||||
if agent.status == AgentStatus.TERMINATED:
|
||||
raise ValidationException(
|
||||
message="Cannot assign issue to a terminated agent",
|
||||
error_code=ErrorCode.VALIDATION_ERROR,
|
||||
field="assigned_agent_id",
|
||||
)
|
||||
|
||||
# Validate sprint if being updated (IDOR prevention and status validation)
|
||||
if issue_in.sprint_id is not None:
|
||||
sprint = await sprint_crud.get(db, id=issue_in.sprint_id)
|
||||
if not sprint:
|
||||
raise NotFoundError(
|
||||
message=f"Sprint {issue_in.sprint_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
if sprint.project_id != project_id:
|
||||
raise ValidationException(
|
||||
message="Sprint does not belong to this project",
|
||||
error_code=ErrorCode.VALIDATION_ERROR,
|
||||
field="sprint_id",
|
||||
)
|
||||
# Cannot add issues to completed or cancelled sprints
|
||||
if sprint.status in [SprintStatus.COMPLETED, SprintStatus.CANCELLED]:
|
||||
raise ValidationException(
|
||||
message=f"Cannot add issues to sprint with status '{sprint.status.value}'",
|
||||
error_code=ErrorCode.VALIDATION_ERROR,
|
||||
field="sprint_id",
|
||||
)
|
||||
|
||||
try:
|
||||
updated_issue = await issue_crud.update(db, db_obj=issue, obj_in=issue_in)
|
||||
logger.info(
|
||||
f"User {current_user.email} updated issue {issue_id} in project {project_id}"
|
||||
)
|
||||
|
||||
# Get full details for response
|
||||
issue_data = await issue_crud.get_with_details(db, issue_id=issue_id)
|
||||
|
||||
return _build_issue_response(
|
||||
updated_issue,
|
||||
project_name=issue_data.get("project_name") if issue_data else None,
|
||||
project_slug=issue_data.get("project_slug") if issue_data else None,
|
||||
sprint_name=issue_data.get("sprint_name") if issue_data else None,
|
||||
assigned_agent_type_name=issue_data.get("assigned_agent_type_name")
|
||||
if issue_data
|
||||
else None,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to update issue {issue_id}: {e!s}")
|
||||
raise ValidationException(
|
||||
message=str(e),
|
||||
error_code=ErrorCode.VALIDATION_ERROR,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating issue {issue_id}: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/projects/{project_id}/issues/{issue_id}",
|
||||
response_model=MessageResponse,
|
||||
summary="Delete Issue",
|
||||
description="Delete an issue permanently",
|
||||
operation_id="delete_issue",
|
||||
)
|
||||
@limiter.limit(f"{30 * RATE_MULTIPLIER}/minute")
|
||||
async def delete_issue(
|
||||
request: Request,
|
||||
project_id: UUID,
|
||||
issue_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Delete an issue permanently.
|
||||
|
||||
The issue will be permanently removed from the database.
|
||||
|
||||
Args:
|
||||
request: FastAPI request object
|
||||
project_id: Project UUID
|
||||
issue_id: Issue UUID
|
||||
current_user: Authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Success message
|
||||
|
||||
Raises:
|
||||
NotFoundError: If project or issue not found
|
||||
AuthorizationError: If user lacks access
|
||||
"""
|
||||
# Verify project access
|
||||
await verify_project_ownership(db, project_id, current_user)
|
||||
|
||||
# Get existing issue
|
||||
issue = await issue_crud.get(db, id=issue_id)
|
||||
if not issue:
|
||||
raise NotFoundError(
|
||||
message=f"Issue {issue_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
# Verify issue belongs to the project
|
||||
if issue.project_id != project_id:
|
||||
raise NotFoundError(
|
||||
message=f"Issue {issue_id} not found in project {project_id}",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
try:
|
||||
issue_title = issue.title
|
||||
await issue_crud.remove(db, id=issue_id)
|
||||
logger.info(
|
||||
f"User {current_user.email} deleted issue {issue_id} "
|
||||
f"('{issue_title}') from project {project_id}"
|
||||
)
|
||||
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message=f"Issue '{issue_title}' has been deleted",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting issue {issue_id}: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
# ===== Issue Assignment Endpoint =====
|
||||
|
||||
|
||||
@router.post(
|
||||
"/projects/{project_id}/issues/{issue_id}/assign",
|
||||
response_model=IssueResponse,
|
||||
summary="Assign Issue",
|
||||
description="Assign an issue to an agent or human",
|
||||
operation_id="assign_issue",
|
||||
)
|
||||
@limiter.limit(f"{60 * RATE_MULTIPLIER}/minute")
|
||||
async def assign_issue(
|
||||
request: Request,
|
||||
project_id: UUID,
|
||||
issue_id: UUID,
|
||||
assignment: IssueAssign,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Assign an issue to an agent or human.
|
||||
|
||||
Only one type of assignment is allowed at a time:
|
||||
- assigned_agent_id: Assign to an AI agent instance
|
||||
- human_assignee: Assign to a human (name/email string)
|
||||
|
||||
To unassign, pass both as null/None.
|
||||
|
||||
Args:
|
||||
request: FastAPI request object
|
||||
project_id: Project UUID
|
||||
issue_id: Issue UUID
|
||||
assignment: Assignment data
|
||||
current_user: Authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Updated issue with assignment
|
||||
|
||||
Raises:
|
||||
NotFoundError: If project, issue, or agent not found
|
||||
AuthorizationError: If user lacks access
|
||||
ValidationException: If agent not in project
|
||||
"""
|
||||
# Verify project access
|
||||
await verify_project_ownership(db, project_id, current_user)
|
||||
|
||||
# Get existing issue
|
||||
issue = await issue_crud.get(db, id=issue_id)
|
||||
if not issue:
|
||||
raise NotFoundError(
|
||||
message=f"Issue {issue_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
# Verify issue belongs to the project
|
||||
if issue.project_id != project_id:
|
||||
raise NotFoundError(
|
||||
message=f"Issue {issue_id} not found in project {project_id}",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
# Process assignment based on type
|
||||
if assignment.assigned_agent_id:
|
||||
# Validate agent exists and belongs to project
|
||||
agent = await agent_instance_crud.get(db, id=assignment.assigned_agent_id)
|
||||
if not agent:
|
||||
raise NotFoundError(
|
||||
message=f"Agent instance {assignment.assigned_agent_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
if agent.project_id != project_id:
|
||||
raise ValidationException(
|
||||
message="Agent instance does not belong to this project",
|
||||
error_code=ErrorCode.VALIDATION_ERROR,
|
||||
field="assigned_agent_id",
|
||||
)
|
||||
if agent.status == AgentStatus.TERMINATED:
|
||||
raise ValidationException(
|
||||
message="Cannot assign issue to a terminated agent",
|
||||
error_code=ErrorCode.VALIDATION_ERROR,
|
||||
field="assigned_agent_id",
|
||||
)
|
||||
|
||||
updated_issue = await issue_crud.assign_to_agent(
|
||||
db, issue_id=issue_id, agent_id=assignment.assigned_agent_id
|
||||
)
|
||||
logger.info(
|
||||
f"User {current_user.email} assigned issue {issue_id} to agent {agent.name}"
|
||||
)
|
||||
|
||||
elif assignment.human_assignee:
|
||||
updated_issue = await issue_crud.assign_to_human(
|
||||
db, issue_id=issue_id, human_assignee=assignment.human_assignee
|
||||
)
|
||||
logger.info(
|
||||
f"User {current_user.email} assigned issue {issue_id} "
|
||||
f"to human '{assignment.human_assignee}'"
|
||||
)
|
||||
|
||||
else:
|
||||
# Unassign - clear both agent and human
|
||||
updated_issue = await issue_crud.assign_to_agent(
|
||||
db, issue_id=issue_id, agent_id=None
|
||||
)
|
||||
logger.info(f"User {current_user.email} unassigned issue {issue_id}")
|
||||
|
||||
if not updated_issue:
|
||||
raise NotFoundError(
|
||||
message=f"Issue {issue_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
# Get full details for response
|
||||
issue_data = await issue_crud.get_with_details(db, issue_id=issue_id)
|
||||
|
||||
return _build_issue_response(
|
||||
updated_issue,
|
||||
project_name=issue_data.get("project_name") if issue_data else None,
|
||||
project_slug=issue_data.get("project_slug") if issue_data else None,
|
||||
sprint_name=issue_data.get("sprint_name") if issue_data else None,
|
||||
assigned_agent_type_name=issue_data.get("assigned_agent_type_name")
|
||||
if issue_data
|
||||
else None,
|
||||
)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/projects/{project_id}/issues/{issue_id}/assignment",
|
||||
response_model=IssueResponse,
|
||||
summary="Unassign Issue",
|
||||
description="""
|
||||
Remove agent/human assignment from an issue.
|
||||
|
||||
**Authentication**: Required (Bearer token)
|
||||
**Authorization**: Project owner or superuser
|
||||
|
||||
This clears both agent and human assignee fields.
|
||||
|
||||
**Rate Limit**: 60 requests/minute
|
||||
""",
|
||||
operation_id="unassign_issue",
|
||||
)
|
||||
@limiter.limit(f"{60 * RATE_MULTIPLIER}/minute")
|
||||
async def unassign_issue(
|
||||
request: Request,
|
||||
project_id: UUID,
|
||||
issue_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Remove assignment from an issue.
|
||||
|
||||
Clears both assigned_agent_id and human_assignee fields.
|
||||
"""
|
||||
# Verify project access
|
||||
await verify_project_ownership(db, project_id, current_user)
|
||||
|
||||
# Get existing issue
|
||||
issue = await issue_crud.get(db, id=issue_id)
|
||||
if not issue:
|
||||
raise NotFoundError(
|
||||
message=f"Issue {issue_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
# Verify issue belongs to project (IDOR prevention)
|
||||
if issue.project_id != project_id:
|
||||
raise NotFoundError(
|
||||
message=f"Issue {issue_id} not found in project {project_id}",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
# Unassign the issue
|
||||
updated_issue = await issue_crud.unassign(db, issue_id=issue_id)
|
||||
|
||||
if not updated_issue:
|
||||
raise NotFoundError(
|
||||
message=f"Issue {issue_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
logger.info(f"User {current_user.email} unassigned issue {issue_id}")
|
||||
|
||||
# Get full details for response
|
||||
issue_data = await issue_crud.get_with_details(db, issue_id=issue_id)
|
||||
|
||||
return _build_issue_response(
|
||||
updated_issue,
|
||||
project_name=issue_data.get("project_name") if issue_data else None,
|
||||
project_slug=issue_data.get("project_slug") if issue_data else None,
|
||||
sprint_name=issue_data.get("sprint_name") if issue_data else None,
|
||||
assigned_agent_type_name=issue_data.get("assigned_agent_type_name")
|
||||
if issue_data
|
||||
else None,
|
||||
)
|
||||
|
||||
|
||||
# ===== Issue Sync Endpoint =====
|
||||
|
||||
|
||||
@router.post(
|
||||
"/projects/{project_id}/issues/{issue_id}/sync",
|
||||
response_model=MessageResponse,
|
||||
summary="Trigger Issue Sync",
|
||||
description="Trigger synchronization with external issue tracker",
|
||||
operation_id="sync_issue",
|
||||
)
|
||||
@limiter.limit(f"{30 * RATE_MULTIPLIER}/minute")
|
||||
async def sync_issue(
|
||||
request: Request,
|
||||
project_id: UUID,
|
||||
issue_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Trigger synchronization of an issue with its external tracker.
|
||||
|
||||
This endpoint queues a sync task for the issue. The actual synchronization
|
||||
happens asynchronously via Celery.
|
||||
|
||||
Prerequisites:
|
||||
- Issue must have external_tracker_type configured
|
||||
- Project must have integration settings for the tracker
|
||||
|
||||
Args:
|
||||
request: FastAPI request object
|
||||
project_id: Project UUID
|
||||
issue_id: Issue UUID
|
||||
current_user: Authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Message indicating sync has been triggered
|
||||
|
||||
Raises:
|
||||
NotFoundError: If project or issue not found
|
||||
AuthorizationError: If user lacks access
|
||||
ValidationException: If issue has no external tracker
|
||||
"""
|
||||
# Verify project access
|
||||
await verify_project_ownership(db, project_id, current_user)
|
||||
|
||||
# Get existing issue
|
||||
issue = await issue_crud.get(db, id=issue_id)
|
||||
if not issue:
|
||||
raise NotFoundError(
|
||||
message=f"Issue {issue_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
# Verify issue belongs to the project
|
||||
if issue.project_id != project_id:
|
||||
raise NotFoundError(
|
||||
message=f"Issue {issue_id} not found in project {project_id}",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
# Check if issue has external tracker configured
|
||||
if not issue.external_tracker_type:
|
||||
raise ValidationException(
|
||||
message="Issue does not have an external tracker configured",
|
||||
error_code=ErrorCode.VALIDATION_ERROR,
|
||||
field="external_tracker_type",
|
||||
)
|
||||
|
||||
# Update sync status to pending
|
||||
await issue_crud.update_sync_status(
|
||||
db,
|
||||
issue_id=issue_id,
|
||||
sync_status=SyncStatus.PENDING,
|
||||
)
|
||||
|
||||
# TODO: Queue Celery task for actual sync
|
||||
# When Celery is set up, this will be:
|
||||
# from app.tasks.sync import sync_issue_task
|
||||
# sync_issue_task.delay(str(issue_id))
|
||||
|
||||
logger.info(
|
||||
f"User {current_user.email} triggered sync for issue {issue_id} "
|
||||
f"(tracker: {issue.external_tracker_type})"
|
||||
)
|
||||
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message=f"Sync triggered for issue '{issue.title}'. "
|
||||
f"Status will update when complete.",
|
||||
)
|
||||
446
backend/app/api/routes/mcp.py
Normal file
446
backend/app/api/routes/mcp.py
Normal file
@@ -0,0 +1,446 @@
|
||||
"""
|
||||
MCP (Model Context Protocol) API Endpoints
|
||||
|
||||
Provides REST endpoints for managing MCP server connections
|
||||
and executing tool calls.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Annotated, Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Path, status
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.api.dependencies.permissions import require_superuser
|
||||
from app.models.user import User
|
||||
from app.services.mcp import (
|
||||
MCPCircuitOpenError,
|
||||
MCPClientManager,
|
||||
MCPConnectionError,
|
||||
MCPError,
|
||||
MCPServerNotFoundError,
|
||||
MCPTimeoutError,
|
||||
MCPToolError,
|
||||
MCPToolNotFoundError,
|
||||
get_mcp_client,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# Server name validation pattern: alphanumeric, hyphens, underscores, 1-64 chars
|
||||
SERVER_NAME_PATTERN = re.compile(r"^[a-zA-Z0-9_-]{1,64}$")
|
||||
|
||||
# Type alias for validated server name path parameter
|
||||
ServerNamePath = Annotated[
|
||||
str,
|
||||
Path(
|
||||
description="MCP server name",
|
||||
min_length=1,
|
||||
max_length=64,
|
||||
pattern=r"^[a-zA-Z0-9_-]+$",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Request/Response Schemas
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class ServerInfo(BaseModel):
|
||||
"""Information about an MCP server."""
|
||||
|
||||
name: str = Field(..., description="Server name")
|
||||
url: str = Field(..., description="Server URL")
|
||||
enabled: bool = Field(..., description="Whether server is enabled")
|
||||
timeout: int = Field(..., description="Request timeout in seconds")
|
||||
transport: str = Field(..., description="Transport type (http, stdio, sse)")
|
||||
description: str | None = Field(None, description="Server description")
|
||||
|
||||
|
||||
class ServerListResponse(BaseModel):
|
||||
"""Response containing list of MCP servers."""
|
||||
|
||||
servers: list[ServerInfo]
|
||||
total: int
|
||||
|
||||
|
||||
class ToolInfoResponse(BaseModel):
|
||||
"""Information about an MCP tool."""
|
||||
|
||||
name: str = Field(..., description="Tool name")
|
||||
description: str | None = Field(None, description="Tool description")
|
||||
server_name: str | None = Field(None, description="Server providing the tool")
|
||||
input_schema: dict[str, Any] | None = Field(
|
||||
None, description="JSON schema for input"
|
||||
)
|
||||
|
||||
|
||||
class ToolListResponse(BaseModel):
|
||||
"""Response containing list of tools."""
|
||||
|
||||
tools: list[ToolInfoResponse]
|
||||
total: int
|
||||
|
||||
|
||||
class ServerHealthStatus(BaseModel):
|
||||
"""Health status for a server."""
|
||||
|
||||
name: str
|
||||
healthy: bool
|
||||
state: str
|
||||
url: str
|
||||
error: str | None = None
|
||||
tools_count: int = 0
|
||||
|
||||
|
||||
class HealthCheckResponse(BaseModel):
|
||||
"""Response containing health status of all servers."""
|
||||
|
||||
servers: dict[str, ServerHealthStatus]
|
||||
healthy_count: int
|
||||
unhealthy_count: int
|
||||
total: int
|
||||
|
||||
|
||||
class ToolCallRequest(BaseModel):
|
||||
"""Request to execute a tool."""
|
||||
|
||||
server: str = Field(..., description="MCP server name")
|
||||
tool: str = Field(..., description="Tool name to execute")
|
||||
arguments: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Tool arguments",
|
||||
)
|
||||
timeout: float | None = Field(
|
||||
None,
|
||||
description="Optional timeout override in seconds",
|
||||
)
|
||||
|
||||
|
||||
class ToolCallResponse(BaseModel):
|
||||
"""Response from tool execution."""
|
||||
|
||||
success: bool
|
||||
data: Any | None = None
|
||||
error: str | None = None
|
||||
error_code: str | None = None
|
||||
tool_name: str | None = None
|
||||
server_name: str | None = None
|
||||
execution_time_ms: float = 0.0
|
||||
request_id: str | None = None
|
||||
|
||||
|
||||
class CircuitBreakerStatus(BaseModel):
|
||||
"""Status of a circuit breaker."""
|
||||
|
||||
server_name: str
|
||||
state: str
|
||||
failure_count: int
|
||||
|
||||
|
||||
class CircuitBreakerListResponse(BaseModel):
|
||||
"""Response containing circuit breaker statuses."""
|
||||
|
||||
circuit_breakers: list[CircuitBreakerStatus]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Endpoints
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.get(
|
||||
"/servers",
|
||||
response_model=ServerListResponse,
|
||||
summary="List MCP Servers",
|
||||
description="Get list of all registered MCP servers with their configurations.",
|
||||
)
|
||||
async def list_servers(
|
||||
mcp: MCPClientManager = Depends(get_mcp_client),
|
||||
) -> ServerListResponse:
|
||||
"""List all registered MCP servers."""
|
||||
servers = []
|
||||
|
||||
for name in mcp.list_servers():
|
||||
try:
|
||||
config = mcp.get_server_config(name)
|
||||
servers.append(
|
||||
ServerInfo(
|
||||
name=name,
|
||||
url=config.url,
|
||||
enabled=config.enabled,
|
||||
timeout=config.timeout,
|
||||
transport=config.transport.value,
|
||||
description=config.description,
|
||||
)
|
||||
)
|
||||
except MCPServerNotFoundError:
|
||||
continue
|
||||
|
||||
return ServerListResponse(
|
||||
servers=servers,
|
||||
total=len(servers),
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/servers/{server_name}/tools",
|
||||
response_model=ToolListResponse,
|
||||
summary="List Server Tools",
|
||||
description="Get list of tools available on a specific MCP server.",
|
||||
)
|
||||
async def list_server_tools(
|
||||
server_name: ServerNamePath,
|
||||
mcp: MCPClientManager = Depends(get_mcp_client),
|
||||
) -> ToolListResponse:
|
||||
"""List all tools available on a specific server."""
|
||||
try:
|
||||
tools = await mcp.list_tools(server_name)
|
||||
return ToolListResponse(
|
||||
tools=[
|
||||
ToolInfoResponse(
|
||||
name=t.name,
|
||||
description=t.description,
|
||||
server_name=t.server_name,
|
||||
input_schema=t.input_schema,
|
||||
)
|
||||
for t in tools
|
||||
],
|
||||
total=len(tools),
|
||||
)
|
||||
except MCPServerNotFoundError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Server not found: {server_name}",
|
||||
) from e
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tools",
|
||||
response_model=ToolListResponse,
|
||||
summary="List All Tools",
|
||||
description="Get list of all tools from all MCP servers.",
|
||||
)
|
||||
async def list_all_tools(
|
||||
mcp: MCPClientManager = Depends(get_mcp_client),
|
||||
) -> ToolListResponse:
|
||||
"""List all tools from all servers."""
|
||||
tools = await mcp.list_all_tools()
|
||||
return ToolListResponse(
|
||||
tools=[
|
||||
ToolInfoResponse(
|
||||
name=t.name,
|
||||
description=t.description,
|
||||
server_name=t.server_name,
|
||||
input_schema=t.input_schema,
|
||||
)
|
||||
for t in tools
|
||||
],
|
||||
total=len(tools),
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/health",
|
||||
response_model=HealthCheckResponse,
|
||||
summary="Health Check",
|
||||
description="Check health status of all MCP servers.",
|
||||
)
|
||||
async def health_check(
|
||||
mcp: MCPClientManager = Depends(get_mcp_client),
|
||||
) -> HealthCheckResponse:
|
||||
"""Perform health check on all MCP servers."""
|
||||
health_results = await mcp.health_check()
|
||||
|
||||
servers = {
|
||||
name: ServerHealthStatus(
|
||||
name=status.name,
|
||||
healthy=status.healthy,
|
||||
state=status.state,
|
||||
url=status.url,
|
||||
error=status.error,
|
||||
tools_count=status.tools_count,
|
||||
)
|
||||
for name, status in health_results.items()
|
||||
}
|
||||
|
||||
healthy_count = sum(1 for s in servers.values() if s.healthy)
|
||||
unhealthy_count = len(servers) - healthy_count
|
||||
|
||||
return HealthCheckResponse(
|
||||
servers=servers,
|
||||
healthy_count=healthy_count,
|
||||
unhealthy_count=unhealthy_count,
|
||||
total=len(servers),
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/call",
|
||||
response_model=ToolCallResponse,
|
||||
summary="Execute Tool (Admin Only)",
|
||||
description="Execute a tool on an MCP server. Requires superuser privileges.",
|
||||
)
|
||||
async def call_tool(
|
||||
request: ToolCallRequest,
|
||||
current_user: User = Depends(require_superuser),
|
||||
mcp: MCPClientManager = Depends(get_mcp_client),
|
||||
) -> ToolCallResponse:
|
||||
"""
|
||||
Execute a tool on an MCP server.
|
||||
|
||||
This endpoint is restricted to superusers for direct tool execution.
|
||||
Normal tool execution should go through agent workflows.
|
||||
"""
|
||||
logger.info(
|
||||
"Tool call by user %s: %s.%s",
|
||||
current_user.id,
|
||||
request.server,
|
||||
request.tool,
|
||||
)
|
||||
|
||||
try:
|
||||
result = await mcp.call_tool(
|
||||
server=request.server,
|
||||
tool=request.tool,
|
||||
args=request.arguments,
|
||||
timeout=request.timeout,
|
||||
)
|
||||
|
||||
return ToolCallResponse(
|
||||
success=result.success,
|
||||
data=result.data,
|
||||
error=result.error,
|
||||
error_code=result.error_code,
|
||||
tool_name=result.tool_name,
|
||||
server_name=result.server_name,
|
||||
execution_time_ms=result.execution_time_ms,
|
||||
request_id=result.request_id,
|
||||
)
|
||||
|
||||
except MCPCircuitOpenError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=f"Server temporarily unavailable: {e.server_name}",
|
||||
) from e
|
||||
except MCPToolNotFoundError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Tool not found: {e.tool_name}",
|
||||
) from e
|
||||
except MCPServerNotFoundError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Server not found: {e.server_name}",
|
||||
) from e
|
||||
except MCPTimeoutError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_504_GATEWAY_TIMEOUT,
|
||||
detail=str(e),
|
||||
) from e
|
||||
except MCPConnectionError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail=str(e),
|
||||
) from e
|
||||
except MCPToolError as e:
|
||||
# Tool errors are returned in the response, not as HTTP errors
|
||||
return ToolCallResponse(
|
||||
success=False,
|
||||
error=str(e),
|
||||
error_code=e.error_code,
|
||||
tool_name=e.tool_name,
|
||||
server_name=e.server_name,
|
||||
)
|
||||
except MCPError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=str(e),
|
||||
) from e
|
||||
|
||||
|
||||
@router.get(
|
||||
"/circuit-breakers",
|
||||
response_model=CircuitBreakerListResponse,
|
||||
summary="List Circuit Breakers",
|
||||
description="Get status of all circuit breakers.",
|
||||
)
|
||||
async def list_circuit_breakers(
|
||||
mcp: MCPClientManager = Depends(get_mcp_client),
|
||||
) -> CircuitBreakerListResponse:
|
||||
"""Get status of all circuit breakers."""
|
||||
status_dict = mcp.get_circuit_breaker_status()
|
||||
|
||||
return CircuitBreakerListResponse(
|
||||
circuit_breakers=[
|
||||
CircuitBreakerStatus(
|
||||
server_name=name,
|
||||
state=info.get("state", "unknown"),
|
||||
failure_count=info.get("failure_count", 0),
|
||||
)
|
||||
for name, info in status_dict.items()
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/circuit-breakers/{server_name}/reset",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
summary="Reset Circuit Breaker (Admin Only)",
|
||||
description="Manually reset a circuit breaker for a server.",
|
||||
)
|
||||
async def reset_circuit_breaker(
|
||||
server_name: ServerNamePath,
|
||||
current_user: User = Depends(require_superuser),
|
||||
mcp: MCPClientManager = Depends(get_mcp_client),
|
||||
) -> None:
|
||||
"""Manually reset a circuit breaker."""
|
||||
logger.info(
|
||||
"Circuit breaker reset by user %s for server %s",
|
||||
current_user.id,
|
||||
server_name,
|
||||
)
|
||||
|
||||
success = await mcp.reset_circuit_breaker(server_name)
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"No circuit breaker found for server: {server_name}",
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/servers/{server_name}/reconnect",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
summary="Reconnect to Server (Admin Only)",
|
||||
description="Force reconnection to an MCP server.",
|
||||
)
|
||||
async def reconnect_server(
|
||||
server_name: ServerNamePath,
|
||||
current_user: User = Depends(require_superuser),
|
||||
mcp: MCPClientManager = Depends(get_mcp_client),
|
||||
) -> None:
|
||||
"""Force reconnection to an MCP server."""
|
||||
logger.info(
|
||||
"Reconnect requested by user %s for server %s",
|
||||
current_user.id,
|
||||
server_name,
|
||||
)
|
||||
|
||||
try:
|
||||
await mcp.disconnect(server_name)
|
||||
await mcp.connect(server_name)
|
||||
except MCPServerNotFoundError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Server not found: {server_name}",
|
||||
) from e
|
||||
except MCPConnectionError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail=f"Failed to reconnect: {e}",
|
||||
) from e
|
||||
659
backend/app/api/routes/projects.py
Normal file
659
backend/app/api/routes/projects.py
Normal file
@@ -0,0 +1,659 @@
|
||||
# app/api/routes/projects.py
|
||||
"""
|
||||
Project management API endpoints for Syndarix.
|
||||
|
||||
These endpoints allow users to manage their AI-powered software consulting projects.
|
||||
Users can create, read, update, and manage the lifecycle of their projects.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Request, status
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.dependencies.auth import get_current_user
|
||||
from app.core.database import get_db
|
||||
from app.core.exceptions import (
|
||||
AuthorizationError,
|
||||
DuplicateError,
|
||||
ErrorCode,
|
||||
NotFoundError,
|
||||
ValidationException,
|
||||
)
|
||||
from app.crud.syndarix.project import project as project_crud
|
||||
from app.models.syndarix.enums import ProjectStatus
|
||||
from app.models.user import User
|
||||
from app.schemas.common import (
|
||||
MessageResponse,
|
||||
PaginatedResponse,
|
||||
PaginationParams,
|
||||
create_pagination_meta,
|
||||
)
|
||||
from app.schemas.syndarix.project import (
|
||||
ProjectCreate,
|
||||
ProjectResponse,
|
||||
ProjectUpdate,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Initialize rate limiter
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
# Use higher rate limits in test environment
|
||||
IS_TEST = os.getenv("IS_TEST", "False") == "True"
|
||||
RATE_MULTIPLIER = 100 if IS_TEST else 1
|
||||
|
||||
|
||||
def _build_project_response(project_data: dict[str, Any]) -> ProjectResponse:
|
||||
"""
|
||||
Build a ProjectResponse from project data dictionary.
|
||||
|
||||
Args:
|
||||
project_data: Dictionary containing project and related counts
|
||||
|
||||
Returns:
|
||||
ProjectResponse with all fields populated
|
||||
"""
|
||||
project = project_data["project"]
|
||||
return ProjectResponse(
|
||||
id=project.id,
|
||||
name=project.name,
|
||||
slug=project.slug,
|
||||
description=project.description,
|
||||
autonomy_level=project.autonomy_level,
|
||||
status=project.status,
|
||||
settings=project.settings,
|
||||
owner_id=project.owner_id,
|
||||
created_at=project.created_at,
|
||||
updated_at=project.updated_at,
|
||||
agent_count=project_data.get("agent_count", 0),
|
||||
issue_count=project_data.get("issue_count", 0),
|
||||
active_sprint_name=project_data.get("active_sprint_name"),
|
||||
)
|
||||
|
||||
|
||||
def _check_project_ownership(project: Any, current_user: User) -> None:
|
||||
"""
|
||||
Check if the current user owns the project or is a superuser.
|
||||
|
||||
Args:
|
||||
project: The project to check ownership of
|
||||
current_user: The authenticated user
|
||||
|
||||
Raises:
|
||||
AuthorizationError: If user doesn't own the project and isn't a superuser
|
||||
"""
|
||||
if not current_user.is_superuser and project.owner_id != current_user.id:
|
||||
raise AuthorizationError(
|
||||
message="You do not have permission to access this project",
|
||||
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Project CRUD Endpoints
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.post(
|
||||
"",
|
||||
response_model=ProjectResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Create Project",
|
||||
description="""
|
||||
Create a new project for the current user.
|
||||
|
||||
The project will be owned by the authenticated user.
|
||||
A unique slug is required for URL-friendly project identification.
|
||||
|
||||
**Rate Limit**: 10 requests/minute
|
||||
""",
|
||||
operation_id="create_project",
|
||||
)
|
||||
@limiter.limit(f"{10 * RATE_MULTIPLIER}/minute")
|
||||
async def create_project(
|
||||
request: Request,
|
||||
project_in: ProjectCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Create a new project.
|
||||
|
||||
The authenticated user becomes the owner of the project.
|
||||
"""
|
||||
try:
|
||||
# Set the owner to the current user
|
||||
project_data = ProjectCreate(
|
||||
name=project_in.name,
|
||||
slug=project_in.slug,
|
||||
description=project_in.description,
|
||||
autonomy_level=project_in.autonomy_level,
|
||||
status=project_in.status,
|
||||
settings=project_in.settings,
|
||||
owner_id=current_user.id,
|
||||
)
|
||||
|
||||
project = await project_crud.create(db, obj_in=project_data)
|
||||
logger.info(f"User {current_user.email} created project {project.slug}")
|
||||
|
||||
return ProjectResponse(
|
||||
id=project.id,
|
||||
name=project.name,
|
||||
slug=project.slug,
|
||||
description=project.description,
|
||||
autonomy_level=project.autonomy_level,
|
||||
status=project.status,
|
||||
settings=project.settings,
|
||||
owner_id=project.owner_id,
|
||||
created_at=project.created_at,
|
||||
updated_at=project.updated_at,
|
||||
agent_count=0,
|
||||
issue_count=0,
|
||||
active_sprint_name=None,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
error_msg = str(e)
|
||||
if "already exists" in error_msg.lower():
|
||||
logger.warning(f"Duplicate project slug attempted: {project_in.slug}")
|
||||
raise DuplicateError(
|
||||
message=error_msg,
|
||||
error_code=ErrorCode.DUPLICATE_ENTRY,
|
||||
field="slug",
|
||||
)
|
||||
logger.error(f"Error creating project: {error_msg}", exc_info=True)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error creating project: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@router.get(
|
||||
"",
|
||||
response_model=PaginatedResponse[ProjectResponse],
|
||||
summary="List Projects",
|
||||
description="""
|
||||
List projects for the current user with filtering and pagination.
|
||||
|
||||
Regular users see only their own projects.
|
||||
Superusers can see all projects by setting `all_projects=true`.
|
||||
|
||||
**Rate Limit**: 30 requests/minute
|
||||
""",
|
||||
operation_id="list_projects",
|
||||
)
|
||||
@limiter.limit(f"{30 * RATE_MULTIPLIER}/minute")
|
||||
async def list_projects(
|
||||
request: Request,
|
||||
pagination: PaginationParams = Depends(),
|
||||
status_filter: ProjectStatus | None = Query(
|
||||
None, alias="status", description="Filter by project status"
|
||||
),
|
||||
search: str | None = Query(
|
||||
None, description="Search by name, slug, or description"
|
||||
),
|
||||
all_projects: bool = Query(False, description="Show all projects (superuser only)"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
List projects with filtering, search, and pagination.
|
||||
|
||||
Regular users only see their own projects.
|
||||
Superusers can view all projects if all_projects is true.
|
||||
"""
|
||||
try:
|
||||
# Determine owner filter based on user role and request
|
||||
owner_id = (
|
||||
None if (current_user.is_superuser and all_projects) else current_user.id
|
||||
)
|
||||
|
||||
projects_data, total = await project_crud.get_multi_with_counts(
|
||||
db,
|
||||
skip=pagination.offset,
|
||||
limit=pagination.limit,
|
||||
status=status_filter,
|
||||
owner_id=owner_id,
|
||||
search=search,
|
||||
)
|
||||
|
||||
# Build response objects
|
||||
project_responses = [_build_project_response(data) for data in projects_data]
|
||||
|
||||
pagination_meta = create_pagination_meta(
|
||||
total=total,
|
||||
page=pagination.page,
|
||||
limit=pagination.limit,
|
||||
items_count=len(project_responses),
|
||||
)
|
||||
|
||||
return PaginatedResponse(data=project_responses, pagination=pagination_meta)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing projects: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{project_id}",
|
||||
response_model=ProjectResponse,
|
||||
summary="Get Project",
|
||||
description="""
|
||||
Get detailed information about a specific project.
|
||||
|
||||
Users can only access their own projects unless they are superusers.
|
||||
|
||||
**Rate Limit**: 60 requests/minute
|
||||
""",
|
||||
operation_id="get_project",
|
||||
)
|
||||
@limiter.limit(f"{60 * RATE_MULTIPLIER}/minute")
|
||||
async def get_project(
|
||||
request: Request,
|
||||
project_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Get detailed information about a project by ID.
|
||||
|
||||
Includes agent count, issue count, and active sprint name.
|
||||
"""
|
||||
try:
|
||||
project_data = await project_crud.get_with_counts(db, project_id=project_id)
|
||||
|
||||
if not project_data:
|
||||
raise NotFoundError(
|
||||
message=f"Project {project_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
project = project_data["project"]
|
||||
_check_project_ownership(project, current_user)
|
||||
|
||||
return _build_project_response(project_data)
|
||||
|
||||
except (NotFoundError, AuthorizationError):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting project {project_id}: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@router.get(
|
||||
"/slug/{slug}",
|
||||
response_model=ProjectResponse,
|
||||
summary="Get Project by Slug",
|
||||
description="""
|
||||
Get detailed information about a project by its slug.
|
||||
|
||||
Users can only access their own projects unless they are superusers.
|
||||
|
||||
**Rate Limit**: 60 requests/minute
|
||||
""",
|
||||
operation_id="get_project_by_slug",
|
||||
)
|
||||
@limiter.limit(f"{60 * RATE_MULTIPLIER}/minute")
|
||||
async def get_project_by_slug(
|
||||
request: Request,
|
||||
slug: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Get detailed information about a project by slug.
|
||||
|
||||
Includes agent count, issue count, and active sprint name.
|
||||
"""
|
||||
try:
|
||||
project = await project_crud.get_by_slug(db, slug=slug)
|
||||
|
||||
if not project:
|
||||
raise NotFoundError(
|
||||
message=f"Project with slug '{slug}' not found",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
_check_project_ownership(project, current_user)
|
||||
|
||||
# Get project with counts
|
||||
project_data = await project_crud.get_with_counts(db, project_id=project.id)
|
||||
|
||||
if not project_data:
|
||||
raise NotFoundError(
|
||||
message=f"Project with slug '{slug}' not found",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
return _build_project_response(project_data)
|
||||
|
||||
except (NotFoundError, AuthorizationError):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting project by slug {slug}: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/{project_id}",
|
||||
response_model=ProjectResponse,
|
||||
summary="Update Project",
|
||||
description="""
|
||||
Update an existing project.
|
||||
|
||||
Only the project owner or a superuser can update a project.
|
||||
Only provided fields will be updated.
|
||||
|
||||
**Rate Limit**: 20 requests/minute
|
||||
""",
|
||||
operation_id="update_project",
|
||||
)
|
||||
@limiter.limit(f"{20 * RATE_MULTIPLIER}/minute")
|
||||
async def update_project(
|
||||
request: Request,
|
||||
project_id: UUID,
|
||||
project_in: ProjectUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Update a project's information.
|
||||
|
||||
Only the project owner or superusers can perform updates.
|
||||
"""
|
||||
try:
|
||||
project = await project_crud.get(db, id=project_id)
|
||||
|
||||
if not project:
|
||||
raise NotFoundError(
|
||||
message=f"Project {project_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
_check_project_ownership(project, current_user)
|
||||
|
||||
# Update the project
|
||||
updated_project = await project_crud.update(
|
||||
db, db_obj=project, obj_in=project_in
|
||||
)
|
||||
logger.info(f"User {current_user.email} updated project {updated_project.slug}")
|
||||
|
||||
# Get updated project with counts
|
||||
project_data = await project_crud.get_with_counts(
|
||||
db, project_id=updated_project.id
|
||||
)
|
||||
|
||||
if not project_data:
|
||||
# This shouldn't happen, but handle gracefully
|
||||
raise NotFoundError(
|
||||
message=f"Project {project_id} not found after update",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
return _build_project_response(project_data)
|
||||
|
||||
except (NotFoundError, AuthorizationError):
|
||||
raise
|
||||
except ValueError as e:
|
||||
error_msg = str(e)
|
||||
if "already exists" in error_msg.lower():
|
||||
logger.warning(f"Duplicate project slug attempted: {project_in.slug}")
|
||||
raise DuplicateError(
|
||||
message=error_msg,
|
||||
error_code=ErrorCode.DUPLICATE_ENTRY,
|
||||
field="slug",
|
||||
)
|
||||
logger.error(f"Error updating project: {error_msg}", exc_info=True)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating project {project_id}: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/{project_id}",
|
||||
response_model=MessageResponse,
|
||||
summary="Archive Project",
|
||||
description="""
|
||||
Archive a project (soft delete).
|
||||
|
||||
Only the project owner or a superuser can archive a project.
|
||||
Archived projects are not deleted but are no longer accessible for active work.
|
||||
|
||||
**Rate Limit**: 10 requests/minute
|
||||
""",
|
||||
operation_id="archive_project",
|
||||
)
|
||||
@limiter.limit(f"{10 * RATE_MULTIPLIER}/minute")
|
||||
async def archive_project(
|
||||
request: Request,
|
||||
project_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Archive a project by setting its status to ARCHIVED.
|
||||
|
||||
This is a soft delete operation. The project data is preserved.
|
||||
"""
|
||||
try:
|
||||
project = await project_crud.get(db, id=project_id)
|
||||
|
||||
if not project:
|
||||
raise NotFoundError(
|
||||
message=f"Project {project_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
_check_project_ownership(project, current_user)
|
||||
|
||||
# Check if project is already archived
|
||||
if project.status == ProjectStatus.ARCHIVED:
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message=f"Project '{project.name}' is already archived",
|
||||
)
|
||||
|
||||
archived_project = await project_crud.archive_project(db, project_id=project_id)
|
||||
|
||||
if not archived_project:
|
||||
raise NotFoundError(
|
||||
message=f"Failed to archive project {project_id}",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
logger.info(f"User {current_user.email} archived project {project.slug}")
|
||||
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message=f"Project '{archived_project.name}' has been archived",
|
||||
)
|
||||
|
||||
except (NotFoundError, AuthorizationError):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error archiving project {project_id}: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Project Lifecycle Endpoints
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{project_id}/pause",
|
||||
response_model=ProjectResponse,
|
||||
summary="Pause Project",
|
||||
description="""
|
||||
Pause an active project.
|
||||
|
||||
Only ACTIVE projects can be paused.
|
||||
Only the project owner or a superuser can pause a project.
|
||||
|
||||
**Rate Limit**: 10 requests/minute
|
||||
""",
|
||||
operation_id="pause_project",
|
||||
)
|
||||
@limiter.limit(f"{10 * RATE_MULTIPLIER}/minute")
|
||||
async def pause_project(
|
||||
request: Request,
|
||||
project_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Pause an active project.
|
||||
|
||||
Sets the project status to PAUSED. Only ACTIVE projects can be paused.
|
||||
"""
|
||||
try:
|
||||
project = await project_crud.get(db, id=project_id)
|
||||
|
||||
if not project:
|
||||
raise NotFoundError(
|
||||
message=f"Project {project_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
_check_project_ownership(project, current_user)
|
||||
|
||||
# Validate current status (business logic validation, not authorization)
|
||||
if project.status == ProjectStatus.PAUSED:
|
||||
raise ValidationException(
|
||||
message="Project is already paused",
|
||||
error_code=ErrorCode.VALIDATION_ERROR,
|
||||
field="status",
|
||||
)
|
||||
|
||||
if project.status == ProjectStatus.ARCHIVED:
|
||||
raise ValidationException(
|
||||
message="Cannot pause an archived project",
|
||||
error_code=ErrorCode.VALIDATION_ERROR,
|
||||
field="status",
|
||||
)
|
||||
|
||||
if project.status == ProjectStatus.COMPLETED:
|
||||
raise ValidationException(
|
||||
message="Cannot pause a completed project",
|
||||
error_code=ErrorCode.VALIDATION_ERROR,
|
||||
field="status",
|
||||
)
|
||||
|
||||
# Update status to PAUSED
|
||||
updated_project = await project_crud.update(
|
||||
db, db_obj=project, obj_in=ProjectUpdate(status=ProjectStatus.PAUSED)
|
||||
)
|
||||
logger.info(f"User {current_user.email} paused project {project.slug}")
|
||||
|
||||
# Get project with counts
|
||||
project_data = await project_crud.get_with_counts(
|
||||
db, project_id=updated_project.id
|
||||
)
|
||||
|
||||
if not project_data:
|
||||
raise NotFoundError(
|
||||
message=f"Project {project_id} not found after update",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
return _build_project_response(project_data)
|
||||
|
||||
except (NotFoundError, AuthorizationError, ValidationException):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error pausing project {project_id}: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{project_id}/resume",
|
||||
response_model=ProjectResponse,
|
||||
summary="Resume Project",
|
||||
description="""
|
||||
Resume a paused project.
|
||||
|
||||
Only PAUSED projects can be resumed.
|
||||
Only the project owner or a superuser can resume a project.
|
||||
|
||||
**Rate Limit**: 10 requests/minute
|
||||
""",
|
||||
operation_id="resume_project",
|
||||
)
|
||||
@limiter.limit(f"{10 * RATE_MULTIPLIER}/minute")
|
||||
async def resume_project(
|
||||
request: Request,
|
||||
project_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Resume a paused project.
|
||||
|
||||
Sets the project status back to ACTIVE. Only PAUSED projects can be resumed.
|
||||
"""
|
||||
try:
|
||||
project = await project_crud.get(db, id=project_id)
|
||||
|
||||
if not project:
|
||||
raise NotFoundError(
|
||||
message=f"Project {project_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
_check_project_ownership(project, current_user)
|
||||
|
||||
# Validate current status (business logic validation, not authorization)
|
||||
if project.status == ProjectStatus.ACTIVE:
|
||||
raise ValidationException(
|
||||
message="Project is already active",
|
||||
error_code=ErrorCode.VALIDATION_ERROR,
|
||||
field="status",
|
||||
)
|
||||
|
||||
if project.status == ProjectStatus.ARCHIVED:
|
||||
raise ValidationException(
|
||||
message="Cannot resume an archived project",
|
||||
error_code=ErrorCode.VALIDATION_ERROR,
|
||||
field="status",
|
||||
)
|
||||
|
||||
if project.status == ProjectStatus.COMPLETED:
|
||||
raise ValidationException(
|
||||
message="Cannot resume a completed project",
|
||||
error_code=ErrorCode.VALIDATION_ERROR,
|
||||
field="status",
|
||||
)
|
||||
|
||||
# Update status to ACTIVE
|
||||
updated_project = await project_crud.update(
|
||||
db, db_obj=project, obj_in=ProjectUpdate(status=ProjectStatus.ACTIVE)
|
||||
)
|
||||
logger.info(f"User {current_user.email} resumed project {project.slug}")
|
||||
|
||||
# Get project with counts
|
||||
project_data = await project_crud.get_with_counts(
|
||||
db, project_id=updated_project.id
|
||||
)
|
||||
|
||||
if not project_data:
|
||||
raise NotFoundError(
|
||||
message=f"Project {project_id} not found after update",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
return _build_project_response(project_data)
|
||||
|
||||
except (NotFoundError, AuthorizationError, ValidationException):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error resuming project {project_id}: {e!s}", exc_info=True)
|
||||
raise
|
||||
1186
backend/app/api/routes/sprints.py
Normal file
1186
backend/app/api/routes/sprints.py
Normal file
File diff suppressed because it is too large
Load Diff
116
backend/app/celery_app.py
Normal file
116
backend/app/celery_app.py
Normal file
@@ -0,0 +1,116 @@
|
||||
# app/celery_app.py
|
||||
"""
|
||||
Celery application configuration for Syndarix.
|
||||
|
||||
This module configures the Celery app for background task processing:
|
||||
- Agent execution tasks (LLM calls, tool execution)
|
||||
- Git operations (clone, commit, push, PR creation)
|
||||
- Issue synchronization with external trackers
|
||||
- Workflow state management
|
||||
- Cost tracking and budget monitoring
|
||||
|
||||
Architecture:
|
||||
- Redis as message broker and result backend
|
||||
- Queue routing for task isolation
|
||||
- JSON serialization for cross-language compatibility
|
||||
- Beat scheduler for periodic tasks
|
||||
"""
|
||||
|
||||
from celery import Celery
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
# Create Celery application instance
|
||||
celery_app = Celery(
|
||||
"syndarix",
|
||||
broker=settings.celery_broker_url,
|
||||
backend=settings.celery_result_backend,
|
||||
)
|
||||
|
||||
# Define task queues with their own exchanges and routing keys
|
||||
TASK_QUEUES = {
|
||||
"agent": {"exchange": "agent", "routing_key": "agent"},
|
||||
"git": {"exchange": "git", "routing_key": "git"},
|
||||
"sync": {"exchange": "sync", "routing_key": "sync"},
|
||||
"default": {"exchange": "default", "routing_key": "default"},
|
||||
}
|
||||
|
||||
# Configure Celery
|
||||
celery_app.conf.update(
|
||||
# Serialization
|
||||
task_serializer="json",
|
||||
accept_content=["json"],
|
||||
result_serializer="json",
|
||||
# Timezone
|
||||
timezone="UTC",
|
||||
enable_utc=True,
|
||||
# Task imports for auto-discovery
|
||||
imports=("app.tasks",),
|
||||
# Default queue
|
||||
task_default_queue="default",
|
||||
# Task queues configuration
|
||||
task_queues=TASK_QUEUES,
|
||||
# Task routing - route tasks to appropriate queues
|
||||
task_routes={
|
||||
"app.tasks.agent.*": {"queue": "agent"},
|
||||
"app.tasks.git.*": {"queue": "git"},
|
||||
"app.tasks.sync.*": {"queue": "sync"},
|
||||
"app.tasks.*": {"queue": "default"},
|
||||
},
|
||||
# Time limits per ADR-003
|
||||
task_soft_time_limit=300, # 5 minutes soft limit
|
||||
task_time_limit=600, # 10 minutes hard limit
|
||||
# Result expiration - 24 hours
|
||||
result_expires=86400,
|
||||
# Broker connection retry
|
||||
broker_connection_retry_on_startup=True,
|
||||
# Retry configuration per ADR-003 (built-in retry with backoff)
|
||||
task_autoretry_for=(Exception,), # Retry on all exceptions
|
||||
task_retry_kwargs={"max_retries": 3, "countdown": 5}, # Initial 5s delay
|
||||
task_retry_backoff=True, # Enable exponential backoff
|
||||
task_retry_backoff_max=600, # Max 10 minutes between retries
|
||||
task_retry_jitter=True, # Add jitter to prevent thundering herd
|
||||
# Beat schedule for periodic tasks
|
||||
beat_schedule={
|
||||
# Cost aggregation every hour per ADR-012
|
||||
"aggregate-daily-costs": {
|
||||
"task": "app.tasks.cost.aggregate_daily_costs",
|
||||
"schedule": 3600.0, # 1 hour in seconds
|
||||
},
|
||||
# Reset daily budget counters at midnight UTC
|
||||
"reset-daily-budget-counters": {
|
||||
"task": "app.tasks.cost.reset_daily_budget_counters",
|
||||
"schedule": 86400.0, # 24 hours in seconds
|
||||
},
|
||||
# Check for stale workflows every 5 minutes
|
||||
"recover-stale-workflows": {
|
||||
"task": "app.tasks.workflow.recover_stale_workflows",
|
||||
"schedule": 300.0, # 5 minutes in seconds
|
||||
},
|
||||
# Incremental issue sync every minute per ADR-011
|
||||
"sync-issues-incremental": {
|
||||
"task": "app.tasks.sync.sync_issues_incremental",
|
||||
"schedule": 60.0, # 1 minute in seconds
|
||||
},
|
||||
# Full issue reconciliation every 15 minutes per ADR-011
|
||||
"sync-issues-full": {
|
||||
"task": "app.tasks.sync.sync_issues_full",
|
||||
"schedule": 900.0, # 15 minutes in seconds
|
||||
},
|
||||
},
|
||||
# Task execution settings
|
||||
task_acks_late=True, # Acknowledge tasks after execution
|
||||
task_reject_on_worker_lost=True, # Reject tasks if worker dies
|
||||
worker_prefetch_multiplier=1, # Fair task distribution
|
||||
)
|
||||
|
||||
# Auto-discover tasks from task modules
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"app.tasks.agent",
|
||||
"app.tasks.git",
|
||||
"app.tasks.sync",
|
||||
"app.tasks.workflow",
|
||||
"app.tasks.cost",
|
||||
]
|
||||
)
|
||||
@@ -5,7 +5,7 @@ from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
PROJECT_NAME: str = "PragmaStack"
|
||||
PROJECT_NAME: str = "Syndarix"
|
||||
VERSION: str = "1.0.0"
|
||||
API_V1_STR: str = "/api/v1"
|
||||
|
||||
@@ -39,6 +39,32 @@ class Settings(BaseSettings):
|
||||
db_pool_timeout: int = 30 # Seconds to wait for a connection
|
||||
db_pool_recycle: int = 3600 # Recycle connections after 1 hour
|
||||
|
||||
# Redis configuration (Syndarix: cache, pub/sub, Celery broker)
|
||||
REDIS_URL: str = Field(
|
||||
default="redis://localhost:6379/0",
|
||||
description="Redis URL for cache, pub/sub, and Celery broker",
|
||||
)
|
||||
|
||||
# Celery configuration (Syndarix: background task processing)
|
||||
CELERY_BROKER_URL: str | None = Field(
|
||||
default=None,
|
||||
description="Celery broker URL (defaults to REDIS_URL if not set)",
|
||||
)
|
||||
CELERY_RESULT_BACKEND: str | None = Field(
|
||||
default=None,
|
||||
description="Celery result backend URL (defaults to REDIS_URL if not set)",
|
||||
)
|
||||
|
||||
@property
|
||||
def celery_broker_url(self) -> str:
|
||||
"""Get Celery broker URL, defaulting to Redis."""
|
||||
return self.CELERY_BROKER_URL or self.REDIS_URL
|
||||
|
||||
@property
|
||||
def celery_result_backend(self) -> str:
|
||||
"""Get Celery result backend URL, defaulting to Redis."""
|
||||
return self.CELERY_RESULT_BACKEND or self.REDIS_URL
|
||||
|
||||
# SQL debugging (disable in production)
|
||||
sql_echo: bool = False # Log SQL statements
|
||||
sql_echo_pool: bool = False # Log connection pool events
|
||||
|
||||
474
backend/app/core/redis.py
Normal file
474
backend/app/core/redis.py
Normal file
@@ -0,0 +1,474 @@
|
||||
# app/core/redis.py
|
||||
"""
|
||||
Redis client configuration for caching and pub/sub.
|
||||
|
||||
This module provides async Redis connectivity with connection pooling
|
||||
for FastAPI endpoints and background tasks.
|
||||
|
||||
Features:
|
||||
- Connection pooling for efficient resource usage
|
||||
- Cache operations (get, set, delete, expire)
|
||||
- Pub/sub operations (publish, subscribe)
|
||||
- Health check for monitoring
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
|
||||
from redis.asyncio import ConnectionPool, Redis
|
||||
from redis.asyncio.client import PubSub
|
||||
from redis.exceptions import ConnectionError, RedisError, TimeoutError
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
# Configure logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default TTL for cache entries (1 hour)
|
||||
DEFAULT_CACHE_TTL = 3600
|
||||
|
||||
# Connection pool settings
|
||||
POOL_MAX_CONNECTIONS = 50
|
||||
POOL_TIMEOUT = 10 # seconds
|
||||
|
||||
|
||||
class RedisClient:
|
||||
"""
|
||||
Async Redis client with connection pooling.
|
||||
|
||||
Provides high-level operations for caching and pub/sub
|
||||
with proper error handling and connection management.
|
||||
"""
|
||||
|
||||
def __init__(self, url: str | None = None) -> None:
|
||||
"""
|
||||
Initialize Redis client.
|
||||
|
||||
Args:
|
||||
url: Redis connection URL. Defaults to settings.REDIS_URL.
|
||||
"""
|
||||
self._url = url or settings.REDIS_URL
|
||||
self._pool: ConnectionPool | None = None
|
||||
self._client: Redis | None = None
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def _ensure_pool(self) -> ConnectionPool:
|
||||
"""Ensure connection pool is initialized (thread-safe)."""
|
||||
if self._pool is None:
|
||||
async with self._lock:
|
||||
# Double-check after acquiring lock
|
||||
if self._pool is None:
|
||||
self._pool = ConnectionPool.from_url(
|
||||
self._url,
|
||||
max_connections=POOL_MAX_CONNECTIONS,
|
||||
socket_timeout=POOL_TIMEOUT,
|
||||
socket_connect_timeout=POOL_TIMEOUT,
|
||||
decode_responses=True,
|
||||
health_check_interval=30,
|
||||
)
|
||||
logger.info("Redis connection pool initialized")
|
||||
return self._pool
|
||||
|
||||
async def _get_client(self) -> Redis:
|
||||
"""Get Redis client instance from pool."""
|
||||
pool = await self._ensure_pool()
|
||||
if self._client is None:
|
||||
self._client = Redis(connection_pool=pool)
|
||||
return self._client
|
||||
|
||||
# =========================================================================
|
||||
# Cache Operations
|
||||
# =========================================================================
|
||||
|
||||
async def cache_get(self, key: str) -> str | None:
|
||||
"""
|
||||
Get a value from cache.
|
||||
|
||||
Args:
|
||||
key: Cache key.
|
||||
|
||||
Returns:
|
||||
Cached value or None if not found.
|
||||
"""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
value = await client.get(key)
|
||||
if value is not None:
|
||||
logger.debug(f"Cache hit for key: {key}")
|
||||
else:
|
||||
logger.debug(f"Cache miss for key: {key}")
|
||||
return value
|
||||
except (ConnectionError, TimeoutError) as e:
|
||||
logger.error(f"Redis cache_get failed for key '{key}': {e}")
|
||||
return None
|
||||
except RedisError as e:
|
||||
logger.error(f"Redis error in cache_get for key '{key}': {e}")
|
||||
return None
|
||||
|
||||
async def cache_get_json(self, key: str) -> Any | None:
|
||||
"""
|
||||
Get a JSON-serialized value from cache.
|
||||
|
||||
Args:
|
||||
key: Cache key.
|
||||
|
||||
Returns:
|
||||
Deserialized value or None if not found.
|
||||
"""
|
||||
value = await self.cache_get(key)
|
||||
if value is not None:
|
||||
try:
|
||||
return json.loads(value)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to decode JSON for key '{key}': {e}")
|
||||
return None
|
||||
return None
|
||||
|
||||
async def cache_set(
|
||||
self,
|
||||
key: str,
|
||||
value: str,
|
||||
ttl: int | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Set a value in cache.
|
||||
|
||||
Args:
|
||||
key: Cache key.
|
||||
value: Value to cache.
|
||||
ttl: Time-to-live in seconds. Defaults to DEFAULT_CACHE_TTL.
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise.
|
||||
"""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
ttl = ttl if ttl is not None else DEFAULT_CACHE_TTL
|
||||
await client.set(key, value, ex=ttl)
|
||||
logger.debug(f"Cache set for key: {key} (TTL: {ttl}s)")
|
||||
return True
|
||||
except (ConnectionError, TimeoutError) as e:
|
||||
logger.error(f"Redis cache_set failed for key '{key}': {e}")
|
||||
return False
|
||||
except RedisError as e:
|
||||
logger.error(f"Redis error in cache_set for key '{key}': {e}")
|
||||
return False
|
||||
|
||||
async def cache_set_json(
|
||||
self,
|
||||
key: str,
|
||||
value: Any,
|
||||
ttl: int | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Set a JSON-serialized value in cache.
|
||||
|
||||
Args:
|
||||
key: Cache key.
|
||||
value: Value to serialize and cache.
|
||||
ttl: Time-to-live in seconds.
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise.
|
||||
"""
|
||||
try:
|
||||
serialized = json.dumps(value)
|
||||
return await self.cache_set(key, serialized, ttl)
|
||||
except (TypeError, ValueError) as e:
|
||||
logger.error(f"Failed to serialize value for key '{key}': {e}")
|
||||
return False
|
||||
|
||||
async def cache_delete(self, key: str) -> bool:
|
||||
"""
|
||||
Delete a key from cache.
|
||||
|
||||
Args:
|
||||
key: Cache key to delete.
|
||||
|
||||
Returns:
|
||||
True if key was deleted, False otherwise.
|
||||
"""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
result = await client.delete(key)
|
||||
logger.debug(f"Cache delete for key: {key} (deleted: {result > 0})")
|
||||
return result > 0
|
||||
except (ConnectionError, TimeoutError) as e:
|
||||
logger.error(f"Redis cache_delete failed for key '{key}': {e}")
|
||||
return False
|
||||
except RedisError as e:
|
||||
logger.error(f"Redis error in cache_delete for key '{key}': {e}")
|
||||
return False
|
||||
|
||||
async def cache_delete_pattern(self, pattern: str) -> int:
|
||||
"""
|
||||
Delete all keys matching a pattern.
|
||||
|
||||
Args:
|
||||
pattern: Glob-style pattern (e.g., "user:*").
|
||||
|
||||
Returns:
|
||||
Number of keys deleted.
|
||||
"""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
deleted = 0
|
||||
async for key in client.scan_iter(pattern):
|
||||
await client.delete(key)
|
||||
deleted += 1
|
||||
logger.debug(f"Cache delete pattern '{pattern}': {deleted} keys deleted")
|
||||
return deleted
|
||||
except (ConnectionError, TimeoutError) as e:
|
||||
logger.error(f"Redis cache_delete_pattern failed for '{pattern}': {e}")
|
||||
return 0
|
||||
except RedisError as e:
|
||||
logger.error(f"Redis error in cache_delete_pattern for '{pattern}': {e}")
|
||||
return 0
|
||||
|
||||
async def cache_expire(self, key: str, ttl: int) -> bool:
|
||||
"""
|
||||
Set or update TTL for a key.
|
||||
|
||||
Args:
|
||||
key: Cache key.
|
||||
ttl: New TTL in seconds.
|
||||
|
||||
Returns:
|
||||
True if TTL was set, False if key doesn't exist.
|
||||
"""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
result = await client.expire(key, ttl)
|
||||
logger.debug(
|
||||
f"Cache expire for key: {key} (TTL: {ttl}s, success: {result})"
|
||||
)
|
||||
return result
|
||||
except (ConnectionError, TimeoutError) as e:
|
||||
logger.error(f"Redis cache_expire failed for key '{key}': {e}")
|
||||
return False
|
||||
except RedisError as e:
|
||||
logger.error(f"Redis error in cache_expire for key '{key}': {e}")
|
||||
return False
|
||||
|
||||
async def cache_exists(self, key: str) -> bool:
|
||||
"""
|
||||
Check if a key exists in cache.
|
||||
|
||||
Args:
|
||||
key: Cache key.
|
||||
|
||||
Returns:
|
||||
True if key exists, False otherwise.
|
||||
"""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
result = await client.exists(key)
|
||||
return result > 0
|
||||
except (ConnectionError, TimeoutError) as e:
|
||||
logger.error(f"Redis cache_exists failed for key '{key}': {e}")
|
||||
return False
|
||||
except RedisError as e:
|
||||
logger.error(f"Redis error in cache_exists for key '{key}': {e}")
|
||||
return False
|
||||
|
||||
async def cache_ttl(self, key: str) -> int:
|
||||
"""
|
||||
Get remaining TTL for a key.
|
||||
|
||||
Args:
|
||||
key: Cache key.
|
||||
|
||||
Returns:
|
||||
TTL in seconds, -1 if no TTL, -2 if key doesn't exist.
|
||||
"""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
return await client.ttl(key)
|
||||
except (ConnectionError, TimeoutError) as e:
|
||||
logger.error(f"Redis cache_ttl failed for key '{key}': {e}")
|
||||
return -2
|
||||
except RedisError as e:
|
||||
logger.error(f"Redis error in cache_ttl for key '{key}': {e}")
|
||||
return -2
|
||||
|
||||
# =========================================================================
|
||||
# Pub/Sub Operations
|
||||
# =========================================================================
|
||||
|
||||
async def publish(self, channel: str, message: str | dict) -> int:
|
||||
"""
|
||||
Publish a message to a channel.
|
||||
|
||||
Args:
|
||||
channel: Channel name.
|
||||
message: Message to publish (string or dict for JSON serialization).
|
||||
|
||||
Returns:
|
||||
Number of subscribers that received the message.
|
||||
"""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
if isinstance(message, dict):
|
||||
message = json.dumps(message)
|
||||
result = await client.publish(channel, message)
|
||||
logger.debug(f"Published to channel '{channel}': {result} subscribers")
|
||||
return result
|
||||
except (ConnectionError, TimeoutError) as e:
|
||||
logger.error(f"Redis publish failed for channel '{channel}': {e}")
|
||||
return 0
|
||||
except RedisError as e:
|
||||
logger.error(f"Redis error in publish for channel '{channel}': {e}")
|
||||
return 0
|
||||
|
||||
@asynccontextmanager
|
||||
async def subscribe(self, *channels: str) -> AsyncGenerator[PubSub, None]:
|
||||
"""
|
||||
Subscribe to one or more channels.
|
||||
|
||||
Usage:
|
||||
async with redis_client.subscribe("channel1", "channel2") as pubsub:
|
||||
async for message in pubsub.listen():
|
||||
if message["type"] == "message":
|
||||
print(message["data"])
|
||||
|
||||
Args:
|
||||
channels: Channel names to subscribe to.
|
||||
|
||||
Yields:
|
||||
PubSub instance for receiving messages.
|
||||
"""
|
||||
client = await self._get_client()
|
||||
pubsub = client.pubsub()
|
||||
try:
|
||||
await pubsub.subscribe(*channels)
|
||||
logger.debug(f"Subscribed to channels: {channels}")
|
||||
yield pubsub
|
||||
finally:
|
||||
await pubsub.unsubscribe(*channels)
|
||||
await pubsub.close()
|
||||
logger.debug(f"Unsubscribed from channels: {channels}")
|
||||
|
||||
@asynccontextmanager
|
||||
async def psubscribe(self, *patterns: str) -> AsyncGenerator[PubSub, None]:
|
||||
"""
|
||||
Subscribe to channels matching patterns.
|
||||
|
||||
Usage:
|
||||
async with redis_client.psubscribe("user:*") as pubsub:
|
||||
async for message in pubsub.listen():
|
||||
if message["type"] == "pmessage":
|
||||
print(message["pattern"], message["channel"], message["data"])
|
||||
|
||||
Args:
|
||||
patterns: Glob-style patterns to subscribe to.
|
||||
|
||||
Yields:
|
||||
PubSub instance for receiving messages.
|
||||
"""
|
||||
client = await self._get_client()
|
||||
pubsub = client.pubsub()
|
||||
try:
|
||||
await pubsub.psubscribe(*patterns)
|
||||
logger.debug(f"Pattern subscribed: {patterns}")
|
||||
yield pubsub
|
||||
finally:
|
||||
await pubsub.punsubscribe(*patterns)
|
||||
await pubsub.close()
|
||||
logger.debug(f"Pattern unsubscribed: {patterns}")
|
||||
|
||||
# =========================================================================
|
||||
# Health & Connection Management
|
||||
# =========================================================================
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""
|
||||
Check if Redis connection is healthy.
|
||||
|
||||
Returns:
|
||||
True if connection is successful, False otherwise.
|
||||
"""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
result = await client.ping()
|
||||
return result is True
|
||||
except (ConnectionError, TimeoutError) as e:
|
||||
logger.error(f"Redis health check failed: {e}")
|
||||
return False
|
||||
except RedisError as e:
|
||||
logger.error(f"Redis health check error: {e}")
|
||||
return False
|
||||
|
||||
async def close(self) -> None:
|
||||
"""
|
||||
Close Redis connections and cleanup resources.
|
||||
|
||||
Should be called during application shutdown.
|
||||
"""
|
||||
if self._client:
|
||||
await self._client.close()
|
||||
self._client = None
|
||||
logger.debug("Redis client closed")
|
||||
|
||||
if self._pool:
|
||||
await self._pool.disconnect()
|
||||
self._pool = None
|
||||
logger.info("Redis connection pool closed")
|
||||
|
||||
async def get_pool_info(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get connection pool statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with pool information.
|
||||
"""
|
||||
if self._pool is None:
|
||||
return {"status": "not_initialized"}
|
||||
|
||||
return {
|
||||
"status": "active",
|
||||
"max_connections": POOL_MAX_CONNECTIONS,
|
||||
"url": self._url.split("@")[-1] if "@" in self._url else self._url,
|
||||
}
|
||||
|
||||
|
||||
# Global Redis client instance
|
||||
redis_client = RedisClient()
|
||||
|
||||
|
||||
# FastAPI dependency for Redis client
|
||||
async def get_redis() -> AsyncGenerator[RedisClient, None]:
|
||||
"""
|
||||
FastAPI dependency that provides the Redis client.
|
||||
|
||||
Usage:
|
||||
@router.get("/cached-data")
|
||||
async def get_data(redis: RedisClient = Depends(get_redis)):
|
||||
cached = await redis.cache_get("my-key")
|
||||
...
|
||||
"""
|
||||
yield redis_client
|
||||
|
||||
|
||||
# Health check function for use in /health endpoint
|
||||
async def check_redis_health() -> bool:
|
||||
"""
|
||||
Check if Redis connection is healthy.
|
||||
|
||||
Returns:
|
||||
True if connection is successful, False otherwise.
|
||||
"""
|
||||
return await redis_client.health_check()
|
||||
|
||||
|
||||
# Cleanup function for application shutdown
|
||||
async def close_redis() -> None:
|
||||
"""
|
||||
Close Redis connections.
|
||||
|
||||
Should be called during application shutdown.
|
||||
"""
|
||||
await redis_client.close()
|
||||
20
backend/app/crud/syndarix/__init__.py
Normal file
20
backend/app/crud/syndarix/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# app/crud/syndarix/__init__.py
|
||||
"""
|
||||
Syndarix CRUD operations.
|
||||
|
||||
This package contains CRUD operations for all Syndarix domain entities.
|
||||
"""
|
||||
|
||||
from .agent_instance import agent_instance
|
||||
from .agent_type import agent_type
|
||||
from .issue import issue
|
||||
from .project import project
|
||||
from .sprint import sprint
|
||||
|
||||
__all__ = [
|
||||
"agent_instance",
|
||||
"agent_type",
|
||||
"issue",
|
||||
"project",
|
||||
"sprint",
|
||||
]
|
||||
394
backend/app/crud/syndarix/agent_instance.py
Normal file
394
backend/app/crud/syndarix/agent_instance.py
Normal file
@@ -0,0 +1,394 @@
|
||||
# app/crud/syndarix/agent_instance.py
|
||||
"""Async CRUD operations for AgentInstance model using SQLAlchemy 2.0 patterns."""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from decimal import Decimal
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func, select, update
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
from app.crud.base import CRUDBase
|
||||
from app.models.syndarix import AgentInstance, Issue
|
||||
from app.models.syndarix.enums import AgentStatus
|
||||
from app.schemas.syndarix import AgentInstanceCreate, AgentInstanceUpdate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CRUDAgentInstance(
|
||||
CRUDBase[AgentInstance, AgentInstanceCreate, AgentInstanceUpdate]
|
||||
):
|
||||
"""Async CRUD operations for AgentInstance model."""
|
||||
|
||||
async def create(
|
||||
self, db: AsyncSession, *, obj_in: AgentInstanceCreate
|
||||
) -> AgentInstance:
|
||||
"""Create a new agent instance with error handling."""
|
||||
try:
|
||||
db_obj = AgentInstance(
|
||||
agent_type_id=obj_in.agent_type_id,
|
||||
project_id=obj_in.project_id,
|
||||
name=obj_in.name,
|
||||
status=obj_in.status,
|
||||
current_task=obj_in.current_task,
|
||||
short_term_memory=obj_in.short_term_memory,
|
||||
long_term_memory_ref=obj_in.long_term_memory_ref,
|
||||
session_id=obj_in.session_id,
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
logger.error(f"Integrity error creating agent instance: {error_msg}")
|
||||
raise ValueError(f"Database integrity error: {error_msg}")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
f"Unexpected error creating agent instance: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_with_details(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
instance_id: UUID,
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
Get an agent instance with full details including related entities.
|
||||
|
||||
Returns:
|
||||
Dictionary with instance and related entity details
|
||||
"""
|
||||
try:
|
||||
# Get instance with joined relationships
|
||||
result = await db.execute(
|
||||
select(AgentInstance)
|
||||
.options(
|
||||
joinedload(AgentInstance.agent_type),
|
||||
joinedload(AgentInstance.project),
|
||||
)
|
||||
.where(AgentInstance.id == instance_id)
|
||||
)
|
||||
instance = result.scalar_one_or_none()
|
||||
|
||||
if not instance:
|
||||
return None
|
||||
|
||||
# Get assigned issues count
|
||||
issues_count_result = await db.execute(
|
||||
select(func.count(Issue.id)).where(
|
||||
Issue.assigned_agent_id == instance_id
|
||||
)
|
||||
)
|
||||
assigned_issues_count = issues_count_result.scalar_one()
|
||||
|
||||
return {
|
||||
"instance": instance,
|
||||
"agent_type_name": instance.agent_type.name
|
||||
if instance.agent_type
|
||||
else None,
|
||||
"agent_type_slug": instance.agent_type.slug
|
||||
if instance.agent_type
|
||||
else None,
|
||||
"project_name": instance.project.name if instance.project else None,
|
||||
"project_slug": instance.project.slug if instance.project else None,
|
||||
"assigned_issues_count": assigned_issues_count,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting agent instance with details {instance_id}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_by_project(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
project_id: UUID,
|
||||
status: AgentStatus | None = None,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
) -> tuple[list[AgentInstance], int]:
|
||||
"""Get agent instances for a specific project."""
|
||||
try:
|
||||
query = select(AgentInstance).where(AgentInstance.project_id == project_id)
|
||||
|
||||
if status is not None:
|
||||
query = query.where(AgentInstance.status == status)
|
||||
|
||||
# Get total count
|
||||
count_query = select(func.count()).select_from(query.alias())
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# Apply pagination
|
||||
query = query.order_by(AgentInstance.created_at.desc())
|
||||
query = query.offset(skip).limit(limit)
|
||||
result = await db.execute(query)
|
||||
instances = list(result.scalars().all())
|
||||
|
||||
return instances, total
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting instances by project {project_id}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_by_agent_type(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
agent_type_id: UUID,
|
||||
status: AgentStatus | None = None,
|
||||
) -> list[AgentInstance]:
|
||||
"""Get all instances of a specific agent type."""
|
||||
try:
|
||||
query = select(AgentInstance).where(
|
||||
AgentInstance.agent_type_id == agent_type_id
|
||||
)
|
||||
|
||||
if status is not None:
|
||||
query = query.where(AgentInstance.status == status)
|
||||
|
||||
query = query.order_by(AgentInstance.created_at.desc())
|
||||
result = await db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting instances by agent type {agent_type_id}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
async def update_status(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
instance_id: UUID,
|
||||
status: AgentStatus,
|
||||
current_task: str | None = None,
|
||||
) -> AgentInstance | None:
|
||||
"""Update the status of an agent instance."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(AgentInstance).where(AgentInstance.id == instance_id)
|
||||
)
|
||||
instance = result.scalar_one_or_none()
|
||||
|
||||
if not instance:
|
||||
return None
|
||||
|
||||
instance.status = status
|
||||
instance.last_activity_at = datetime.now(UTC)
|
||||
if current_task is not None:
|
||||
instance.current_task = current_task
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(instance)
|
||||
return instance
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
f"Error updating instance status {instance_id}: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def terminate(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
instance_id: UUID,
|
||||
) -> AgentInstance | None:
|
||||
"""Terminate an agent instance.
|
||||
|
||||
Also unassigns all issues from this agent to prevent orphaned assignments.
|
||||
"""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(AgentInstance).where(AgentInstance.id == instance_id)
|
||||
)
|
||||
instance = result.scalar_one_or_none()
|
||||
|
||||
if not instance:
|
||||
return None
|
||||
|
||||
# Unassign all issues from this agent before terminating
|
||||
await db.execute(
|
||||
update(Issue)
|
||||
.where(Issue.assigned_agent_id == instance_id)
|
||||
.values(assigned_agent_id=None)
|
||||
)
|
||||
|
||||
instance.status = AgentStatus.TERMINATED
|
||||
instance.terminated_at = datetime.now(UTC)
|
||||
instance.current_task = None
|
||||
instance.session_id = None
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(instance)
|
||||
return instance
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
f"Error terminating instance {instance_id}: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def record_task_completion(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
instance_id: UUID,
|
||||
tokens_used: int,
|
||||
cost_incurred: Decimal,
|
||||
) -> AgentInstance | None:
|
||||
"""Record a completed task and update metrics.
|
||||
|
||||
Uses atomic SQL UPDATE to prevent lost updates under concurrent load.
|
||||
This avoids the read-modify-write race condition that occurs when
|
||||
multiple task completions happen simultaneously.
|
||||
"""
|
||||
try:
|
||||
now = datetime.now(UTC)
|
||||
|
||||
# Use atomic SQL UPDATE to increment counters without race conditions
|
||||
# This is safe for concurrent updates - no read-modify-write pattern
|
||||
result = await db.execute(
|
||||
update(AgentInstance)
|
||||
.where(AgentInstance.id == instance_id)
|
||||
.values(
|
||||
tasks_completed=AgentInstance.tasks_completed + 1,
|
||||
tokens_used=AgentInstance.tokens_used + tokens_used,
|
||||
cost_incurred=AgentInstance.cost_incurred + cost_incurred,
|
||||
last_activity_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
.returning(AgentInstance)
|
||||
)
|
||||
instance = result.scalar_one_or_none()
|
||||
|
||||
if not instance:
|
||||
return None
|
||||
|
||||
await db.commit()
|
||||
return instance
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
f"Error recording task completion {instance_id}: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_project_metrics(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
project_id: UUID,
|
||||
) -> dict[str, Any]:
|
||||
"""Get aggregated metrics for all agents in a project."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(
|
||||
func.count(AgentInstance.id).label("total_instances"),
|
||||
func.count(AgentInstance.id)
|
||||
.filter(AgentInstance.status == AgentStatus.WORKING)
|
||||
.label("active_instances"),
|
||||
func.count(AgentInstance.id)
|
||||
.filter(AgentInstance.status == AgentStatus.IDLE)
|
||||
.label("idle_instances"),
|
||||
func.sum(AgentInstance.tasks_completed).label("total_tasks"),
|
||||
func.sum(AgentInstance.tokens_used).label("total_tokens"),
|
||||
func.sum(AgentInstance.cost_incurred).label("total_cost"),
|
||||
).where(AgentInstance.project_id == project_id)
|
||||
)
|
||||
row = result.one()
|
||||
|
||||
return {
|
||||
"total_instances": row.total_instances or 0,
|
||||
"active_instances": row.active_instances or 0,
|
||||
"idle_instances": row.idle_instances or 0,
|
||||
"total_tasks_completed": row.total_tasks or 0,
|
||||
"total_tokens_used": row.total_tokens or 0,
|
||||
"total_cost_incurred": row.total_cost or Decimal("0.0000"),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting project metrics {project_id}: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def bulk_terminate_by_project(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
project_id: UUID,
|
||||
) -> int:
|
||||
"""Terminate all active instances in a project.
|
||||
|
||||
Also unassigns all issues from these agents to prevent orphaned assignments.
|
||||
"""
|
||||
try:
|
||||
# First, unassign all issues from agents in this project
|
||||
# Get all agent IDs that will be terminated
|
||||
agents_to_terminate = await db.execute(
|
||||
select(AgentInstance.id).where(
|
||||
AgentInstance.project_id == project_id,
|
||||
AgentInstance.status != AgentStatus.TERMINATED,
|
||||
)
|
||||
)
|
||||
agent_ids = [row[0] for row in agents_to_terminate.fetchall()]
|
||||
|
||||
# Unassign issues from these agents
|
||||
if agent_ids:
|
||||
await db.execute(
|
||||
update(Issue)
|
||||
.where(Issue.assigned_agent_id.in_(agent_ids))
|
||||
.values(assigned_agent_id=None)
|
||||
)
|
||||
|
||||
now = datetime.now(UTC)
|
||||
stmt = (
|
||||
update(AgentInstance)
|
||||
.where(
|
||||
AgentInstance.project_id == project_id,
|
||||
AgentInstance.status != AgentStatus.TERMINATED,
|
||||
)
|
||||
.values(
|
||||
status=AgentStatus.TERMINATED,
|
||||
terminated_at=now,
|
||||
current_task=None,
|
||||
session_id=None,
|
||||
updated_at=now,
|
||||
)
|
||||
)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
await db.commit()
|
||||
|
||||
terminated_count = result.rowcount
|
||||
logger.info(
|
||||
f"Bulk terminated {terminated_count} instances in project {project_id}"
|
||||
)
|
||||
return terminated_count
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
f"Error bulk terminating instances for project {project_id}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
# Create a singleton instance for use across the application
|
||||
agent_instance = CRUDAgentInstance(AgentInstance)
|
||||
265
backend/app/crud/syndarix/agent_type.py
Normal file
265
backend/app/crud/syndarix/agent_type.py
Normal file
@@ -0,0 +1,265 @@
|
||||
# app/crud/syndarix/agent_type.py
|
||||
"""Async CRUD operations for AgentType model using SQLAlchemy 2.0 patterns."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func, or_, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.crud.base import CRUDBase
|
||||
from app.models.syndarix import AgentInstance, AgentType
|
||||
from app.schemas.syndarix import AgentTypeCreate, AgentTypeUpdate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CRUDAgentType(CRUDBase[AgentType, AgentTypeCreate, AgentTypeUpdate]):
|
||||
"""Async CRUD operations for AgentType model."""
|
||||
|
||||
async def get_by_slug(self, db: AsyncSession, *, slug: str) -> AgentType | None:
|
||||
"""Get agent type by slug."""
|
||||
try:
|
||||
result = await db.execute(select(AgentType).where(AgentType.slug == slug))
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting agent type by slug {slug}: {e!s}")
|
||||
raise
|
||||
|
||||
async def create(self, db: AsyncSession, *, obj_in: AgentTypeCreate) -> AgentType:
|
||||
"""Create a new agent type with error handling."""
|
||||
try:
|
||||
db_obj = AgentType(
|
||||
name=obj_in.name,
|
||||
slug=obj_in.slug,
|
||||
description=obj_in.description,
|
||||
expertise=obj_in.expertise,
|
||||
personality_prompt=obj_in.personality_prompt,
|
||||
primary_model=obj_in.primary_model,
|
||||
fallback_models=obj_in.fallback_models,
|
||||
model_params=obj_in.model_params,
|
||||
mcp_servers=obj_in.mcp_servers,
|
||||
tool_permissions=obj_in.tool_permissions,
|
||||
is_active=obj_in.is_active,
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
if "slug" in error_msg.lower():
|
||||
logger.warning(f"Duplicate slug attempted: {obj_in.slug}")
|
||||
raise ValueError(f"Agent type with slug '{obj_in.slug}' already exists")
|
||||
logger.error(f"Integrity error creating agent type: {error_msg}")
|
||||
raise ValueError(f"Database integrity error: {error_msg}")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Unexpected error creating agent type: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_multi_with_filters(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
is_active: bool | None = None,
|
||||
search: str | None = None,
|
||||
sort_by: str = "created_at",
|
||||
sort_order: str = "desc",
|
||||
) -> tuple[list[AgentType], int]:
|
||||
"""
|
||||
Get multiple agent types with filtering, searching, and sorting.
|
||||
|
||||
Returns:
|
||||
Tuple of (agent types list, total count)
|
||||
"""
|
||||
try:
|
||||
query = select(AgentType)
|
||||
|
||||
# Apply filters
|
||||
if is_active is not None:
|
||||
query = query.where(AgentType.is_active == is_active)
|
||||
|
||||
if search:
|
||||
search_filter = or_(
|
||||
AgentType.name.ilike(f"%{search}%"),
|
||||
AgentType.slug.ilike(f"%{search}%"),
|
||||
AgentType.description.ilike(f"%{search}%"),
|
||||
)
|
||||
query = query.where(search_filter)
|
||||
|
||||
# Get total count before pagination
|
||||
count_query = select(func.count()).select_from(query.alias())
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# Apply sorting
|
||||
sort_column = getattr(AgentType, sort_by, AgentType.created_at)
|
||||
if sort_order == "desc":
|
||||
query = query.order_by(sort_column.desc())
|
||||
else:
|
||||
query = query.order_by(sort_column.asc())
|
||||
|
||||
# Apply pagination
|
||||
query = query.offset(skip).limit(limit)
|
||||
result = await db.execute(query)
|
||||
agent_types = list(result.scalars().all())
|
||||
|
||||
return agent_types, total
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting agent types with filters: {e!s}")
|
||||
raise
|
||||
|
||||
async def get_with_instance_count(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
agent_type_id: UUID,
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
Get a single agent type with its instance count.
|
||||
|
||||
Returns:
|
||||
Dictionary with agent_type and instance_count
|
||||
"""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(AgentType).where(AgentType.id == agent_type_id)
|
||||
)
|
||||
agent_type = result.scalar_one_or_none()
|
||||
|
||||
if not agent_type:
|
||||
return None
|
||||
|
||||
# Get instance count
|
||||
count_result = await db.execute(
|
||||
select(func.count(AgentInstance.id)).where(
|
||||
AgentInstance.agent_type_id == agent_type_id
|
||||
)
|
||||
)
|
||||
instance_count = count_result.scalar_one()
|
||||
|
||||
return {
|
||||
"agent_type": agent_type,
|
||||
"instance_count": instance_count,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting agent type with count {agent_type_id}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_multi_with_instance_counts(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
is_active: bool | None = None,
|
||||
search: str | None = None,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""
|
||||
Get agent types with instance counts in optimized queries.
|
||||
|
||||
Returns:
|
||||
Tuple of (list of dicts with agent_type and instance_count, total count)
|
||||
"""
|
||||
try:
|
||||
# Get filtered agent types
|
||||
agent_types, total = await self.get_multi_with_filters(
|
||||
db,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
is_active=is_active,
|
||||
search=search,
|
||||
)
|
||||
|
||||
if not agent_types:
|
||||
return [], 0
|
||||
|
||||
agent_type_ids = [at.id for at in agent_types]
|
||||
|
||||
# Get instance counts in bulk
|
||||
counts_result = await db.execute(
|
||||
select(
|
||||
AgentInstance.agent_type_id,
|
||||
func.count(AgentInstance.id).label("count"),
|
||||
)
|
||||
.where(AgentInstance.agent_type_id.in_(agent_type_ids))
|
||||
.group_by(AgentInstance.agent_type_id)
|
||||
)
|
||||
counts = {row.agent_type_id: row.count for row in counts_result}
|
||||
|
||||
# Combine results
|
||||
results = [
|
||||
{
|
||||
"agent_type": agent_type,
|
||||
"instance_count": counts.get(agent_type.id, 0),
|
||||
}
|
||||
for agent_type in agent_types
|
||||
]
|
||||
|
||||
return results, total
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting agent types with counts: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_by_expertise(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
expertise: str,
|
||||
is_active: bool = True,
|
||||
) -> list[AgentType]:
|
||||
"""Get agent types that have a specific expertise."""
|
||||
try:
|
||||
# Use PostgreSQL JSONB contains operator
|
||||
query = select(AgentType).where(
|
||||
AgentType.expertise.contains([expertise.lower()]),
|
||||
AgentType.is_active == is_active,
|
||||
)
|
||||
result = await db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting agent types by expertise {expertise}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
async def deactivate(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
agent_type_id: UUID,
|
||||
) -> AgentType | None:
|
||||
"""Deactivate an agent type (soft delete)."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(AgentType).where(AgentType.id == agent_type_id)
|
||||
)
|
||||
agent_type = result.scalar_one_or_none()
|
||||
|
||||
if not agent_type:
|
||||
return None
|
||||
|
||||
agent_type.is_active = False
|
||||
await db.commit()
|
||||
await db.refresh(agent_type)
|
||||
return agent_type
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
f"Error deactivating agent type {agent_type_id}: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
# Create a singleton instance for use across the application
|
||||
agent_type = CRUDAgentType(AgentType)
|
||||
525
backend/app/crud/syndarix/issue.py
Normal file
525
backend/app/crud/syndarix/issue.py
Normal file
@@ -0,0 +1,525 @@
|
||||
# app/crud/syndarix/issue.py
|
||||
"""Async CRUD operations for Issue model using SQLAlchemy 2.0 patterns."""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func, or_, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
from app.crud.base import CRUDBase
|
||||
from app.models.syndarix import AgentInstance, Issue
|
||||
from app.models.syndarix.enums import IssuePriority, IssueStatus, SyncStatus
|
||||
from app.schemas.syndarix import IssueCreate, IssueUpdate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CRUDIssue(CRUDBase[Issue, IssueCreate, IssueUpdate]):
|
||||
"""Async CRUD operations for Issue model."""
|
||||
|
||||
async def create(self, db: AsyncSession, *, obj_in: IssueCreate) -> Issue:
|
||||
"""Create a new issue with error handling."""
|
||||
try:
|
||||
db_obj = Issue(
|
||||
project_id=obj_in.project_id,
|
||||
title=obj_in.title,
|
||||
body=obj_in.body,
|
||||
status=obj_in.status,
|
||||
priority=obj_in.priority,
|
||||
labels=obj_in.labels,
|
||||
assigned_agent_id=obj_in.assigned_agent_id,
|
||||
human_assignee=obj_in.human_assignee,
|
||||
sprint_id=obj_in.sprint_id,
|
||||
story_points=obj_in.story_points,
|
||||
external_tracker_type=obj_in.external_tracker_type,
|
||||
external_issue_id=obj_in.external_issue_id,
|
||||
remote_url=obj_in.remote_url,
|
||||
external_issue_number=obj_in.external_issue_number,
|
||||
sync_status=SyncStatus.SYNCED,
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
logger.error(f"Integrity error creating issue: {error_msg}")
|
||||
raise ValueError(f"Database integrity error: {error_msg}")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Unexpected error creating issue: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_with_details(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
issue_id: UUID,
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
Get an issue with full details including related entity names.
|
||||
|
||||
Returns:
|
||||
Dictionary with issue and related entity details
|
||||
"""
|
||||
try:
|
||||
# Get issue with joined relationships
|
||||
result = await db.execute(
|
||||
select(Issue)
|
||||
.options(
|
||||
joinedload(Issue.project),
|
||||
joinedload(Issue.sprint),
|
||||
joinedload(Issue.assigned_agent).joinedload(
|
||||
AgentInstance.agent_type
|
||||
),
|
||||
)
|
||||
.where(Issue.id == issue_id)
|
||||
)
|
||||
issue = result.scalar_one_or_none()
|
||||
|
||||
if not issue:
|
||||
return None
|
||||
|
||||
return {
|
||||
"issue": issue,
|
||||
"project_name": issue.project.name if issue.project else None,
|
||||
"project_slug": issue.project.slug if issue.project else None,
|
||||
"sprint_name": issue.sprint.name if issue.sprint else None,
|
||||
"assigned_agent_type_name": (
|
||||
issue.assigned_agent.agent_type.name
|
||||
if issue.assigned_agent and issue.assigned_agent.agent_type
|
||||
else None
|
||||
),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting issue with details {issue_id}: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_by_project(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
project_id: UUID,
|
||||
status: IssueStatus | None = None,
|
||||
priority: IssuePriority | None = None,
|
||||
sprint_id: UUID | None = None,
|
||||
assigned_agent_id: UUID | None = None,
|
||||
labels: list[str] | None = None,
|
||||
search: str | None = None,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
sort_by: str = "created_at",
|
||||
sort_order: str = "desc",
|
||||
) -> tuple[list[Issue], int]:
|
||||
"""Get issues for a specific project with filters."""
|
||||
try:
|
||||
query = select(Issue).where(Issue.project_id == project_id)
|
||||
|
||||
# Apply filters
|
||||
if status is not None:
|
||||
query = query.where(Issue.status == status)
|
||||
|
||||
if priority is not None:
|
||||
query = query.where(Issue.priority == priority)
|
||||
|
||||
if sprint_id is not None:
|
||||
query = query.where(Issue.sprint_id == sprint_id)
|
||||
|
||||
if assigned_agent_id is not None:
|
||||
query = query.where(Issue.assigned_agent_id == assigned_agent_id)
|
||||
|
||||
if labels:
|
||||
# Match any of the provided labels
|
||||
for label in labels:
|
||||
query = query.where(Issue.labels.contains([label.lower()]))
|
||||
|
||||
if search:
|
||||
search_filter = or_(
|
||||
Issue.title.ilike(f"%{search}%"),
|
||||
Issue.body.ilike(f"%{search}%"),
|
||||
)
|
||||
query = query.where(search_filter)
|
||||
|
||||
# Get total count
|
||||
count_query = select(func.count()).select_from(query.alias())
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# Apply sorting
|
||||
sort_column = getattr(Issue, sort_by, Issue.created_at)
|
||||
if sort_order == "desc":
|
||||
query = query.order_by(sort_column.desc())
|
||||
else:
|
||||
query = query.order_by(sort_column.asc())
|
||||
|
||||
# Apply pagination
|
||||
query = query.offset(skip).limit(limit)
|
||||
result = await db.execute(query)
|
||||
issues = list(result.scalars().all())
|
||||
|
||||
return issues, total
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting issues by project {project_id}: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_by_sprint(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
sprint_id: UUID,
|
||||
status: IssueStatus | None = None,
|
||||
) -> list[Issue]:
|
||||
"""Get all issues in a sprint."""
|
||||
try:
|
||||
query = select(Issue).where(Issue.sprint_id == sprint_id)
|
||||
|
||||
if status is not None:
|
||||
query = query.where(Issue.status == status)
|
||||
|
||||
query = query.order_by(Issue.priority.desc(), Issue.created_at.asc())
|
||||
result = await db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting issues by sprint {sprint_id}: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def assign_to_agent(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
issue_id: UUID,
|
||||
agent_id: UUID | None,
|
||||
) -> Issue | None:
|
||||
"""Assign an issue to an agent (or unassign if agent_id is None)."""
|
||||
try:
|
||||
result = await db.execute(select(Issue).where(Issue.id == issue_id))
|
||||
issue = result.scalar_one_or_none()
|
||||
|
||||
if not issue:
|
||||
return None
|
||||
|
||||
issue.assigned_agent_id = agent_id
|
||||
issue.human_assignee = None # Clear human assignee when assigning to agent
|
||||
await db.commit()
|
||||
await db.refresh(issue)
|
||||
return issue
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
f"Error assigning issue {issue_id} to agent {agent_id}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
async def assign_to_human(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
issue_id: UUID,
|
||||
human_assignee: str | None,
|
||||
) -> Issue | None:
|
||||
"""Assign an issue to a human (or unassign if human_assignee is None)."""
|
||||
try:
|
||||
result = await db.execute(select(Issue).where(Issue.id == issue_id))
|
||||
issue = result.scalar_one_or_none()
|
||||
|
||||
if not issue:
|
||||
return None
|
||||
|
||||
issue.human_assignee = human_assignee
|
||||
issue.assigned_agent_id = None # Clear agent when assigning to human
|
||||
await db.commit()
|
||||
await db.refresh(issue)
|
||||
return issue
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
f"Error assigning issue {issue_id} to human {human_assignee}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
async def close_issue(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
issue_id: UUID,
|
||||
) -> Issue | None:
|
||||
"""Close an issue by setting status and closed_at timestamp."""
|
||||
try:
|
||||
result = await db.execute(select(Issue).where(Issue.id == issue_id))
|
||||
issue = result.scalar_one_or_none()
|
||||
|
||||
if not issue:
|
||||
return None
|
||||
|
||||
issue.status = IssueStatus.CLOSED
|
||||
issue.closed_at = datetime.now(UTC)
|
||||
await db.commit()
|
||||
await db.refresh(issue)
|
||||
return issue
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error closing issue {issue_id}: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def reopen_issue(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
issue_id: UUID,
|
||||
) -> Issue | None:
|
||||
"""Reopen a closed issue."""
|
||||
try:
|
||||
result = await db.execute(select(Issue).where(Issue.id == issue_id))
|
||||
issue = result.scalar_one_or_none()
|
||||
|
||||
if not issue:
|
||||
return None
|
||||
|
||||
issue.status = IssueStatus.OPEN
|
||||
issue.closed_at = None
|
||||
await db.commit()
|
||||
await db.refresh(issue)
|
||||
return issue
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error reopening issue {issue_id}: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def update_sync_status(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
issue_id: UUID,
|
||||
sync_status: SyncStatus,
|
||||
last_synced_at: datetime | None = None,
|
||||
external_updated_at: datetime | None = None,
|
||||
) -> Issue | None:
|
||||
"""Update the sync status of an issue."""
|
||||
try:
|
||||
result = await db.execute(select(Issue).where(Issue.id == issue_id))
|
||||
issue = result.scalar_one_or_none()
|
||||
|
||||
if not issue:
|
||||
return None
|
||||
|
||||
issue.sync_status = sync_status
|
||||
if last_synced_at:
|
||||
issue.last_synced_at = last_synced_at
|
||||
if external_updated_at:
|
||||
issue.external_updated_at = external_updated_at
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(issue)
|
||||
return issue
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
f"Error updating sync status for issue {issue_id}: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_project_stats(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
project_id: UUID,
|
||||
) -> dict[str, Any]:
|
||||
"""Get issue statistics for a project."""
|
||||
try:
|
||||
# Get counts by status
|
||||
status_counts = await db.execute(
|
||||
select(Issue.status, func.count(Issue.id).label("count"))
|
||||
.where(Issue.project_id == project_id)
|
||||
.group_by(Issue.status)
|
||||
)
|
||||
by_status = {row.status.value: row.count for row in status_counts}
|
||||
|
||||
# Get counts by priority
|
||||
priority_counts = await db.execute(
|
||||
select(Issue.priority, func.count(Issue.id).label("count"))
|
||||
.where(Issue.project_id == project_id)
|
||||
.group_by(Issue.priority)
|
||||
)
|
||||
by_priority = {row.priority.value: row.count for row in priority_counts}
|
||||
|
||||
# Get story points
|
||||
points_result = await db.execute(
|
||||
select(
|
||||
func.sum(Issue.story_points).label("total"),
|
||||
func.sum(Issue.story_points)
|
||||
.filter(Issue.status == IssueStatus.CLOSED)
|
||||
.label("completed"),
|
||||
).where(Issue.project_id == project_id)
|
||||
)
|
||||
points_row = points_result.one()
|
||||
|
||||
total_issues = sum(by_status.values())
|
||||
|
||||
return {
|
||||
"total": total_issues,
|
||||
"open": by_status.get("open", 0),
|
||||
"in_progress": by_status.get("in_progress", 0),
|
||||
"in_review": by_status.get("in_review", 0),
|
||||
"blocked": by_status.get("blocked", 0),
|
||||
"closed": by_status.get("closed", 0),
|
||||
"by_priority": by_priority,
|
||||
"total_story_points": points_row.total,
|
||||
"completed_story_points": points_row.completed,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting issue stats for project {project_id}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_by_external_id(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
external_tracker_type: str,
|
||||
external_issue_id: str,
|
||||
) -> Issue | None:
|
||||
"""Get an issue by its external tracker ID."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(Issue).where(
|
||||
Issue.external_tracker_type == external_tracker_type,
|
||||
Issue.external_issue_id == external_issue_id,
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting issue by external ID {external_tracker_type}:{external_issue_id}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_pending_sync(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
project_id: UUID | None = None,
|
||||
limit: int = 100,
|
||||
) -> list[Issue]:
|
||||
"""Get issues that need to be synced with external tracker."""
|
||||
try:
|
||||
query = select(Issue).where(
|
||||
Issue.external_tracker_type.isnot(None),
|
||||
Issue.sync_status.in_([SyncStatus.PENDING, SyncStatus.ERROR]),
|
||||
)
|
||||
|
||||
if project_id:
|
||||
query = query.where(Issue.project_id == project_id)
|
||||
|
||||
query = query.order_by(Issue.updated_at.asc()).limit(limit)
|
||||
result = await db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting pending sync issues: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def remove_sprint_from_issues(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
sprint_id: UUID,
|
||||
) -> int:
|
||||
"""Remove sprint assignment from all issues in a sprint.
|
||||
|
||||
Used when deleting a sprint to clean up references.
|
||||
|
||||
Returns:
|
||||
Number of issues updated
|
||||
"""
|
||||
try:
|
||||
from sqlalchemy import update
|
||||
|
||||
result = await db.execute(
|
||||
update(Issue).where(Issue.sprint_id == sprint_id).values(sprint_id=None)
|
||||
)
|
||||
await db.commit()
|
||||
return result.rowcount
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
f"Error removing sprint {sprint_id} from issues: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
async def unassign(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
issue_id: UUID,
|
||||
) -> Issue | None:
|
||||
"""Remove agent assignment from an issue.
|
||||
|
||||
Returns:
|
||||
Updated issue or None if not found
|
||||
"""
|
||||
try:
|
||||
result = await db.execute(select(Issue).where(Issue.id == issue_id))
|
||||
issue = result.scalar_one_or_none()
|
||||
|
||||
if not issue:
|
||||
return None
|
||||
|
||||
issue.assigned_agent_id = None
|
||||
await db.commit()
|
||||
await db.refresh(issue)
|
||||
return issue
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error unassigning issue {issue_id}: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def remove_from_sprint(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
issue_id: UUID,
|
||||
) -> Issue | None:
|
||||
"""Remove an issue from its current sprint.
|
||||
|
||||
Returns:
|
||||
Updated issue or None if not found
|
||||
"""
|
||||
try:
|
||||
result = await db.execute(select(Issue).where(Issue.id == issue_id))
|
||||
issue = result.scalar_one_or_none()
|
||||
|
||||
if not issue:
|
||||
return None
|
||||
|
||||
issue.sprint_id = None
|
||||
await db.commit()
|
||||
await db.refresh(issue)
|
||||
return issue
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
f"Error removing issue {issue_id} from sprint: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
# Create a singleton instance for use across the application
|
||||
issue = CRUDIssue(Issue)
|
||||
362
backend/app/crud/syndarix/project.py
Normal file
362
backend/app/crud/syndarix/project.py
Normal file
@@ -0,0 +1,362 @@
|
||||
# app/crud/syndarix/project.py
|
||||
"""Async CRUD operations for Project model using SQLAlchemy 2.0 patterns."""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func, or_, select, update
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.crud.base import CRUDBase
|
||||
from app.models.syndarix import AgentInstance, Issue, Project, Sprint
|
||||
from app.models.syndarix.enums import AgentStatus, ProjectStatus, SprintStatus
|
||||
from app.schemas.syndarix import ProjectCreate, ProjectUpdate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CRUDProject(CRUDBase[Project, ProjectCreate, ProjectUpdate]):
|
||||
"""Async CRUD operations for Project model."""
|
||||
|
||||
async def get_by_slug(self, db: AsyncSession, *, slug: str) -> Project | None:
|
||||
"""Get project by slug."""
|
||||
try:
|
||||
result = await db.execute(select(Project).where(Project.slug == slug))
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting project by slug {slug}: {e!s}")
|
||||
raise
|
||||
|
||||
async def create(self, db: AsyncSession, *, obj_in: ProjectCreate) -> Project:
|
||||
"""Create a new project with error handling."""
|
||||
try:
|
||||
db_obj = Project(
|
||||
name=obj_in.name,
|
||||
slug=obj_in.slug,
|
||||
description=obj_in.description,
|
||||
autonomy_level=obj_in.autonomy_level,
|
||||
status=obj_in.status,
|
||||
settings=obj_in.settings or {},
|
||||
owner_id=obj_in.owner_id,
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
if "slug" in error_msg.lower():
|
||||
logger.warning(f"Duplicate slug attempted: {obj_in.slug}")
|
||||
raise ValueError(f"Project with slug '{obj_in.slug}' already exists")
|
||||
logger.error(f"Integrity error creating project: {error_msg}")
|
||||
raise ValueError(f"Database integrity error: {error_msg}")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Unexpected error creating project: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_multi_with_filters(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
status: ProjectStatus | None = None,
|
||||
owner_id: UUID | None = None,
|
||||
search: str | None = None,
|
||||
sort_by: str = "created_at",
|
||||
sort_order: str = "desc",
|
||||
) -> tuple[list[Project], int]:
|
||||
"""
|
||||
Get multiple projects with filtering, searching, and sorting.
|
||||
|
||||
Returns:
|
||||
Tuple of (projects list, total count)
|
||||
"""
|
||||
try:
|
||||
query = select(Project)
|
||||
|
||||
# Apply filters
|
||||
if status is not None:
|
||||
query = query.where(Project.status == status)
|
||||
|
||||
if owner_id is not None:
|
||||
query = query.where(Project.owner_id == owner_id)
|
||||
|
||||
if search:
|
||||
search_filter = or_(
|
||||
Project.name.ilike(f"%{search}%"),
|
||||
Project.slug.ilike(f"%{search}%"),
|
||||
Project.description.ilike(f"%{search}%"),
|
||||
)
|
||||
query = query.where(search_filter)
|
||||
|
||||
# Get total count before pagination
|
||||
count_query = select(func.count()).select_from(query.alias())
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# Apply sorting
|
||||
sort_column = getattr(Project, sort_by, Project.created_at)
|
||||
if sort_order == "desc":
|
||||
query = query.order_by(sort_column.desc())
|
||||
else:
|
||||
query = query.order_by(sort_column.asc())
|
||||
|
||||
# Apply pagination
|
||||
query = query.offset(skip).limit(limit)
|
||||
result = await db.execute(query)
|
||||
projects = list(result.scalars().all())
|
||||
|
||||
return projects, total
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting projects with filters: {e!s}")
|
||||
raise
|
||||
|
||||
async def get_with_counts(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
project_id: UUID,
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
Get a single project with agent and issue counts.
|
||||
|
||||
Returns:
|
||||
Dictionary with project, agent_count, issue_count, active_sprint_name
|
||||
"""
|
||||
try:
|
||||
# Get project
|
||||
result = await db.execute(select(Project).where(Project.id == project_id))
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
return None
|
||||
|
||||
# Get agent count
|
||||
agent_count_result = await db.execute(
|
||||
select(func.count(AgentInstance.id)).where(
|
||||
AgentInstance.project_id == project_id
|
||||
)
|
||||
)
|
||||
agent_count = agent_count_result.scalar_one()
|
||||
|
||||
# Get issue count
|
||||
issue_count_result = await db.execute(
|
||||
select(func.count(Issue.id)).where(Issue.project_id == project_id)
|
||||
)
|
||||
issue_count = issue_count_result.scalar_one()
|
||||
|
||||
# Get active sprint name
|
||||
active_sprint_result = await db.execute(
|
||||
select(Sprint.name).where(
|
||||
Sprint.project_id == project_id,
|
||||
Sprint.status == SprintStatus.ACTIVE,
|
||||
)
|
||||
)
|
||||
active_sprint_name = active_sprint_result.scalar_one_or_none()
|
||||
|
||||
return {
|
||||
"project": project,
|
||||
"agent_count": agent_count,
|
||||
"issue_count": issue_count,
|
||||
"active_sprint_name": active_sprint_name,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting project with counts {project_id}: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_multi_with_counts(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
status: ProjectStatus | None = None,
|
||||
owner_id: UUID | None = None,
|
||||
search: str | None = None,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""
|
||||
Get projects with agent/issue counts in optimized queries.
|
||||
|
||||
Returns:
|
||||
Tuple of (list of dicts with project and counts, total count)
|
||||
"""
|
||||
try:
|
||||
# Get filtered projects
|
||||
projects, total = await self.get_multi_with_filters(
|
||||
db,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
status=status,
|
||||
owner_id=owner_id,
|
||||
search=search,
|
||||
)
|
||||
|
||||
if not projects:
|
||||
return [], 0
|
||||
|
||||
project_ids = [p.id for p in projects]
|
||||
|
||||
# Get agent counts in bulk
|
||||
agent_counts_result = await db.execute(
|
||||
select(
|
||||
AgentInstance.project_id,
|
||||
func.count(AgentInstance.id).label("count"),
|
||||
)
|
||||
.where(AgentInstance.project_id.in_(project_ids))
|
||||
.group_by(AgentInstance.project_id)
|
||||
)
|
||||
agent_counts = {row.project_id: row.count for row in agent_counts_result}
|
||||
|
||||
# Get issue counts in bulk
|
||||
issue_counts_result = await db.execute(
|
||||
select(
|
||||
Issue.project_id,
|
||||
func.count(Issue.id).label("count"),
|
||||
)
|
||||
.where(Issue.project_id.in_(project_ids))
|
||||
.group_by(Issue.project_id)
|
||||
)
|
||||
issue_counts = {row.project_id: row.count for row in issue_counts_result}
|
||||
|
||||
# Get active sprint names
|
||||
active_sprints_result = await db.execute(
|
||||
select(Sprint.project_id, Sprint.name).where(
|
||||
Sprint.project_id.in_(project_ids),
|
||||
Sprint.status == SprintStatus.ACTIVE,
|
||||
)
|
||||
)
|
||||
active_sprints = {row.project_id: row.name for row in active_sprints_result}
|
||||
|
||||
# Combine results
|
||||
results = [
|
||||
{
|
||||
"project": project,
|
||||
"agent_count": agent_counts.get(project.id, 0),
|
||||
"issue_count": issue_counts.get(project.id, 0),
|
||||
"active_sprint_name": active_sprints.get(project.id),
|
||||
}
|
||||
for project in projects
|
||||
]
|
||||
|
||||
return results, total
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting projects with counts: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_projects_by_owner(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
owner_id: UUID,
|
||||
status: ProjectStatus | None = None,
|
||||
) -> list[Project]:
|
||||
"""Get all projects owned by a specific user."""
|
||||
try:
|
||||
query = select(Project).where(Project.owner_id == owner_id)
|
||||
|
||||
if status is not None:
|
||||
query = query.where(Project.status == status)
|
||||
|
||||
query = query.order_by(Project.created_at.desc())
|
||||
result = await db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting projects by owner {owner_id}: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def archive_project(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
project_id: UUID,
|
||||
) -> Project | None:
|
||||
"""Archive a project by setting status to ARCHIVED.
|
||||
|
||||
This also performs cascading cleanup:
|
||||
- Terminates all active agent instances
|
||||
- Cancels all planned/active sprints
|
||||
- Unassigns issues from terminated agents
|
||||
"""
|
||||
try:
|
||||
result = await db.execute(select(Project).where(Project.id == project_id))
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
return None
|
||||
|
||||
now = datetime.now(UTC)
|
||||
|
||||
# 1. Get all agent IDs that will be terminated
|
||||
agents_to_terminate = await db.execute(
|
||||
select(AgentInstance.id).where(
|
||||
AgentInstance.project_id == project_id,
|
||||
AgentInstance.status != AgentStatus.TERMINATED,
|
||||
)
|
||||
)
|
||||
agent_ids = [row[0] for row in agents_to_terminate.fetchall()]
|
||||
|
||||
# 2. Unassign issues from these agents to prevent orphaned assignments
|
||||
if agent_ids:
|
||||
await db.execute(
|
||||
update(Issue)
|
||||
.where(Issue.assigned_agent_id.in_(agent_ids))
|
||||
.values(assigned_agent_id=None)
|
||||
)
|
||||
|
||||
# 3. Terminate all active agents
|
||||
await db.execute(
|
||||
update(AgentInstance)
|
||||
.where(
|
||||
AgentInstance.project_id == project_id,
|
||||
AgentInstance.status != AgentStatus.TERMINATED,
|
||||
)
|
||||
.values(
|
||||
status=AgentStatus.TERMINATED,
|
||||
terminated_at=now,
|
||||
current_task=None,
|
||||
session_id=None,
|
||||
updated_at=now,
|
||||
)
|
||||
)
|
||||
|
||||
# 4. Cancel all planned/active sprints
|
||||
await db.execute(
|
||||
update(Sprint)
|
||||
.where(
|
||||
Sprint.project_id == project_id,
|
||||
Sprint.status.in_([SprintStatus.PLANNED, SprintStatus.ACTIVE]),
|
||||
)
|
||||
.values(
|
||||
status=SprintStatus.CANCELLED,
|
||||
updated_at=now,
|
||||
)
|
||||
)
|
||||
|
||||
# 5. Archive the project
|
||||
project.status = ProjectStatus.ARCHIVED
|
||||
await db.commit()
|
||||
await db.refresh(project)
|
||||
|
||||
logger.info(
|
||||
f"Archived project {project_id}: terminated agents={len(agent_ids)}"
|
||||
)
|
||||
|
||||
return project
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error archiving project {project_id}: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
# Create a singleton instance for use across the application
|
||||
project = CRUDProject(Project)
|
||||
439
backend/app/crud/syndarix/sprint.py
Normal file
439
backend/app/crud/syndarix/sprint.py
Normal file
@@ -0,0 +1,439 @@
|
||||
# app/crud/syndarix/sprint.py
|
||||
"""Async CRUD operations for Sprint model using SQLAlchemy 2.0 patterns."""
|
||||
|
||||
import logging
|
||||
from datetime import date
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
from app.crud.base import CRUDBase
|
||||
from app.models.syndarix import Issue, Sprint
|
||||
from app.models.syndarix.enums import IssueStatus, SprintStatus
|
||||
from app.schemas.syndarix import SprintCreate, SprintUpdate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CRUDSprint(CRUDBase[Sprint, SprintCreate, SprintUpdate]):
|
||||
"""Async CRUD operations for Sprint model."""
|
||||
|
||||
async def create(self, db: AsyncSession, *, obj_in: SprintCreate) -> Sprint:
|
||||
"""Create a new sprint with error handling."""
|
||||
try:
|
||||
db_obj = Sprint(
|
||||
project_id=obj_in.project_id,
|
||||
name=obj_in.name,
|
||||
number=obj_in.number,
|
||||
goal=obj_in.goal,
|
||||
start_date=obj_in.start_date,
|
||||
end_date=obj_in.end_date,
|
||||
status=obj_in.status,
|
||||
planned_points=obj_in.planned_points,
|
||||
velocity=obj_in.velocity,
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
logger.error(f"Integrity error creating sprint: {error_msg}")
|
||||
raise ValueError(f"Database integrity error: {error_msg}")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Unexpected error creating sprint: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_with_details(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
sprint_id: UUID,
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
Get a sprint with full details including issue counts.
|
||||
|
||||
Returns:
|
||||
Dictionary with sprint and related details
|
||||
"""
|
||||
try:
|
||||
# Get sprint with joined project
|
||||
result = await db.execute(
|
||||
select(Sprint)
|
||||
.options(joinedload(Sprint.project))
|
||||
.where(Sprint.id == sprint_id)
|
||||
)
|
||||
sprint = result.scalar_one_or_none()
|
||||
|
||||
if not sprint:
|
||||
return None
|
||||
|
||||
# Get issue counts
|
||||
issue_counts = await db.execute(
|
||||
select(
|
||||
func.count(Issue.id).label("total"),
|
||||
func.count(Issue.id)
|
||||
.filter(Issue.status == IssueStatus.OPEN)
|
||||
.label("open"),
|
||||
func.count(Issue.id)
|
||||
.filter(Issue.status == IssueStatus.CLOSED)
|
||||
.label("completed"),
|
||||
).where(Issue.sprint_id == sprint_id)
|
||||
)
|
||||
counts = issue_counts.one()
|
||||
|
||||
return {
|
||||
"sprint": sprint,
|
||||
"project_name": sprint.project.name if sprint.project else None,
|
||||
"project_slug": sprint.project.slug if sprint.project else None,
|
||||
"issue_count": counts.total,
|
||||
"open_issues": counts.open,
|
||||
"completed_issues": counts.completed,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting sprint with details {sprint_id}: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_by_project(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
project_id: UUID,
|
||||
status: SprintStatus | None = None,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
) -> tuple[list[Sprint], int]:
|
||||
"""Get sprints for a specific project."""
|
||||
try:
|
||||
query = select(Sprint).where(Sprint.project_id == project_id)
|
||||
|
||||
if status is not None:
|
||||
query = query.where(Sprint.status == status)
|
||||
|
||||
# Get total count
|
||||
count_query = select(func.count()).select_from(query.alias())
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# Apply sorting (by number descending - newest first)
|
||||
query = query.order_by(Sprint.number.desc())
|
||||
query = query.offset(skip).limit(limit)
|
||||
result = await db.execute(query)
|
||||
sprints = list(result.scalars().all())
|
||||
|
||||
return sprints, total
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting sprints by project {project_id}: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_active_sprint(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
project_id: UUID,
|
||||
) -> Sprint | None:
|
||||
"""Get the currently active sprint for a project."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(Sprint).where(
|
||||
Sprint.project_id == project_id,
|
||||
Sprint.status == SprintStatus.ACTIVE,
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting active sprint for project {project_id}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_next_sprint_number(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
project_id: UUID,
|
||||
) -> int:
|
||||
"""Get the next sprint number for a project."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(func.max(Sprint.number)).where(Sprint.project_id == project_id)
|
||||
)
|
||||
max_number = result.scalar_one_or_none()
|
||||
return (max_number or 0) + 1
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting next sprint number for project {project_id}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
async def start_sprint(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
sprint_id: UUID,
|
||||
start_date: date | None = None,
|
||||
) -> Sprint | None:
|
||||
"""Start a planned sprint.
|
||||
|
||||
Uses row-level locking (SELECT FOR UPDATE) to prevent race conditions
|
||||
when multiple requests try to start sprints concurrently.
|
||||
"""
|
||||
try:
|
||||
# Lock the sprint row to prevent concurrent modifications
|
||||
result = await db.execute(
|
||||
select(Sprint).where(Sprint.id == sprint_id).with_for_update()
|
||||
)
|
||||
sprint = result.scalar_one_or_none()
|
||||
|
||||
if not sprint:
|
||||
return None
|
||||
|
||||
if sprint.status != SprintStatus.PLANNED:
|
||||
raise ValueError(
|
||||
f"Cannot start sprint with status {sprint.status.value}"
|
||||
)
|
||||
|
||||
# Check for existing active sprint with lock to prevent race condition
|
||||
# Lock all sprints for this project to ensure atomic check-and-update
|
||||
active_check = await db.execute(
|
||||
select(Sprint)
|
||||
.where(
|
||||
Sprint.project_id == sprint.project_id,
|
||||
Sprint.status == SprintStatus.ACTIVE,
|
||||
)
|
||||
.with_for_update()
|
||||
)
|
||||
active_sprint = active_check.scalar_one_or_none()
|
||||
if active_sprint:
|
||||
raise ValueError(
|
||||
f"Project already has an active sprint: {active_sprint.name}"
|
||||
)
|
||||
|
||||
sprint.status = SprintStatus.ACTIVE
|
||||
if start_date:
|
||||
sprint.start_date = start_date
|
||||
|
||||
# Calculate planned points from issues
|
||||
points_result = await db.execute(
|
||||
select(func.sum(Issue.story_points)).where(Issue.sprint_id == sprint_id)
|
||||
)
|
||||
sprint.planned_points = points_result.scalar_one_or_none() or 0
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(sprint)
|
||||
return sprint
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error starting sprint {sprint_id}: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def complete_sprint(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
sprint_id: UUID,
|
||||
) -> Sprint | None:
|
||||
"""Complete an active sprint and calculate completed points.
|
||||
|
||||
Uses row-level locking (SELECT FOR UPDATE) to prevent race conditions
|
||||
when velocity is being calculated and other operations might modify issues.
|
||||
"""
|
||||
try:
|
||||
# Lock the sprint row to prevent concurrent modifications
|
||||
result = await db.execute(
|
||||
select(Sprint).where(Sprint.id == sprint_id).with_for_update()
|
||||
)
|
||||
sprint = result.scalar_one_or_none()
|
||||
|
||||
if not sprint:
|
||||
return None
|
||||
|
||||
if sprint.status != SprintStatus.ACTIVE:
|
||||
raise ValueError(
|
||||
f"Cannot complete sprint with status {sprint.status.value}"
|
||||
)
|
||||
|
||||
sprint.status = SprintStatus.COMPLETED
|
||||
|
||||
# Calculate velocity (completed points) from closed issues
|
||||
# Note: Issues are not locked, but sprint lock ensures this sprint's
|
||||
# completion is atomic and prevents concurrent completion attempts
|
||||
points_result = await db.execute(
|
||||
select(func.sum(Issue.story_points)).where(
|
||||
Issue.sprint_id == sprint_id,
|
||||
Issue.status == IssueStatus.CLOSED,
|
||||
)
|
||||
)
|
||||
sprint.velocity = points_result.scalar_one_or_none() or 0
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(sprint)
|
||||
return sprint
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error completing sprint {sprint_id}: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def cancel_sprint(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
sprint_id: UUID,
|
||||
) -> Sprint | None:
|
||||
"""Cancel a sprint (only PLANNED or ACTIVE sprints can be cancelled).
|
||||
|
||||
Uses row-level locking to prevent race conditions with concurrent
|
||||
sprint status modifications.
|
||||
"""
|
||||
try:
|
||||
# Lock the sprint row to prevent concurrent modifications
|
||||
result = await db.execute(
|
||||
select(Sprint).where(Sprint.id == sprint_id).with_for_update()
|
||||
)
|
||||
sprint = result.scalar_one_or_none()
|
||||
|
||||
if not sprint:
|
||||
return None
|
||||
|
||||
if sprint.status not in [SprintStatus.PLANNED, SprintStatus.ACTIVE]:
|
||||
raise ValueError(
|
||||
f"Cannot cancel sprint with status {sprint.status.value}"
|
||||
)
|
||||
|
||||
sprint.status = SprintStatus.CANCELLED
|
||||
await db.commit()
|
||||
await db.refresh(sprint)
|
||||
return sprint
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error cancelling sprint {sprint_id}: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_velocity(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
project_id: UUID,
|
||||
limit: int = 5,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get velocity data for completed sprints."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(Sprint)
|
||||
.where(
|
||||
Sprint.project_id == project_id,
|
||||
Sprint.status == SprintStatus.COMPLETED,
|
||||
)
|
||||
.order_by(Sprint.number.desc())
|
||||
.limit(limit)
|
||||
)
|
||||
sprints = list(result.scalars().all())
|
||||
|
||||
velocity_data = []
|
||||
for sprint in reversed(sprints): # Return in chronological order
|
||||
velocity_ratio = None
|
||||
if sprint.planned_points and sprint.planned_points > 0:
|
||||
velocity_ratio = (sprint.velocity or 0) / sprint.planned_points
|
||||
velocity_data.append(
|
||||
{
|
||||
"sprint_number": sprint.number,
|
||||
"sprint_name": sprint.name,
|
||||
"planned_points": sprint.planned_points,
|
||||
"velocity": sprint.velocity,
|
||||
"velocity_ratio": velocity_ratio,
|
||||
}
|
||||
)
|
||||
|
||||
return velocity_data
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting velocity for project {project_id}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_sprints_with_issue_counts(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
project_id: UUID,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""Get sprints with issue counts in optimized queries."""
|
||||
try:
|
||||
# Get sprints
|
||||
sprints, total = await self.get_by_project(
|
||||
db, project_id=project_id, skip=skip, limit=limit
|
||||
)
|
||||
|
||||
if not sprints:
|
||||
return [], 0
|
||||
|
||||
sprint_ids = [s.id for s in sprints]
|
||||
|
||||
# Get issue counts in bulk
|
||||
issue_counts = await db.execute(
|
||||
select(
|
||||
Issue.sprint_id,
|
||||
func.count(Issue.id).label("total"),
|
||||
func.count(Issue.id)
|
||||
.filter(Issue.status == IssueStatus.OPEN)
|
||||
.label("open"),
|
||||
func.count(Issue.id)
|
||||
.filter(Issue.status == IssueStatus.CLOSED)
|
||||
.label("completed"),
|
||||
)
|
||||
.where(Issue.sprint_id.in_(sprint_ids))
|
||||
.group_by(Issue.sprint_id)
|
||||
)
|
||||
counts_map = {
|
||||
row.sprint_id: {
|
||||
"issue_count": row.total,
|
||||
"open_issues": row.open,
|
||||
"completed_issues": row.completed,
|
||||
}
|
||||
for row in issue_counts
|
||||
}
|
||||
|
||||
# Combine results
|
||||
results = [
|
||||
{
|
||||
"sprint": sprint,
|
||||
**counts_map.get(
|
||||
sprint.id,
|
||||
{"issue_count": 0, "open_issues": 0, "completed_issues": 0},
|
||||
),
|
||||
}
|
||||
for sprint in sprints
|
||||
]
|
||||
|
||||
return results, total
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting sprints with counts for project {project_id}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
# Create a singleton instance for use across the application
|
||||
sprint = CRUDSprint(Sprint)
|
||||
@@ -18,13 +18,26 @@ from .oauth_provider_token import OAuthConsent, OAuthProviderRefreshToken
|
||||
from .oauth_state import OAuthState
|
||||
from .organization import Organization
|
||||
|
||||
# Syndarix domain models
|
||||
from .syndarix import (
|
||||
AgentInstance,
|
||||
AgentType,
|
||||
Issue,
|
||||
Project,
|
||||
Sprint,
|
||||
)
|
||||
|
||||
# Import models
|
||||
from .user import User
|
||||
from .user_organization import OrganizationRole, UserOrganization
|
||||
from .user_session import UserSession
|
||||
|
||||
__all__ = [
|
||||
# Syndarix models
|
||||
"AgentInstance",
|
||||
"AgentType",
|
||||
"Base",
|
||||
"Issue",
|
||||
"OAuthAccount",
|
||||
"OAuthAuthorizationCode",
|
||||
"OAuthClient",
|
||||
@@ -33,6 +46,8 @@ __all__ = [
|
||||
"OAuthState",
|
||||
"Organization",
|
||||
"OrganizationRole",
|
||||
"Project",
|
||||
"Sprint",
|
||||
"TimestampMixin",
|
||||
"UUIDMixin",
|
||||
"User",
|
||||
|
||||
47
backend/app/models/syndarix/__init__.py
Normal file
47
backend/app/models/syndarix/__init__.py
Normal file
@@ -0,0 +1,47 @@
|
||||
# app/models/syndarix/__init__.py
|
||||
"""
|
||||
Syndarix domain models.
|
||||
|
||||
This package contains all the core entities for the Syndarix AI consulting platform:
|
||||
- Project: Client engagements with autonomy settings
|
||||
- AgentType: Templates for AI agent capabilities
|
||||
- AgentInstance: Spawned agents working on projects
|
||||
- Issue: Units of work with external tracker sync
|
||||
- Sprint: Time-boxed iterations for organizing work
|
||||
"""
|
||||
|
||||
from .agent_instance import AgentInstance
|
||||
from .agent_type import AgentType
|
||||
from .enums import (
|
||||
AgentStatus,
|
||||
AutonomyLevel,
|
||||
ClientMode,
|
||||
IssuePriority,
|
||||
IssueStatus,
|
||||
IssueType,
|
||||
ProjectComplexity,
|
||||
ProjectStatus,
|
||||
SprintStatus,
|
||||
SyncStatus,
|
||||
)
|
||||
from .issue import Issue
|
||||
from .project import Project
|
||||
from .sprint import Sprint
|
||||
|
||||
__all__ = [
|
||||
"AgentInstance",
|
||||
"AgentStatus",
|
||||
"AgentType",
|
||||
"AutonomyLevel",
|
||||
"ClientMode",
|
||||
"Issue",
|
||||
"IssuePriority",
|
||||
"IssueStatus",
|
||||
"IssueType",
|
||||
"Project",
|
||||
"ProjectComplexity",
|
||||
"ProjectStatus",
|
||||
"Sprint",
|
||||
"SprintStatus",
|
||||
"SyncStatus",
|
||||
]
|
||||
111
backend/app/models/syndarix/agent_instance.py
Normal file
111
backend/app/models/syndarix/agent_instance.py
Normal file
@@ -0,0 +1,111 @@
|
||||
# app/models/syndarix/agent_instance.py
|
||||
"""
|
||||
AgentInstance model for Syndarix AI consulting platform.
|
||||
|
||||
An AgentInstance is a spawned instance of an AgentType, assigned to a
|
||||
specific project to perform work.
|
||||
"""
|
||||
|
||||
from sqlalchemy import (
|
||||
BigInteger,
|
||||
Column,
|
||||
DateTime,
|
||||
Enum,
|
||||
ForeignKey,
|
||||
Index,
|
||||
Integer,
|
||||
Numeric,
|
||||
String,
|
||||
Text,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import (
|
||||
JSONB,
|
||||
UUID as PGUUID,
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.models.base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
from .enums import AgentStatus
|
||||
|
||||
|
||||
class AgentInstance(Base, UUIDMixin, TimestampMixin):
|
||||
"""
|
||||
AgentInstance model representing a spawned agent working on a project.
|
||||
|
||||
Tracks:
|
||||
- Current status and task
|
||||
- Memory (short-term in DB, long-term reference to vector store)
|
||||
- Session information for MCP connections
|
||||
- Usage metrics (tasks completed, tokens, cost)
|
||||
"""
|
||||
|
||||
__tablename__ = "agent_instances"
|
||||
|
||||
# Foreign keys
|
||||
agent_type_id = Column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("agent_types.id", ondelete="RESTRICT"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
project_id = Column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("projects.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Agent instance name (e.g., "Dave", "Eve") for personality
|
||||
name = Column(String(100), nullable=False, index=True)
|
||||
|
||||
# Status tracking
|
||||
status: Column[AgentStatus] = Column(
|
||||
Enum(AgentStatus),
|
||||
default=AgentStatus.IDLE,
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Current task description (brief summary of what agent is doing)
|
||||
current_task = Column(Text, nullable=True)
|
||||
|
||||
# Short-term memory stored in database (conversation context, recent decisions)
|
||||
short_term_memory = Column(JSONB, default=dict, nullable=False)
|
||||
|
||||
# Reference to long-term memory in vector store (e.g., "project-123/agent-456")
|
||||
long_term_memory_ref = Column(String(500), nullable=True)
|
||||
|
||||
# Session ID for active MCP connections
|
||||
session_id = Column(String(255), nullable=True, index=True)
|
||||
|
||||
# Activity tracking
|
||||
last_activity_at = Column(DateTime(timezone=True), nullable=True, index=True)
|
||||
terminated_at = Column(DateTime(timezone=True), nullable=True, index=True)
|
||||
|
||||
# Usage metrics
|
||||
tasks_completed = Column(Integer, default=0, nullable=False)
|
||||
tokens_used = Column(BigInteger, default=0, nullable=False)
|
||||
cost_incurred = Column(Numeric(precision=10, scale=4), default=0, nullable=False)
|
||||
|
||||
# Relationships
|
||||
agent_type = relationship("AgentType", back_populates="instances")
|
||||
project = relationship("Project", back_populates="agent_instances")
|
||||
assigned_issues = relationship(
|
||||
"Issue",
|
||||
back_populates="assigned_agent",
|
||||
foreign_keys="Issue.assigned_agent_id",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_agent_instances_project_status", "project_id", "status"),
|
||||
Index("ix_agent_instances_type_status", "agent_type_id", "status"),
|
||||
Index("ix_agent_instances_project_type", "project_id", "agent_type_id"),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<AgentInstance {self.name} ({self.id}) type={self.agent_type_id} "
|
||||
f"project={self.project_id} status={self.status.value}>"
|
||||
)
|
||||
72
backend/app/models/syndarix/agent_type.py
Normal file
72
backend/app/models/syndarix/agent_type.py
Normal file
@@ -0,0 +1,72 @@
|
||||
# app/models/syndarix/agent_type.py
|
||||
"""
|
||||
AgentType model for Syndarix AI consulting platform.
|
||||
|
||||
An AgentType is a template that defines the capabilities, personality,
|
||||
and model configuration for agent instances.
|
||||
"""
|
||||
|
||||
from sqlalchemy import Boolean, Column, Index, String, Text
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.models.base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
|
||||
class AgentType(Base, UUIDMixin, TimestampMixin):
|
||||
"""
|
||||
AgentType model representing a template for agent instances.
|
||||
|
||||
Each agent type defines:
|
||||
- Expertise areas and personality prompt
|
||||
- Model configuration (primary, fallback, parameters)
|
||||
- MCP server access and tool permissions
|
||||
|
||||
Examples: ProductOwner, Architect, BackendEngineer, QAEngineer
|
||||
"""
|
||||
|
||||
__tablename__ = "agent_types"
|
||||
|
||||
name = Column(String(255), nullable=False, index=True)
|
||||
slug = Column(String(255), unique=True, nullable=False, index=True)
|
||||
description = Column(Text, nullable=True)
|
||||
|
||||
# Areas of expertise for this agent type (e.g., ["python", "fastapi", "databases"])
|
||||
expertise = Column(JSONB, default=list, nullable=False)
|
||||
|
||||
# System prompt defining the agent's personality and behavior
|
||||
personality_prompt = Column(Text, nullable=False)
|
||||
|
||||
# Primary LLM model to use (e.g., "claude-opus-4-5-20251101")
|
||||
primary_model = Column(String(100), nullable=False)
|
||||
|
||||
# Fallback models in order of preference
|
||||
fallback_models = Column(JSONB, default=list, nullable=False)
|
||||
|
||||
# Model parameters (temperature, max_tokens, etc.)
|
||||
model_params = Column(JSONB, default=dict, nullable=False)
|
||||
|
||||
# List of MCP servers this agent can connect to
|
||||
mcp_servers = Column(JSONB, default=list, nullable=False)
|
||||
|
||||
# Tool permissions configuration
|
||||
# Structure: {"allowed": ["*"], "denied": [], "require_approval": ["gitea:create_pr"]}
|
||||
tool_permissions = Column(JSONB, default=dict, nullable=False)
|
||||
|
||||
# Whether this agent type is available for new instances
|
||||
is_active = Column(Boolean, default=True, nullable=False, index=True)
|
||||
|
||||
# Relationships
|
||||
instances = relationship(
|
||||
"AgentInstance",
|
||||
back_populates="agent_type",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_agent_types_slug_active", "slug", "is_active"),
|
||||
Index("ix_agent_types_name_active", "name", "is_active"),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<AgentType {self.name} ({self.slug}) active={self.is_active}>"
|
||||
169
backend/app/models/syndarix/enums.py
Normal file
169
backend/app/models/syndarix/enums.py
Normal file
@@ -0,0 +1,169 @@
|
||||
# app/models/syndarix/enums.py
|
||||
"""
|
||||
Enums for Syndarix domain models.
|
||||
|
||||
These enums represent the core state machines and categorizations
|
||||
used throughout the Syndarix AI consulting platform.
|
||||
"""
|
||||
|
||||
from enum import Enum as PyEnum
|
||||
|
||||
|
||||
class AutonomyLevel(str, PyEnum):
|
||||
"""
|
||||
Defines how much control the human has over agent actions.
|
||||
|
||||
FULL_CONTROL: Human must approve every agent action
|
||||
MILESTONE: Human approves at sprint boundaries and major decisions
|
||||
AUTONOMOUS: Agents work independently, only escalating critical issues
|
||||
"""
|
||||
|
||||
FULL_CONTROL = "full_control"
|
||||
MILESTONE = "milestone"
|
||||
AUTONOMOUS = "autonomous"
|
||||
|
||||
|
||||
class ProjectComplexity(str, PyEnum):
|
||||
"""
|
||||
Project complexity level for estimation and planning.
|
||||
|
||||
SCRIPT: Simple automation or script-level work
|
||||
SIMPLE: Straightforward feature or fix
|
||||
MEDIUM: Standard complexity with some architectural considerations
|
||||
COMPLEX: Large-scale feature requiring significant design work
|
||||
"""
|
||||
|
||||
SCRIPT = "script"
|
||||
SIMPLE = "simple"
|
||||
MEDIUM = "medium"
|
||||
COMPLEX = "complex"
|
||||
|
||||
|
||||
class ClientMode(str, PyEnum):
|
||||
"""
|
||||
How the client prefers to interact with agents.
|
||||
|
||||
TECHNICAL: Client is technical and prefers detailed updates
|
||||
AUTO: Agents automatically determine communication level
|
||||
"""
|
||||
|
||||
TECHNICAL = "technical"
|
||||
AUTO = "auto"
|
||||
|
||||
|
||||
class ProjectStatus(str, PyEnum):
|
||||
"""
|
||||
Project lifecycle status.
|
||||
|
||||
ACTIVE: Project is actively being worked on
|
||||
PAUSED: Project is temporarily on hold
|
||||
COMPLETED: Project has been delivered successfully
|
||||
ARCHIVED: Project is no longer accessible for work
|
||||
"""
|
||||
|
||||
ACTIVE = "active"
|
||||
PAUSED = "paused"
|
||||
COMPLETED = "completed"
|
||||
ARCHIVED = "archived"
|
||||
|
||||
|
||||
class AgentStatus(str, PyEnum):
|
||||
"""
|
||||
Current operational status of an agent instance.
|
||||
|
||||
IDLE: Agent is available but not currently working
|
||||
WORKING: Agent is actively processing a task
|
||||
WAITING: Agent is waiting for external input or approval
|
||||
PAUSED: Agent has been manually paused
|
||||
TERMINATED: Agent instance has been shut down
|
||||
"""
|
||||
|
||||
IDLE = "idle"
|
||||
WORKING = "working"
|
||||
WAITING = "waiting"
|
||||
PAUSED = "paused"
|
||||
TERMINATED = "terminated"
|
||||
|
||||
|
||||
class IssueType(str, PyEnum):
|
||||
"""
|
||||
Issue type for categorization and hierarchy.
|
||||
|
||||
EPIC: Large feature or body of work containing stories
|
||||
STORY: User-facing feature or requirement
|
||||
TASK: Technical work item
|
||||
BUG: Defect or issue to be fixed
|
||||
"""
|
||||
|
||||
EPIC = "epic"
|
||||
STORY = "story"
|
||||
TASK = "task"
|
||||
BUG = "bug"
|
||||
|
||||
|
||||
class IssueStatus(str, PyEnum):
|
||||
"""
|
||||
Issue workflow status.
|
||||
|
||||
OPEN: Issue is ready to be worked on
|
||||
IN_PROGRESS: Agent or human is actively working on the issue
|
||||
IN_REVIEW: Work is complete, awaiting review
|
||||
BLOCKED: Issue cannot proceed due to dependencies or blockers
|
||||
CLOSED: Issue has been completed or cancelled
|
||||
"""
|
||||
|
||||
OPEN = "open"
|
||||
IN_PROGRESS = "in_progress"
|
||||
IN_REVIEW = "in_review"
|
||||
BLOCKED = "blocked"
|
||||
CLOSED = "closed"
|
||||
|
||||
|
||||
class IssuePriority(str, PyEnum):
|
||||
"""
|
||||
Issue priority levels.
|
||||
|
||||
LOW: Nice to have, can be deferred
|
||||
MEDIUM: Standard priority, should be done
|
||||
HIGH: Important, should be prioritized
|
||||
CRITICAL: Must be done immediately, blocking other work
|
||||
"""
|
||||
|
||||
LOW = "low"
|
||||
MEDIUM = "medium"
|
||||
HIGH = "high"
|
||||
CRITICAL = "critical"
|
||||
|
||||
|
||||
class SyncStatus(str, PyEnum):
|
||||
"""
|
||||
External issue tracker synchronization status.
|
||||
|
||||
SYNCED: Local and remote are in sync
|
||||
PENDING: Local changes waiting to be pushed
|
||||
CONFLICT: Merge conflict between local and remote
|
||||
ERROR: Synchronization failed due to an error
|
||||
"""
|
||||
|
||||
SYNCED = "synced"
|
||||
PENDING = "pending"
|
||||
CONFLICT = "conflict"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class SprintStatus(str, PyEnum):
|
||||
"""
|
||||
Sprint lifecycle status.
|
||||
|
||||
PLANNED: Sprint has been created but not started
|
||||
ACTIVE: Sprint is currently in progress
|
||||
IN_REVIEW: Sprint work is done, demo/review pending
|
||||
COMPLETED: Sprint has been finished successfully
|
||||
CANCELLED: Sprint was cancelled before completion
|
||||
"""
|
||||
|
||||
PLANNED = "planned"
|
||||
ACTIVE = "active"
|
||||
IN_REVIEW = "in_review"
|
||||
COMPLETED = "completed"
|
||||
CANCELLED = "cancelled"
|
||||
176
backend/app/models/syndarix/issue.py
Normal file
176
backend/app/models/syndarix/issue.py
Normal file
@@ -0,0 +1,176 @@
|
||||
# app/models/syndarix/issue.py
|
||||
"""
|
||||
Issue model for Syndarix AI consulting platform.
|
||||
|
||||
An Issue represents a unit of work that can be assigned to agents or humans,
|
||||
with optional synchronization to external issue trackers (Gitea, GitHub, GitLab).
|
||||
"""
|
||||
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
Date,
|
||||
DateTime,
|
||||
Enum,
|
||||
ForeignKey,
|
||||
Index,
|
||||
Integer,
|
||||
String,
|
||||
Text,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import (
|
||||
JSONB,
|
||||
UUID as PGUUID,
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.models.base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
from .enums import IssuePriority, IssueStatus, IssueType, SyncStatus
|
||||
|
||||
|
||||
class Issue(Base, UUIDMixin, TimestampMixin):
|
||||
"""
|
||||
Issue model representing a unit of work in a project.
|
||||
|
||||
Features:
|
||||
- Standard issue fields (title, body, status, priority)
|
||||
- Assignment to agent instances or human assignees
|
||||
- Sprint association for backlog management
|
||||
- External tracker synchronization (Gitea, GitHub, GitLab)
|
||||
"""
|
||||
|
||||
__tablename__ = "issues"
|
||||
|
||||
# Foreign key to project
|
||||
project_id = Column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("projects.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Parent issue for hierarchy (Epic -> Story -> Task)
|
||||
parent_id = Column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("issues.id", ondelete="CASCADE"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Issue type (Epic, Story, Task, Bug)
|
||||
type: Column[IssueType] = Column(
|
||||
Enum(IssueType),
|
||||
default=IssueType.TASK,
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Reporter (who created this issue - can be user or agent)
|
||||
reporter_id = Column(
|
||||
PGUUID(as_uuid=True),
|
||||
nullable=True, # System-generated issues may have no reporter
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Issue content
|
||||
title = Column(String(500), nullable=False)
|
||||
body = Column(Text, nullable=False, default="")
|
||||
|
||||
# Status and priority
|
||||
status: Column[IssueStatus] = Column(
|
||||
Enum(IssueStatus),
|
||||
default=IssueStatus.OPEN,
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
priority: Column[IssuePriority] = Column(
|
||||
Enum(IssuePriority),
|
||||
default=IssuePriority.MEDIUM,
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Labels for categorization (e.g., ["bug", "frontend", "urgent"])
|
||||
labels = Column(JSONB, default=list, nullable=False)
|
||||
|
||||
# Assignment - either to an agent or a human (mutually exclusive)
|
||||
assigned_agent_id = Column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("agent_instances.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Human assignee (username or email, not a FK to allow external users)
|
||||
human_assignee = Column(String(255), nullable=True, index=True)
|
||||
|
||||
# Sprint association
|
||||
sprint_id = Column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("sprints.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Story points for estimation
|
||||
story_points = Column(Integer, nullable=True)
|
||||
|
||||
# Due date for the issue
|
||||
due_date = Column(Date, nullable=True, index=True)
|
||||
|
||||
# External tracker integration
|
||||
external_tracker_type = Column(
|
||||
String(50),
|
||||
nullable=True,
|
||||
index=True,
|
||||
) # 'gitea', 'github', 'gitlab'
|
||||
|
||||
external_issue_id = Column(String(255), nullable=True) # External system's ID
|
||||
remote_url = Column(String(1000), nullable=True) # Link to external issue
|
||||
external_issue_number = Column(Integer, nullable=True) # Issue number (e.g., #123)
|
||||
|
||||
# Sync status with external tracker
|
||||
sync_status: Column[SyncStatus] = Column(
|
||||
Enum(SyncStatus),
|
||||
default=SyncStatus.SYNCED,
|
||||
nullable=False,
|
||||
# Note: Index defined in __table_args__ as ix_issues_sync_status
|
||||
)
|
||||
|
||||
last_synced_at = Column(DateTime(timezone=True), nullable=True)
|
||||
external_updated_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Lifecycle timestamp
|
||||
closed_at = Column(DateTime(timezone=True), nullable=True, index=True)
|
||||
|
||||
# Relationships
|
||||
project = relationship("Project", back_populates="issues")
|
||||
assigned_agent = relationship(
|
||||
"AgentInstance",
|
||||
back_populates="assigned_issues",
|
||||
foreign_keys=[assigned_agent_id],
|
||||
)
|
||||
sprint = relationship("Sprint", back_populates="issues")
|
||||
parent = relationship("Issue", remote_side="Issue.id", backref="children")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_issues_project_status", "project_id", "status"),
|
||||
Index("ix_issues_project_priority", "project_id", "priority"),
|
||||
Index("ix_issues_project_sprint", "project_id", "sprint_id"),
|
||||
Index(
|
||||
"ix_issues_external_tracker_id",
|
||||
"external_tracker_type",
|
||||
"external_issue_id",
|
||||
),
|
||||
Index("ix_issues_sync_status", "sync_status"),
|
||||
Index("ix_issues_project_agent", "project_id", "assigned_agent_id"),
|
||||
Index("ix_issues_project_type", "project_id", "type"),
|
||||
Index("ix_issues_project_status_priority", "project_id", "status", "priority"),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<Issue {self.id} title='{self.title[:30]}...' "
|
||||
f"status={self.status.value} priority={self.priority.value}>"
|
||||
)
|
||||
103
backend/app/models/syndarix/project.py
Normal file
103
backend/app/models/syndarix/project.py
Normal file
@@ -0,0 +1,103 @@
|
||||
# app/models/syndarix/project.py
|
||||
"""
|
||||
Project model for Syndarix AI consulting platform.
|
||||
|
||||
A Project represents a client engagement where AI agents collaborate
|
||||
to deliver software solutions.
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, Enum, ForeignKey, Index, String, Text
|
||||
from sqlalchemy.dialects.postgresql import (
|
||||
JSONB,
|
||||
UUID as PGUUID,
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.models.base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
from .enums import AutonomyLevel, ClientMode, ProjectComplexity, ProjectStatus
|
||||
|
||||
|
||||
class Project(Base, UUIDMixin, TimestampMixin):
|
||||
"""
|
||||
Project model representing a client engagement.
|
||||
|
||||
A project contains:
|
||||
- Configuration for how autonomous agents should operate
|
||||
- Settings for MCP server integrations
|
||||
- Relationship to assigned agents, issues, and sprints
|
||||
"""
|
||||
|
||||
__tablename__ = "projects"
|
||||
|
||||
name = Column(String(255), nullable=False, index=True)
|
||||
slug = Column(String(255), unique=True, nullable=False, index=True)
|
||||
description = Column(Text, nullable=True)
|
||||
|
||||
autonomy_level: Column[AutonomyLevel] = Column(
|
||||
Enum(AutonomyLevel),
|
||||
default=AutonomyLevel.MILESTONE,
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
status: Column[ProjectStatus] = Column(
|
||||
Enum(ProjectStatus),
|
||||
default=ProjectStatus.ACTIVE,
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
complexity: Column[ProjectComplexity] = Column(
|
||||
Enum(ProjectComplexity),
|
||||
default=ProjectComplexity.MEDIUM,
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
client_mode: Column[ClientMode] = Column(
|
||||
Enum(ClientMode),
|
||||
default=ClientMode.AUTO,
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# JSON field for flexible project configuration
|
||||
# Can include: mcp_servers, webhook_urls, notification_settings, etc.
|
||||
settings = Column(JSONB, default=dict, nullable=False)
|
||||
|
||||
# Foreign key to the User who owns this project
|
||||
owner_id = Column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
owner = relationship("User", foreign_keys=[owner_id])
|
||||
agent_instances = relationship(
|
||||
"AgentInstance",
|
||||
back_populates="project",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
issues = relationship(
|
||||
"Issue",
|
||||
back_populates="project",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
sprints = relationship(
|
||||
"Sprint",
|
||||
back_populates="project",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_projects_slug_status", "slug", "status"),
|
||||
Index("ix_projects_owner_status", "owner_id", "status"),
|
||||
Index("ix_projects_autonomy_status", "autonomy_level", "status"),
|
||||
Index("ix_projects_complexity_status", "complexity", "status"),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Project {self.name} ({self.slug}) status={self.status.value}>"
|
||||
86
backend/app/models/syndarix/sprint.py
Normal file
86
backend/app/models/syndarix/sprint.py
Normal file
@@ -0,0 +1,86 @@
|
||||
# app/models/syndarix/sprint.py
|
||||
"""
|
||||
Sprint model for Syndarix AI consulting platform.
|
||||
|
||||
A Sprint represents a time-boxed iteration for organizing and delivering work.
|
||||
"""
|
||||
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
Date,
|
||||
Enum,
|
||||
ForeignKey,
|
||||
Index,
|
||||
Integer,
|
||||
String,
|
||||
Text,
|
||||
UniqueConstraint,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import UUID as PGUUID
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.models.base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
from .enums import SprintStatus
|
||||
|
||||
|
||||
class Sprint(Base, UUIDMixin, TimestampMixin):
|
||||
"""
|
||||
Sprint model representing a time-boxed iteration.
|
||||
|
||||
Tracks:
|
||||
- Sprint metadata (name, number, goal)
|
||||
- Date range (start/end)
|
||||
- Progress metrics (planned vs completed points)
|
||||
"""
|
||||
|
||||
__tablename__ = "sprints"
|
||||
|
||||
# Foreign key to project
|
||||
project_id = Column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("projects.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Sprint identification
|
||||
name = Column(String(255), nullable=False)
|
||||
number = Column(Integer, nullable=False) # Sprint number within project
|
||||
|
||||
# Sprint goal (what we aim to achieve)
|
||||
goal = Column(Text, nullable=True)
|
||||
|
||||
# Date range
|
||||
start_date = Column(Date, nullable=False, index=True)
|
||||
end_date = Column(Date, nullable=False, index=True)
|
||||
|
||||
# Status
|
||||
status: Column[SprintStatus] = Column(
|
||||
Enum(SprintStatus),
|
||||
default=SprintStatus.PLANNED,
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Progress metrics
|
||||
planned_points = Column(Integer, nullable=True) # Sum of story points at start
|
||||
velocity = Column(Integer, nullable=True) # Sum of completed story points
|
||||
|
||||
# Relationships
|
||||
project = relationship("Project", back_populates="sprints")
|
||||
issues = relationship("Issue", back_populates="sprint")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_sprints_project_status", "project_id", "status"),
|
||||
Index("ix_sprints_project_number", "project_id", "number"),
|
||||
Index("ix_sprints_date_range", "start_date", "end_date"),
|
||||
# Ensure sprint numbers are unique within a project
|
||||
UniqueConstraint("project_id", "number", name="uq_sprint_project_number"),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<Sprint {self.name} (#{self.number}) "
|
||||
f"project={self.project_id} status={self.status.value}>"
|
||||
)
|
||||
273
backend/app/schemas/events.py
Normal file
273
backend/app/schemas/events.py
Normal file
@@ -0,0 +1,273 @@
|
||||
"""
|
||||
Event schemas for the Syndarix EventBus (Redis Pub/Sub).
|
||||
|
||||
This module defines event types and payload schemas for real-time communication
|
||||
between services, agents, and the frontend.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class EventType(str, Enum):
|
||||
"""
|
||||
Event types for the EventBus.
|
||||
|
||||
Naming convention: {domain}.{action}
|
||||
"""
|
||||
|
||||
# Agent Events
|
||||
AGENT_SPAWNED = "agent.spawned"
|
||||
AGENT_STATUS_CHANGED = "agent.status_changed"
|
||||
AGENT_MESSAGE = "agent.message"
|
||||
AGENT_TERMINATED = "agent.terminated"
|
||||
|
||||
# Issue Events
|
||||
ISSUE_CREATED = "issue.created"
|
||||
ISSUE_UPDATED = "issue.updated"
|
||||
ISSUE_ASSIGNED = "issue.assigned"
|
||||
ISSUE_CLOSED = "issue.closed"
|
||||
|
||||
# Sprint Events
|
||||
SPRINT_STARTED = "sprint.started"
|
||||
SPRINT_COMPLETED = "sprint.completed"
|
||||
|
||||
# Approval Events
|
||||
APPROVAL_REQUESTED = "approval.requested"
|
||||
APPROVAL_GRANTED = "approval.granted"
|
||||
APPROVAL_DENIED = "approval.denied"
|
||||
|
||||
# Project Events
|
||||
PROJECT_CREATED = "project.created"
|
||||
PROJECT_UPDATED = "project.updated"
|
||||
PROJECT_ARCHIVED = "project.archived"
|
||||
|
||||
# Workflow Events
|
||||
WORKFLOW_STARTED = "workflow.started"
|
||||
WORKFLOW_STEP_COMPLETED = "workflow.step_completed"
|
||||
WORKFLOW_COMPLETED = "workflow.completed"
|
||||
WORKFLOW_FAILED = "workflow.failed"
|
||||
|
||||
|
||||
ActorType = Literal["agent", "user", "system"]
|
||||
|
||||
|
||||
class Event(BaseModel):
|
||||
"""
|
||||
Base event schema for the EventBus.
|
||||
|
||||
All events published to the EventBus must conform to this schema.
|
||||
"""
|
||||
|
||||
id: str = Field(
|
||||
...,
|
||||
description="Unique event identifier (UUID string)",
|
||||
examples=["550e8400-e29b-41d4-a716-446655440000"],
|
||||
)
|
||||
type: EventType = Field(
|
||||
...,
|
||||
description="Event type enum value",
|
||||
examples=[EventType.AGENT_MESSAGE],
|
||||
)
|
||||
timestamp: datetime = Field(
|
||||
...,
|
||||
description="When the event occurred (UTC)",
|
||||
examples=["2024-01-15T10:30:00Z"],
|
||||
)
|
||||
project_id: UUID = Field(
|
||||
...,
|
||||
description="Project this event belongs to",
|
||||
examples=["550e8400-e29b-41d4-a716-446655440001"],
|
||||
)
|
||||
actor_id: UUID | None = Field(
|
||||
default=None,
|
||||
description="ID of the agent or user who triggered the event",
|
||||
examples=["550e8400-e29b-41d4-a716-446655440002"],
|
||||
)
|
||||
actor_type: ActorType = Field(
|
||||
...,
|
||||
description="Type of actor: 'agent', 'user', or 'system'",
|
||||
examples=["agent"],
|
||||
)
|
||||
payload: dict = Field(
|
||||
default_factory=dict,
|
||||
description="Event-specific payload data",
|
||||
)
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
"type": "agent.message",
|
||||
"timestamp": "2024-01-15T10:30:00Z",
|
||||
"project_id": "550e8400-e29b-41d4-a716-446655440001",
|
||||
"actor_id": "550e8400-e29b-41d4-a716-446655440002",
|
||||
"actor_type": "agent",
|
||||
"payload": {"message": "Processing task...", "progress": 50},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# Specific payload schemas for type safety
|
||||
|
||||
|
||||
class AgentSpawnedPayload(BaseModel):
|
||||
"""Payload for AGENT_SPAWNED events."""
|
||||
|
||||
agent_instance_id: UUID = Field(..., description="ID of the spawned agent instance")
|
||||
agent_type_id: UUID = Field(..., description="ID of the agent type")
|
||||
agent_name: str = Field(..., description="Human-readable name of the agent")
|
||||
role: str = Field(..., description="Agent role (e.g., 'product_owner', 'engineer')")
|
||||
|
||||
|
||||
class AgentStatusChangedPayload(BaseModel):
|
||||
"""Payload for AGENT_STATUS_CHANGED events."""
|
||||
|
||||
agent_instance_id: UUID = Field(..., description="ID of the agent instance")
|
||||
previous_status: str = Field(..., description="Previous status")
|
||||
new_status: str = Field(..., description="New status")
|
||||
reason: str | None = Field(default=None, description="Reason for status change")
|
||||
|
||||
|
||||
class AgentMessagePayload(BaseModel):
|
||||
"""Payload for AGENT_MESSAGE events."""
|
||||
|
||||
agent_instance_id: UUID = Field(..., description="ID of the agent instance")
|
||||
message: str = Field(..., description="Message content")
|
||||
message_type: str = Field(
|
||||
default="info",
|
||||
description="Message type: 'info', 'warning', 'error', 'debug'",
|
||||
)
|
||||
metadata: dict = Field(
|
||||
default_factory=dict,
|
||||
description="Additional metadata (e.g., token usage, model info)",
|
||||
)
|
||||
|
||||
|
||||
class AgentTerminatedPayload(BaseModel):
|
||||
"""Payload for AGENT_TERMINATED events."""
|
||||
|
||||
agent_instance_id: UUID = Field(..., description="ID of the agent instance")
|
||||
termination_reason: str = Field(..., description="Reason for termination")
|
||||
final_status: str = Field(..., description="Final status at termination")
|
||||
|
||||
|
||||
class IssueCreatedPayload(BaseModel):
|
||||
"""Payload for ISSUE_CREATED events."""
|
||||
|
||||
issue_id: str = Field(..., description="Issue ID (from external tracker)")
|
||||
title: str = Field(..., description="Issue title")
|
||||
priority: str | None = Field(default=None, description="Issue priority")
|
||||
labels: list[str] = Field(default_factory=list, description="Issue labels")
|
||||
|
||||
|
||||
class IssueUpdatedPayload(BaseModel):
|
||||
"""Payload for ISSUE_UPDATED events."""
|
||||
|
||||
issue_id: str = Field(..., description="Issue ID (from external tracker)")
|
||||
changes: dict = Field(..., description="Dictionary of field changes")
|
||||
|
||||
|
||||
class IssueAssignedPayload(BaseModel):
|
||||
"""Payload for ISSUE_ASSIGNED events."""
|
||||
|
||||
issue_id: str = Field(..., description="Issue ID (from external tracker)")
|
||||
assignee_id: UUID | None = Field(
|
||||
default=None, description="Agent or user assigned to"
|
||||
)
|
||||
assignee_name: str | None = Field(default=None, description="Assignee name")
|
||||
|
||||
|
||||
class IssueClosedPayload(BaseModel):
|
||||
"""Payload for ISSUE_CLOSED events."""
|
||||
|
||||
issue_id: str = Field(..., description="Issue ID (from external tracker)")
|
||||
resolution: str = Field(..., description="Resolution status")
|
||||
|
||||
|
||||
class SprintStartedPayload(BaseModel):
|
||||
"""Payload for SPRINT_STARTED events."""
|
||||
|
||||
sprint_id: UUID = Field(..., description="Sprint ID")
|
||||
sprint_name: str = Field(..., description="Sprint name")
|
||||
goal: str | None = Field(default=None, description="Sprint goal")
|
||||
issue_count: int = Field(default=0, description="Number of issues in sprint")
|
||||
|
||||
|
||||
class SprintCompletedPayload(BaseModel):
|
||||
"""Payload for SPRINT_COMPLETED events."""
|
||||
|
||||
sprint_id: UUID = Field(..., description="Sprint ID")
|
||||
sprint_name: str = Field(..., description="Sprint name")
|
||||
completed_issues: int = Field(default=0, description="Number of completed issues")
|
||||
incomplete_issues: int = Field(default=0, description="Number of incomplete issues")
|
||||
|
||||
|
||||
class ApprovalRequestedPayload(BaseModel):
|
||||
"""Payload for APPROVAL_REQUESTED events."""
|
||||
|
||||
approval_id: UUID = Field(..., description="Approval request ID")
|
||||
approval_type: str = Field(..., description="Type of approval needed")
|
||||
description: str = Field(..., description="Description of what needs approval")
|
||||
requested_by: UUID | None = Field(
|
||||
default=None, description="Agent/user requesting approval"
|
||||
)
|
||||
timeout_minutes: int | None = Field(
|
||||
default=None, description="Minutes before auto-escalation"
|
||||
)
|
||||
|
||||
|
||||
class ApprovalGrantedPayload(BaseModel):
|
||||
"""Payload for APPROVAL_GRANTED events."""
|
||||
|
||||
approval_id: UUID = Field(..., description="Approval request ID")
|
||||
approved_by: UUID = Field(..., description="User who granted approval")
|
||||
comments: str | None = Field(default=None, description="Approval comments")
|
||||
|
||||
|
||||
class ApprovalDeniedPayload(BaseModel):
|
||||
"""Payload for APPROVAL_DENIED events."""
|
||||
|
||||
approval_id: UUID = Field(..., description="Approval request ID")
|
||||
denied_by: UUID = Field(..., description="User who denied approval")
|
||||
reason: str = Field(..., description="Reason for denial")
|
||||
|
||||
|
||||
class WorkflowStartedPayload(BaseModel):
|
||||
"""Payload for WORKFLOW_STARTED events."""
|
||||
|
||||
workflow_id: UUID = Field(..., description="Workflow execution ID")
|
||||
workflow_type: str = Field(..., description="Type of workflow")
|
||||
total_steps: int = Field(default=0, description="Total number of steps")
|
||||
|
||||
|
||||
class WorkflowStepCompletedPayload(BaseModel):
|
||||
"""Payload for WORKFLOW_STEP_COMPLETED events."""
|
||||
|
||||
workflow_id: UUID = Field(..., description="Workflow execution ID")
|
||||
step_name: str = Field(..., description="Name of completed step")
|
||||
step_number: int = Field(..., description="Step number (1-indexed)")
|
||||
total_steps: int = Field(..., description="Total number of steps")
|
||||
result: dict = Field(default_factory=dict, description="Step result data")
|
||||
|
||||
|
||||
class WorkflowCompletedPayload(BaseModel):
|
||||
"""Payload for WORKFLOW_COMPLETED events."""
|
||||
|
||||
workflow_id: UUID = Field(..., description="Workflow execution ID")
|
||||
duration_seconds: float = Field(..., description="Total execution duration")
|
||||
result: dict = Field(default_factory=dict, description="Workflow result data")
|
||||
|
||||
|
||||
class WorkflowFailedPayload(BaseModel):
|
||||
"""Payload for WORKFLOW_FAILED events."""
|
||||
|
||||
workflow_id: UUID = Field(..., description="Workflow execution ID")
|
||||
error_message: str = Field(..., description="Error message")
|
||||
failed_step: str | None = Field(default=None, description="Step that failed")
|
||||
recoverable: bool = Field(default=False, description="Whether error is recoverable")
|
||||
113
backend/app/schemas/syndarix/__init__.py
Normal file
113
backend/app/schemas/syndarix/__init__.py
Normal file
@@ -0,0 +1,113 @@
|
||||
# app/schemas/syndarix/__init__.py
|
||||
"""
|
||||
Syndarix domain schemas.
|
||||
|
||||
This package contains Pydantic schemas for validating and serializing
|
||||
Syndarix domain entities.
|
||||
"""
|
||||
|
||||
from .agent_instance import (
|
||||
AgentInstanceCreate,
|
||||
AgentInstanceInDB,
|
||||
AgentInstanceListResponse,
|
||||
AgentInstanceMetrics,
|
||||
AgentInstanceResponse,
|
||||
AgentInstanceTerminate,
|
||||
AgentInstanceUpdate,
|
||||
)
|
||||
from .agent_type import (
|
||||
AgentTypeCreate,
|
||||
AgentTypeInDB,
|
||||
AgentTypeListResponse,
|
||||
AgentTypeResponse,
|
||||
AgentTypeUpdate,
|
||||
)
|
||||
from .enums import (
|
||||
AgentStatus,
|
||||
AutonomyLevel,
|
||||
IssuePriority,
|
||||
IssueStatus,
|
||||
ProjectStatus,
|
||||
SprintStatus,
|
||||
SyncStatus,
|
||||
)
|
||||
from .issue import (
|
||||
IssueAssign,
|
||||
IssueClose,
|
||||
IssueCreate,
|
||||
IssueInDB,
|
||||
IssueListResponse,
|
||||
IssueResponse,
|
||||
IssueStats,
|
||||
IssueSyncUpdate,
|
||||
IssueUpdate,
|
||||
)
|
||||
from .project import (
|
||||
ProjectCreate,
|
||||
ProjectInDB,
|
||||
ProjectListResponse,
|
||||
ProjectResponse,
|
||||
ProjectUpdate,
|
||||
)
|
||||
from .sprint import (
|
||||
SprintBurndown,
|
||||
SprintComplete,
|
||||
SprintCreate,
|
||||
SprintInDB,
|
||||
SprintListResponse,
|
||||
SprintResponse,
|
||||
SprintStart,
|
||||
SprintUpdate,
|
||||
SprintVelocity,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# AgentInstance schemas
|
||||
"AgentInstanceCreate",
|
||||
"AgentInstanceInDB",
|
||||
"AgentInstanceListResponse",
|
||||
"AgentInstanceMetrics",
|
||||
"AgentInstanceResponse",
|
||||
"AgentInstanceTerminate",
|
||||
"AgentInstanceUpdate",
|
||||
# Enums
|
||||
"AgentStatus",
|
||||
# AgentType schemas
|
||||
"AgentTypeCreate",
|
||||
"AgentTypeInDB",
|
||||
"AgentTypeListResponse",
|
||||
"AgentTypeResponse",
|
||||
"AgentTypeUpdate",
|
||||
"AutonomyLevel",
|
||||
# Issue schemas
|
||||
"IssueAssign",
|
||||
"IssueClose",
|
||||
"IssueCreate",
|
||||
"IssueInDB",
|
||||
"IssueListResponse",
|
||||
"IssuePriority",
|
||||
"IssueResponse",
|
||||
"IssueStats",
|
||||
"IssueStatus",
|
||||
"IssueSyncUpdate",
|
||||
"IssueUpdate",
|
||||
# Project schemas
|
||||
"ProjectCreate",
|
||||
"ProjectInDB",
|
||||
"ProjectListResponse",
|
||||
"ProjectResponse",
|
||||
"ProjectStatus",
|
||||
"ProjectUpdate",
|
||||
# Sprint schemas
|
||||
"SprintBurndown",
|
||||
"SprintComplete",
|
||||
"SprintCreate",
|
||||
"SprintInDB",
|
||||
"SprintListResponse",
|
||||
"SprintResponse",
|
||||
"SprintStart",
|
||||
"SprintStatus",
|
||||
"SprintUpdate",
|
||||
"SprintVelocity",
|
||||
"SyncStatus",
|
||||
]
|
||||
124
backend/app/schemas/syndarix/agent_instance.py
Normal file
124
backend/app/schemas/syndarix/agent_instance.py
Normal file
@@ -0,0 +1,124 @@
|
||||
# app/schemas/syndarix/agent_instance.py
|
||||
"""
|
||||
Pydantic schemas for AgentInstance entity.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from .enums import AgentStatus
|
||||
|
||||
|
||||
class AgentInstanceBase(BaseModel):
|
||||
"""Base agent instance schema with common fields."""
|
||||
|
||||
agent_type_id: UUID
|
||||
project_id: UUID
|
||||
status: AgentStatus = AgentStatus.IDLE
|
||||
current_task: str | None = None
|
||||
short_term_memory: dict[str, Any] = Field(default_factory=dict)
|
||||
long_term_memory_ref: str | None = Field(None, max_length=500)
|
||||
session_id: str | None = Field(None, max_length=255)
|
||||
|
||||
|
||||
class AgentInstanceCreate(BaseModel):
|
||||
"""Schema for creating a new agent instance."""
|
||||
|
||||
agent_type_id: UUID
|
||||
project_id: UUID
|
||||
name: str = Field(..., min_length=1, max_length=100)
|
||||
status: AgentStatus = AgentStatus.IDLE
|
||||
current_task: str | None = None
|
||||
short_term_memory: dict[str, Any] = Field(default_factory=dict)
|
||||
long_term_memory_ref: str | None = Field(None, max_length=500)
|
||||
session_id: str | None = Field(None, max_length=255)
|
||||
|
||||
|
||||
class AgentInstanceUpdate(BaseModel):
|
||||
"""Schema for updating an agent instance."""
|
||||
|
||||
status: AgentStatus | None = None
|
||||
current_task: str | None = None
|
||||
short_term_memory: dict[str, Any] | None = None
|
||||
long_term_memory_ref: str | None = None
|
||||
session_id: str | None = None
|
||||
last_activity_at: datetime | None = None
|
||||
tasks_completed: int | None = Field(None, ge=0)
|
||||
tokens_used: int | None = Field(None, ge=0)
|
||||
cost_incurred: Decimal | None = Field(None, ge=0)
|
||||
|
||||
|
||||
class AgentInstanceTerminate(BaseModel):
|
||||
"""Schema for terminating an agent instance."""
|
||||
|
||||
reason: str | None = None
|
||||
|
||||
|
||||
class AgentInstanceInDB(AgentInstanceBase):
|
||||
"""Schema for agent instance in database."""
|
||||
|
||||
id: UUID
|
||||
last_activity_at: datetime | None = None
|
||||
terminated_at: datetime | None = None
|
||||
tasks_completed: int = 0
|
||||
tokens_used: int = 0
|
||||
cost_incurred: Decimal = Decimal("0.0000")
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class AgentInstanceResponse(BaseModel):
|
||||
"""Schema for agent instance API responses."""
|
||||
|
||||
id: UUID
|
||||
agent_type_id: UUID
|
||||
project_id: UUID
|
||||
name: str
|
||||
status: AgentStatus
|
||||
current_task: str | None = None
|
||||
short_term_memory: dict[str, Any] = Field(default_factory=dict)
|
||||
long_term_memory_ref: str | None = None
|
||||
session_id: str | None = None
|
||||
last_activity_at: datetime | None = None
|
||||
terminated_at: datetime | None = None
|
||||
tasks_completed: int = 0
|
||||
tokens_used: int = 0
|
||||
cost_incurred: Decimal = Decimal("0.0000")
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
# Expanded fields from relationships
|
||||
agent_type_name: str | None = None
|
||||
agent_type_slug: str | None = None
|
||||
project_name: str | None = None
|
||||
project_slug: str | None = None
|
||||
assigned_issues_count: int | None = 0
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class AgentInstanceListResponse(BaseModel):
|
||||
"""Schema for paginated agent instance list responses."""
|
||||
|
||||
agent_instances: list[AgentInstanceResponse]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
pages: int
|
||||
|
||||
|
||||
class AgentInstanceMetrics(BaseModel):
|
||||
"""Schema for agent instance metrics summary."""
|
||||
|
||||
total_instances: int
|
||||
active_instances: int
|
||||
idle_instances: int
|
||||
total_tasks_completed: int
|
||||
total_tokens_used: int
|
||||
total_cost_incurred: Decimal
|
||||
151
backend/app/schemas/syndarix/agent_type.py
Normal file
151
backend/app/schemas/syndarix/agent_type.py
Normal file
@@ -0,0 +1,151 @@
|
||||
# app/schemas/syndarix/agent_type.py
|
||||
"""
|
||||
Pydantic schemas for AgentType entity.
|
||||
"""
|
||||
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
|
||||
class AgentTypeBase(BaseModel):
|
||||
"""Base agent type schema with common fields."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
slug: str | None = Field(None, min_length=1, max_length=255)
|
||||
description: str | None = None
|
||||
expertise: list[str] = Field(default_factory=list)
|
||||
personality_prompt: str = Field(..., min_length=1)
|
||||
primary_model: str = Field(..., min_length=1, max_length=100)
|
||||
fallback_models: list[str] = Field(default_factory=list)
|
||||
model_params: dict[str, Any] = Field(default_factory=dict)
|
||||
mcp_servers: list[str] = Field(default_factory=list)
|
||||
tool_permissions: dict[str, Any] = Field(default_factory=dict)
|
||||
is_active: bool = True
|
||||
|
||||
@field_validator("slug")
|
||||
@classmethod
|
||||
def validate_slug(cls, v: str | None) -> str | None:
|
||||
"""Validate slug format: lowercase, alphanumeric, hyphens only."""
|
||||
if v is None:
|
||||
return v
|
||||
if not re.match(r"^[a-z0-9-]+$", v):
|
||||
raise ValueError(
|
||||
"Slug must contain only lowercase letters, numbers, and hyphens"
|
||||
)
|
||||
if v.startswith("-") or v.endswith("-"):
|
||||
raise ValueError("Slug cannot start or end with a hyphen")
|
||||
if "--" in v:
|
||||
raise ValueError("Slug cannot contain consecutive hyphens")
|
||||
return v
|
||||
|
||||
@field_validator("name")
|
||||
@classmethod
|
||||
def validate_name(cls, v: str) -> str:
|
||||
"""Validate agent type name."""
|
||||
if not v or v.strip() == "":
|
||||
raise ValueError("Agent type name cannot be empty")
|
||||
return v.strip()
|
||||
|
||||
@field_validator("expertise")
|
||||
@classmethod
|
||||
def validate_expertise(cls, v: list[str]) -> list[str]:
|
||||
"""Validate and normalize expertise list."""
|
||||
return [e.strip().lower() for e in v if e.strip()]
|
||||
|
||||
@field_validator("mcp_servers")
|
||||
@classmethod
|
||||
def validate_mcp_servers(cls, v: list[str]) -> list[str]:
|
||||
"""Validate MCP server list."""
|
||||
return [s.strip() for s in v if s.strip()]
|
||||
|
||||
|
||||
class AgentTypeCreate(AgentTypeBase):
|
||||
"""Schema for creating a new agent type."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
slug: str = Field(..., min_length=1, max_length=255)
|
||||
personality_prompt: str = Field(..., min_length=1)
|
||||
primary_model: str = Field(..., min_length=1, max_length=100)
|
||||
|
||||
|
||||
class AgentTypeUpdate(BaseModel):
|
||||
"""Schema for updating an agent type."""
|
||||
|
||||
name: str | None = Field(None, min_length=1, max_length=255)
|
||||
slug: str | None = Field(None, min_length=1, max_length=255)
|
||||
description: str | None = None
|
||||
expertise: list[str] | None = None
|
||||
personality_prompt: str | None = None
|
||||
primary_model: str | None = Field(None, min_length=1, max_length=100)
|
||||
fallback_models: list[str] | None = None
|
||||
model_params: dict[str, Any] | None = None
|
||||
mcp_servers: list[str] | None = None
|
||||
tool_permissions: dict[str, Any] | None = None
|
||||
is_active: bool | None = None
|
||||
|
||||
@field_validator("slug")
|
||||
@classmethod
|
||||
def validate_slug(cls, v: str | None) -> str | None:
|
||||
"""Validate slug format."""
|
||||
if v is None:
|
||||
return v
|
||||
if not re.match(r"^[a-z0-9-]+$", v):
|
||||
raise ValueError(
|
||||
"Slug must contain only lowercase letters, numbers, and hyphens"
|
||||
)
|
||||
if v.startswith("-") or v.endswith("-"):
|
||||
raise ValueError("Slug cannot start or end with a hyphen")
|
||||
if "--" in v:
|
||||
raise ValueError("Slug cannot contain consecutive hyphens")
|
||||
return v
|
||||
|
||||
@field_validator("name")
|
||||
@classmethod
|
||||
def validate_name(cls, v: str | None) -> str | None:
|
||||
"""Validate agent type name."""
|
||||
if v is not None and (not v or v.strip() == ""):
|
||||
raise ValueError("Agent type name cannot be empty")
|
||||
return v.strip() if v else v
|
||||
|
||||
@field_validator("expertise")
|
||||
@classmethod
|
||||
def validate_expertise(cls, v: list[str] | None) -> list[str] | None:
|
||||
"""Validate and normalize expertise list."""
|
||||
if v is None:
|
||||
return v
|
||||
return [e.strip().lower() for e in v if e.strip()]
|
||||
|
||||
|
||||
class AgentTypeInDB(AgentTypeBase):
|
||||
"""Schema for agent type in database."""
|
||||
|
||||
id: UUID
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class AgentTypeResponse(AgentTypeBase):
|
||||
"""Schema for agent type API responses."""
|
||||
|
||||
id: UUID
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
instance_count: int | None = 0
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class AgentTypeListResponse(BaseModel):
|
||||
"""Schema for paginated agent type list responses."""
|
||||
|
||||
agent_types: list[AgentTypeResponse]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
pages: int
|
||||
26
backend/app/schemas/syndarix/enums.py
Normal file
26
backend/app/schemas/syndarix/enums.py
Normal file
@@ -0,0 +1,26 @@
|
||||
# app/schemas/syndarix/enums.py
|
||||
"""
|
||||
Re-export enums from models for use in schemas.
|
||||
|
||||
This allows schemas to import enums without depending on SQLAlchemy models directly.
|
||||
"""
|
||||
|
||||
from app.models.syndarix.enums import (
|
||||
AgentStatus,
|
||||
AutonomyLevel,
|
||||
IssuePriority,
|
||||
IssueStatus,
|
||||
ProjectStatus,
|
||||
SprintStatus,
|
||||
SyncStatus,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AgentStatus",
|
||||
"AutonomyLevel",
|
||||
"IssuePriority",
|
||||
"IssueStatus",
|
||||
"ProjectStatus",
|
||||
"SprintStatus",
|
||||
"SyncStatus",
|
||||
]
|
||||
191
backend/app/schemas/syndarix/issue.py
Normal file
191
backend/app/schemas/syndarix/issue.py
Normal file
@@ -0,0 +1,191 @@
|
||||
# app/schemas/syndarix/issue.py
|
||||
"""
|
||||
Pydantic schemas for Issue entity.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
|
||||
from .enums import IssuePriority, IssueStatus, SyncStatus
|
||||
|
||||
|
||||
class IssueBase(BaseModel):
|
||||
"""Base issue schema with common fields."""
|
||||
|
||||
title: str = Field(..., min_length=1, max_length=500)
|
||||
body: str = ""
|
||||
status: IssueStatus = IssueStatus.OPEN
|
||||
priority: IssuePriority = IssuePriority.MEDIUM
|
||||
labels: list[str] = Field(default_factory=list)
|
||||
story_points: int | None = Field(None, ge=0, le=100)
|
||||
|
||||
@field_validator("title")
|
||||
@classmethod
|
||||
def validate_title(cls, v: str) -> str:
|
||||
"""Validate issue title."""
|
||||
if not v or v.strip() == "":
|
||||
raise ValueError("Issue title cannot be empty")
|
||||
return v.strip()
|
||||
|
||||
@field_validator("labels")
|
||||
@classmethod
|
||||
def validate_labels(cls, v: list[str]) -> list[str]:
|
||||
"""Validate and normalize labels."""
|
||||
return [label.strip().lower() for label in v if label.strip()]
|
||||
|
||||
|
||||
class IssueCreate(IssueBase):
|
||||
"""Schema for creating a new issue."""
|
||||
|
||||
project_id: UUID
|
||||
assigned_agent_id: UUID | None = None
|
||||
human_assignee: str | None = Field(None, max_length=255)
|
||||
sprint_id: UUID | None = None
|
||||
|
||||
# External tracker fields (optional, for importing from external systems)
|
||||
external_tracker_type: Literal["gitea", "github", "gitlab"] | None = None
|
||||
external_issue_id: str | None = Field(None, max_length=255)
|
||||
remote_url: str | None = Field(None, max_length=1000)
|
||||
external_issue_number: int | None = None
|
||||
|
||||
|
||||
class IssueUpdate(BaseModel):
|
||||
"""Schema for updating an issue."""
|
||||
|
||||
title: str | None = Field(None, min_length=1, max_length=500)
|
||||
body: str | None = None
|
||||
status: IssueStatus | None = None
|
||||
priority: IssuePriority | None = None
|
||||
labels: list[str] | None = None
|
||||
assigned_agent_id: UUID | None = None
|
||||
human_assignee: str | None = Field(None, max_length=255)
|
||||
sprint_id: UUID | None = None
|
||||
story_points: int | None = Field(None, ge=0, le=100)
|
||||
sync_status: SyncStatus | None = None
|
||||
|
||||
@field_validator("title")
|
||||
@classmethod
|
||||
def validate_title(cls, v: str | None) -> str | None:
|
||||
"""Validate issue title."""
|
||||
if v is not None and (not v or v.strip() == ""):
|
||||
raise ValueError("Issue title cannot be empty")
|
||||
return v.strip() if v else v
|
||||
|
||||
@field_validator("labels")
|
||||
@classmethod
|
||||
def validate_labels(cls, v: list[str] | None) -> list[str] | None:
|
||||
"""Validate and normalize labels."""
|
||||
if v is None:
|
||||
return v
|
||||
return [label.strip().lower() for label in v if label.strip()]
|
||||
|
||||
|
||||
class IssueClose(BaseModel):
|
||||
"""Schema for closing an issue."""
|
||||
|
||||
resolution: str | None = None # Optional resolution note
|
||||
|
||||
|
||||
class IssueAssign(BaseModel):
|
||||
"""Schema for assigning an issue."""
|
||||
|
||||
assigned_agent_id: UUID | None = None
|
||||
human_assignee: str | None = Field(None, max_length=255)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_assignment(self) -> "IssueAssign":
|
||||
"""Ensure only one type of assignee is set."""
|
||||
if self.assigned_agent_id and self.human_assignee:
|
||||
raise ValueError("Cannot assign to both an agent and a human. Choose one.")
|
||||
return self
|
||||
|
||||
|
||||
class IssueSyncUpdate(BaseModel):
|
||||
"""Schema for updating sync-related fields."""
|
||||
|
||||
sync_status: SyncStatus
|
||||
last_synced_at: datetime | None = None
|
||||
external_updated_at: datetime | None = None
|
||||
|
||||
|
||||
class IssueInDB(IssueBase):
|
||||
"""Schema for issue in database."""
|
||||
|
||||
id: UUID
|
||||
project_id: UUID
|
||||
assigned_agent_id: UUID | None = None
|
||||
human_assignee: str | None = None
|
||||
sprint_id: UUID | None = None
|
||||
external_tracker_type: str | None = None
|
||||
external_issue_id: str | None = None
|
||||
remote_url: str | None = None
|
||||
external_issue_number: int | None = None
|
||||
sync_status: SyncStatus = SyncStatus.SYNCED
|
||||
last_synced_at: datetime | None = None
|
||||
external_updated_at: datetime | None = None
|
||||
closed_at: datetime | None = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class IssueResponse(BaseModel):
|
||||
"""Schema for issue API responses."""
|
||||
|
||||
id: UUID
|
||||
project_id: UUID
|
||||
title: str
|
||||
body: str
|
||||
status: IssueStatus
|
||||
priority: IssuePriority
|
||||
labels: list[str] = Field(default_factory=list)
|
||||
assigned_agent_id: UUID | None = None
|
||||
human_assignee: str | None = None
|
||||
sprint_id: UUID | None = None
|
||||
story_points: int | None = None
|
||||
external_tracker_type: str | None = None
|
||||
external_issue_id: str | None = None
|
||||
remote_url: str | None = None
|
||||
external_issue_number: int | None = None
|
||||
sync_status: SyncStatus = SyncStatus.SYNCED
|
||||
last_synced_at: datetime | None = None
|
||||
external_updated_at: datetime | None = None
|
||||
closed_at: datetime | None = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
# Expanded fields from relationships
|
||||
project_name: str | None = None
|
||||
project_slug: str | None = None
|
||||
sprint_name: str | None = None
|
||||
assigned_agent_type_name: str | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class IssueListResponse(BaseModel):
|
||||
"""Schema for paginated issue list responses."""
|
||||
|
||||
issues: list[IssueResponse]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
pages: int
|
||||
|
||||
|
||||
class IssueStats(BaseModel):
|
||||
"""Schema for issue statistics."""
|
||||
|
||||
total: int
|
||||
open: int
|
||||
in_progress: int
|
||||
in_review: int
|
||||
blocked: int
|
||||
closed: int
|
||||
by_priority: dict[str, int]
|
||||
total_story_points: int | None = None
|
||||
completed_story_points: int | None = None
|
||||
131
backend/app/schemas/syndarix/project.py
Normal file
131
backend/app/schemas/syndarix/project.py
Normal file
@@ -0,0 +1,131 @@
|
||||
# app/schemas/syndarix/project.py
|
||||
"""
|
||||
Pydantic schemas for Project entity.
|
||||
"""
|
||||
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from .enums import AutonomyLevel, ProjectStatus
|
||||
|
||||
|
||||
class ProjectBase(BaseModel):
|
||||
"""Base project schema with common fields."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
slug: str | None = Field(None, min_length=1, max_length=255)
|
||||
description: str | None = None
|
||||
autonomy_level: AutonomyLevel = AutonomyLevel.MILESTONE
|
||||
status: ProjectStatus = ProjectStatus.ACTIVE
|
||||
settings: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
@field_validator("slug")
|
||||
@classmethod
|
||||
def validate_slug(cls, v: str | None) -> str | None:
|
||||
"""Validate slug format: lowercase, alphanumeric, hyphens only."""
|
||||
if v is None:
|
||||
return v
|
||||
if not re.match(r"^[a-z0-9-]+$", v):
|
||||
raise ValueError(
|
||||
"Slug must contain only lowercase letters, numbers, and hyphens"
|
||||
)
|
||||
if v.startswith("-") or v.endswith("-"):
|
||||
raise ValueError("Slug cannot start or end with a hyphen")
|
||||
if "--" in v:
|
||||
raise ValueError("Slug cannot contain consecutive hyphens")
|
||||
return v
|
||||
|
||||
@field_validator("name")
|
||||
@classmethod
|
||||
def validate_name(cls, v: str) -> str:
|
||||
"""Validate project name."""
|
||||
if not v or v.strip() == "":
|
||||
raise ValueError("Project name cannot be empty")
|
||||
return v.strip()
|
||||
|
||||
|
||||
class ProjectCreate(ProjectBase):
|
||||
"""Schema for creating a new project."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
slug: str = Field(..., min_length=1, max_length=255)
|
||||
owner_id: UUID | None = None
|
||||
|
||||
|
||||
class ProjectUpdate(BaseModel):
|
||||
"""Schema for updating a project.
|
||||
|
||||
Note: owner_id is intentionally excluded to prevent IDOR vulnerabilities.
|
||||
Project ownership transfer should be done via a dedicated endpoint with
|
||||
proper authorization checks.
|
||||
"""
|
||||
|
||||
name: str | None = Field(None, min_length=1, max_length=255)
|
||||
slug: str | None = Field(None, min_length=1, max_length=255)
|
||||
description: str | None = None
|
||||
autonomy_level: AutonomyLevel | None = None
|
||||
status: ProjectStatus | None = None
|
||||
settings: dict[str, Any] | None = None
|
||||
|
||||
@field_validator("slug")
|
||||
@classmethod
|
||||
def validate_slug(cls, v: str | None) -> str | None:
|
||||
"""Validate slug format."""
|
||||
if v is None:
|
||||
return v
|
||||
if not re.match(r"^[a-z0-9-]+$", v):
|
||||
raise ValueError(
|
||||
"Slug must contain only lowercase letters, numbers, and hyphens"
|
||||
)
|
||||
if v.startswith("-") or v.endswith("-"):
|
||||
raise ValueError("Slug cannot start or end with a hyphen")
|
||||
if "--" in v:
|
||||
raise ValueError("Slug cannot contain consecutive hyphens")
|
||||
return v
|
||||
|
||||
@field_validator("name")
|
||||
@classmethod
|
||||
def validate_name(cls, v: str | None) -> str | None:
|
||||
"""Validate project name."""
|
||||
if v is not None and (not v or v.strip() == ""):
|
||||
raise ValueError("Project name cannot be empty")
|
||||
return v.strip() if v else v
|
||||
|
||||
|
||||
class ProjectInDB(ProjectBase):
|
||||
"""Schema for project in database."""
|
||||
|
||||
id: UUID
|
||||
owner_id: UUID | None = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class ProjectResponse(ProjectBase):
|
||||
"""Schema for project API responses."""
|
||||
|
||||
id: UUID
|
||||
owner_id: UUID | None = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
agent_count: int | None = 0
|
||||
issue_count: int | None = 0
|
||||
active_sprint_name: str | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class ProjectListResponse(BaseModel):
|
||||
"""Schema for paginated project list responses."""
|
||||
|
||||
projects: list[ProjectResponse]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
pages: int
|
||||
135
backend/app/schemas/syndarix/sprint.py
Normal file
135
backend/app/schemas/syndarix/sprint.py
Normal file
@@ -0,0 +1,135 @@
|
||||
# app/schemas/syndarix/sprint.py
|
||||
"""
|
||||
Pydantic schemas for Sprint entity.
|
||||
"""
|
||||
|
||||
from datetime import date, datetime
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
|
||||
from .enums import SprintStatus
|
||||
|
||||
|
||||
class SprintBase(BaseModel):
|
||||
"""Base sprint schema with common fields."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
number: int = Field(..., ge=1)
|
||||
goal: str | None = None
|
||||
start_date: date
|
||||
end_date: date
|
||||
status: SprintStatus = SprintStatus.PLANNED
|
||||
planned_points: int | None = Field(None, ge=0)
|
||||
velocity: int | None = Field(None, ge=0)
|
||||
|
||||
@field_validator("name")
|
||||
@classmethod
|
||||
def validate_name(cls, v: str) -> str:
|
||||
"""Validate sprint name."""
|
||||
if not v or v.strip() == "":
|
||||
raise ValueError("Sprint name cannot be empty")
|
||||
return v.strip()
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_dates(self) -> "SprintBase":
|
||||
"""Validate that end_date is after start_date."""
|
||||
if self.end_date < self.start_date:
|
||||
raise ValueError("End date must be after or equal to start date")
|
||||
return self
|
||||
|
||||
|
||||
class SprintCreate(SprintBase):
|
||||
"""Schema for creating a new sprint."""
|
||||
|
||||
project_id: UUID
|
||||
|
||||
|
||||
class SprintUpdate(BaseModel):
|
||||
"""Schema for updating a sprint."""
|
||||
|
||||
name: str | None = Field(None, min_length=1, max_length=255)
|
||||
goal: str | None = None
|
||||
start_date: date | None = None
|
||||
end_date: date | None = None
|
||||
status: SprintStatus | None = None
|
||||
planned_points: int | None = Field(None, ge=0)
|
||||
velocity: int | None = Field(None, ge=0)
|
||||
|
||||
@field_validator("name")
|
||||
@classmethod
|
||||
def validate_name(cls, v: str | None) -> str | None:
|
||||
"""Validate sprint name."""
|
||||
if v is not None and (not v or v.strip() == ""):
|
||||
raise ValueError("Sprint name cannot be empty")
|
||||
return v.strip() if v else v
|
||||
|
||||
|
||||
class SprintStart(BaseModel):
|
||||
"""Schema for starting a sprint."""
|
||||
|
||||
start_date: date | None = None # Optionally override start date
|
||||
|
||||
|
||||
class SprintComplete(BaseModel):
|
||||
"""Schema for completing a sprint."""
|
||||
|
||||
velocity: int | None = Field(None, ge=0)
|
||||
notes: str | None = None
|
||||
|
||||
|
||||
class SprintInDB(SprintBase):
|
||||
"""Schema for sprint in database."""
|
||||
|
||||
id: UUID
|
||||
project_id: UUID
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class SprintResponse(SprintBase):
|
||||
"""Schema for sprint API responses."""
|
||||
|
||||
id: UUID
|
||||
project_id: UUID
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
# Expanded fields from relationships
|
||||
project_name: str | None = None
|
||||
project_slug: str | None = None
|
||||
issue_count: int | None = 0
|
||||
open_issues: int | None = 0
|
||||
completed_issues: int | None = 0
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class SprintListResponse(BaseModel):
|
||||
"""Schema for paginated sprint list responses."""
|
||||
|
||||
sprints: list[SprintResponse]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
pages: int
|
||||
|
||||
|
||||
class SprintVelocity(BaseModel):
|
||||
"""Schema for sprint velocity metrics."""
|
||||
|
||||
sprint_number: int
|
||||
sprint_name: str
|
||||
planned_points: int | None
|
||||
velocity: int | None # Sum of completed story points
|
||||
velocity_ratio: float | None # velocity/planned ratio
|
||||
|
||||
|
||||
class SprintBurndown(BaseModel):
|
||||
"""Schema for sprint burndown data point."""
|
||||
|
||||
date: date
|
||||
remaining_points: int
|
||||
ideal_remaining: float
|
||||
178
backend/app/services/context/__init__.py
Normal file
178
backend/app/services/context/__init__.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""
|
||||
Context Management Engine
|
||||
|
||||
Sophisticated context assembly and optimization for LLM requests.
|
||||
Provides intelligent context selection, token budget management,
|
||||
and model-specific formatting.
|
||||
|
||||
Usage:
|
||||
from app.services.context import (
|
||||
ContextSettings,
|
||||
get_context_settings,
|
||||
SystemContext,
|
||||
KnowledgeContext,
|
||||
ConversationContext,
|
||||
TaskContext,
|
||||
ToolContext,
|
||||
TokenBudget,
|
||||
BudgetAllocator,
|
||||
TokenCalculator,
|
||||
)
|
||||
|
||||
# Get settings
|
||||
settings = get_context_settings()
|
||||
|
||||
# Create budget for a model
|
||||
allocator = BudgetAllocator(settings)
|
||||
budget = allocator.create_budget_for_model("claude-3-sonnet")
|
||||
|
||||
# Create context instances
|
||||
system_ctx = SystemContext.create_persona(
|
||||
name="Code Assistant",
|
||||
description="You are a helpful code assistant.",
|
||||
capabilities=["Write code", "Debug issues"],
|
||||
)
|
||||
"""
|
||||
|
||||
# Budget Management
|
||||
# Adapters
|
||||
from .adapters import (
|
||||
ClaudeAdapter,
|
||||
DefaultAdapter,
|
||||
ModelAdapter,
|
||||
OpenAIAdapter,
|
||||
get_adapter,
|
||||
)
|
||||
|
||||
# Assembly
|
||||
from .assembly import (
|
||||
ContextPipeline,
|
||||
PipelineMetrics,
|
||||
)
|
||||
from .budget import (
|
||||
BudgetAllocator,
|
||||
TokenBudget,
|
||||
TokenCalculator,
|
||||
)
|
||||
|
||||
# Cache
|
||||
from .cache import ContextCache
|
||||
|
||||
# Compression
|
||||
from .compression import (
|
||||
ContextCompressor,
|
||||
TruncationResult,
|
||||
TruncationStrategy,
|
||||
)
|
||||
|
||||
# Configuration
|
||||
from .config import (
|
||||
ContextSettings,
|
||||
get_context_settings,
|
||||
get_default_settings,
|
||||
reset_context_settings,
|
||||
)
|
||||
|
||||
# Engine
|
||||
from .engine import ContextEngine, create_context_engine
|
||||
|
||||
# Exceptions
|
||||
from .exceptions import (
|
||||
AssemblyTimeoutError,
|
||||
BudgetExceededError,
|
||||
CacheError,
|
||||
CompressionError,
|
||||
ContextError,
|
||||
ContextNotFoundError,
|
||||
FormattingError,
|
||||
InvalidContextError,
|
||||
ScoringError,
|
||||
TokenCountError,
|
||||
)
|
||||
|
||||
# Prioritization
|
||||
from .prioritization import (
|
||||
ContextRanker,
|
||||
RankingResult,
|
||||
)
|
||||
|
||||
# Scoring
|
||||
from .scoring import (
|
||||
BaseScorer,
|
||||
CompositeScorer,
|
||||
PriorityScorer,
|
||||
RecencyScorer,
|
||||
RelevanceScorer,
|
||||
ScoredContext,
|
||||
)
|
||||
|
||||
# Types
|
||||
from .types import (
|
||||
AssembledContext,
|
||||
BaseContext,
|
||||
ContextPriority,
|
||||
ContextType,
|
||||
ConversationContext,
|
||||
KnowledgeContext,
|
||||
MessageRole,
|
||||
SystemContext,
|
||||
TaskComplexity,
|
||||
TaskContext,
|
||||
TaskStatus,
|
||||
ToolContext,
|
||||
ToolResultStatus,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AssembledContext",
|
||||
"AssemblyTimeoutError",
|
||||
"BaseContext",
|
||||
"BaseScorer",
|
||||
"BudgetAllocator",
|
||||
"BudgetExceededError",
|
||||
"CacheError",
|
||||
"ClaudeAdapter",
|
||||
"CompositeScorer",
|
||||
"CompressionError",
|
||||
"ContextCache",
|
||||
"ContextCompressor",
|
||||
"ContextEngine",
|
||||
"ContextError",
|
||||
"ContextNotFoundError",
|
||||
"ContextPipeline",
|
||||
"ContextPriority",
|
||||
"ContextRanker",
|
||||
"ContextSettings",
|
||||
"ContextType",
|
||||
"ConversationContext",
|
||||
"DefaultAdapter",
|
||||
"FormattingError",
|
||||
"InvalidContextError",
|
||||
"KnowledgeContext",
|
||||
"MessageRole",
|
||||
"ModelAdapter",
|
||||
"OpenAIAdapter",
|
||||
"PipelineMetrics",
|
||||
"PriorityScorer",
|
||||
"RankingResult",
|
||||
"RecencyScorer",
|
||||
"RelevanceScorer",
|
||||
"ScoredContext",
|
||||
"ScoringError",
|
||||
"SystemContext",
|
||||
"TaskComplexity",
|
||||
"TaskContext",
|
||||
"TaskStatus",
|
||||
"TokenBudget",
|
||||
"TokenCalculator",
|
||||
"TokenCountError",
|
||||
"ToolContext",
|
||||
"ToolResultStatus",
|
||||
"TruncationResult",
|
||||
"TruncationStrategy",
|
||||
"create_context_engine",
|
||||
"get_adapter",
|
||||
"get_context_settings",
|
||||
"get_default_settings",
|
||||
"reset_context_settings",
|
||||
]
|
||||
35
backend/app/services/context/adapters/__init__.py
Normal file
35
backend/app/services/context/adapters/__init__.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""
|
||||
Model Adapters Module.
|
||||
|
||||
Provides model-specific context formatting adapters.
|
||||
"""
|
||||
|
||||
from .base import DefaultAdapter, ModelAdapter
|
||||
from .claude import ClaudeAdapter
|
||||
from .openai import OpenAIAdapter
|
||||
|
||||
|
||||
def get_adapter(model: str) -> ModelAdapter:
|
||||
"""
|
||||
Get the appropriate adapter for a model.
|
||||
|
||||
Args:
|
||||
model: Model name
|
||||
|
||||
Returns:
|
||||
Adapter instance for the model
|
||||
"""
|
||||
if ClaudeAdapter.matches_model(model):
|
||||
return ClaudeAdapter()
|
||||
elif OpenAIAdapter.matches_model(model):
|
||||
return OpenAIAdapter()
|
||||
return DefaultAdapter()
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ClaudeAdapter",
|
||||
"DefaultAdapter",
|
||||
"ModelAdapter",
|
||||
"OpenAIAdapter",
|
||||
"get_adapter",
|
||||
]
|
||||
178
backend/app/services/context/adapters/base.py
Normal file
178
backend/app/services/context/adapters/base.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""
|
||||
Base Model Adapter.
|
||||
|
||||
Abstract base class for model-specific context formatting.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from ..types import BaseContext, ContextType
|
||||
|
||||
|
||||
class ModelAdapter(ABC):
|
||||
"""
|
||||
Abstract base adapter for model-specific context formatting.
|
||||
|
||||
Each adapter knows how to format contexts for optimal
|
||||
understanding by a specific LLM family (Claude, OpenAI, etc.).
|
||||
"""
|
||||
|
||||
# Model name patterns this adapter handles
|
||||
MODEL_PATTERNS: ClassVar[list[str]] = []
|
||||
|
||||
@classmethod
|
||||
def matches_model(cls, model: str) -> bool:
|
||||
"""
|
||||
Check if this adapter handles the given model.
|
||||
|
||||
Args:
|
||||
model: Model name to check
|
||||
|
||||
Returns:
|
||||
True if this adapter handles the model
|
||||
"""
|
||||
model_lower = model.lower()
|
||||
return any(pattern in model_lower for pattern in cls.MODEL_PATTERNS)
|
||||
|
||||
@abstractmethod
|
||||
def format(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""
|
||||
Format contexts for the target model.
|
||||
|
||||
Args:
|
||||
contexts: List of contexts to format
|
||||
**kwargs: Additional formatting options
|
||||
|
||||
Returns:
|
||||
Formatted context string
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def format_type(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
context_type: ContextType,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""
|
||||
Format contexts of a specific type.
|
||||
|
||||
Args:
|
||||
contexts: List of contexts of the same type
|
||||
context_type: The type of contexts
|
||||
**kwargs: Additional formatting options
|
||||
|
||||
Returns:
|
||||
Formatted string for this context type
|
||||
"""
|
||||
...
|
||||
|
||||
def get_type_order(self) -> list[ContextType]:
|
||||
"""
|
||||
Get the preferred order of context types.
|
||||
|
||||
Returns:
|
||||
List of context types in preferred order
|
||||
"""
|
||||
return [
|
||||
ContextType.SYSTEM,
|
||||
ContextType.TASK,
|
||||
ContextType.KNOWLEDGE,
|
||||
ContextType.CONVERSATION,
|
||||
ContextType.TOOL,
|
||||
]
|
||||
|
||||
def group_by_type(
|
||||
self, contexts: list[BaseContext]
|
||||
) -> dict[ContextType, list[BaseContext]]:
|
||||
"""
|
||||
Group contexts by their type.
|
||||
|
||||
Args:
|
||||
contexts: List of contexts to group
|
||||
|
||||
Returns:
|
||||
Dictionary mapping context type to list of contexts
|
||||
"""
|
||||
by_type: dict[ContextType, list[BaseContext]] = {}
|
||||
for context in contexts:
|
||||
ct = context.get_type()
|
||||
if ct not in by_type:
|
||||
by_type[ct] = []
|
||||
by_type[ct].append(context)
|
||||
return by_type
|
||||
|
||||
def get_separator(self) -> str:
|
||||
"""
|
||||
Get the separator between context sections.
|
||||
|
||||
Returns:
|
||||
Separator string
|
||||
"""
|
||||
return "\n\n"
|
||||
|
||||
|
||||
class DefaultAdapter(ModelAdapter):
|
||||
"""
|
||||
Default adapter for unknown models.
|
||||
|
||||
Uses simple plain-text formatting with minimal structure.
|
||||
"""
|
||||
|
||||
MODEL_PATTERNS: ClassVar[list[str]] = [] # Fallback adapter
|
||||
|
||||
@classmethod
|
||||
def matches_model(cls, model: str) -> bool:
|
||||
"""Always returns True as fallback."""
|
||||
return True
|
||||
|
||||
def format(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Format contexts as plain text."""
|
||||
if not contexts:
|
||||
return ""
|
||||
|
||||
by_type = self.group_by_type(contexts)
|
||||
parts: list[str] = []
|
||||
|
||||
for ct in self.get_type_order():
|
||||
if ct in by_type:
|
||||
formatted = self.format_type(by_type[ct], ct, **kwargs)
|
||||
if formatted:
|
||||
parts.append(formatted)
|
||||
|
||||
return self.get_separator().join(parts)
|
||||
|
||||
def format_type(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
context_type: ContextType,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Format contexts of a type as plain text."""
|
||||
if not contexts:
|
||||
return ""
|
||||
|
||||
content = "\n\n".join(c.content for c in contexts)
|
||||
|
||||
if context_type == ContextType.SYSTEM:
|
||||
return content
|
||||
elif context_type == ContextType.TASK:
|
||||
return f"Task:\n{content}"
|
||||
elif context_type == ContextType.KNOWLEDGE:
|
||||
return f"Reference Information:\n{content}"
|
||||
elif context_type == ContextType.CONVERSATION:
|
||||
return f"Previous Conversation:\n{content}"
|
||||
elif context_type == ContextType.TOOL:
|
||||
return f"Tool Results:\n{content}"
|
||||
|
||||
return content
|
||||
212
backend/app/services/context/adapters/claude.py
Normal file
212
backend/app/services/context/adapters/claude.py
Normal file
@@ -0,0 +1,212 @@
|
||||
"""
|
||||
Claude Model Adapter.
|
||||
|
||||
Provides Claude-specific context formatting using XML tags
|
||||
which Claude models understand natively.
|
||||
"""
|
||||
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from ..types import BaseContext, ContextType
|
||||
from .base import ModelAdapter
|
||||
|
||||
|
||||
class ClaudeAdapter(ModelAdapter):
|
||||
"""
|
||||
Claude-specific context formatting adapter.
|
||||
|
||||
Claude models have native understanding of XML structure,
|
||||
so we use XML tags for clear delineation of context types.
|
||||
|
||||
Features:
|
||||
- XML tags for each context type
|
||||
- Document structure for knowledge contexts
|
||||
- Role-based message formatting for conversations
|
||||
- Tool result wrapping with tool names
|
||||
"""
|
||||
|
||||
MODEL_PATTERNS: ClassVar[list[str]] = ["claude", "anthropic"]
|
||||
|
||||
def format(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""
|
||||
Format contexts for Claude models.
|
||||
|
||||
Uses XML tags for structured content that Claude
|
||||
understands natively.
|
||||
|
||||
Args:
|
||||
contexts: List of contexts to format
|
||||
**kwargs: Additional formatting options
|
||||
|
||||
Returns:
|
||||
XML-structured context string
|
||||
"""
|
||||
if not contexts:
|
||||
return ""
|
||||
|
||||
by_type = self.group_by_type(contexts)
|
||||
parts: list[str] = []
|
||||
|
||||
for ct in self.get_type_order():
|
||||
if ct in by_type:
|
||||
formatted = self.format_type(by_type[ct], ct, **kwargs)
|
||||
if formatted:
|
||||
parts.append(formatted)
|
||||
|
||||
return self.get_separator().join(parts)
|
||||
|
||||
def format_type(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
context_type: ContextType,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""
|
||||
Format contexts of a specific type for Claude.
|
||||
|
||||
Args:
|
||||
contexts: List of contexts of the same type
|
||||
context_type: The type of contexts
|
||||
**kwargs: Additional formatting options
|
||||
|
||||
Returns:
|
||||
XML-formatted string for this context type
|
||||
"""
|
||||
if not contexts:
|
||||
return ""
|
||||
|
||||
if context_type == ContextType.SYSTEM:
|
||||
return self._format_system(contexts)
|
||||
elif context_type == ContextType.TASK:
|
||||
return self._format_task(contexts)
|
||||
elif context_type == ContextType.KNOWLEDGE:
|
||||
return self._format_knowledge(contexts)
|
||||
elif context_type == ContextType.CONVERSATION:
|
||||
return self._format_conversation(contexts)
|
||||
elif context_type == ContextType.TOOL:
|
||||
return self._format_tool(contexts)
|
||||
|
||||
# Fallback for any unhandled context types - still escape content
|
||||
# to prevent XML injection if new types are added without updating adapter
|
||||
return "\n".join(self._escape_xml_content(c.content) for c in contexts)
|
||||
|
||||
def _format_system(self, contexts: list[BaseContext]) -> str:
|
||||
"""Format system contexts."""
|
||||
# System prompts are typically admin-controlled, but escape for safety
|
||||
content = "\n\n".join(self._escape_xml_content(c.content) for c in contexts)
|
||||
return f"<system_instructions>\n{content}\n</system_instructions>"
|
||||
|
||||
def _format_task(self, contexts: list[BaseContext]) -> str:
|
||||
"""Format task contexts."""
|
||||
content = "\n\n".join(self._escape_xml_content(c.content) for c in contexts)
|
||||
return f"<current_task>\n{content}\n</current_task>"
|
||||
|
||||
def _format_knowledge(self, contexts: list[BaseContext]) -> str:
|
||||
"""
|
||||
Format knowledge contexts as structured documents.
|
||||
|
||||
Each knowledge context becomes a document with source attribution.
|
||||
All content is XML-escaped to prevent injection attacks.
|
||||
"""
|
||||
parts = ["<reference_documents>"]
|
||||
|
||||
for ctx in contexts:
|
||||
source = self._escape_xml(ctx.source)
|
||||
# Escape content to prevent XML injection
|
||||
content = self._escape_xml_content(ctx.content)
|
||||
score = ctx.metadata.get("score", ctx.metadata.get("relevance_score", ""))
|
||||
|
||||
if score:
|
||||
# Escape score to prevent XML injection via metadata
|
||||
escaped_score = self._escape_xml(str(score))
|
||||
parts.append(
|
||||
f'<document source="{source}" relevance="{escaped_score}">'
|
||||
)
|
||||
else:
|
||||
parts.append(f'<document source="{source}">')
|
||||
|
||||
parts.append(content)
|
||||
parts.append("</document>")
|
||||
|
||||
parts.append("</reference_documents>")
|
||||
return "\n".join(parts)
|
||||
|
||||
def _format_conversation(self, contexts: list[BaseContext]) -> str:
|
||||
"""
|
||||
Format conversation contexts as message history.
|
||||
|
||||
Uses role-based message tags for clear turn delineation.
|
||||
All content is XML-escaped to prevent prompt injection.
|
||||
"""
|
||||
parts = ["<conversation_history>"]
|
||||
|
||||
for ctx in contexts:
|
||||
role = self._escape_xml(ctx.metadata.get("role", "user"))
|
||||
# Escape content to prevent prompt injection via fake XML tags
|
||||
content = self._escape_xml_content(ctx.content)
|
||||
parts.append(f'<message role="{role}">')
|
||||
parts.append(content)
|
||||
parts.append("</message>")
|
||||
|
||||
parts.append("</conversation_history>")
|
||||
return "\n".join(parts)
|
||||
|
||||
def _format_tool(self, contexts: list[BaseContext]) -> str:
|
||||
"""
|
||||
Format tool contexts as tool results.
|
||||
|
||||
Each tool result is wrapped with the tool name.
|
||||
All content is XML-escaped to prevent injection.
|
||||
"""
|
||||
parts = ["<tool_results>"]
|
||||
|
||||
for ctx in contexts:
|
||||
tool_name = self._escape_xml(ctx.metadata.get("tool_name", "unknown"))
|
||||
status = ctx.metadata.get("status", "")
|
||||
|
||||
if status:
|
||||
parts.append(
|
||||
f'<tool_result name="{tool_name}" status="{self._escape_xml(status)}">'
|
||||
)
|
||||
else:
|
||||
parts.append(f'<tool_result name="{tool_name}">')
|
||||
|
||||
# Escape content to prevent injection
|
||||
parts.append(self._escape_xml_content(ctx.content))
|
||||
parts.append("</tool_result>")
|
||||
|
||||
parts.append("</tool_results>")
|
||||
return "\n".join(parts)
|
||||
|
||||
@staticmethod
|
||||
def _escape_xml(text: str) -> str:
|
||||
"""Escape XML special characters in attribute values."""
|
||||
return (
|
||||
text.replace("&", "&")
|
||||
.replace("<", "<")
|
||||
.replace(">", ">")
|
||||
.replace('"', """)
|
||||
.replace("'", "'")
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _escape_xml_content(text: str) -> str:
|
||||
"""
|
||||
Escape XML special characters in element content.
|
||||
|
||||
This prevents XML injection attacks where malicious content
|
||||
could break out of XML tags or inject fake tags for prompt injection.
|
||||
|
||||
Only escapes &, <, > since quotes don't need escaping in content.
|
||||
|
||||
Args:
|
||||
text: Content text to escape
|
||||
|
||||
Returns:
|
||||
XML-safe content string
|
||||
"""
|
||||
return text.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
160
backend/app/services/context/adapters/openai.py
Normal file
160
backend/app/services/context/adapters/openai.py
Normal file
@@ -0,0 +1,160 @@
|
||||
"""
|
||||
OpenAI Model Adapter.
|
||||
|
||||
Provides OpenAI-specific context formatting using markdown
|
||||
which GPT models understand well.
|
||||
"""
|
||||
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from ..types import BaseContext, ContextType
|
||||
from .base import ModelAdapter
|
||||
|
||||
|
||||
class OpenAIAdapter(ModelAdapter):
|
||||
"""
|
||||
OpenAI-specific context formatting adapter.
|
||||
|
||||
GPT models work well with markdown formatting,
|
||||
so we use headers and structured markdown for clarity.
|
||||
|
||||
Features:
|
||||
- Markdown headers for each context type
|
||||
- Bulleted lists for document sources
|
||||
- Bold role labels for conversations
|
||||
- Code blocks for tool outputs
|
||||
"""
|
||||
|
||||
MODEL_PATTERNS: ClassVar[list[str]] = ["gpt", "openai", "o1", "o3"]
|
||||
|
||||
def format(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""
|
||||
Format contexts for OpenAI models.
|
||||
|
||||
Uses markdown formatting for structured content.
|
||||
|
||||
Args:
|
||||
contexts: List of contexts to format
|
||||
**kwargs: Additional formatting options
|
||||
|
||||
Returns:
|
||||
Markdown-structured context string
|
||||
"""
|
||||
if not contexts:
|
||||
return ""
|
||||
|
||||
by_type = self.group_by_type(contexts)
|
||||
parts: list[str] = []
|
||||
|
||||
for ct in self.get_type_order():
|
||||
if ct in by_type:
|
||||
formatted = self.format_type(by_type[ct], ct, **kwargs)
|
||||
if formatted:
|
||||
parts.append(formatted)
|
||||
|
||||
return self.get_separator().join(parts)
|
||||
|
||||
def format_type(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
context_type: ContextType,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""
|
||||
Format contexts of a specific type for OpenAI.
|
||||
|
||||
Args:
|
||||
contexts: List of contexts of the same type
|
||||
context_type: The type of contexts
|
||||
**kwargs: Additional formatting options
|
||||
|
||||
Returns:
|
||||
Markdown-formatted string for this context type
|
||||
"""
|
||||
if not contexts:
|
||||
return ""
|
||||
|
||||
if context_type == ContextType.SYSTEM:
|
||||
return self._format_system(contexts)
|
||||
elif context_type == ContextType.TASK:
|
||||
return self._format_task(contexts)
|
||||
elif context_type == ContextType.KNOWLEDGE:
|
||||
return self._format_knowledge(contexts)
|
||||
elif context_type == ContextType.CONVERSATION:
|
||||
return self._format_conversation(contexts)
|
||||
elif context_type == ContextType.TOOL:
|
||||
return self._format_tool(contexts)
|
||||
|
||||
return "\n".join(c.content for c in contexts)
|
||||
|
||||
def _format_system(self, contexts: list[BaseContext]) -> str:
|
||||
"""Format system contexts."""
|
||||
content = "\n\n".join(c.content for c in contexts)
|
||||
return content
|
||||
|
||||
def _format_task(self, contexts: list[BaseContext]) -> str:
|
||||
"""Format task contexts."""
|
||||
content = "\n\n".join(c.content for c in contexts)
|
||||
return f"## Current Task\n\n{content}"
|
||||
|
||||
def _format_knowledge(self, contexts: list[BaseContext]) -> str:
|
||||
"""
|
||||
Format knowledge contexts as structured documents.
|
||||
|
||||
Each knowledge context becomes a section with source attribution.
|
||||
"""
|
||||
parts = ["## Reference Documents\n"]
|
||||
|
||||
for ctx in contexts:
|
||||
source = ctx.source
|
||||
score = ctx.metadata.get("score", ctx.metadata.get("relevance_score", ""))
|
||||
|
||||
if score:
|
||||
parts.append(f"### Source: {source} (relevance: {score})\n")
|
||||
else:
|
||||
parts.append(f"### Source: {source}\n")
|
||||
|
||||
parts.append(ctx.content)
|
||||
parts.append("")
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
def _format_conversation(self, contexts: list[BaseContext]) -> str:
|
||||
"""
|
||||
Format conversation contexts as message history.
|
||||
|
||||
Uses bold role labels for clear turn delineation.
|
||||
"""
|
||||
parts = []
|
||||
|
||||
for ctx in contexts:
|
||||
role = ctx.metadata.get("role", "user").upper()
|
||||
parts.append(f"**{role}**: {ctx.content}")
|
||||
|
||||
return "\n\n".join(parts)
|
||||
|
||||
def _format_tool(self, contexts: list[BaseContext]) -> str:
|
||||
"""
|
||||
Format tool contexts as tool results.
|
||||
|
||||
Each tool result is in a code block with the tool name.
|
||||
"""
|
||||
parts = ["## Recent Tool Results\n"]
|
||||
|
||||
for ctx in contexts:
|
||||
tool_name = ctx.metadata.get("tool_name", "unknown")
|
||||
status = ctx.metadata.get("status", "")
|
||||
|
||||
if status:
|
||||
parts.append(f"### Tool: {tool_name} ({status})\n")
|
||||
else:
|
||||
parts.append(f"### Tool: {tool_name}\n")
|
||||
|
||||
parts.append(f"```\n{ctx.content}\n```")
|
||||
parts.append("")
|
||||
|
||||
return "\n".join(parts)
|
||||
12
backend/app/services/context/assembly/__init__.py
Normal file
12
backend/app/services/context/assembly/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""
|
||||
Context Assembly Module.
|
||||
|
||||
Provides the assembly pipeline and formatting.
|
||||
"""
|
||||
|
||||
from .pipeline import ContextPipeline, PipelineMetrics
|
||||
|
||||
__all__ = [
|
||||
"ContextPipeline",
|
||||
"PipelineMetrics",
|
||||
]
|
||||
362
backend/app/services/context/assembly/pipeline.py
Normal file
362
backend/app/services/context/assembly/pipeline.py
Normal file
@@ -0,0 +1,362 @@
|
||||
"""
|
||||
Context Assembly Pipeline.
|
||||
|
||||
Orchestrates the full context assembly workflow:
|
||||
Gather → Count → Score → Rank → Compress → Format
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from ..adapters import get_adapter
|
||||
from ..budget import BudgetAllocator, TokenBudget, TokenCalculator
|
||||
from ..compression.truncation import ContextCompressor
|
||||
from ..config import ContextSettings, get_context_settings
|
||||
from ..exceptions import AssemblyTimeoutError
|
||||
from ..prioritization import ContextRanker
|
||||
from ..scoring import CompositeScorer
|
||||
from ..types import AssembledContext, BaseContext, ContextType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.services.mcp.client_manager import MCPClientManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineMetrics:
|
||||
"""Metrics from pipeline execution."""
|
||||
|
||||
start_time: datetime = field(default_factory=lambda: datetime.now(UTC))
|
||||
end_time: datetime | None = None
|
||||
total_contexts: int = 0
|
||||
selected_contexts: int = 0
|
||||
excluded_contexts: int = 0
|
||||
compressed_contexts: int = 0
|
||||
total_tokens: int = 0
|
||||
assembly_time_ms: float = 0.0
|
||||
scoring_time_ms: float = 0.0
|
||||
compression_time_ms: float = 0.0
|
||||
formatting_time_ms: float = 0.0
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"start_time": self.start_time.isoformat(),
|
||||
"end_time": self.end_time.isoformat() if self.end_time else None,
|
||||
"total_contexts": self.total_contexts,
|
||||
"selected_contexts": self.selected_contexts,
|
||||
"excluded_contexts": self.excluded_contexts,
|
||||
"compressed_contexts": self.compressed_contexts,
|
||||
"total_tokens": self.total_tokens,
|
||||
"assembly_time_ms": round(self.assembly_time_ms, 2),
|
||||
"scoring_time_ms": round(self.scoring_time_ms, 2),
|
||||
"compression_time_ms": round(self.compression_time_ms, 2),
|
||||
"formatting_time_ms": round(self.formatting_time_ms, 2),
|
||||
}
|
||||
|
||||
|
||||
class ContextPipeline:
|
||||
"""
|
||||
Context assembly pipeline.
|
||||
|
||||
Orchestrates the full workflow of context assembly:
|
||||
1. Validate and count tokens for all contexts
|
||||
2. Score contexts based on relevance, recency, and priority
|
||||
3. Rank and select contexts within budget
|
||||
4. Compress if needed to fit remaining budget
|
||||
5. Format for the target model
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mcp_manager: "MCPClientManager | None" = None,
|
||||
settings: ContextSettings | None = None,
|
||||
calculator: TokenCalculator | None = None,
|
||||
scorer: CompositeScorer | None = None,
|
||||
ranker: ContextRanker | None = None,
|
||||
compressor: ContextCompressor | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the context pipeline.
|
||||
|
||||
Args:
|
||||
mcp_manager: MCP client manager for LLM Gateway integration
|
||||
settings: Context settings
|
||||
calculator: Token calculator
|
||||
scorer: Context scorer
|
||||
ranker: Context ranker
|
||||
compressor: Context compressor
|
||||
"""
|
||||
self._settings = settings or get_context_settings()
|
||||
self._mcp = mcp_manager
|
||||
|
||||
# Initialize components
|
||||
self._calculator = calculator or TokenCalculator(mcp_manager=mcp_manager)
|
||||
self._scorer = scorer or CompositeScorer(
|
||||
mcp_manager=mcp_manager, settings=self._settings
|
||||
)
|
||||
self._ranker = ranker or ContextRanker(
|
||||
scorer=self._scorer, calculator=self._calculator
|
||||
)
|
||||
self._compressor = compressor or ContextCompressor(calculator=self._calculator)
|
||||
self._allocator = BudgetAllocator(self._settings)
|
||||
|
||||
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
|
||||
"""Set MCP manager for all components."""
|
||||
self._mcp = mcp_manager
|
||||
self._calculator.set_mcp_manager(mcp_manager)
|
||||
self._scorer.set_mcp_manager(mcp_manager)
|
||||
|
||||
async def assemble(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
query: str,
|
||||
model: str,
|
||||
max_tokens: int | None = None,
|
||||
custom_budget: TokenBudget | None = None,
|
||||
compress: bool = True,
|
||||
format_output: bool = True,
|
||||
timeout_ms: int | None = None,
|
||||
) -> AssembledContext:
|
||||
"""
|
||||
Assemble context for an LLM request.
|
||||
|
||||
This is the main entry point for context assembly.
|
||||
|
||||
Args:
|
||||
contexts: List of contexts to assemble
|
||||
query: Query to optimize for
|
||||
model: Target model name
|
||||
max_tokens: Maximum total tokens (uses model default if None)
|
||||
custom_budget: Optional pre-configured budget
|
||||
compress: Whether to compress oversized contexts
|
||||
format_output: Whether to format the final output
|
||||
timeout_ms: Maximum assembly time in milliseconds
|
||||
|
||||
Returns:
|
||||
AssembledContext with optimized content
|
||||
|
||||
Raises:
|
||||
AssemblyTimeoutError: If assembly exceeds timeout
|
||||
"""
|
||||
timeout = timeout_ms or self._settings.max_assembly_time_ms
|
||||
start = time.perf_counter()
|
||||
metrics = PipelineMetrics(total_contexts=len(contexts))
|
||||
|
||||
try:
|
||||
# Create or use budget
|
||||
if custom_budget:
|
||||
budget = custom_budget
|
||||
elif max_tokens:
|
||||
budget = self._allocator.create_budget(max_tokens)
|
||||
else:
|
||||
budget = self._allocator.create_budget_for_model(model)
|
||||
|
||||
# 1. Count tokens for all contexts (with timeout enforcement)
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self._ensure_token_counts(contexts, model),
|
||||
timeout=self._remaining_timeout(start, timeout),
|
||||
)
|
||||
except TimeoutError:
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
raise AssemblyTimeoutError(
|
||||
message="Context assembly timed out during token counting",
|
||||
elapsed_ms=elapsed_ms,
|
||||
timeout_ms=timeout,
|
||||
)
|
||||
|
||||
# Check timeout (handles edge case where operation finished just at limit)
|
||||
self._check_timeout(start, timeout, "token counting")
|
||||
|
||||
# 2. Score and rank contexts (with timeout enforcement)
|
||||
scoring_start = time.perf_counter()
|
||||
try:
|
||||
ranking_result = await asyncio.wait_for(
|
||||
self._ranker.rank(
|
||||
contexts=contexts,
|
||||
query=query,
|
||||
budget=budget,
|
||||
model=model,
|
||||
),
|
||||
timeout=self._remaining_timeout(start, timeout),
|
||||
)
|
||||
except TimeoutError:
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
raise AssemblyTimeoutError(
|
||||
message="Context assembly timed out during scoring/ranking",
|
||||
elapsed_ms=elapsed_ms,
|
||||
timeout_ms=timeout,
|
||||
)
|
||||
metrics.scoring_time_ms = (time.perf_counter() - scoring_start) * 1000
|
||||
|
||||
selected_contexts = ranking_result.selected_contexts
|
||||
metrics.selected_contexts = len(selected_contexts)
|
||||
metrics.excluded_contexts = len(ranking_result.excluded)
|
||||
|
||||
# Check timeout
|
||||
self._check_timeout(start, timeout, "scoring")
|
||||
|
||||
# 3. Compress if needed and enabled (with timeout enforcement)
|
||||
if compress and self._needs_compression(selected_contexts, budget):
|
||||
compression_start = time.perf_counter()
|
||||
try:
|
||||
selected_contexts = await asyncio.wait_for(
|
||||
self._compressor.compress_contexts(
|
||||
selected_contexts, budget, model
|
||||
),
|
||||
timeout=self._remaining_timeout(start, timeout),
|
||||
)
|
||||
except TimeoutError:
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
raise AssemblyTimeoutError(
|
||||
message="Context assembly timed out during compression",
|
||||
elapsed_ms=elapsed_ms,
|
||||
timeout_ms=timeout,
|
||||
)
|
||||
metrics.compression_time_ms = (
|
||||
time.perf_counter() - compression_start
|
||||
) * 1000
|
||||
metrics.compressed_contexts = sum(
|
||||
1 for c in selected_contexts if c.metadata.get("truncated", False)
|
||||
)
|
||||
|
||||
# Check timeout
|
||||
self._check_timeout(start, timeout, "compression")
|
||||
|
||||
# 4. Format output
|
||||
formatting_start = time.perf_counter()
|
||||
if format_output:
|
||||
formatted_content = self._format_contexts(selected_contexts, model)
|
||||
else:
|
||||
formatted_content = "\n\n".join(c.content for c in selected_contexts)
|
||||
metrics.formatting_time_ms = (time.perf_counter() - formatting_start) * 1000
|
||||
|
||||
# Calculate final metrics
|
||||
total_tokens = sum(c.token_count or 0 for c in selected_contexts)
|
||||
metrics.total_tokens = total_tokens
|
||||
metrics.assembly_time_ms = (time.perf_counter() - start) * 1000
|
||||
metrics.end_time = datetime.now(UTC)
|
||||
|
||||
return AssembledContext(
|
||||
content=formatted_content,
|
||||
total_tokens=total_tokens,
|
||||
context_count=len(selected_contexts),
|
||||
assembly_time_ms=metrics.assembly_time_ms,
|
||||
model=model,
|
||||
contexts=selected_contexts,
|
||||
excluded_count=metrics.excluded_contexts,
|
||||
metadata={
|
||||
"metrics": metrics.to_dict(),
|
||||
"query": query,
|
||||
"budget": budget.to_dict(),
|
||||
},
|
||||
)
|
||||
|
||||
except AssemblyTimeoutError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Context assembly failed: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def _ensure_token_counts(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
model: str | None = None,
|
||||
) -> None:
|
||||
"""Ensure all contexts have token counts."""
|
||||
tasks = []
|
||||
for context in contexts:
|
||||
if context.token_count is None:
|
||||
tasks.append(self._count_and_set(context, model))
|
||||
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def _count_and_set(
|
||||
self,
|
||||
context: BaseContext,
|
||||
model: str | None = None,
|
||||
) -> None:
|
||||
"""Count tokens and set on context."""
|
||||
count = await self._calculator.count_tokens(context.content, model)
|
||||
context.token_count = count
|
||||
|
||||
def _needs_compression(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
budget: TokenBudget,
|
||||
) -> bool:
|
||||
"""Check if any contexts exceed their type budget."""
|
||||
# Group by type and check totals
|
||||
by_type: dict[ContextType, int] = {}
|
||||
for context in contexts:
|
||||
ct = context.get_type()
|
||||
by_type[ct] = by_type.get(ct, 0) + (context.token_count or 0)
|
||||
|
||||
for ct, total in by_type.items():
|
||||
if total > budget.get_allocation(ct):
|
||||
return True
|
||||
|
||||
# Also check if utilization exceeds threshold
|
||||
return budget.utilization() > self._settings.compression_threshold
|
||||
|
||||
def _format_contexts(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
model: str,
|
||||
) -> str:
|
||||
"""
|
||||
Format contexts for the target model.
|
||||
|
||||
Uses model-specific adapters (ClaudeAdapter, OpenAIAdapter, etc.)
|
||||
to format contexts optimally for each model family.
|
||||
|
||||
Args:
|
||||
contexts: Contexts to format
|
||||
model: Target model name
|
||||
|
||||
Returns:
|
||||
Formatted context string
|
||||
"""
|
||||
adapter = get_adapter(model)
|
||||
return adapter.format(contexts)
|
||||
|
||||
def _check_timeout(
|
||||
self,
|
||||
start: float,
|
||||
timeout_ms: int,
|
||||
phase: str,
|
||||
) -> None:
|
||||
"""Check if timeout exceeded and raise if so."""
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
if elapsed_ms >= timeout_ms:
|
||||
raise AssemblyTimeoutError(
|
||||
message=f"Context assembly timed out during {phase}",
|
||||
elapsed_ms=elapsed_ms,
|
||||
timeout_ms=timeout_ms,
|
||||
)
|
||||
|
||||
def _remaining_timeout(self, start: float, timeout_ms: int) -> float:
|
||||
"""
|
||||
Calculate remaining timeout in seconds for asyncio.wait_for.
|
||||
|
||||
Returns at least a small positive value to avoid immediate timeout
|
||||
edge cases with wait_for.
|
||||
|
||||
Args:
|
||||
start: Start time from time.perf_counter()
|
||||
timeout_ms: Total timeout in milliseconds
|
||||
|
||||
Returns:
|
||||
Remaining timeout in seconds (minimum 0.001)
|
||||
"""
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
remaining_ms = timeout_ms - elapsed_ms
|
||||
# Return at least 1ms to avoid zero/negative timeout edge cases
|
||||
return max(remaining_ms / 1000.0, 0.001)
|
||||
14
backend/app/services/context/budget/__init__.py
Normal file
14
backend/app/services/context/budget/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""
|
||||
Token Budget Management Module.
|
||||
|
||||
Provides token counting and budget allocation.
|
||||
"""
|
||||
|
||||
from .allocator import BudgetAllocator, TokenBudget
|
||||
from .calculator import TokenCalculator
|
||||
|
||||
__all__ = [
|
||||
"BudgetAllocator",
|
||||
"TokenBudget",
|
||||
"TokenCalculator",
|
||||
]
|
||||
433
backend/app/services/context/budget/allocator.py
Normal file
433
backend/app/services/context/budget/allocator.py
Normal file
@@ -0,0 +1,433 @@
|
||||
"""
|
||||
Token Budget Allocator for Context Management.
|
||||
|
||||
Manages token budget allocation across context types.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from ..config import ContextSettings, get_context_settings
|
||||
from ..exceptions import BudgetExceededError
|
||||
from ..types import ContextType
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenBudget:
|
||||
"""
|
||||
Token budget allocation and tracking.
|
||||
|
||||
Tracks allocated tokens per context type and
|
||||
monitors usage to prevent overflows.
|
||||
"""
|
||||
|
||||
# Total budget
|
||||
total: int
|
||||
|
||||
# Allocated per type
|
||||
system: int = 0
|
||||
task: int = 0
|
||||
knowledge: int = 0
|
||||
conversation: int = 0
|
||||
tools: int = 0
|
||||
response_reserve: int = 0
|
||||
buffer: int = 0
|
||||
|
||||
# Usage tracking
|
||||
used: dict[str, int] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Initialize usage tracking."""
|
||||
if not self.used:
|
||||
self.used = {ct.value: 0 for ct in ContextType}
|
||||
|
||||
def get_allocation(self, context_type: ContextType | str) -> int:
|
||||
"""
|
||||
Get allocated tokens for a context type.
|
||||
|
||||
Args:
|
||||
context_type: Context type to get allocation for
|
||||
|
||||
Returns:
|
||||
Allocated token count
|
||||
"""
|
||||
if isinstance(context_type, ContextType):
|
||||
context_type = context_type.value
|
||||
|
||||
allocation_map = {
|
||||
"system": self.system,
|
||||
"task": self.task,
|
||||
"knowledge": self.knowledge,
|
||||
"conversation": self.conversation,
|
||||
"tool": self.tools,
|
||||
}
|
||||
return allocation_map.get(context_type, 0)
|
||||
|
||||
def get_used(self, context_type: ContextType | str) -> int:
|
||||
"""
|
||||
Get used tokens for a context type.
|
||||
|
||||
Args:
|
||||
context_type: Context type to check
|
||||
|
||||
Returns:
|
||||
Used token count
|
||||
"""
|
||||
if isinstance(context_type, ContextType):
|
||||
context_type = context_type.value
|
||||
return self.used.get(context_type, 0)
|
||||
|
||||
def remaining(self, context_type: ContextType | str) -> int:
|
||||
"""
|
||||
Get remaining tokens for a context type.
|
||||
|
||||
Args:
|
||||
context_type: Context type to check
|
||||
|
||||
Returns:
|
||||
Remaining token count
|
||||
"""
|
||||
allocated = self.get_allocation(context_type)
|
||||
used = self.get_used(context_type)
|
||||
return max(0, allocated - used)
|
||||
|
||||
def total_remaining(self) -> int:
|
||||
"""
|
||||
Get total remaining tokens across all types.
|
||||
|
||||
Returns:
|
||||
Total remaining tokens
|
||||
"""
|
||||
total_used = sum(self.used.values())
|
||||
usable = self.total - self.response_reserve - self.buffer
|
||||
return max(0, usable - total_used)
|
||||
|
||||
def total_used(self) -> int:
|
||||
"""
|
||||
Get total used tokens.
|
||||
|
||||
Returns:
|
||||
Total used tokens
|
||||
"""
|
||||
return sum(self.used.values())
|
||||
|
||||
def can_fit(self, context_type: ContextType | str, tokens: int) -> bool:
|
||||
"""
|
||||
Check if tokens fit within budget for a type.
|
||||
|
||||
Args:
|
||||
context_type: Context type to check
|
||||
tokens: Number of tokens to fit
|
||||
|
||||
Returns:
|
||||
True if tokens fit within remaining budget
|
||||
"""
|
||||
return tokens <= self.remaining(context_type)
|
||||
|
||||
def allocate(
|
||||
self,
|
||||
context_type: ContextType | str,
|
||||
tokens: int,
|
||||
force: bool = False,
|
||||
) -> bool:
|
||||
"""
|
||||
Allocate (use) tokens from a context type's budget.
|
||||
|
||||
Args:
|
||||
context_type: Context type to allocate from
|
||||
tokens: Number of tokens to allocate
|
||||
force: If True, allow exceeding budget
|
||||
|
||||
Returns:
|
||||
True if allocation succeeded
|
||||
|
||||
Raises:
|
||||
BudgetExceededError: If tokens exceed budget and force=False
|
||||
"""
|
||||
if isinstance(context_type, ContextType):
|
||||
context_type = context_type.value
|
||||
|
||||
if not force and not self.can_fit(context_type, tokens):
|
||||
raise BudgetExceededError(
|
||||
message=f"Token budget exceeded for {context_type}",
|
||||
allocated=self.get_allocation(context_type),
|
||||
requested=self.get_used(context_type) + tokens,
|
||||
context_type=context_type,
|
||||
)
|
||||
|
||||
self.used[context_type] = self.used.get(context_type, 0) + tokens
|
||||
return True
|
||||
|
||||
def deallocate(
|
||||
self,
|
||||
context_type: ContextType | str,
|
||||
tokens: int,
|
||||
) -> None:
|
||||
"""
|
||||
Deallocate (return) tokens to a context type's budget.
|
||||
|
||||
Args:
|
||||
context_type: Context type to return to
|
||||
tokens: Number of tokens to return
|
||||
"""
|
||||
if isinstance(context_type, ContextType):
|
||||
context_type = context_type.value
|
||||
|
||||
current = self.used.get(context_type, 0)
|
||||
self.used[context_type] = max(0, current - tokens)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset all usage tracking."""
|
||||
self.used = {ct.value: 0 for ct in ContextType}
|
||||
|
||||
def utilization(self, context_type: ContextType | str | None = None) -> float:
|
||||
"""
|
||||
Get budget utilization percentage.
|
||||
|
||||
Args:
|
||||
context_type: Specific type or None for total
|
||||
|
||||
Returns:
|
||||
Utilization as a fraction (0.0 to 1.0+)
|
||||
"""
|
||||
if context_type is None:
|
||||
usable = self.total - self.response_reserve - self.buffer
|
||||
if usable <= 0:
|
||||
return 0.0
|
||||
return self.total_used() / usable
|
||||
|
||||
allocated = self.get_allocation(context_type)
|
||||
if allocated <= 0:
|
||||
return 0.0
|
||||
return self.get_used(context_type) / allocated
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert budget to dictionary."""
|
||||
return {
|
||||
"total": self.total,
|
||||
"allocations": {
|
||||
"system": self.system,
|
||||
"task": self.task,
|
||||
"knowledge": self.knowledge,
|
||||
"conversation": self.conversation,
|
||||
"tools": self.tools,
|
||||
"response_reserve": self.response_reserve,
|
||||
"buffer": self.buffer,
|
||||
},
|
||||
"used": dict(self.used),
|
||||
"remaining": {ct.value: self.remaining(ct) for ct in ContextType},
|
||||
"total_used": self.total_used(),
|
||||
"total_remaining": self.total_remaining(),
|
||||
"utilization": round(self.utilization(), 3),
|
||||
}
|
||||
|
||||
|
||||
class BudgetAllocator:
|
||||
"""
|
||||
Budget allocator for context management.
|
||||
|
||||
Creates token budgets based on configuration and
|
||||
model context window sizes.
|
||||
"""
|
||||
|
||||
def __init__(self, settings: ContextSettings | None = None) -> None:
|
||||
"""
|
||||
Initialize budget allocator.
|
||||
|
||||
Args:
|
||||
settings: Context settings (uses default if None)
|
||||
"""
|
||||
self._settings = settings or get_context_settings()
|
||||
|
||||
def create_budget(
|
||||
self,
|
||||
total_tokens: int,
|
||||
custom_allocations: dict[str, float] | None = None,
|
||||
) -> TokenBudget:
|
||||
"""
|
||||
Create a token budget with allocations.
|
||||
|
||||
Args:
|
||||
total_tokens: Total available tokens
|
||||
custom_allocations: Optional custom allocation percentages
|
||||
|
||||
Returns:
|
||||
TokenBudget with allocations set
|
||||
"""
|
||||
# Use custom or default allocations
|
||||
if custom_allocations:
|
||||
alloc = custom_allocations
|
||||
else:
|
||||
alloc = self._settings.get_budget_allocation()
|
||||
|
||||
return TokenBudget(
|
||||
total=total_tokens,
|
||||
system=int(total_tokens * alloc.get("system", 0.05)),
|
||||
task=int(total_tokens * alloc.get("task", 0.10)),
|
||||
knowledge=int(total_tokens * alloc.get("knowledge", 0.40)),
|
||||
conversation=int(total_tokens * alloc.get("conversation", 0.20)),
|
||||
tools=int(total_tokens * alloc.get("tools", 0.05)),
|
||||
response_reserve=int(total_tokens * alloc.get("response", 0.15)),
|
||||
buffer=int(total_tokens * alloc.get("buffer", 0.05)),
|
||||
)
|
||||
|
||||
def adjust_budget(
|
||||
self,
|
||||
budget: TokenBudget,
|
||||
context_type: ContextType | str,
|
||||
adjustment: int,
|
||||
) -> TokenBudget:
|
||||
"""
|
||||
Adjust a specific allocation in a budget.
|
||||
|
||||
Takes tokens from buffer and adds to specified type.
|
||||
|
||||
Args:
|
||||
budget: Budget to adjust
|
||||
context_type: Type to adjust
|
||||
adjustment: Positive to increase, negative to decrease
|
||||
|
||||
Returns:
|
||||
Adjusted budget
|
||||
"""
|
||||
if isinstance(context_type, ContextType):
|
||||
context_type = context_type.value
|
||||
|
||||
# Calculate adjustment (limited by buffer for increases, by current allocation for decreases)
|
||||
if adjustment > 0:
|
||||
# Taking from buffer - limited by available buffer
|
||||
actual_adjustment = min(adjustment, budget.buffer)
|
||||
budget.buffer -= actual_adjustment
|
||||
else:
|
||||
# Returning to buffer - limited by current allocation of target type
|
||||
current_allocation = budget.get_allocation(context_type)
|
||||
# Can't return more than current allocation
|
||||
actual_adjustment = max(adjustment, -current_allocation)
|
||||
# Add returned tokens back to buffer (adjustment is negative, so subtract)
|
||||
budget.buffer -= actual_adjustment
|
||||
|
||||
# Apply to target type
|
||||
if context_type == "system":
|
||||
budget.system = max(0, budget.system + actual_adjustment)
|
||||
elif context_type == "task":
|
||||
budget.task = max(0, budget.task + actual_adjustment)
|
||||
elif context_type == "knowledge":
|
||||
budget.knowledge = max(0, budget.knowledge + actual_adjustment)
|
||||
elif context_type == "conversation":
|
||||
budget.conversation = max(0, budget.conversation + actual_adjustment)
|
||||
elif context_type == "tool":
|
||||
budget.tools = max(0, budget.tools + actual_adjustment)
|
||||
|
||||
return budget
|
||||
|
||||
def rebalance_budget(
|
||||
self,
|
||||
budget: TokenBudget,
|
||||
prioritize: list[ContextType] | None = None,
|
||||
) -> TokenBudget:
|
||||
"""
|
||||
Rebalance budget based on actual usage.
|
||||
|
||||
Moves unused allocations to prioritized types.
|
||||
|
||||
Args:
|
||||
budget: Budget to rebalance
|
||||
prioritize: Types to prioritize (in order)
|
||||
|
||||
Returns:
|
||||
Rebalanced budget
|
||||
"""
|
||||
if prioritize is None:
|
||||
prioritize = [ContextType.KNOWLEDGE, ContextType.TASK, ContextType.SYSTEM]
|
||||
|
||||
# Calculate unused tokens per type
|
||||
unused: dict[str, int] = {}
|
||||
for ct in ContextType:
|
||||
remaining = budget.remaining(ct)
|
||||
if remaining > 0:
|
||||
unused[ct.value] = remaining
|
||||
|
||||
# Calculate total reclaimable (excluding prioritized types)
|
||||
prioritize_values = {ct.value for ct in prioritize}
|
||||
reclaimable = sum(
|
||||
tokens for ct, tokens in unused.items() if ct not in prioritize_values
|
||||
)
|
||||
|
||||
# Redistribute to prioritized types that are near capacity
|
||||
for ct in prioritize:
|
||||
utilization = budget.utilization(ct)
|
||||
|
||||
if utilization > 0.8: # Near capacity
|
||||
# Give more tokens from reclaimable pool
|
||||
bonus = min(reclaimable, budget.get_allocation(ct) // 2)
|
||||
self.adjust_budget(budget, ct, bonus)
|
||||
reclaimable -= bonus
|
||||
|
||||
if reclaimable <= 0:
|
||||
break
|
||||
|
||||
return budget
|
||||
|
||||
def get_model_context_size(self, model: str) -> int:
|
||||
"""
|
||||
Get context window size for a model.
|
||||
|
||||
Args:
|
||||
model: Model name
|
||||
|
||||
Returns:
|
||||
Context window size in tokens
|
||||
"""
|
||||
# Common model context sizes
|
||||
context_sizes = {
|
||||
"claude-3-opus": 200000,
|
||||
"claude-3-sonnet": 200000,
|
||||
"claude-3-haiku": 200000,
|
||||
"claude-3-5-sonnet": 200000,
|
||||
"claude-3-5-haiku": 200000,
|
||||
"claude-opus-4": 200000,
|
||||
"gpt-4-turbo": 128000,
|
||||
"gpt-4": 8192,
|
||||
"gpt-4-32k": 32768,
|
||||
"gpt-4o": 128000,
|
||||
"gpt-4o-mini": 128000,
|
||||
"gpt-3.5-turbo": 16385,
|
||||
"gemini-1.5-pro": 2000000,
|
||||
"gemini-1.5-flash": 1000000,
|
||||
"gemini-2.0-flash": 1000000,
|
||||
"qwen-plus": 32000,
|
||||
"qwen-turbo": 8000,
|
||||
"deepseek-chat": 64000,
|
||||
"deepseek-reasoner": 64000,
|
||||
}
|
||||
|
||||
# Check exact match first
|
||||
model_lower = model.lower()
|
||||
if model_lower in context_sizes:
|
||||
return context_sizes[model_lower]
|
||||
|
||||
# Check prefix match
|
||||
for model_name, size in context_sizes.items():
|
||||
if model_lower.startswith(model_name):
|
||||
return size
|
||||
|
||||
# Default fallback
|
||||
return 8192
|
||||
|
||||
def create_budget_for_model(
|
||||
self,
|
||||
model: str,
|
||||
custom_allocations: dict[str, float] | None = None,
|
||||
) -> TokenBudget:
|
||||
"""
|
||||
Create a budget based on model's context window.
|
||||
|
||||
Args:
|
||||
model: Model name
|
||||
custom_allocations: Optional custom allocation percentages
|
||||
|
||||
Returns:
|
||||
TokenBudget sized for the model
|
||||
"""
|
||||
context_size = self.get_model_context_size(model)
|
||||
return self.create_budget(context_size, custom_allocations)
|
||||
285
backend/app/services/context/budget/calculator.py
Normal file
285
backend/app/services/context/budget/calculator.py
Normal file
@@ -0,0 +1,285 @@
|
||||
"""
|
||||
Token Calculator for Context Management.
|
||||
|
||||
Provides token counting with caching and fallback estimation.
|
||||
Integrates with LLM Gateway for accurate counts.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Protocol
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.services.mcp.client_manager import MCPClientManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TokenCounterProtocol(Protocol):
|
||||
"""Protocol for token counting implementations."""
|
||||
|
||||
async def count_tokens(
|
||||
self,
|
||||
text: str,
|
||||
model: str | None = None,
|
||||
) -> int:
|
||||
"""Count tokens in text."""
|
||||
...
|
||||
|
||||
|
||||
class TokenCalculator:
|
||||
"""
|
||||
Token calculator with LLM Gateway integration.
|
||||
|
||||
Features:
|
||||
- In-memory caching for repeated text
|
||||
- Fallback to character-based estimation
|
||||
- Model-specific counting when possible
|
||||
|
||||
The calculator uses the LLM Gateway's count_tokens tool
|
||||
for accurate counting, with a local cache to avoid
|
||||
repeated calls for the same content.
|
||||
"""
|
||||
|
||||
# Default characters per token ratio for estimation
|
||||
DEFAULT_CHARS_PER_TOKEN: ClassVar[float] = 4.0
|
||||
|
||||
# Model-specific ratios (more accurate estimation)
|
||||
MODEL_CHAR_RATIOS: ClassVar[dict[str, float]] = {
|
||||
"claude": 3.5,
|
||||
"gpt-4": 4.0,
|
||||
"gpt-3.5": 4.0,
|
||||
"gemini": 4.0,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mcp_manager: "MCPClientManager | None" = None,
|
||||
project_id: str = "system",
|
||||
agent_id: str = "context-engine",
|
||||
cache_enabled: bool = True,
|
||||
cache_max_size: int = 10000,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize token calculator.
|
||||
|
||||
Args:
|
||||
mcp_manager: MCP client manager for LLM Gateway calls
|
||||
project_id: Project ID for LLM Gateway calls
|
||||
agent_id: Agent ID for LLM Gateway calls
|
||||
cache_enabled: Whether to enable in-memory caching
|
||||
cache_max_size: Maximum cache entries
|
||||
"""
|
||||
self._mcp = mcp_manager
|
||||
self._project_id = project_id
|
||||
self._agent_id = agent_id
|
||||
self._cache_enabled = cache_enabled
|
||||
self._cache_max_size = cache_max_size
|
||||
|
||||
# In-memory cache: hash(model:text) -> token_count
|
||||
self._cache: dict[str, int] = {}
|
||||
self._cache_hits = 0
|
||||
self._cache_misses = 0
|
||||
|
||||
def _get_cache_key(self, text: str, model: str | None) -> str:
|
||||
"""Generate cache key from text and model."""
|
||||
# Use hash for efficient storage
|
||||
content = f"{model or 'default'}:{text}"
|
||||
return hashlib.sha256(content.encode()).hexdigest()[:32]
|
||||
|
||||
def _check_cache(self, cache_key: str) -> int | None:
|
||||
"""Check cache for existing count."""
|
||||
if not self._cache_enabled:
|
||||
return None
|
||||
|
||||
if cache_key in self._cache:
|
||||
self._cache_hits += 1
|
||||
return self._cache[cache_key]
|
||||
|
||||
self._cache_misses += 1
|
||||
return None
|
||||
|
||||
def _store_cache(self, cache_key: str, count: int) -> None:
|
||||
"""Store count in cache."""
|
||||
if not self._cache_enabled:
|
||||
return
|
||||
|
||||
# Simple LRU-like eviction: remove oldest entries when full
|
||||
if len(self._cache) >= self._cache_max_size:
|
||||
# Remove first 10% of entries
|
||||
entries_to_remove = self._cache_max_size // 10
|
||||
keys_to_remove = list(self._cache.keys())[:entries_to_remove]
|
||||
for key in keys_to_remove:
|
||||
del self._cache[key]
|
||||
|
||||
self._cache[cache_key] = count
|
||||
|
||||
def estimate_tokens(self, text: str, model: str | None = None) -> int:
|
||||
"""
|
||||
Estimate token count based on character count.
|
||||
|
||||
This is a fast fallback when LLM Gateway is unavailable.
|
||||
|
||||
Args:
|
||||
text: Text to count
|
||||
model: Optional model for more accurate ratio
|
||||
|
||||
Returns:
|
||||
Estimated token count
|
||||
"""
|
||||
if not text:
|
||||
return 0
|
||||
|
||||
# Get model-specific ratio
|
||||
ratio = self.DEFAULT_CHARS_PER_TOKEN
|
||||
if model:
|
||||
model_lower = model.lower()
|
||||
for model_prefix, model_ratio in self.MODEL_CHAR_RATIOS.items():
|
||||
if model_prefix in model_lower:
|
||||
ratio = model_ratio
|
||||
break
|
||||
|
||||
return max(1, int(len(text) / ratio))
|
||||
|
||||
async def count_tokens(
|
||||
self,
|
||||
text: str,
|
||||
model: str | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Count tokens in text.
|
||||
|
||||
Uses LLM Gateway for accurate counts with fallback to estimation.
|
||||
|
||||
Args:
|
||||
text: Text to count
|
||||
model: Optional model for accurate counting
|
||||
|
||||
Returns:
|
||||
Token count
|
||||
"""
|
||||
if not text:
|
||||
return 0
|
||||
|
||||
# Check cache first
|
||||
cache_key = self._get_cache_key(text, model)
|
||||
cached = self._check_cache(cache_key)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
# Try LLM Gateway
|
||||
if self._mcp is not None:
|
||||
try:
|
||||
result = await self._mcp.call_tool(
|
||||
server="llm-gateway",
|
||||
tool="count_tokens",
|
||||
args={
|
||||
"project_id": self._project_id,
|
||||
"agent_id": self._agent_id,
|
||||
"text": text,
|
||||
"model": model,
|
||||
},
|
||||
)
|
||||
|
||||
# Parse result
|
||||
if result.success and result.data:
|
||||
count = self._parse_token_count(result.data)
|
||||
if count is not None:
|
||||
self._store_cache(cache_key, count)
|
||||
return count
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM Gateway token count failed, using estimation: {e}")
|
||||
|
||||
# Fallback to estimation
|
||||
count = self.estimate_tokens(text, model)
|
||||
self._store_cache(cache_key, count)
|
||||
return count
|
||||
|
||||
def _parse_token_count(self, data: Any) -> int | None:
|
||||
"""Parse token count from LLM Gateway response."""
|
||||
if isinstance(data, dict):
|
||||
if "token_count" in data:
|
||||
return int(data["token_count"])
|
||||
if "tokens" in data:
|
||||
return int(data["tokens"])
|
||||
if "count" in data:
|
||||
return int(data["count"])
|
||||
|
||||
if isinstance(data, int):
|
||||
return data
|
||||
|
||||
if isinstance(data, str):
|
||||
# Try to parse from text content
|
||||
try:
|
||||
# Handle {"token_count": 123} or just "123"
|
||||
import json
|
||||
|
||||
parsed = json.loads(data)
|
||||
if isinstance(parsed, dict) and "token_count" in parsed:
|
||||
return int(parsed["token_count"])
|
||||
if isinstance(parsed, int):
|
||||
return parsed
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
# Try direct int conversion
|
||||
try:
|
||||
return int(data)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
async def count_tokens_batch(
|
||||
self,
|
||||
texts: list[str],
|
||||
model: str | None = None,
|
||||
) -> list[int]:
|
||||
"""
|
||||
Count tokens for multiple texts.
|
||||
|
||||
Efficient batch counting with caching and parallel execution.
|
||||
|
||||
Args:
|
||||
texts: List of texts to count
|
||||
model: Optional model for accurate counting
|
||||
|
||||
Returns:
|
||||
List of token counts (same order as input)
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
# Execute all token counts in parallel for better performance
|
||||
tasks = [self.count_tokens(text, model) for text in texts]
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear the token count cache."""
|
||||
self._cache.clear()
|
||||
self._cache_hits = 0
|
||||
self._cache_misses = 0
|
||||
|
||||
def get_cache_stats(self) -> dict[str, Any]:
|
||||
"""Get cache statistics."""
|
||||
total = self._cache_hits + self._cache_misses
|
||||
hit_rate = self._cache_hits / total if total > 0 else 0.0
|
||||
|
||||
return {
|
||||
"enabled": self._cache_enabled,
|
||||
"size": len(self._cache),
|
||||
"max_size": self._cache_max_size,
|
||||
"hits": self._cache_hits,
|
||||
"misses": self._cache_misses,
|
||||
"hit_rate": round(hit_rate, 3),
|
||||
}
|
||||
|
||||
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
|
||||
"""
|
||||
Set the MCP manager (for lazy initialization).
|
||||
|
||||
Args:
|
||||
mcp_manager: MCP client manager instance
|
||||
"""
|
||||
self._mcp = mcp_manager
|
||||
11
backend/app/services/context/cache/__init__.py
vendored
Normal file
11
backend/app/services/context/cache/__init__.py
vendored
Normal file
@@ -0,0 +1,11 @@
|
||||
"""
|
||||
Context Cache Module.
|
||||
|
||||
Provides Redis-based caching for assembled contexts.
|
||||
"""
|
||||
|
||||
from .context_cache import ContextCache
|
||||
|
||||
__all__ = [
|
||||
"ContextCache",
|
||||
]
|
||||
434
backend/app/services/context/cache/context_cache.py
vendored
Normal file
434
backend/app/services/context/cache/context_cache.py
vendored
Normal file
@@ -0,0 +1,434 @@
|
||||
"""
|
||||
Context Cache Implementation.
|
||||
|
||||
Provides Redis-based caching for context operations including
|
||||
assembled contexts, token counts, and scoring results.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from ..config import ContextSettings, get_context_settings
|
||||
from ..exceptions import CacheError
|
||||
from ..types import AssembledContext, BaseContext
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis.asyncio import Redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ContextCache:
|
||||
"""
|
||||
Redis-based caching for context operations.
|
||||
|
||||
Provides caching for:
|
||||
- Assembled contexts (fingerprint-based)
|
||||
- Token counts (content hash-based)
|
||||
- Scoring results (context + query hash-based)
|
||||
|
||||
Cache keys use a hierarchical structure:
|
||||
- ctx:assembled:{fingerprint}
|
||||
- ctx:tokens:{model}:{content_hash}
|
||||
- ctx:score:{scorer}:{context_hash}:{query_hash}
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis: "Redis | None" = None,
|
||||
settings: ContextSettings | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the context cache.
|
||||
|
||||
Args:
|
||||
redis: Redis connection (optional for testing)
|
||||
settings: Cache settings
|
||||
"""
|
||||
self._redis = redis
|
||||
self._settings = settings or get_context_settings()
|
||||
self._prefix = self._settings.cache_prefix
|
||||
self._ttl = self._settings.cache_ttl_seconds
|
||||
|
||||
# In-memory fallback cache when Redis unavailable
|
||||
self._memory_cache: dict[str, tuple[str, float]] = {}
|
||||
self._max_memory_items = self._settings.cache_memory_max_items
|
||||
|
||||
def set_redis(self, redis: "Redis") -> None:
|
||||
"""Set Redis connection."""
|
||||
self._redis = redis
|
||||
|
||||
@property
|
||||
def is_enabled(self) -> bool:
|
||||
"""Check if caching is enabled and available."""
|
||||
return self._settings.cache_enabled and self._redis is not None
|
||||
|
||||
def _cache_key(self, *parts: str) -> str:
|
||||
"""
|
||||
Build a cache key from parts.
|
||||
|
||||
Args:
|
||||
*parts: Key components
|
||||
|
||||
Returns:
|
||||
Colon-separated cache key
|
||||
"""
|
||||
return f"{self._prefix}:{':'.join(parts)}"
|
||||
|
||||
@staticmethod
|
||||
def _hash_content(content: str) -> str:
|
||||
"""
|
||||
Compute hash of content for cache key.
|
||||
|
||||
Args:
|
||||
content: Content to hash
|
||||
|
||||
Returns:
|
||||
32-character hex hash
|
||||
"""
|
||||
return hashlib.sha256(content.encode()).hexdigest()[:32]
|
||||
|
||||
def compute_fingerprint(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
query: str,
|
||||
model: str,
|
||||
project_id: str | None = None,
|
||||
agent_id: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Compute a fingerprint for a context assembly request.
|
||||
|
||||
The fingerprint is based on:
|
||||
- Project and agent IDs (for tenant isolation)
|
||||
- Context content hash and metadata (not full content for performance)
|
||||
- Query string
|
||||
- Target model
|
||||
|
||||
SECURITY: project_id and agent_id MUST be included to prevent
|
||||
cross-tenant cache pollution. Without these, one tenant could
|
||||
receive cached contexts from another tenant with the same query.
|
||||
|
||||
Args:
|
||||
contexts: List of contexts
|
||||
query: Query string
|
||||
model: Model name
|
||||
project_id: Project ID for tenant isolation
|
||||
agent_id: Agent ID for tenant isolation
|
||||
|
||||
Returns:
|
||||
32-character hex fingerprint
|
||||
"""
|
||||
# Build a deterministic representation using content hashes for performance
|
||||
# This avoids JSON serializing potentially large content strings
|
||||
context_data = []
|
||||
for ctx in contexts:
|
||||
context_data.append(
|
||||
{
|
||||
"type": ctx.get_type().value,
|
||||
"content_hash": self._hash_content(
|
||||
ctx.content
|
||||
), # Hash instead of full content
|
||||
"source": ctx.source,
|
||||
"priority": ctx.priority, # Already an int
|
||||
}
|
||||
)
|
||||
|
||||
data = {
|
||||
# CRITICAL: Include tenant identifiers for cache isolation
|
||||
"project_id": project_id or "",
|
||||
"agent_id": agent_id or "",
|
||||
"contexts": context_data,
|
||||
"query": query,
|
||||
"model": model,
|
||||
}
|
||||
|
||||
content = json.dumps(data, sort_keys=True)
|
||||
return self._hash_content(content)
|
||||
|
||||
async def get_assembled(
|
||||
self,
|
||||
fingerprint: str,
|
||||
) -> AssembledContext | None:
|
||||
"""
|
||||
Get cached assembled context by fingerprint.
|
||||
|
||||
Args:
|
||||
fingerprint: Assembly fingerprint
|
||||
|
||||
Returns:
|
||||
Cached AssembledContext or None if not found
|
||||
"""
|
||||
if not self.is_enabled:
|
||||
return None
|
||||
|
||||
key = self._cache_key("assembled", fingerprint)
|
||||
|
||||
try:
|
||||
data = await self._redis.get(key) # type: ignore
|
||||
if data:
|
||||
logger.debug(f"Cache hit for assembled context: {fingerprint}")
|
||||
result = AssembledContext.from_json(data)
|
||||
result.cache_hit = True
|
||||
result.cache_key = fingerprint
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache get error: {e}")
|
||||
raise CacheError(f"Failed to get assembled context: {e}") from e
|
||||
|
||||
return None
|
||||
|
||||
async def set_assembled(
|
||||
self,
|
||||
fingerprint: str,
|
||||
context: AssembledContext,
|
||||
ttl: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Cache an assembled context.
|
||||
|
||||
Args:
|
||||
fingerprint: Assembly fingerprint
|
||||
context: Assembled context to cache
|
||||
ttl: Optional TTL override in seconds
|
||||
"""
|
||||
if not self.is_enabled:
|
||||
return
|
||||
|
||||
key = self._cache_key("assembled", fingerprint)
|
||||
expire = ttl or self._ttl
|
||||
|
||||
try:
|
||||
await self._redis.setex(key, expire, context.to_json()) # type: ignore
|
||||
logger.debug(f"Cached assembled context: {fingerprint}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache set error: {e}")
|
||||
raise CacheError(f"Failed to cache assembled context: {e}") from e
|
||||
|
||||
async def get_token_count(
|
||||
self,
|
||||
content: str,
|
||||
model: str | None = None,
|
||||
) -> int | None:
|
||||
"""
|
||||
Get cached token count.
|
||||
|
||||
Args:
|
||||
content: Content to look up
|
||||
model: Model name for model-specific tokenization
|
||||
|
||||
Returns:
|
||||
Cached token count or None if not found
|
||||
"""
|
||||
model_key = model or "default"
|
||||
content_hash = self._hash_content(content)
|
||||
key = self._cache_key("tokens", model_key, content_hash)
|
||||
|
||||
# Try in-memory first
|
||||
if key in self._memory_cache:
|
||||
return int(self._memory_cache[key][0])
|
||||
|
||||
if not self.is_enabled:
|
||||
return None
|
||||
|
||||
try:
|
||||
data = await self._redis.get(key) # type: ignore
|
||||
if data:
|
||||
count = int(data)
|
||||
# Store in memory for faster subsequent access
|
||||
self._set_memory(key, str(count))
|
||||
return count
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache get error for tokens: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def set_token_count(
|
||||
self,
|
||||
content: str,
|
||||
count: int,
|
||||
model: str | None = None,
|
||||
ttl: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Cache a token count.
|
||||
|
||||
Args:
|
||||
content: Content that was counted
|
||||
count: Token count
|
||||
model: Model name
|
||||
ttl: Optional TTL override in seconds
|
||||
"""
|
||||
model_key = model or "default"
|
||||
content_hash = self._hash_content(content)
|
||||
key = self._cache_key("tokens", model_key, content_hash)
|
||||
expire = ttl or self._ttl
|
||||
|
||||
# Always store in memory
|
||||
self._set_memory(key, str(count))
|
||||
|
||||
if not self.is_enabled:
|
||||
return
|
||||
|
||||
try:
|
||||
await self._redis.setex(key, expire, str(count)) # type: ignore
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache set error for tokens: {e}")
|
||||
|
||||
async def get_score(
|
||||
self,
|
||||
scorer_name: str,
|
||||
context_id: str,
|
||||
query: str,
|
||||
) -> float | None:
|
||||
"""
|
||||
Get cached score.
|
||||
|
||||
Args:
|
||||
scorer_name: Name of the scorer
|
||||
context_id: Context identifier
|
||||
query: Query string
|
||||
|
||||
Returns:
|
||||
Cached score or None if not found
|
||||
"""
|
||||
query_hash = self._hash_content(query)[:16]
|
||||
key = self._cache_key("score", scorer_name, context_id, query_hash)
|
||||
|
||||
# Try in-memory first
|
||||
if key in self._memory_cache:
|
||||
return float(self._memory_cache[key][0])
|
||||
|
||||
if not self.is_enabled:
|
||||
return None
|
||||
|
||||
try:
|
||||
data = await self._redis.get(key) # type: ignore
|
||||
if data:
|
||||
score = float(data)
|
||||
self._set_memory(key, str(score))
|
||||
return score
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache get error for score: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def set_score(
|
||||
self,
|
||||
scorer_name: str,
|
||||
context_id: str,
|
||||
query: str,
|
||||
score: float,
|
||||
ttl: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Cache a score.
|
||||
|
||||
Args:
|
||||
scorer_name: Name of the scorer
|
||||
context_id: Context identifier
|
||||
query: Query string
|
||||
score: Score value
|
||||
ttl: Optional TTL override in seconds
|
||||
"""
|
||||
query_hash = self._hash_content(query)[:16]
|
||||
key = self._cache_key("score", scorer_name, context_id, query_hash)
|
||||
expire = ttl or self._ttl
|
||||
|
||||
# Always store in memory
|
||||
self._set_memory(key, str(score))
|
||||
|
||||
if not self.is_enabled:
|
||||
return
|
||||
|
||||
try:
|
||||
await self._redis.setex(key, expire, str(score)) # type: ignore
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache set error for score: {e}")
|
||||
|
||||
async def invalidate(self, pattern: str) -> int:
|
||||
"""
|
||||
Invalidate cache entries matching a pattern.
|
||||
|
||||
Args:
|
||||
pattern: Key pattern (supports * wildcard)
|
||||
|
||||
Returns:
|
||||
Number of keys deleted
|
||||
"""
|
||||
if not self.is_enabled:
|
||||
return 0
|
||||
|
||||
full_pattern = self._cache_key(pattern)
|
||||
deleted = 0
|
||||
|
||||
try:
|
||||
async for key in self._redis.scan_iter(match=full_pattern): # type: ignore
|
||||
await self._redis.delete(key) # type: ignore
|
||||
deleted += 1
|
||||
|
||||
logger.info(f"Invalidated {deleted} cache entries matching {pattern}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache invalidation error: {e}")
|
||||
raise CacheError(f"Failed to invalidate cache: {e}") from e
|
||||
|
||||
return deleted
|
||||
|
||||
async def clear_all(self) -> int:
|
||||
"""
|
||||
Clear all context cache entries.
|
||||
|
||||
Returns:
|
||||
Number of keys deleted
|
||||
"""
|
||||
self._memory_cache.clear()
|
||||
return await self.invalidate("*")
|
||||
|
||||
def _set_memory(self, key: str, value: str) -> None:
|
||||
"""
|
||||
Set a value in the memory cache.
|
||||
|
||||
Uses LRU-style eviction when max items reached.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
value: Value to store
|
||||
"""
|
||||
import time
|
||||
|
||||
if len(self._memory_cache) >= self._max_memory_items:
|
||||
# Evict oldest entries
|
||||
sorted_keys = sorted(
|
||||
self._memory_cache.keys(),
|
||||
key=lambda k: self._memory_cache[k][1],
|
||||
)
|
||||
for k in sorted_keys[: len(sorted_keys) // 2]:
|
||||
del self._memory_cache[k]
|
||||
|
||||
self._memory_cache[key] = (value, time.time())
|
||||
|
||||
async def get_stats(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get cache statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with cache stats
|
||||
"""
|
||||
stats = {
|
||||
"enabled": self._settings.cache_enabled,
|
||||
"redis_available": self._redis is not None,
|
||||
"memory_items": len(self._memory_cache),
|
||||
"ttl_seconds": self._ttl,
|
||||
}
|
||||
|
||||
if self.is_enabled:
|
||||
try:
|
||||
# Get Redis info
|
||||
info = await self._redis.info("memory") # type: ignore
|
||||
stats["redis_memory_used"] = info.get("used_memory_human", "unknown")
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to get Redis stats: {e}")
|
||||
|
||||
return stats
|
||||
13
backend/app/services/context/compression/__init__.py
Normal file
13
backend/app/services/context/compression/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""
|
||||
Context Compression Module.
|
||||
|
||||
Provides truncation and compression strategies.
|
||||
"""
|
||||
|
||||
from .truncation import ContextCompressor, TruncationResult, TruncationStrategy
|
||||
|
||||
__all__ = [
|
||||
"ContextCompressor",
|
||||
"TruncationResult",
|
||||
"TruncationStrategy",
|
||||
]
|
||||
453
backend/app/services/context/compression/truncation.py
Normal file
453
backend/app/services/context/compression/truncation.py
Normal file
@@ -0,0 +1,453 @@
|
||||
"""
|
||||
Smart Truncation for Context Compression.
|
||||
|
||||
Provides intelligent truncation strategies to reduce context size
|
||||
while preserving the most important information.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ..config import ContextSettings, get_context_settings
|
||||
from ..types import BaseContext, ContextType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..budget import TokenBudget, TokenCalculator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _estimate_tokens(text: str, model: str | None = None) -> int:
|
||||
"""
|
||||
Estimate token count using model-specific character ratios.
|
||||
|
||||
Module-level function for reuse across classes. Uses the same ratios
|
||||
as TokenCalculator for consistency.
|
||||
|
||||
Args:
|
||||
text: Text to estimate tokens for
|
||||
model: Optional model name for model-specific ratios
|
||||
|
||||
Returns:
|
||||
Estimated token count (minimum 1)
|
||||
"""
|
||||
# Model-specific character ratios (chars per token)
|
||||
model_ratios = {
|
||||
"claude": 3.5,
|
||||
"gpt-4": 4.0,
|
||||
"gpt-3.5": 4.0,
|
||||
"gemini": 4.0,
|
||||
}
|
||||
default_ratio = 4.0
|
||||
|
||||
ratio = default_ratio
|
||||
if model:
|
||||
model_lower = model.lower()
|
||||
for model_prefix, model_ratio in model_ratios.items():
|
||||
if model_prefix in model_lower:
|
||||
ratio = model_ratio
|
||||
break
|
||||
|
||||
return max(1, int(len(text) / ratio))
|
||||
|
||||
|
||||
@dataclass
|
||||
class TruncationResult:
|
||||
"""Result of truncation operation."""
|
||||
|
||||
original_tokens: int
|
||||
truncated_tokens: int
|
||||
content: str
|
||||
truncated: bool
|
||||
truncation_ratio: float # 0.0 = no truncation, 1.0 = completely removed
|
||||
|
||||
@property
|
||||
def tokens_saved(self) -> int:
|
||||
"""Calculate tokens saved by truncation."""
|
||||
return self.original_tokens - self.truncated_tokens
|
||||
|
||||
|
||||
class TruncationStrategy:
|
||||
"""
|
||||
Smart truncation strategies for context compression.
|
||||
|
||||
Strategies:
|
||||
1. End truncation: Cut from end (for knowledge/docs)
|
||||
2. Middle truncation: Keep start and end (for code)
|
||||
3. Sentence-aware: Truncate at sentence boundaries
|
||||
4. Semantic chunking: Keep most relevant chunks
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
calculator: "TokenCalculator | None" = None,
|
||||
preserve_ratio_start: float | None = None,
|
||||
min_content_length: int | None = None,
|
||||
settings: ContextSettings | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize truncation strategy.
|
||||
|
||||
Args:
|
||||
calculator: Token calculator for accurate counting
|
||||
preserve_ratio_start: Ratio of content to keep from start (overrides settings)
|
||||
min_content_length: Minimum characters to preserve (overrides settings)
|
||||
settings: Context settings (uses global if None)
|
||||
"""
|
||||
self._settings = settings or get_context_settings()
|
||||
self._calculator = calculator
|
||||
|
||||
# Use provided values or fall back to settings
|
||||
self._preserve_ratio_start = (
|
||||
preserve_ratio_start
|
||||
if preserve_ratio_start is not None
|
||||
else self._settings.truncation_preserve_ratio
|
||||
)
|
||||
self._min_content_length = (
|
||||
min_content_length
|
||||
if min_content_length is not None
|
||||
else self._settings.truncation_min_content_length
|
||||
)
|
||||
|
||||
@property
|
||||
def truncation_marker(self) -> str:
|
||||
"""Get truncation marker from settings."""
|
||||
return self._settings.truncation_marker
|
||||
|
||||
def set_calculator(self, calculator: "TokenCalculator") -> None:
|
||||
"""Set token calculator."""
|
||||
self._calculator = calculator
|
||||
|
||||
async def truncate_to_tokens(
|
||||
self,
|
||||
content: str,
|
||||
max_tokens: int,
|
||||
strategy: str = "end",
|
||||
model: str | None = None,
|
||||
) -> TruncationResult:
|
||||
"""
|
||||
Truncate content to fit within token limit.
|
||||
|
||||
Args:
|
||||
content: Content to truncate
|
||||
max_tokens: Maximum tokens allowed
|
||||
strategy: Truncation strategy ('end', 'middle', 'sentence')
|
||||
model: Model for token counting
|
||||
|
||||
Returns:
|
||||
TruncationResult with truncated content
|
||||
"""
|
||||
if not content:
|
||||
return TruncationResult(
|
||||
original_tokens=0,
|
||||
truncated_tokens=0,
|
||||
content="",
|
||||
truncated=False,
|
||||
truncation_ratio=0.0,
|
||||
)
|
||||
|
||||
# Get original token count
|
||||
original_tokens = await self._count_tokens(content, model)
|
||||
|
||||
if original_tokens <= max_tokens:
|
||||
return TruncationResult(
|
||||
original_tokens=original_tokens,
|
||||
truncated_tokens=original_tokens,
|
||||
content=content,
|
||||
truncated=False,
|
||||
truncation_ratio=0.0,
|
||||
)
|
||||
|
||||
# Apply truncation strategy
|
||||
if strategy == "middle":
|
||||
truncated = await self._truncate_middle(content, max_tokens, model)
|
||||
elif strategy == "sentence":
|
||||
truncated = await self._truncate_sentence(content, max_tokens, model)
|
||||
else: # "end"
|
||||
truncated = await self._truncate_end(content, max_tokens, model)
|
||||
|
||||
truncated_tokens = await self._count_tokens(truncated, model)
|
||||
|
||||
return TruncationResult(
|
||||
original_tokens=original_tokens,
|
||||
truncated_tokens=truncated_tokens,
|
||||
content=truncated,
|
||||
truncated=True,
|
||||
truncation_ratio=0.0
|
||||
if original_tokens == 0
|
||||
else 1 - (truncated_tokens / original_tokens),
|
||||
)
|
||||
|
||||
async def _truncate_end(
|
||||
self,
|
||||
content: str,
|
||||
max_tokens: int,
|
||||
model: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Truncate from end of content.
|
||||
|
||||
Simple but effective for most content types.
|
||||
"""
|
||||
# Binary search for optimal truncation point
|
||||
marker_tokens = await self._count_tokens(self.truncation_marker, model)
|
||||
available_tokens = max(0, max_tokens - marker_tokens)
|
||||
|
||||
# Edge case: if no tokens available for content, return just the marker
|
||||
if available_tokens <= 0:
|
||||
return self.truncation_marker
|
||||
|
||||
# Estimate characters per token (guard against division by zero)
|
||||
content_tokens = await self._count_tokens(content, model)
|
||||
if content_tokens == 0:
|
||||
return content + self.truncation_marker
|
||||
chars_per_token = len(content) / content_tokens
|
||||
|
||||
# Start with estimated position
|
||||
estimated_chars = int(available_tokens * chars_per_token)
|
||||
truncated = content[:estimated_chars]
|
||||
|
||||
# Refine with binary search
|
||||
low, high = len(truncated) // 2, len(truncated)
|
||||
best = truncated
|
||||
|
||||
for _ in range(5): # Max 5 iterations
|
||||
mid = (low + high) // 2
|
||||
candidate = content[:mid]
|
||||
tokens = await self._count_tokens(candidate, model)
|
||||
|
||||
if tokens <= available_tokens:
|
||||
best = candidate
|
||||
low = mid + 1
|
||||
else:
|
||||
high = mid - 1
|
||||
|
||||
return best + self.truncation_marker
|
||||
|
||||
async def _truncate_middle(
|
||||
self,
|
||||
content: str,
|
||||
max_tokens: int,
|
||||
model: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Truncate from middle, keeping start and end.
|
||||
|
||||
Good for code or content where context at boundaries matters.
|
||||
"""
|
||||
marker_tokens = await self._count_tokens(self.truncation_marker, model)
|
||||
available_tokens = max_tokens - marker_tokens
|
||||
|
||||
# Split between start and end
|
||||
start_tokens = int(available_tokens * self._preserve_ratio_start)
|
||||
end_tokens = available_tokens - start_tokens
|
||||
|
||||
# Get start portion
|
||||
start_content = await self._get_content_for_tokens(
|
||||
content, start_tokens, from_start=True, model=model
|
||||
)
|
||||
|
||||
# Get end portion
|
||||
end_content = await self._get_content_for_tokens(
|
||||
content, end_tokens, from_start=False, model=model
|
||||
)
|
||||
|
||||
return start_content + self.truncation_marker + end_content
|
||||
|
||||
async def _truncate_sentence(
|
||||
self,
|
||||
content: str,
|
||||
max_tokens: int,
|
||||
model: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Truncate at sentence boundaries.
|
||||
|
||||
Produces cleaner output by not cutting mid-sentence.
|
||||
"""
|
||||
# Split into sentences
|
||||
sentences = re.split(r"(?<=[.!?])\s+", content)
|
||||
|
||||
result: list[str] = []
|
||||
total_tokens = 0
|
||||
marker_tokens = await self._count_tokens(self.truncation_marker, model)
|
||||
available = max_tokens - marker_tokens
|
||||
|
||||
for sentence in sentences:
|
||||
sentence_tokens = await self._count_tokens(sentence, model)
|
||||
if total_tokens + sentence_tokens <= available:
|
||||
result.append(sentence)
|
||||
total_tokens += sentence_tokens
|
||||
else:
|
||||
break
|
||||
|
||||
if len(result) < len(sentences):
|
||||
return " ".join(result) + self.truncation_marker
|
||||
return " ".join(result)
|
||||
|
||||
async def _get_content_for_tokens(
|
||||
self,
|
||||
content: str,
|
||||
target_tokens: int,
|
||||
from_start: bool = True,
|
||||
model: str | None = None,
|
||||
) -> str:
|
||||
"""Get portion of content fitting within token limit."""
|
||||
if target_tokens <= 0:
|
||||
return ""
|
||||
|
||||
current_tokens = await self._count_tokens(content, model)
|
||||
if current_tokens <= target_tokens:
|
||||
return content
|
||||
|
||||
# Estimate characters (guard against division by zero)
|
||||
if current_tokens == 0:
|
||||
return content
|
||||
chars_per_token = len(content) / current_tokens
|
||||
estimated_chars = int(target_tokens * chars_per_token)
|
||||
|
||||
if from_start:
|
||||
return content[:estimated_chars]
|
||||
else:
|
||||
return content[-estimated_chars:]
|
||||
|
||||
async def _count_tokens(self, text: str, model: str | None = None) -> int:
|
||||
"""Count tokens using calculator or estimation."""
|
||||
if self._calculator is not None:
|
||||
return await self._calculator.count_tokens(text, model)
|
||||
|
||||
# Fallback estimation with model-specific ratios
|
||||
return _estimate_tokens(text, model)
|
||||
|
||||
|
||||
class ContextCompressor:
|
||||
"""
|
||||
Compresses contexts to fit within budget constraints.
|
||||
|
||||
Uses truncation strategies to reduce context size while
|
||||
preserving the most important information.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
truncation: TruncationStrategy | None = None,
|
||||
calculator: "TokenCalculator | None" = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize context compressor.
|
||||
|
||||
Args:
|
||||
truncation: Truncation strategy to use
|
||||
calculator: Token calculator for counting
|
||||
"""
|
||||
self._truncation = truncation or TruncationStrategy(calculator)
|
||||
self._calculator = calculator
|
||||
|
||||
if calculator:
|
||||
self._truncation.set_calculator(calculator)
|
||||
|
||||
def set_calculator(self, calculator: "TokenCalculator") -> None:
|
||||
"""Set token calculator."""
|
||||
self._calculator = calculator
|
||||
self._truncation.set_calculator(calculator)
|
||||
|
||||
async def compress_context(
|
||||
self,
|
||||
context: BaseContext,
|
||||
max_tokens: int,
|
||||
model: str | None = None,
|
||||
) -> BaseContext:
|
||||
"""
|
||||
Compress a single context to fit token limit.
|
||||
|
||||
Args:
|
||||
context: Context to compress
|
||||
max_tokens: Maximum tokens allowed
|
||||
model: Model for token counting
|
||||
|
||||
Returns:
|
||||
Compressed context (may be same object if no compression needed)
|
||||
"""
|
||||
current_tokens = context.token_count or await self._count_tokens(
|
||||
context.content, model
|
||||
)
|
||||
|
||||
if current_tokens <= max_tokens:
|
||||
return context
|
||||
|
||||
# Choose strategy based on context type
|
||||
strategy = self._get_strategy_for_type(context.get_type())
|
||||
|
||||
result = await self._truncation.truncate_to_tokens(
|
||||
content=context.content,
|
||||
max_tokens=max_tokens,
|
||||
strategy=strategy,
|
||||
model=model,
|
||||
)
|
||||
|
||||
# Update context with truncated content
|
||||
context.content = result.content
|
||||
context.token_count = result.truncated_tokens
|
||||
context.metadata["truncated"] = True
|
||||
context.metadata["original_tokens"] = result.original_tokens
|
||||
|
||||
return context
|
||||
|
||||
async def compress_contexts(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
budget: "TokenBudget",
|
||||
model: str | None = None,
|
||||
) -> list[BaseContext]:
|
||||
"""
|
||||
Compress multiple contexts to fit within budget.
|
||||
|
||||
Args:
|
||||
contexts: Contexts to potentially compress
|
||||
budget: Token budget constraints
|
||||
model: Model for token counting
|
||||
|
||||
Returns:
|
||||
List of contexts (compressed as needed)
|
||||
"""
|
||||
result: list[BaseContext] = []
|
||||
|
||||
for context in contexts:
|
||||
context_type = context.get_type()
|
||||
remaining = budget.remaining(context_type)
|
||||
current_tokens = context.token_count or await self._count_tokens(
|
||||
context.content, model
|
||||
)
|
||||
|
||||
if current_tokens > remaining:
|
||||
# Need to compress
|
||||
compressed = await self.compress_context(context, remaining, model)
|
||||
result.append(compressed)
|
||||
logger.debug(
|
||||
f"Compressed {context_type.value} context from "
|
||||
f"{current_tokens} to {compressed.token_count} tokens"
|
||||
)
|
||||
else:
|
||||
result.append(context)
|
||||
|
||||
return result
|
||||
|
||||
def _get_strategy_for_type(self, context_type: ContextType) -> str:
|
||||
"""Get optimal truncation strategy for context type."""
|
||||
strategies = {
|
||||
ContextType.SYSTEM: "end", # Keep instructions at start
|
||||
ContextType.TASK: "end", # Keep task description start
|
||||
ContextType.KNOWLEDGE: "sentence", # Clean sentence boundaries
|
||||
ContextType.CONVERSATION: "end", # Keep recent conversation
|
||||
ContextType.TOOL: "middle", # Keep command and result summary
|
||||
}
|
||||
return strategies.get(context_type, "end")
|
||||
|
||||
async def _count_tokens(self, text: str, model: str | None = None) -> int:
|
||||
"""Count tokens using calculator or estimation."""
|
||||
if self._calculator is not None:
|
||||
return await self._calculator.count_tokens(text, model)
|
||||
# Use model-specific estimation for consistency
|
||||
return _estimate_tokens(text, model)
|
||||
380
backend/app/services/context/config.py
Normal file
380
backend/app/services/context/config.py
Normal file
@@ -0,0 +1,380 @@
|
||||
"""
|
||||
Context Management Engine Configuration.
|
||||
|
||||
Provides Pydantic settings for context assembly,
|
||||
token budget allocation, and caching.
|
||||
"""
|
||||
|
||||
import threading
|
||||
from functools import lru_cache
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, field_validator, model_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class ContextSettings(BaseSettings):
|
||||
"""
|
||||
Configuration for the Context Management Engine.
|
||||
|
||||
All settings can be overridden via environment variables
|
||||
with the CTX_ prefix.
|
||||
"""
|
||||
|
||||
# Budget allocation percentages (must sum to 1.0)
|
||||
budget_system: float = Field(
|
||||
default=0.05,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Percentage of budget for system prompts (5%)",
|
||||
)
|
||||
budget_task: float = Field(
|
||||
default=0.10,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Percentage of budget for task context (10%)",
|
||||
)
|
||||
budget_knowledge: float = Field(
|
||||
default=0.40,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Percentage of budget for RAG/knowledge (40%)",
|
||||
)
|
||||
budget_conversation: float = Field(
|
||||
default=0.20,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Percentage of budget for conversation history (20%)",
|
||||
)
|
||||
budget_tools: float = Field(
|
||||
default=0.05,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Percentage of budget for tool descriptions (5%)",
|
||||
)
|
||||
budget_response: float = Field(
|
||||
default=0.15,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Percentage reserved for response (15%)",
|
||||
)
|
||||
budget_buffer: float = Field(
|
||||
default=0.05,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Percentage buffer for safety margin (5%)",
|
||||
)
|
||||
|
||||
# Scoring weights
|
||||
scoring_relevance_weight: float = Field(
|
||||
default=0.5,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Weight for relevance scoring",
|
||||
)
|
||||
scoring_recency_weight: float = Field(
|
||||
default=0.3,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Weight for recency scoring",
|
||||
)
|
||||
scoring_priority_weight: float = Field(
|
||||
default=0.2,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Weight for priority scoring",
|
||||
)
|
||||
|
||||
# Recency decay settings
|
||||
recency_decay_hours: float = Field(
|
||||
default=24.0,
|
||||
gt=0.0,
|
||||
description="Hours until recency score decays to 50%",
|
||||
)
|
||||
recency_max_age_hours: float = Field(
|
||||
default=168.0,
|
||||
gt=0.0,
|
||||
description="Hours until context is considered stale (7 days)",
|
||||
)
|
||||
|
||||
# Compression settings
|
||||
compression_threshold: float = Field(
|
||||
default=0.8,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Compress when budget usage exceeds this percentage",
|
||||
)
|
||||
truncation_marker: str = Field(
|
||||
default="\n\n[...content truncated...]\n\n",
|
||||
description="Marker text to insert where content was truncated",
|
||||
)
|
||||
truncation_preserve_ratio: float = Field(
|
||||
default=0.7,
|
||||
ge=0.1,
|
||||
le=0.9,
|
||||
description="Ratio of content to preserve from start in middle truncation (0.7 = 70% start, 30% end)",
|
||||
)
|
||||
truncation_min_content_length: int = Field(
|
||||
default=100,
|
||||
ge=10,
|
||||
le=1000,
|
||||
description="Minimum content length in characters before truncation applies",
|
||||
)
|
||||
summary_model_group: str = Field(
|
||||
default="fast",
|
||||
description="Model group to use for summarization",
|
||||
)
|
||||
|
||||
# Caching settings
|
||||
cache_enabled: bool = Field(
|
||||
default=True,
|
||||
description="Enable Redis caching for assembled contexts",
|
||||
)
|
||||
cache_ttl_seconds: int = Field(
|
||||
default=3600,
|
||||
ge=60,
|
||||
le=86400,
|
||||
description="Cache TTL in seconds (1 hour default, max 24 hours)",
|
||||
)
|
||||
cache_prefix: str = Field(
|
||||
default="ctx",
|
||||
description="Redis key prefix for context cache",
|
||||
)
|
||||
cache_memory_max_items: int = Field(
|
||||
default=1000,
|
||||
ge=100,
|
||||
le=100000,
|
||||
description="Maximum items in memory fallback cache when Redis unavailable",
|
||||
)
|
||||
|
||||
# Performance settings
|
||||
max_assembly_time_ms: int = Field(
|
||||
default=2000,
|
||||
ge=10,
|
||||
le=30000,
|
||||
description="Maximum time for context assembly in milliseconds. "
|
||||
"Should be high enough to accommodate MCP calls for knowledge retrieval.",
|
||||
)
|
||||
parallel_scoring: bool = Field(
|
||||
default=True,
|
||||
description="Score contexts in parallel for better performance",
|
||||
)
|
||||
max_parallel_scores: int = Field(
|
||||
default=10,
|
||||
ge=1,
|
||||
le=50,
|
||||
description="Maximum number of contexts to score in parallel",
|
||||
)
|
||||
|
||||
# Knowledge retrieval settings
|
||||
knowledge_search_type: str = Field(
|
||||
default="hybrid",
|
||||
description="Default search type for knowledge retrieval",
|
||||
)
|
||||
knowledge_max_results: int = Field(
|
||||
default=10,
|
||||
ge=1,
|
||||
le=50,
|
||||
description="Maximum knowledge chunks to retrieve",
|
||||
)
|
||||
knowledge_min_score: float = Field(
|
||||
default=0.5,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Minimum relevance score for knowledge",
|
||||
)
|
||||
|
||||
# Relevance scoring settings
|
||||
relevance_keyword_fallback_weight: float = Field(
|
||||
default=0.5,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Maximum score for keyword-based fallback scoring (when semantic unavailable)",
|
||||
)
|
||||
relevance_semantic_max_chars: int = Field(
|
||||
default=2000,
|
||||
ge=100,
|
||||
le=10000,
|
||||
description="Maximum content length in chars for semantic similarity computation",
|
||||
)
|
||||
|
||||
# Diversity/ranking settings
|
||||
diversity_max_per_source: int = Field(
|
||||
default=3,
|
||||
ge=1,
|
||||
le=20,
|
||||
description="Maximum contexts from the same source in diversity reranking",
|
||||
)
|
||||
|
||||
# Conversation history settings
|
||||
conversation_max_turns: int = Field(
|
||||
default=20,
|
||||
ge=1,
|
||||
le=100,
|
||||
description="Maximum conversation turns to include",
|
||||
)
|
||||
conversation_recent_priority: bool = Field(
|
||||
default=True,
|
||||
description="Prioritize recent conversation turns",
|
||||
)
|
||||
|
||||
@field_validator("knowledge_search_type")
|
||||
@classmethod
|
||||
def validate_search_type(cls, v: str) -> str:
|
||||
"""Validate search type is valid."""
|
||||
valid_types = {"semantic", "keyword", "hybrid"}
|
||||
if v not in valid_types:
|
||||
raise ValueError(f"search_type must be one of: {valid_types}")
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_budget_allocation(self) -> "ContextSettings":
|
||||
"""Validate that budget percentages sum to 1.0."""
|
||||
total = (
|
||||
self.budget_system
|
||||
+ self.budget_task
|
||||
+ self.budget_knowledge
|
||||
+ self.budget_conversation
|
||||
+ self.budget_tools
|
||||
+ self.budget_response
|
||||
+ self.budget_buffer
|
||||
)
|
||||
# Allow small floating point error
|
||||
if abs(total - 1.0) > 0.001:
|
||||
raise ValueError(
|
||||
f"Budget percentages must sum to 1.0, got {total:.3f}. "
|
||||
f"Current allocation: system={self.budget_system}, task={self.budget_task}, "
|
||||
f"knowledge={self.budget_knowledge}, conversation={self.budget_conversation}, "
|
||||
f"tools={self.budget_tools}, response={self.budget_response}, buffer={self.budget_buffer}"
|
||||
)
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_scoring_weights(self) -> "ContextSettings":
|
||||
"""Validate that scoring weights sum to 1.0."""
|
||||
total = (
|
||||
self.scoring_relevance_weight
|
||||
+ self.scoring_recency_weight
|
||||
+ self.scoring_priority_weight
|
||||
)
|
||||
# Allow small floating point error
|
||||
if abs(total - 1.0) > 0.001:
|
||||
raise ValueError(
|
||||
f"Scoring weights must sum to 1.0, got {total:.3f}. "
|
||||
f"Current weights: relevance={self.scoring_relevance_weight}, "
|
||||
f"recency={self.scoring_recency_weight}, priority={self.scoring_priority_weight}"
|
||||
)
|
||||
return self
|
||||
|
||||
def get_budget_allocation(self) -> dict[str, float]:
|
||||
"""Get budget allocation as a dictionary."""
|
||||
return {
|
||||
"system": self.budget_system,
|
||||
"task": self.budget_task,
|
||||
"knowledge": self.budget_knowledge,
|
||||
"conversation": self.budget_conversation,
|
||||
"tools": self.budget_tools,
|
||||
"response": self.budget_response,
|
||||
"buffer": self.budget_buffer,
|
||||
}
|
||||
|
||||
def get_scoring_weights(self) -> dict[str, float]:
|
||||
"""Get scoring weights as a dictionary."""
|
||||
return {
|
||||
"relevance": self.scoring_relevance_weight,
|
||||
"recency": self.scoring_recency_weight,
|
||||
"priority": self.scoring_priority_weight,
|
||||
}
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert settings to dictionary for logging/debugging."""
|
||||
return {
|
||||
"budget": self.get_budget_allocation(),
|
||||
"scoring": self.get_scoring_weights(),
|
||||
"compression": {
|
||||
"threshold": self.compression_threshold,
|
||||
"summary_model_group": self.summary_model_group,
|
||||
"truncation_marker": self.truncation_marker,
|
||||
"truncation_preserve_ratio": self.truncation_preserve_ratio,
|
||||
"truncation_min_content_length": self.truncation_min_content_length,
|
||||
},
|
||||
"cache": {
|
||||
"enabled": self.cache_enabled,
|
||||
"ttl_seconds": self.cache_ttl_seconds,
|
||||
"prefix": self.cache_prefix,
|
||||
"memory_max_items": self.cache_memory_max_items,
|
||||
},
|
||||
"performance": {
|
||||
"max_assembly_time_ms": self.max_assembly_time_ms,
|
||||
"parallel_scoring": self.parallel_scoring,
|
||||
"max_parallel_scores": self.max_parallel_scores,
|
||||
},
|
||||
"knowledge": {
|
||||
"search_type": self.knowledge_search_type,
|
||||
"max_results": self.knowledge_max_results,
|
||||
"min_score": self.knowledge_min_score,
|
||||
},
|
||||
"relevance": {
|
||||
"keyword_fallback_weight": self.relevance_keyword_fallback_weight,
|
||||
"semantic_max_chars": self.relevance_semantic_max_chars,
|
||||
},
|
||||
"diversity": {
|
||||
"max_per_source": self.diversity_max_per_source,
|
||||
},
|
||||
"conversation": {
|
||||
"max_turns": self.conversation_max_turns,
|
||||
"recent_priority": self.conversation_recent_priority,
|
||||
},
|
||||
}
|
||||
|
||||
model_config = {
|
||||
"env_prefix": "CTX_",
|
||||
"env_file": "../.env",
|
||||
"env_file_encoding": "utf-8",
|
||||
"case_sensitive": False,
|
||||
"extra": "ignore",
|
||||
}
|
||||
|
||||
|
||||
# Thread-safe singleton pattern
|
||||
_settings: ContextSettings | None = None
|
||||
_settings_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_context_settings() -> ContextSettings:
|
||||
"""
|
||||
Get the global ContextSettings instance.
|
||||
|
||||
Thread-safe with double-checked locking pattern.
|
||||
|
||||
Returns:
|
||||
ContextSettings instance
|
||||
"""
|
||||
global _settings
|
||||
if _settings is None:
|
||||
with _settings_lock:
|
||||
if _settings is None:
|
||||
_settings = ContextSettings()
|
||||
return _settings
|
||||
|
||||
|
||||
def reset_context_settings() -> None:
|
||||
"""
|
||||
Reset the global settings instance.
|
||||
|
||||
Primarily used for testing.
|
||||
"""
|
||||
global _settings
|
||||
with _settings_lock:
|
||||
_settings = None
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_default_settings() -> ContextSettings:
|
||||
"""
|
||||
Get default settings (cached).
|
||||
|
||||
Use this for read-only access to defaults.
|
||||
For mutable access, use get_context_settings().
|
||||
"""
|
||||
return ContextSettings()
|
||||
485
backend/app/services/context/engine.py
Normal file
485
backend/app/services/context/engine.py
Normal file
@@ -0,0 +1,485 @@
|
||||
"""
|
||||
Context Management Engine.
|
||||
|
||||
Main orchestration layer for context assembly and optimization.
|
||||
Provides a high-level API for assembling optimized context for LLM requests.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from .assembly import ContextPipeline
|
||||
from .budget import BudgetAllocator, TokenBudget, TokenCalculator
|
||||
from .cache import ContextCache
|
||||
from .compression import ContextCompressor
|
||||
from .config import ContextSettings, get_context_settings
|
||||
from .prioritization import ContextRanker
|
||||
from .scoring import CompositeScorer
|
||||
from .types import (
|
||||
AssembledContext,
|
||||
BaseContext,
|
||||
ConversationContext,
|
||||
KnowledgeContext,
|
||||
MessageRole,
|
||||
SystemContext,
|
||||
TaskContext,
|
||||
ToolContext,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis.asyncio import Redis
|
||||
|
||||
from app.services.mcp.client_manager import MCPClientManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ContextEngine:
|
||||
"""
|
||||
Main context management engine.
|
||||
|
||||
Provides high-level API for context assembly and optimization.
|
||||
Integrates all components: scoring, ranking, compression, formatting, and caching.
|
||||
|
||||
Usage:
|
||||
engine = ContextEngine(mcp_manager=mcp, redis=redis)
|
||||
|
||||
# Assemble context for an LLM request
|
||||
result = await engine.assemble_context(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
query="implement user authentication",
|
||||
model="claude-3-sonnet",
|
||||
system_prompt="You are an expert developer.",
|
||||
knowledge_query="authentication best practices",
|
||||
)
|
||||
|
||||
# Use the assembled context
|
||||
print(result.content)
|
||||
print(f"Tokens: {result.total_tokens}")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mcp_manager: "MCPClientManager | None" = None,
|
||||
redis: "Redis | None" = None,
|
||||
settings: ContextSettings | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the context engine.
|
||||
|
||||
Args:
|
||||
mcp_manager: MCP client manager for LLM Gateway/Knowledge Base
|
||||
redis: Redis connection for caching
|
||||
settings: Context settings
|
||||
"""
|
||||
self._mcp = mcp_manager
|
||||
self._settings = settings or get_context_settings()
|
||||
|
||||
# Initialize components
|
||||
self._calculator = TokenCalculator(mcp_manager=mcp_manager)
|
||||
self._scorer = CompositeScorer(mcp_manager=mcp_manager, settings=self._settings)
|
||||
self._ranker = ContextRanker(scorer=self._scorer, calculator=self._calculator)
|
||||
self._compressor = ContextCompressor(calculator=self._calculator)
|
||||
self._allocator = BudgetAllocator(self._settings)
|
||||
self._cache = ContextCache(redis=redis, settings=self._settings)
|
||||
|
||||
# Pipeline for assembly
|
||||
self._pipeline = ContextPipeline(
|
||||
mcp_manager=mcp_manager,
|
||||
settings=self._settings,
|
||||
calculator=self._calculator,
|
||||
scorer=self._scorer,
|
||||
ranker=self._ranker,
|
||||
compressor=self._compressor,
|
||||
)
|
||||
|
||||
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
|
||||
"""
|
||||
Set MCP manager for all components.
|
||||
|
||||
Args:
|
||||
mcp_manager: MCP client manager
|
||||
"""
|
||||
self._mcp = mcp_manager
|
||||
self._calculator.set_mcp_manager(mcp_manager)
|
||||
self._scorer.set_mcp_manager(mcp_manager)
|
||||
self._pipeline.set_mcp_manager(mcp_manager)
|
||||
|
||||
def set_redis(self, redis: "Redis") -> None:
|
||||
"""
|
||||
Set Redis connection for caching.
|
||||
|
||||
Args:
|
||||
redis: Redis connection
|
||||
"""
|
||||
self._cache.set_redis(redis)
|
||||
|
||||
async def assemble_context(
|
||||
self,
|
||||
project_id: str,
|
||||
agent_id: str,
|
||||
query: str,
|
||||
model: str,
|
||||
max_tokens: int | None = None,
|
||||
system_prompt: str | None = None,
|
||||
task_description: str | None = None,
|
||||
knowledge_query: str | None = None,
|
||||
knowledge_limit: int = 10,
|
||||
conversation_history: list[dict[str, str]] | None = None,
|
||||
tool_results: list[dict[str, Any]] | None = None,
|
||||
custom_contexts: list[BaseContext] | None = None,
|
||||
custom_budget: TokenBudget | None = None,
|
||||
compress: bool = True,
|
||||
format_output: bool = True,
|
||||
use_cache: bool = True,
|
||||
) -> AssembledContext:
|
||||
"""
|
||||
Assemble optimized context for an LLM request.
|
||||
|
||||
This is the main entry point for context management.
|
||||
It gathers context from various sources, scores and ranks them,
|
||||
compresses if needed, and formats for the target model.
|
||||
|
||||
Args:
|
||||
project_id: Project identifier
|
||||
agent_id: Agent identifier
|
||||
query: User's query or current request
|
||||
model: Target model name
|
||||
max_tokens: Maximum context tokens (uses model default if None)
|
||||
system_prompt: System prompt/instructions
|
||||
task_description: Current task description
|
||||
knowledge_query: Query for knowledge base search
|
||||
knowledge_limit: Max number of knowledge results
|
||||
conversation_history: List of {"role": str, "content": str}
|
||||
tool_results: List of tool results to include
|
||||
custom_contexts: Additional custom contexts
|
||||
custom_budget: Custom token budget
|
||||
compress: Whether to apply compression
|
||||
format_output: Whether to format for the model
|
||||
use_cache: Whether to use caching
|
||||
|
||||
Returns:
|
||||
AssembledContext with optimized content
|
||||
|
||||
Raises:
|
||||
AssemblyTimeoutError: If assembly exceeds timeout
|
||||
BudgetExceededError: If context exceeds budget
|
||||
"""
|
||||
# Gather all contexts
|
||||
contexts: list[BaseContext] = []
|
||||
|
||||
# 1. System context
|
||||
if system_prompt:
|
||||
contexts.append(
|
||||
SystemContext(
|
||||
content=system_prompt,
|
||||
source="system_prompt",
|
||||
)
|
||||
)
|
||||
|
||||
# 2. Task context
|
||||
if task_description:
|
||||
contexts.append(
|
||||
TaskContext(
|
||||
content=task_description,
|
||||
source=f"task:{project_id}:{agent_id}",
|
||||
)
|
||||
)
|
||||
|
||||
# 3. Knowledge context from Knowledge Base
|
||||
if knowledge_query and self._mcp:
|
||||
knowledge_contexts = await self._fetch_knowledge(
|
||||
project_id=project_id,
|
||||
agent_id=agent_id,
|
||||
query=knowledge_query,
|
||||
limit=knowledge_limit,
|
||||
)
|
||||
contexts.extend(knowledge_contexts)
|
||||
|
||||
# 4. Conversation history
|
||||
if conversation_history:
|
||||
contexts.extend(self._convert_conversation(conversation_history))
|
||||
|
||||
# 5. Tool results
|
||||
if tool_results:
|
||||
contexts.extend(self._convert_tool_results(tool_results))
|
||||
|
||||
# 6. Custom contexts
|
||||
if custom_contexts:
|
||||
contexts.extend(custom_contexts)
|
||||
|
||||
# Check cache if enabled
|
||||
fingerprint: str | None = None
|
||||
if use_cache and self._cache.is_enabled:
|
||||
# Include project_id and agent_id for tenant isolation
|
||||
fingerprint = self._cache.compute_fingerprint(
|
||||
contexts, query, model, project_id=project_id, agent_id=agent_id
|
||||
)
|
||||
cached = await self._cache.get_assembled(fingerprint)
|
||||
if cached:
|
||||
logger.debug(f"Cache hit for context assembly: {fingerprint}")
|
||||
return cached
|
||||
|
||||
# Run assembly pipeline
|
||||
result = await self._pipeline.assemble(
|
||||
contexts=contexts,
|
||||
query=query,
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
custom_budget=custom_budget,
|
||||
compress=compress,
|
||||
format_output=format_output,
|
||||
)
|
||||
|
||||
# Cache result if enabled (reuse fingerprint computed above)
|
||||
if use_cache and self._cache.is_enabled and fingerprint is not None:
|
||||
await self._cache.set_assembled(fingerprint, result)
|
||||
|
||||
return result
|
||||
|
||||
async def _fetch_knowledge(
|
||||
self,
|
||||
project_id: str,
|
||||
agent_id: str,
|
||||
query: str,
|
||||
limit: int = 10,
|
||||
) -> list[KnowledgeContext]:
|
||||
"""
|
||||
Fetch relevant knowledge from Knowledge Base via MCP.
|
||||
|
||||
Args:
|
||||
project_id: Project identifier
|
||||
agent_id: Agent identifier
|
||||
query: Search query
|
||||
limit: Maximum results
|
||||
|
||||
Returns:
|
||||
List of KnowledgeContext instances
|
||||
"""
|
||||
if not self._mcp:
|
||||
return []
|
||||
|
||||
try:
|
||||
result = await self._mcp.call_tool(
|
||||
"knowledge-base",
|
||||
"search_knowledge",
|
||||
{
|
||||
"project_id": project_id,
|
||||
"agent_id": agent_id,
|
||||
"query": query,
|
||||
"search_type": "hybrid",
|
||||
"limit": limit,
|
||||
},
|
||||
)
|
||||
|
||||
# Check both ToolResult.success AND response success
|
||||
if not result.success:
|
||||
logger.warning(f"Knowledge search failed: {result.error}")
|
||||
return []
|
||||
|
||||
if not isinstance(result.data, dict) or not result.data.get(
|
||||
"success", True
|
||||
):
|
||||
logger.warning("Knowledge search returned unsuccessful response")
|
||||
return []
|
||||
|
||||
contexts = []
|
||||
results = result.data.get("results", [])
|
||||
for chunk in results:
|
||||
contexts.append(
|
||||
KnowledgeContext(
|
||||
content=chunk.get("content", ""),
|
||||
source=chunk.get("source_path", "unknown"),
|
||||
relevance_score=chunk.get("score", 0.0),
|
||||
metadata={
|
||||
"chunk_id": chunk.get(
|
||||
"id"
|
||||
), # Server returns 'id' not 'chunk_id'
|
||||
"document_id": chunk.get("document_id"),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug(f"Fetched {len(contexts)} knowledge chunks for query: {query}")
|
||||
return contexts
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch knowledge: {e}")
|
||||
return []
|
||||
|
||||
def _convert_conversation(
|
||||
self,
|
||||
history: list[dict[str, str]],
|
||||
) -> list[ConversationContext]:
|
||||
"""
|
||||
Convert conversation history to ConversationContext instances.
|
||||
|
||||
Args:
|
||||
history: List of {"role": str, "content": str}
|
||||
|
||||
Returns:
|
||||
List of ConversationContext instances
|
||||
"""
|
||||
contexts = []
|
||||
for i, turn in enumerate(history):
|
||||
role_str = turn.get("role", "user").lower()
|
||||
role = (
|
||||
MessageRole.ASSISTANT if role_str == "assistant" else MessageRole.USER
|
||||
)
|
||||
|
||||
contexts.append(
|
||||
ConversationContext(
|
||||
content=turn.get("content", ""),
|
||||
source=f"conversation:{i}",
|
||||
role=role,
|
||||
metadata={"role": role_str, "turn": i},
|
||||
)
|
||||
)
|
||||
|
||||
return contexts
|
||||
|
||||
def _convert_tool_results(
|
||||
self,
|
||||
results: list[dict[str, Any]],
|
||||
) -> list[ToolContext]:
|
||||
"""
|
||||
Convert tool results to ToolContext instances.
|
||||
|
||||
Args:
|
||||
results: List of tool result dictionaries
|
||||
|
||||
Returns:
|
||||
List of ToolContext instances
|
||||
"""
|
||||
contexts = []
|
||||
for result in results:
|
||||
tool_name = result.get("tool_name", "unknown")
|
||||
content = result.get("content", result.get("result", ""))
|
||||
|
||||
# Handle dict content
|
||||
if isinstance(content, dict):
|
||||
import json
|
||||
|
||||
content = json.dumps(content, indent=2)
|
||||
|
||||
contexts.append(
|
||||
ToolContext(
|
||||
content=str(content),
|
||||
source=f"tool:{tool_name}",
|
||||
metadata={
|
||||
"tool_name": tool_name,
|
||||
"status": result.get("status", "success"),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
return contexts
|
||||
|
||||
async def get_budget_for_model(
|
||||
self,
|
||||
model: str,
|
||||
max_tokens: int | None = None,
|
||||
) -> TokenBudget:
|
||||
"""
|
||||
Get the token budget for a specific model.
|
||||
|
||||
Args:
|
||||
model: Model name
|
||||
max_tokens: Optional max tokens override
|
||||
|
||||
Returns:
|
||||
TokenBudget instance
|
||||
"""
|
||||
if max_tokens:
|
||||
return self._allocator.create_budget(max_tokens)
|
||||
return self._allocator.create_budget_for_model(model)
|
||||
|
||||
async def count_tokens(
|
||||
self,
|
||||
content: str,
|
||||
model: str | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Count tokens in content.
|
||||
|
||||
Args:
|
||||
content: Content to count
|
||||
model: Model for model-specific tokenization
|
||||
|
||||
Returns:
|
||||
Token count
|
||||
"""
|
||||
# Check cache first
|
||||
cached = await self._cache.get_token_count(content, model)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
count = await self._calculator.count_tokens(content, model)
|
||||
|
||||
# Cache the result
|
||||
await self._cache.set_token_count(content, count, model)
|
||||
|
||||
return count
|
||||
|
||||
async def invalidate_cache(
|
||||
self,
|
||||
project_id: str | None = None,
|
||||
pattern: str | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Invalidate cache entries.
|
||||
|
||||
Args:
|
||||
project_id: Invalidate all cache for a project
|
||||
pattern: Custom pattern to match
|
||||
|
||||
Returns:
|
||||
Number of entries invalidated
|
||||
"""
|
||||
if pattern:
|
||||
return await self._cache.invalidate(pattern)
|
||||
elif project_id:
|
||||
return await self._cache.invalidate(f"*{project_id}*")
|
||||
else:
|
||||
return await self._cache.clear_all()
|
||||
|
||||
async def get_stats(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get engine statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with engine stats
|
||||
"""
|
||||
return {
|
||||
"cache": await self._cache.get_stats(),
|
||||
"settings": {
|
||||
"compression_threshold": self._settings.compression_threshold,
|
||||
"max_assembly_time_ms": self._settings.max_assembly_time_ms,
|
||||
"cache_enabled": self._settings.cache_enabled,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# Convenience factory function
|
||||
def create_context_engine(
|
||||
mcp_manager: "MCPClientManager | None" = None,
|
||||
redis: "Redis | None" = None,
|
||||
settings: ContextSettings | None = None,
|
||||
) -> ContextEngine:
|
||||
"""
|
||||
Create a context engine instance.
|
||||
|
||||
Args:
|
||||
mcp_manager: MCP client manager
|
||||
redis: Redis connection
|
||||
settings: Context settings
|
||||
|
||||
Returns:
|
||||
Configured ContextEngine instance
|
||||
"""
|
||||
return ContextEngine(
|
||||
mcp_manager=mcp_manager,
|
||||
redis=redis,
|
||||
settings=settings,
|
||||
)
|
||||
354
backend/app/services/context/exceptions.py
Normal file
354
backend/app/services/context/exceptions.py
Normal file
@@ -0,0 +1,354 @@
|
||||
"""
|
||||
Context Management Engine Exceptions.
|
||||
|
||||
Provides a hierarchy of exceptions for context assembly,
|
||||
token budget management, and related operations.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ContextError(Exception):
|
||||
"""
|
||||
Base exception for all context management errors.
|
||||
|
||||
All context-related exceptions should inherit from this class
|
||||
to allow for catch-all handling when needed.
|
||||
"""
|
||||
|
||||
def __init__(self, message: str, details: dict[str, Any] | None = None) -> None:
|
||||
"""
|
||||
Initialize context error.
|
||||
|
||||
Args:
|
||||
message: Human-readable error message
|
||||
details: Optional dict with additional error context
|
||||
"""
|
||||
self.message = message
|
||||
self.details = details or {}
|
||||
super().__init__(message)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert exception to dictionary for logging/serialization."""
|
||||
return {
|
||||
"error_type": self.__class__.__name__,
|
||||
"message": self.message,
|
||||
"details": self.details,
|
||||
}
|
||||
|
||||
|
||||
class BudgetExceededError(ContextError):
|
||||
"""
|
||||
Raised when token budget is exceeded.
|
||||
|
||||
This occurs when the assembled context would exceed the
|
||||
allocated token budget for a specific context type or total.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Token budget exceeded",
|
||||
allocated: int = 0,
|
||||
requested: int = 0,
|
||||
context_type: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize budget exceeded error.
|
||||
|
||||
Args:
|
||||
message: Error message
|
||||
allocated: Tokens allocated for this context type
|
||||
requested: Tokens requested
|
||||
context_type: Type of context that exceeded budget
|
||||
"""
|
||||
details: dict[str, Any] = {
|
||||
"allocated": allocated,
|
||||
"requested": requested,
|
||||
"overage": requested - allocated,
|
||||
}
|
||||
if context_type:
|
||||
details["context_type"] = context_type
|
||||
|
||||
super().__init__(message, details)
|
||||
self.allocated = allocated
|
||||
self.requested = requested
|
||||
self.context_type = context_type
|
||||
|
||||
|
||||
class TokenCountError(ContextError):
|
||||
"""
|
||||
Raised when token counting fails.
|
||||
|
||||
This typically occurs when the LLM Gateway token counting
|
||||
service is unavailable or returns an error.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Failed to count tokens",
|
||||
model: str | None = None,
|
||||
text_length: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize token count error.
|
||||
|
||||
Args:
|
||||
message: Error message
|
||||
model: Model for which counting was attempted
|
||||
text_length: Length of text that failed to count
|
||||
"""
|
||||
details: dict[str, Any] = {}
|
||||
if model:
|
||||
details["model"] = model
|
||||
if text_length is not None:
|
||||
details["text_length"] = text_length
|
||||
|
||||
super().__init__(message, details)
|
||||
self.model = model
|
||||
self.text_length = text_length
|
||||
|
||||
|
||||
class CompressionError(ContextError):
|
||||
"""
|
||||
Raised when context compression fails.
|
||||
|
||||
This can occur when summarization or truncation cannot
|
||||
reduce content to fit within the budget.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Failed to compress context",
|
||||
original_tokens: int | None = None,
|
||||
target_tokens: int | None = None,
|
||||
achieved_tokens: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize compression error.
|
||||
|
||||
Args:
|
||||
message: Error message
|
||||
original_tokens: Tokens before compression
|
||||
target_tokens: Target token count
|
||||
achieved_tokens: Tokens achieved after compression attempt
|
||||
"""
|
||||
details: dict[str, Any] = {}
|
||||
if original_tokens is not None:
|
||||
details["original_tokens"] = original_tokens
|
||||
if target_tokens is not None:
|
||||
details["target_tokens"] = target_tokens
|
||||
if achieved_tokens is not None:
|
||||
details["achieved_tokens"] = achieved_tokens
|
||||
|
||||
super().__init__(message, details)
|
||||
self.original_tokens = original_tokens
|
||||
self.target_tokens = target_tokens
|
||||
self.achieved_tokens = achieved_tokens
|
||||
|
||||
|
||||
class AssemblyTimeoutError(ContextError):
|
||||
"""
|
||||
Raised when context assembly exceeds time limit.
|
||||
|
||||
Context assembly must complete within a configurable
|
||||
time limit to maintain responsiveness.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Context assembly timed out",
|
||||
timeout_ms: int = 0,
|
||||
elapsed_ms: float = 0.0,
|
||||
stage: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize assembly timeout error.
|
||||
|
||||
Args:
|
||||
message: Error message
|
||||
timeout_ms: Configured timeout in milliseconds
|
||||
elapsed_ms: Actual elapsed time in milliseconds
|
||||
stage: Pipeline stage where timeout occurred
|
||||
"""
|
||||
details: dict[str, Any] = {
|
||||
"timeout_ms": timeout_ms,
|
||||
"elapsed_ms": round(elapsed_ms, 2),
|
||||
}
|
||||
if stage:
|
||||
details["stage"] = stage
|
||||
|
||||
super().__init__(message, details)
|
||||
self.timeout_ms = timeout_ms
|
||||
self.elapsed_ms = elapsed_ms
|
||||
self.stage = stage
|
||||
|
||||
|
||||
class ScoringError(ContextError):
|
||||
"""
|
||||
Raised when context scoring fails.
|
||||
|
||||
This occurs when relevance, recency, or priority scoring
|
||||
encounters an error.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Failed to score context",
|
||||
scorer_type: str | None = None,
|
||||
context_id: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize scoring error.
|
||||
|
||||
Args:
|
||||
message: Error message
|
||||
scorer_type: Type of scorer that failed
|
||||
context_id: ID of context being scored
|
||||
"""
|
||||
details: dict[str, Any] = {}
|
||||
if scorer_type:
|
||||
details["scorer_type"] = scorer_type
|
||||
if context_id:
|
||||
details["context_id"] = context_id
|
||||
|
||||
super().__init__(message, details)
|
||||
self.scorer_type = scorer_type
|
||||
self.context_id = context_id
|
||||
|
||||
|
||||
class FormattingError(ContextError):
|
||||
"""
|
||||
Raised when context formatting fails.
|
||||
|
||||
This occurs when converting assembled context to
|
||||
model-specific format fails.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Failed to format context",
|
||||
model: str | None = None,
|
||||
adapter: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize formatting error.
|
||||
|
||||
Args:
|
||||
message: Error message
|
||||
model: Target model
|
||||
adapter: Adapter that failed
|
||||
"""
|
||||
details: dict[str, Any] = {}
|
||||
if model:
|
||||
details["model"] = model
|
||||
if adapter:
|
||||
details["adapter"] = adapter
|
||||
|
||||
super().__init__(message, details)
|
||||
self.model = model
|
||||
self.adapter = adapter
|
||||
|
||||
|
||||
class CacheError(ContextError):
|
||||
"""
|
||||
Raised when cache operations fail.
|
||||
|
||||
This is typically non-fatal and should be handled
|
||||
gracefully by falling back to recomputation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Cache operation failed",
|
||||
operation: str | None = None,
|
||||
cache_key: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize cache error.
|
||||
|
||||
Args:
|
||||
message: Error message
|
||||
operation: Cache operation that failed (get, set, delete)
|
||||
cache_key: Key involved in the failed operation
|
||||
"""
|
||||
details: dict[str, Any] = {}
|
||||
if operation:
|
||||
details["operation"] = operation
|
||||
if cache_key:
|
||||
details["cache_key"] = cache_key
|
||||
|
||||
super().__init__(message, details)
|
||||
self.operation = operation
|
||||
self.cache_key = cache_key
|
||||
|
||||
|
||||
class ContextNotFoundError(ContextError):
|
||||
"""
|
||||
Raised when expected context is not found.
|
||||
|
||||
This occurs when required context sources return
|
||||
no results or are unavailable.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Required context not found",
|
||||
source: str | None = None,
|
||||
query: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize context not found error.
|
||||
|
||||
Args:
|
||||
message: Error message
|
||||
source: Source that returned no results
|
||||
query: Query used to search
|
||||
"""
|
||||
details: dict[str, Any] = {}
|
||||
if source:
|
||||
details["source"] = source
|
||||
if query:
|
||||
details["query"] = query
|
||||
|
||||
super().__init__(message, details)
|
||||
self.source = source
|
||||
self.query = query
|
||||
|
||||
|
||||
class InvalidContextError(ContextError):
|
||||
"""
|
||||
Raised when context data is invalid.
|
||||
|
||||
This occurs when context content or metadata
|
||||
fails validation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Invalid context data",
|
||||
field: str | None = None,
|
||||
value: Any | None = None,
|
||||
reason: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize invalid context error.
|
||||
|
||||
Args:
|
||||
message: Error message
|
||||
field: Field that is invalid
|
||||
value: Invalid value (may be redacted for security)
|
||||
reason: Reason for invalidity
|
||||
"""
|
||||
details: dict[str, Any] = {}
|
||||
if field:
|
||||
details["field"] = field
|
||||
if value is not None:
|
||||
# Avoid logging potentially sensitive values
|
||||
details["value_type"] = type(value).__name__
|
||||
if reason:
|
||||
details["reason"] = reason
|
||||
|
||||
super().__init__(message, details)
|
||||
self.field = field
|
||||
self.value = value
|
||||
self.reason = reason
|
||||
12
backend/app/services/context/prioritization/__init__.py
Normal file
12
backend/app/services/context/prioritization/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""
|
||||
Context Prioritization Module.
|
||||
|
||||
Provides context ranking and selection.
|
||||
"""
|
||||
|
||||
from .ranker import ContextRanker, RankingResult
|
||||
|
||||
__all__ = [
|
||||
"ContextRanker",
|
||||
"RankingResult",
|
||||
]
|
||||
374
backend/app/services/context/prioritization/ranker.py
Normal file
374
backend/app/services/context/prioritization/ranker.py
Normal file
@@ -0,0 +1,374 @@
|
||||
"""
|
||||
Context Ranker for Context Management.
|
||||
|
||||
Ranks and selects contexts based on scores and budget constraints.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from ..budget import TokenBudget, TokenCalculator
|
||||
from ..config import ContextSettings, get_context_settings
|
||||
from ..exceptions import BudgetExceededError
|
||||
from ..scoring.composite import CompositeScorer, ScoredContext
|
||||
from ..types import BaseContext, ContextPriority
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RankingResult:
|
||||
"""Result of context ranking and selection."""
|
||||
|
||||
selected: list[ScoredContext]
|
||||
excluded: list[ScoredContext]
|
||||
total_tokens: int
|
||||
selection_stats: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def selected_contexts(self) -> list[BaseContext]:
|
||||
"""Get just the context objects (not scored wrappers)."""
|
||||
return [s.context for s in self.selected]
|
||||
|
||||
|
||||
class ContextRanker:
|
||||
"""
|
||||
Ranks and selects contexts within budget constraints.
|
||||
|
||||
Uses greedy selection to maximize total score
|
||||
while respecting token budgets per context type.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scorer: CompositeScorer | None = None,
|
||||
calculator: TokenCalculator | None = None,
|
||||
settings: ContextSettings | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize context ranker.
|
||||
|
||||
Args:
|
||||
scorer: Composite scorer for scoring contexts
|
||||
calculator: Token calculator for counting tokens
|
||||
settings: Context settings (uses global if None)
|
||||
"""
|
||||
self._settings = settings or get_context_settings()
|
||||
self._scorer = scorer or CompositeScorer()
|
||||
self._calculator = calculator or TokenCalculator()
|
||||
|
||||
def set_scorer(self, scorer: CompositeScorer) -> None:
|
||||
"""Set the scorer."""
|
||||
self._scorer = scorer
|
||||
|
||||
def set_calculator(self, calculator: TokenCalculator) -> None:
|
||||
"""Set the token calculator."""
|
||||
self._calculator = calculator
|
||||
|
||||
async def rank(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
query: str,
|
||||
budget: TokenBudget,
|
||||
model: str | None = None,
|
||||
ensure_required: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> RankingResult:
|
||||
"""
|
||||
Rank and select contexts within budget.
|
||||
|
||||
Args:
|
||||
contexts: Contexts to rank
|
||||
query: Query to rank against
|
||||
budget: Token budget constraints
|
||||
model: Model for token counting
|
||||
ensure_required: If True, always include CRITICAL priority contexts
|
||||
**kwargs: Additional scoring parameters
|
||||
|
||||
Returns:
|
||||
RankingResult with selected and excluded contexts
|
||||
"""
|
||||
if not contexts:
|
||||
return RankingResult(
|
||||
selected=[],
|
||||
excluded=[],
|
||||
total_tokens=0,
|
||||
selection_stats={"total_contexts": 0},
|
||||
)
|
||||
|
||||
# 1. Ensure all contexts have token counts
|
||||
await self._ensure_token_counts(contexts, model)
|
||||
|
||||
# 2. Score all contexts
|
||||
scored_contexts = await self._scorer.score_batch(contexts, query, **kwargs)
|
||||
|
||||
# 3. Separate required (CRITICAL priority) from optional
|
||||
required: list[ScoredContext] = []
|
||||
optional: list[ScoredContext] = []
|
||||
|
||||
if ensure_required:
|
||||
for sc in scored_contexts:
|
||||
# CRITICAL priority (150) contexts are always included
|
||||
if sc.context.priority >= ContextPriority.CRITICAL.value:
|
||||
required.append(sc)
|
||||
else:
|
||||
optional.append(sc)
|
||||
else:
|
||||
optional = list(scored_contexts)
|
||||
|
||||
# 4. Sort optional by score (highest first)
|
||||
optional.sort(reverse=True)
|
||||
|
||||
# 5. Greedy selection
|
||||
selected: list[ScoredContext] = []
|
||||
excluded: list[ScoredContext] = []
|
||||
total_tokens = 0
|
||||
|
||||
# Calculate the usable budget (total minus reserved portions)
|
||||
usable_budget = budget.total - budget.response_reserve - budget.buffer
|
||||
|
||||
# Guard against invalid budget configuration
|
||||
if usable_budget <= 0:
|
||||
raise BudgetExceededError(
|
||||
message=(
|
||||
f"Invalid budget configuration: no usable tokens available. "
|
||||
f"total={budget.total}, response_reserve={budget.response_reserve}, "
|
||||
f"buffer={budget.buffer}"
|
||||
),
|
||||
allocated=budget.total,
|
||||
requested=0,
|
||||
context_type="CONFIGURATION_ERROR",
|
||||
)
|
||||
|
||||
# First, try to fit required contexts
|
||||
for sc in required:
|
||||
token_count = self._get_valid_token_count(sc.context)
|
||||
context_type = sc.context.get_type()
|
||||
|
||||
if budget.can_fit(context_type, token_count):
|
||||
budget.allocate(context_type, token_count)
|
||||
selected.append(sc)
|
||||
total_tokens += token_count
|
||||
else:
|
||||
# Force-fit CRITICAL contexts if needed, but check total budget first
|
||||
if total_tokens + token_count > usable_budget:
|
||||
# Even CRITICAL contexts cannot exceed total model context window
|
||||
raise BudgetExceededError(
|
||||
message=(
|
||||
f"CRITICAL contexts exceed total budget. "
|
||||
f"Context '{sc.context.source}' ({token_count} tokens) "
|
||||
f"would exceed usable budget of {usable_budget} tokens."
|
||||
),
|
||||
allocated=usable_budget,
|
||||
requested=total_tokens + token_count,
|
||||
context_type="CRITICAL_OVERFLOW",
|
||||
)
|
||||
|
||||
budget.allocate(context_type, token_count, force=True)
|
||||
selected.append(sc)
|
||||
total_tokens += token_count
|
||||
logger.warning(
|
||||
f"Force-fitted CRITICAL context: {sc.context.source} "
|
||||
f"({token_count} tokens)"
|
||||
)
|
||||
|
||||
# Then, greedily add optional contexts
|
||||
for sc in optional:
|
||||
token_count = self._get_valid_token_count(sc.context)
|
||||
context_type = sc.context.get_type()
|
||||
|
||||
if budget.can_fit(context_type, token_count):
|
||||
budget.allocate(context_type, token_count)
|
||||
selected.append(sc)
|
||||
total_tokens += token_count
|
||||
else:
|
||||
excluded.append(sc)
|
||||
|
||||
# Build stats
|
||||
stats = {
|
||||
"total_contexts": len(contexts),
|
||||
"required_count": len(required),
|
||||
"selected_count": len(selected),
|
||||
"excluded_count": len(excluded),
|
||||
"total_tokens": total_tokens,
|
||||
"by_type": self._count_by_type(selected),
|
||||
}
|
||||
|
||||
return RankingResult(
|
||||
selected=selected,
|
||||
excluded=excluded,
|
||||
total_tokens=total_tokens,
|
||||
selection_stats=stats,
|
||||
)
|
||||
|
||||
async def rank_simple(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
query: str,
|
||||
max_tokens: int,
|
||||
model: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> list[BaseContext]:
|
||||
"""
|
||||
Simple ranking without budget per type.
|
||||
|
||||
Selects top contexts by score until max tokens reached.
|
||||
|
||||
Args:
|
||||
contexts: Contexts to rank
|
||||
query: Query to rank against
|
||||
max_tokens: Maximum total tokens
|
||||
model: Model for token counting
|
||||
**kwargs: Additional scoring parameters
|
||||
|
||||
Returns:
|
||||
Selected contexts (in score order)
|
||||
"""
|
||||
if not contexts:
|
||||
return []
|
||||
|
||||
# Ensure token counts
|
||||
await self._ensure_token_counts(contexts, model)
|
||||
|
||||
# Score all contexts
|
||||
scored_contexts = await self._scorer.score_batch(contexts, query, **kwargs)
|
||||
|
||||
# Sort by score (highest first)
|
||||
scored_contexts.sort(reverse=True)
|
||||
|
||||
# Greedy selection
|
||||
selected: list[BaseContext] = []
|
||||
total_tokens = 0
|
||||
|
||||
for sc in scored_contexts:
|
||||
token_count = self._get_valid_token_count(sc.context)
|
||||
if total_tokens + token_count <= max_tokens:
|
||||
selected.append(sc.context)
|
||||
total_tokens += token_count
|
||||
|
||||
return selected
|
||||
|
||||
def _get_valid_token_count(self, context: BaseContext) -> int:
|
||||
"""
|
||||
Get validated token count from a context.
|
||||
|
||||
Ensures token_count is set (not None) and non-negative to prevent
|
||||
budget bypass attacks where:
|
||||
- None would be treated as 0 (allowing huge contexts to slip through)
|
||||
- Negative values would corrupt budget tracking
|
||||
|
||||
Args:
|
||||
context: Context to get token count from
|
||||
|
||||
Returns:
|
||||
Valid non-negative token count
|
||||
|
||||
Raises:
|
||||
ValueError: If token_count is None or negative
|
||||
"""
|
||||
if context.token_count is None:
|
||||
raise ValueError(
|
||||
f"Context '{context.source}' has no token count. "
|
||||
"Ensure _ensure_token_counts() is called before ranking."
|
||||
)
|
||||
if context.token_count < 0:
|
||||
raise ValueError(
|
||||
f"Context '{context.source}' has invalid negative token count: "
|
||||
f"{context.token_count}"
|
||||
)
|
||||
return context.token_count
|
||||
|
||||
async def _ensure_token_counts(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
model: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Ensure all contexts have token counts.
|
||||
|
||||
Counts tokens in parallel for contexts that don't have counts.
|
||||
|
||||
Args:
|
||||
contexts: Contexts to check
|
||||
model: Model for token counting
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
# Find contexts needing counts
|
||||
contexts_needing_counts = [ctx for ctx in contexts if ctx.token_count is None]
|
||||
|
||||
if not contexts_needing_counts:
|
||||
return
|
||||
|
||||
# Count all in parallel
|
||||
tasks = [
|
||||
self._calculator.count_tokens(ctx.content, model)
|
||||
for ctx in contexts_needing_counts
|
||||
]
|
||||
counts = await asyncio.gather(*tasks)
|
||||
|
||||
# Assign counts back
|
||||
for ctx, count in zip(contexts_needing_counts, counts, strict=True):
|
||||
ctx.token_count = count
|
||||
|
||||
def _count_by_type(
|
||||
self, scored_contexts: list[ScoredContext]
|
||||
) -> dict[str, dict[str, int]]:
|
||||
"""Count selected contexts by type."""
|
||||
by_type: dict[str, dict[str, int]] = {}
|
||||
|
||||
for sc in scored_contexts:
|
||||
type_name = sc.context.get_type().value
|
||||
if type_name not in by_type:
|
||||
by_type[type_name] = {"count": 0, "tokens": 0}
|
||||
by_type[type_name]["count"] += 1
|
||||
# Use validated token count (already validated during ranking)
|
||||
by_type[type_name]["tokens"] += sc.context.token_count or 0
|
||||
|
||||
return by_type
|
||||
|
||||
async def rerank_for_diversity(
|
||||
self,
|
||||
scored_contexts: list[ScoredContext],
|
||||
max_per_source: int | None = None,
|
||||
) -> list[ScoredContext]:
|
||||
"""
|
||||
Rerank to ensure source diversity.
|
||||
|
||||
Prevents too many items from the same source.
|
||||
|
||||
Args:
|
||||
scored_contexts: Already scored contexts
|
||||
max_per_source: Maximum items per source (uses settings if None)
|
||||
|
||||
Returns:
|
||||
Reranked contexts
|
||||
"""
|
||||
# Use provided value or fall back to settings
|
||||
effective_max = (
|
||||
max_per_source
|
||||
if max_per_source is not None
|
||||
else self._settings.diversity_max_per_source
|
||||
)
|
||||
|
||||
source_counts: dict[str, int] = {}
|
||||
result: list[ScoredContext] = []
|
||||
deferred: list[ScoredContext] = []
|
||||
|
||||
for sc in scored_contexts:
|
||||
source = sc.context.source
|
||||
current_count = source_counts.get(source, 0)
|
||||
|
||||
if current_count < effective_max:
|
||||
result.append(sc)
|
||||
source_counts[source] = current_count + 1
|
||||
else:
|
||||
deferred.append(sc)
|
||||
|
||||
# Add deferred items at the end
|
||||
result.extend(deferred)
|
||||
return result
|
||||
21
backend/app/services/context/scoring/__init__.py
Normal file
21
backend/app/services/context/scoring/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""
|
||||
Context Scoring Module.
|
||||
|
||||
Provides scoring strategies for context prioritization.
|
||||
"""
|
||||
|
||||
from .base import BaseScorer, ScorerProtocol
|
||||
from .composite import CompositeScorer, ScoredContext
|
||||
from .priority import PriorityScorer
|
||||
from .recency import RecencyScorer
|
||||
from .relevance import RelevanceScorer
|
||||
|
||||
__all__ = [
|
||||
"BaseScorer",
|
||||
"CompositeScorer",
|
||||
"PriorityScorer",
|
||||
"RecencyScorer",
|
||||
"RelevanceScorer",
|
||||
"ScoredContext",
|
||||
"ScorerProtocol",
|
||||
]
|
||||
99
backend/app/services/context/scoring/base.py
Normal file
99
backend/app/services/context/scoring/base.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""
|
||||
Base Scorer Protocol and Types.
|
||||
|
||||
Defines the interface for context scoring implementations.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
|
||||
|
||||
from ..types import BaseContext
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ScorerProtocol(Protocol):
|
||||
"""Protocol for context scorers."""
|
||||
|
||||
async def score(
|
||||
self,
|
||||
context: BaseContext,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> float:
|
||||
"""
|
||||
Score a context item.
|
||||
|
||||
Args:
|
||||
context: Context to score
|
||||
query: Query to score against
|
||||
**kwargs: Additional scoring parameters
|
||||
|
||||
Returns:
|
||||
Score between 0.0 and 1.0
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class BaseScorer(ABC):
|
||||
"""
|
||||
Abstract base class for context scorers.
|
||||
|
||||
Provides common functionality and interface for
|
||||
different scoring strategies.
|
||||
"""
|
||||
|
||||
def __init__(self, weight: float = 1.0) -> None:
|
||||
"""
|
||||
Initialize scorer.
|
||||
|
||||
Args:
|
||||
weight: Weight for this scorer in composite scoring
|
||||
"""
|
||||
self._weight = weight
|
||||
|
||||
@property
|
||||
def weight(self) -> float:
|
||||
"""Get scorer weight."""
|
||||
return self._weight
|
||||
|
||||
@weight.setter
|
||||
def weight(self, value: float) -> None:
|
||||
"""Set scorer weight."""
|
||||
if not 0.0 <= value <= 1.0:
|
||||
raise ValueError("Weight must be between 0.0 and 1.0")
|
||||
self._weight = value
|
||||
|
||||
@abstractmethod
|
||||
async def score(
|
||||
self,
|
||||
context: BaseContext,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> float:
|
||||
"""
|
||||
Score a context item.
|
||||
|
||||
Args:
|
||||
context: Context to score
|
||||
query: Query to score against
|
||||
**kwargs: Additional scoring parameters
|
||||
|
||||
Returns:
|
||||
Score between 0.0 and 1.0
|
||||
"""
|
||||
...
|
||||
|
||||
def normalize_score(self, score: float) -> float:
|
||||
"""
|
||||
Normalize score to [0.0, 1.0] range.
|
||||
|
||||
Args:
|
||||
score: Raw score
|
||||
|
||||
Returns:
|
||||
Normalized score
|
||||
"""
|
||||
return max(0.0, min(1.0, score))
|
||||
368
backend/app/services/context/scoring/composite.py
Normal file
368
backend/app/services/context/scoring/composite.py
Normal file
@@ -0,0 +1,368 @@
|
||||
"""
|
||||
Composite Scorer for Context Management.
|
||||
|
||||
Combines multiple scoring strategies with configurable weights.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from ..config import ContextSettings, get_context_settings
|
||||
from ..types import BaseContext
|
||||
from .priority import PriorityScorer
|
||||
from .recency import RecencyScorer
|
||||
from .relevance import RelevanceScorer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.services.mcp.client_manager import MCPClientManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScoredContext:
|
||||
"""Context with computed scores."""
|
||||
|
||||
context: BaseContext
|
||||
composite_score: float
|
||||
relevance_score: float = 0.0
|
||||
recency_score: float = 0.0
|
||||
priority_score: float = 0.0
|
||||
|
||||
def __lt__(self, other: "ScoredContext") -> bool:
|
||||
"""Enable sorting by composite score."""
|
||||
return self.composite_score < other.composite_score
|
||||
|
||||
def __gt__(self, other: "ScoredContext") -> bool:
|
||||
"""Enable sorting by composite score."""
|
||||
return self.composite_score > other.composite_score
|
||||
|
||||
|
||||
class CompositeScorer:
|
||||
"""
|
||||
Combines multiple scoring strategies.
|
||||
|
||||
Weights:
|
||||
- relevance: How well content matches the query
|
||||
- recency: How recent the content is
|
||||
- priority: Explicit priority assignments
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mcp_manager: "MCPClientManager | None" = None,
|
||||
settings: ContextSettings | None = None,
|
||||
relevance_weight: float | None = None,
|
||||
recency_weight: float | None = None,
|
||||
priority_weight: float | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize composite scorer.
|
||||
|
||||
Args:
|
||||
mcp_manager: MCP manager for semantic scoring
|
||||
settings: Context settings (uses default if None)
|
||||
relevance_weight: Override relevance weight
|
||||
recency_weight: Override recency weight
|
||||
priority_weight: Override priority weight
|
||||
"""
|
||||
self._settings = settings or get_context_settings()
|
||||
weights = self._settings.get_scoring_weights()
|
||||
|
||||
self._relevance_weight = (
|
||||
relevance_weight if relevance_weight is not None else weights["relevance"]
|
||||
)
|
||||
self._recency_weight = (
|
||||
recency_weight if recency_weight is not None else weights["recency"]
|
||||
)
|
||||
self._priority_weight = (
|
||||
priority_weight if priority_weight is not None else weights["priority"]
|
||||
)
|
||||
|
||||
# Initialize scorers
|
||||
self._relevance_scorer = RelevanceScorer(
|
||||
mcp_manager=mcp_manager,
|
||||
weight=self._relevance_weight,
|
||||
)
|
||||
self._recency_scorer = RecencyScorer(weight=self._recency_weight)
|
||||
self._priority_scorer = PriorityScorer(weight=self._priority_weight)
|
||||
|
||||
# Per-context locks to prevent race conditions during parallel scoring
|
||||
# Uses dict with (lock, last_used_time) tuples for cleanup
|
||||
self._context_locks: dict[str, tuple[asyncio.Lock, float]] = {}
|
||||
self._locks_lock = asyncio.Lock() # Lock to protect _context_locks access
|
||||
self._max_locks = 1000 # Maximum locks to keep (prevent memory growth)
|
||||
self._lock_ttl = 60.0 # Seconds before a lock can be cleaned up
|
||||
|
||||
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
|
||||
"""Set MCP manager for semantic scoring."""
|
||||
self._relevance_scorer.set_mcp_manager(mcp_manager)
|
||||
|
||||
@property
|
||||
def weights(self) -> dict[str, float]:
|
||||
"""Get current scoring weights."""
|
||||
return {
|
||||
"relevance": self._relevance_weight,
|
||||
"recency": self._recency_weight,
|
||||
"priority": self._priority_weight,
|
||||
}
|
||||
|
||||
def update_weights(
|
||||
self,
|
||||
relevance: float | None = None,
|
||||
recency: float | None = None,
|
||||
priority: float | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Update scoring weights.
|
||||
|
||||
Args:
|
||||
relevance: New relevance weight
|
||||
recency: New recency weight
|
||||
priority: New priority weight
|
||||
"""
|
||||
if relevance is not None:
|
||||
self._relevance_weight = max(0.0, min(1.0, relevance))
|
||||
self._relevance_scorer.weight = self._relevance_weight
|
||||
|
||||
if recency is not None:
|
||||
self._recency_weight = max(0.0, min(1.0, recency))
|
||||
self._recency_scorer.weight = self._recency_weight
|
||||
|
||||
if priority is not None:
|
||||
self._priority_weight = max(0.0, min(1.0, priority))
|
||||
self._priority_scorer.weight = self._priority_weight
|
||||
|
||||
async def _get_context_lock(self, context_id: str) -> asyncio.Lock:
|
||||
"""
|
||||
Get or create a lock for a specific context.
|
||||
|
||||
Thread-safe access to per-context locks prevents race conditions
|
||||
when the same context is scored concurrently. Includes automatic
|
||||
cleanup of old locks to prevent memory growth.
|
||||
|
||||
Args:
|
||||
context_id: The context ID to get a lock for
|
||||
|
||||
Returns:
|
||||
asyncio.Lock for the context
|
||||
"""
|
||||
now = time.time()
|
||||
|
||||
# Fast path: check if lock exists without acquiring main lock
|
||||
# NOTE: We only READ here - no writes to avoid race conditions
|
||||
# with cleanup. The timestamp will be updated in the slow path
|
||||
# if the lock is still valid.
|
||||
lock_entry = self._context_locks.get(context_id)
|
||||
if lock_entry is not None:
|
||||
lock, _ = lock_entry
|
||||
# Return the lock but defer timestamp update to avoid race
|
||||
# The lock is still valid; timestamp update is best-effort
|
||||
return lock
|
||||
|
||||
# Slow path: create lock or update timestamp while holding main lock
|
||||
async with self._locks_lock:
|
||||
# Double-check after acquiring lock - entry may have been
|
||||
# created by another coroutine or deleted by cleanup
|
||||
lock_entry = self._context_locks.get(context_id)
|
||||
if lock_entry is not None:
|
||||
lock, _ = lock_entry
|
||||
# Safe to update timestamp here since we hold the lock
|
||||
self._context_locks[context_id] = (lock, now)
|
||||
return lock
|
||||
|
||||
# Cleanup old locks if we have too many
|
||||
if len(self._context_locks) >= self._max_locks:
|
||||
self._cleanup_old_locks(now)
|
||||
|
||||
# Create new lock
|
||||
new_lock = asyncio.Lock()
|
||||
self._context_locks[context_id] = (new_lock, now)
|
||||
return new_lock
|
||||
|
||||
def _cleanup_old_locks(self, now: float) -> None:
|
||||
"""
|
||||
Remove old locks that haven't been used recently.
|
||||
|
||||
Called while holding _locks_lock. Removes locks older than _lock_ttl,
|
||||
but only if they're not currently held.
|
||||
|
||||
Args:
|
||||
now: Current timestamp for age calculation
|
||||
"""
|
||||
cutoff = now - self._lock_ttl
|
||||
to_remove = []
|
||||
|
||||
for context_id, (lock, last_used) in self._context_locks.items():
|
||||
# Only remove if old AND not currently held
|
||||
if last_used < cutoff and not lock.locked():
|
||||
to_remove.append(context_id)
|
||||
|
||||
# Remove oldest 50% if still over limit after TTL filtering
|
||||
if len(self._context_locks) - len(to_remove) >= self._max_locks:
|
||||
# Sort by last used time and mark oldest for removal
|
||||
sorted_entries = sorted(
|
||||
self._context_locks.items(),
|
||||
key=lambda x: x[1][1], # Sort by last_used time
|
||||
)
|
||||
# Remove oldest 50% that aren't locked
|
||||
target_remove = len(self._context_locks) // 2
|
||||
for context_id, (lock, _) in sorted_entries:
|
||||
if len(to_remove) >= target_remove:
|
||||
break
|
||||
if context_id not in to_remove and not lock.locked():
|
||||
to_remove.append(context_id)
|
||||
|
||||
for context_id in to_remove:
|
||||
del self._context_locks[context_id]
|
||||
|
||||
if to_remove:
|
||||
logger.debug(f"Cleaned up {len(to_remove)} context locks")
|
||||
|
||||
async def score(
|
||||
self,
|
||||
context: BaseContext,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> float:
|
||||
"""
|
||||
Compute composite score for a context.
|
||||
|
||||
Args:
|
||||
context: Context to score
|
||||
query: Query to score against
|
||||
**kwargs: Additional scoring parameters
|
||||
|
||||
Returns:
|
||||
Composite score between 0.0 and 1.0
|
||||
"""
|
||||
scored = await self.score_with_details(context, query, **kwargs)
|
||||
return scored.composite_score
|
||||
|
||||
async def score_with_details(
|
||||
self,
|
||||
context: BaseContext,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> ScoredContext:
|
||||
"""
|
||||
Compute composite score with individual scores.
|
||||
|
||||
Uses per-context locking to prevent race conditions when the same
|
||||
context is scored concurrently in parallel scoring operations.
|
||||
|
||||
Args:
|
||||
context: Context to score
|
||||
query: Query to score against
|
||||
**kwargs: Additional scoring parameters
|
||||
|
||||
Returns:
|
||||
ScoredContext with all scores
|
||||
"""
|
||||
# Get lock for this specific context to prevent race conditions
|
||||
# within concurrent scoring operations for the same query
|
||||
context_lock = await self._get_context_lock(context.id)
|
||||
|
||||
async with context_lock:
|
||||
# Compute individual scores in parallel
|
||||
# Note: We do NOT cache scores on the context because scores are
|
||||
# query-dependent. Caching without considering the query would
|
||||
# return incorrect scores for different queries.
|
||||
relevance_task = self._relevance_scorer.score(context, query, **kwargs)
|
||||
recency_task = self._recency_scorer.score(context, query, **kwargs)
|
||||
priority_task = self._priority_scorer.score(context, query, **kwargs)
|
||||
|
||||
relevance_score, recency_score, priority_score = await asyncio.gather(
|
||||
relevance_task, recency_task, priority_task
|
||||
)
|
||||
|
||||
# Compute weighted composite
|
||||
total_weight = (
|
||||
self._relevance_weight + self._recency_weight + self._priority_weight
|
||||
)
|
||||
|
||||
if total_weight > 0:
|
||||
composite = (
|
||||
relevance_score * self._relevance_weight
|
||||
+ recency_score * self._recency_weight
|
||||
+ priority_score * self._priority_weight
|
||||
) / total_weight
|
||||
else:
|
||||
composite = 0.0
|
||||
|
||||
return ScoredContext(
|
||||
context=context,
|
||||
composite_score=composite,
|
||||
relevance_score=relevance_score,
|
||||
recency_score=recency_score,
|
||||
priority_score=priority_score,
|
||||
)
|
||||
|
||||
async def score_batch(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
query: str,
|
||||
parallel: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> list[ScoredContext]:
|
||||
"""
|
||||
Score multiple contexts.
|
||||
|
||||
Args:
|
||||
contexts: Contexts to score
|
||||
query: Query to score against
|
||||
parallel: Whether to score in parallel
|
||||
**kwargs: Additional scoring parameters
|
||||
|
||||
Returns:
|
||||
List of ScoredContext (same order as input)
|
||||
"""
|
||||
if parallel:
|
||||
tasks = [self.score_with_details(ctx, query, **kwargs) for ctx in contexts]
|
||||
return await asyncio.gather(*tasks)
|
||||
else:
|
||||
results = []
|
||||
for ctx in contexts:
|
||||
scored = await self.score_with_details(ctx, query, **kwargs)
|
||||
results.append(scored)
|
||||
return results
|
||||
|
||||
async def rank(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
query: str,
|
||||
limit: int | None = None,
|
||||
min_score: float = 0.0,
|
||||
**kwargs: Any,
|
||||
) -> list[ScoredContext]:
|
||||
"""
|
||||
Score and rank contexts.
|
||||
|
||||
Args:
|
||||
contexts: Contexts to rank
|
||||
query: Query to rank against
|
||||
limit: Maximum number of results
|
||||
min_score: Minimum score threshold
|
||||
**kwargs: Additional scoring parameters
|
||||
|
||||
Returns:
|
||||
Sorted list of ScoredContext (highest first)
|
||||
"""
|
||||
# Score all contexts
|
||||
scored = await self.score_batch(contexts, query, **kwargs)
|
||||
|
||||
# Filter by minimum score
|
||||
if min_score > 0:
|
||||
scored = [s for s in scored if s.composite_score >= min_score]
|
||||
|
||||
# Sort by score (highest first)
|
||||
scored.sort(reverse=True)
|
||||
|
||||
# Apply limit
|
||||
if limit is not None:
|
||||
scored = scored[:limit]
|
||||
|
||||
return scored
|
||||
135
backend/app/services/context/scoring/priority.py
Normal file
135
backend/app/services/context/scoring/priority.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""
|
||||
Priority Scorer for Context Management.
|
||||
|
||||
Scores context based on assigned priority levels.
|
||||
"""
|
||||
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from ..types import BaseContext, ContextType
|
||||
from .base import BaseScorer
|
||||
|
||||
|
||||
class PriorityScorer(BaseScorer):
|
||||
"""
|
||||
Scores context based on priority levels.
|
||||
|
||||
Converts priority enum values to normalized scores.
|
||||
Also applies type-based priority bonuses.
|
||||
"""
|
||||
|
||||
# Default priority bonuses by context type
|
||||
DEFAULT_TYPE_BONUSES: ClassVar[dict[ContextType, float]] = {
|
||||
ContextType.SYSTEM: 0.2, # System prompts get a boost
|
||||
ContextType.TASK: 0.15, # Current task is important
|
||||
ContextType.TOOL: 0.1, # Recent tool results matter
|
||||
ContextType.KNOWLEDGE: 0.0, # Knowledge scored by relevance
|
||||
ContextType.CONVERSATION: 0.0, # Conversation scored by recency
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight: float = 1.0,
|
||||
type_bonuses: dict[ContextType, float] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize priority scorer.
|
||||
|
||||
Args:
|
||||
weight: Scorer weight for composite scoring
|
||||
type_bonuses: Optional context-type priority bonuses
|
||||
"""
|
||||
super().__init__(weight)
|
||||
self._type_bonuses = type_bonuses or self.DEFAULT_TYPE_BONUSES.copy()
|
||||
|
||||
async def score(
|
||||
self,
|
||||
context: BaseContext,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> float:
|
||||
"""
|
||||
Score context based on priority.
|
||||
|
||||
Args:
|
||||
context: Context to score
|
||||
query: Query (not used for priority, kept for interface)
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
Priority score between 0.0 and 1.0
|
||||
"""
|
||||
# Get base priority score
|
||||
priority_value = context.priority
|
||||
base_score = self._priority_to_score(priority_value)
|
||||
|
||||
# Apply type bonus
|
||||
context_type = context.get_type()
|
||||
bonus = self._type_bonuses.get(context_type, 0.0)
|
||||
|
||||
return self.normalize_score(base_score + bonus)
|
||||
|
||||
def _priority_to_score(self, priority: int) -> float:
|
||||
"""
|
||||
Convert priority value to normalized score.
|
||||
|
||||
Priority values (from ContextPriority):
|
||||
- CRITICAL (100) -> 1.0
|
||||
- HIGH (80) -> 0.8
|
||||
- NORMAL (50) -> 0.5
|
||||
- LOW (20) -> 0.2
|
||||
- MINIMAL (0) -> 0.0
|
||||
|
||||
Args:
|
||||
priority: Priority value (0-100)
|
||||
|
||||
Returns:
|
||||
Normalized score (0.0-1.0)
|
||||
"""
|
||||
# Clamp to valid range
|
||||
clamped = max(0, min(100, priority))
|
||||
return clamped / 100.0
|
||||
|
||||
def get_type_bonus(self, context_type: ContextType) -> float:
|
||||
"""
|
||||
Get priority bonus for a context type.
|
||||
|
||||
Args:
|
||||
context_type: Context type
|
||||
|
||||
Returns:
|
||||
Bonus value
|
||||
"""
|
||||
return self._type_bonuses.get(context_type, 0.0)
|
||||
|
||||
def set_type_bonus(self, context_type: ContextType, bonus: float) -> None:
|
||||
"""
|
||||
Set priority bonus for a context type.
|
||||
|
||||
Args:
|
||||
context_type: Context type
|
||||
bonus: Bonus value (0.0-1.0)
|
||||
"""
|
||||
if not 0.0 <= bonus <= 1.0:
|
||||
raise ValueError("Bonus must be between 0.0 and 1.0")
|
||||
self._type_bonuses[context_type] = bonus
|
||||
|
||||
async def score_batch(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> list[float]:
|
||||
"""
|
||||
Score multiple contexts.
|
||||
|
||||
Args:
|
||||
contexts: Contexts to score
|
||||
query: Query (not used)
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
List of scores (same order as input)
|
||||
"""
|
||||
# Priority scoring is fast, no async needed
|
||||
return [await self.score(ctx, query, **kwargs) for ctx in contexts]
|
||||
141
backend/app/services/context/scoring/recency.py
Normal file
141
backend/app/services/context/scoring/recency.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""
|
||||
Recency Scorer for Context Management.
|
||||
|
||||
Scores context based on how recent it is.
|
||||
More recent content gets higher scores.
|
||||
"""
|
||||
|
||||
import math
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from ..types import BaseContext, ContextType
|
||||
from .base import BaseScorer
|
||||
|
||||
|
||||
class RecencyScorer(BaseScorer):
|
||||
"""
|
||||
Scores context based on recency.
|
||||
|
||||
Uses exponential decay to score content based on age.
|
||||
More recent content scores higher.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight: float = 1.0,
|
||||
half_life_hours: float = 24.0,
|
||||
type_half_lives: dict[ContextType, float] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize recency scorer.
|
||||
|
||||
Args:
|
||||
weight: Scorer weight for composite scoring
|
||||
half_life_hours: Default hours until score decays to 0.5
|
||||
type_half_lives: Optional context-type-specific half lives
|
||||
"""
|
||||
super().__init__(weight)
|
||||
self._half_life_hours = half_life_hours
|
||||
self._type_half_lives = type_half_lives or {}
|
||||
|
||||
# Set sensible defaults for context types
|
||||
if ContextType.CONVERSATION not in self._type_half_lives:
|
||||
self._type_half_lives[ContextType.CONVERSATION] = 1.0 # 1 hour
|
||||
if ContextType.TOOL not in self._type_half_lives:
|
||||
self._type_half_lives[ContextType.TOOL] = 0.5 # 30 minutes
|
||||
if ContextType.KNOWLEDGE not in self._type_half_lives:
|
||||
self._type_half_lives[ContextType.KNOWLEDGE] = 168.0 # 1 week
|
||||
if ContextType.SYSTEM not in self._type_half_lives:
|
||||
self._type_half_lives[ContextType.SYSTEM] = 720.0 # 30 days
|
||||
if ContextType.TASK not in self._type_half_lives:
|
||||
self._type_half_lives[ContextType.TASK] = 24.0 # 1 day
|
||||
|
||||
async def score(
|
||||
self,
|
||||
context: BaseContext,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> float:
|
||||
"""
|
||||
Score context based on recency.
|
||||
|
||||
Args:
|
||||
context: Context to score
|
||||
query: Query (not used for recency, kept for interface)
|
||||
**kwargs: Additional parameters
|
||||
- reference_time: Time to measure recency from (default: now)
|
||||
|
||||
Returns:
|
||||
Recency score between 0.0 and 1.0
|
||||
"""
|
||||
reference_time = kwargs.get("reference_time")
|
||||
if reference_time is None:
|
||||
reference_time = datetime.now(UTC)
|
||||
elif reference_time.tzinfo is None:
|
||||
reference_time = reference_time.replace(tzinfo=UTC)
|
||||
|
||||
# Ensure context timestamp is timezone-aware
|
||||
context_time = context.timestamp
|
||||
if context_time.tzinfo is None:
|
||||
context_time = context_time.replace(tzinfo=UTC)
|
||||
|
||||
# Calculate age in hours
|
||||
age = reference_time - context_time
|
||||
age_hours = max(0, age.total_seconds() / 3600)
|
||||
|
||||
# Get half-life for this context type
|
||||
context_type = context.get_type()
|
||||
half_life = self._type_half_lives.get(context_type, self._half_life_hours)
|
||||
|
||||
# Exponential decay
|
||||
decay_factor = math.exp(-math.log(2) * age_hours / half_life)
|
||||
|
||||
return self.normalize_score(decay_factor)
|
||||
|
||||
def get_half_life(self, context_type: ContextType) -> float:
|
||||
"""
|
||||
Get half-life for a context type.
|
||||
|
||||
Args:
|
||||
context_type: Context type to get half-life for
|
||||
|
||||
Returns:
|
||||
Half-life in hours
|
||||
"""
|
||||
return self._type_half_lives.get(context_type, self._half_life_hours)
|
||||
|
||||
def set_half_life(self, context_type: ContextType, hours: float) -> None:
|
||||
"""
|
||||
Set half-life for a context type.
|
||||
|
||||
Args:
|
||||
context_type: Context type to set half-life for
|
||||
hours: Half-life in hours
|
||||
"""
|
||||
if hours <= 0:
|
||||
raise ValueError("Half-life must be positive")
|
||||
self._type_half_lives[context_type] = hours
|
||||
|
||||
async def score_batch(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> list[float]:
|
||||
"""
|
||||
Score multiple contexts.
|
||||
|
||||
Args:
|
||||
contexts: Contexts to score
|
||||
query: Query (not used)
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
List of scores (same order as input)
|
||||
"""
|
||||
scores = []
|
||||
for context in contexts:
|
||||
score = await self.score(context, query, **kwargs)
|
||||
scores.append(score)
|
||||
return scores
|
||||
220
backend/app/services/context/scoring/relevance.py
Normal file
220
backend/app/services/context/scoring/relevance.py
Normal file
@@ -0,0 +1,220 @@
|
||||
"""
|
||||
Relevance Scorer for Context Management.
|
||||
|
||||
Scores context based on semantic similarity to the query.
|
||||
Uses Knowledge Base embeddings when available.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from ..config import ContextSettings, get_context_settings
|
||||
from ..types import BaseContext, KnowledgeContext
|
||||
from .base import BaseScorer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.services.mcp.client_manager import MCPClientManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RelevanceScorer(BaseScorer):
|
||||
"""
|
||||
Scores context based on relevance to query.
|
||||
|
||||
Uses multiple strategies:
|
||||
1. Pre-computed scores (from RAG results)
|
||||
2. MCP-based semantic similarity (via Knowledge Base)
|
||||
3. Keyword matching fallback
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mcp_manager: "MCPClientManager | None" = None,
|
||||
weight: float = 1.0,
|
||||
keyword_fallback_weight: float | None = None,
|
||||
semantic_max_chars: int | None = None,
|
||||
settings: ContextSettings | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize relevance scorer.
|
||||
|
||||
Args:
|
||||
mcp_manager: MCP manager for Knowledge Base calls
|
||||
weight: Scorer weight for composite scoring
|
||||
keyword_fallback_weight: Max score for keyword-based fallback (overrides settings)
|
||||
semantic_max_chars: Max content length for semantic similarity (overrides settings)
|
||||
settings: Context settings (uses global if None)
|
||||
"""
|
||||
super().__init__(weight)
|
||||
self._settings = settings or get_context_settings()
|
||||
self._mcp = mcp_manager
|
||||
|
||||
# Use provided values or fall back to settings
|
||||
self._keyword_fallback_weight = (
|
||||
keyword_fallback_weight
|
||||
if keyword_fallback_weight is not None
|
||||
else self._settings.relevance_keyword_fallback_weight
|
||||
)
|
||||
self._semantic_max_chars = (
|
||||
semantic_max_chars
|
||||
if semantic_max_chars is not None
|
||||
else self._settings.relevance_semantic_max_chars
|
||||
)
|
||||
|
||||
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
|
||||
"""Set MCP manager for semantic scoring."""
|
||||
self._mcp = mcp_manager
|
||||
|
||||
async def score(
|
||||
self,
|
||||
context: BaseContext,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> float:
|
||||
"""
|
||||
Score context relevance to query.
|
||||
|
||||
Args:
|
||||
context: Context to score
|
||||
query: Query to score against
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
Relevance score between 0.0 and 1.0
|
||||
"""
|
||||
# 1. Check for pre-computed relevance score
|
||||
if (
|
||||
isinstance(context, KnowledgeContext)
|
||||
and context.relevance_score is not None
|
||||
):
|
||||
return self.normalize_score(context.relevance_score)
|
||||
|
||||
# 2. Check metadata for score
|
||||
if "relevance_score" in context.metadata:
|
||||
return self.normalize_score(context.metadata["relevance_score"])
|
||||
|
||||
if "score" in context.metadata:
|
||||
return self.normalize_score(context.metadata["score"])
|
||||
|
||||
# 3. Try MCP-based semantic similarity (if compute_similarity tool is available)
|
||||
# Note: This requires the knowledge-base MCP server to implement compute_similarity
|
||||
if self._mcp is not None:
|
||||
try:
|
||||
score = await self._compute_semantic_similarity(context, query)
|
||||
if score is not None:
|
||||
return score
|
||||
except Exception as e:
|
||||
# Log at debug level since this is expected if compute_similarity
|
||||
# tool is not implemented in the Knowledge Base server
|
||||
logger.debug(
|
||||
f"Semantic scoring unavailable, using keyword fallback: {e}"
|
||||
)
|
||||
|
||||
# 4. Fall back to keyword matching
|
||||
return self._compute_keyword_score(context, query)
|
||||
|
||||
async def _compute_semantic_similarity(
|
||||
self,
|
||||
context: BaseContext,
|
||||
query: str,
|
||||
) -> float | None:
|
||||
"""
|
||||
Compute semantic similarity using Knowledge Base embeddings.
|
||||
|
||||
Args:
|
||||
context: Context to score
|
||||
query: Query to compare
|
||||
|
||||
Returns:
|
||||
Similarity score or None if unavailable
|
||||
"""
|
||||
if self._mcp is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Use Knowledge Base's search capability to compute similarity
|
||||
result = await self._mcp.call_tool(
|
||||
server="knowledge-base",
|
||||
tool="compute_similarity",
|
||||
args={
|
||||
"text1": query,
|
||||
"text2": context.content[
|
||||
: self._semantic_max_chars
|
||||
], # Limit content length
|
||||
},
|
||||
)
|
||||
|
||||
if result.success and isinstance(result.data, dict):
|
||||
similarity = result.data.get("similarity")
|
||||
if similarity is not None:
|
||||
return self.normalize_score(float(similarity))
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Semantic similarity computation failed: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def _compute_keyword_score(
|
||||
self,
|
||||
context: BaseContext,
|
||||
query: str,
|
||||
) -> float:
|
||||
"""
|
||||
Compute relevance score based on keyword matching.
|
||||
|
||||
Simple but fast fallback when semantic search is unavailable.
|
||||
|
||||
Args:
|
||||
context: Context to score
|
||||
query: Query to match
|
||||
|
||||
Returns:
|
||||
Keyword-based relevance score
|
||||
"""
|
||||
if not query or not context.content:
|
||||
return 0.0
|
||||
|
||||
# Extract keywords from query
|
||||
query_lower = query.lower()
|
||||
content_lower = context.content.lower()
|
||||
|
||||
# Simple word tokenization
|
||||
query_words = set(re.findall(r"\b\w{3,}\b", query_lower))
|
||||
content_words = set(re.findall(r"\b\w{3,}\b", content_lower))
|
||||
|
||||
if not query_words:
|
||||
return 0.0
|
||||
|
||||
# Calculate overlap
|
||||
common_words = query_words & content_words
|
||||
overlap_ratio = len(common_words) / len(query_words)
|
||||
|
||||
# Apply fallback weight ceiling
|
||||
return self.normalize_score(overlap_ratio * self._keyword_fallback_weight)
|
||||
|
||||
async def score_batch(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> list[float]:
|
||||
"""
|
||||
Score multiple contexts in parallel.
|
||||
|
||||
Args:
|
||||
contexts: Contexts to score
|
||||
query: Query to score against
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
List of scores (same order as input)
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
if not contexts:
|
||||
return []
|
||||
|
||||
tasks = [self.score(context, query, **kwargs) for context in contexts]
|
||||
return await asyncio.gather(*tasks)
|
||||
43
backend/app/services/context/types/__init__.py
Normal file
43
backend/app/services/context/types/__init__.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""
|
||||
Context Types Module.
|
||||
|
||||
Provides all context types used in the Context Management Engine.
|
||||
"""
|
||||
|
||||
from .base import (
|
||||
AssembledContext,
|
||||
BaseContext,
|
||||
ContextPriority,
|
||||
ContextType,
|
||||
)
|
||||
from .conversation import (
|
||||
ConversationContext,
|
||||
MessageRole,
|
||||
)
|
||||
from .knowledge import KnowledgeContext
|
||||
from .system import SystemContext
|
||||
from .task import (
|
||||
TaskComplexity,
|
||||
TaskContext,
|
||||
TaskStatus,
|
||||
)
|
||||
from .tool import (
|
||||
ToolContext,
|
||||
ToolResultStatus,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AssembledContext",
|
||||
"BaseContext",
|
||||
"ContextPriority",
|
||||
"ContextType",
|
||||
"ConversationContext",
|
||||
"KnowledgeContext",
|
||||
"MessageRole",
|
||||
"SystemContext",
|
||||
"TaskComplexity",
|
||||
"TaskContext",
|
||||
"TaskStatus",
|
||||
"ToolContext",
|
||||
"ToolResultStatus",
|
||||
]
|
||||
347
backend/app/services/context/types/base.py
Normal file
347
backend/app/services/context/types/base.py
Normal file
@@ -0,0 +1,347 @@
|
||||
"""
|
||||
Base Context Types and Enums.
|
||||
|
||||
Provides the foundation for all context types used in
|
||||
the Context Management Engine.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
|
||||
class ContextType(str, Enum):
|
||||
"""
|
||||
Types of context that can be assembled.
|
||||
|
||||
Each type has specific handling, formatting, and
|
||||
budget allocation rules.
|
||||
"""
|
||||
|
||||
SYSTEM = "system"
|
||||
TASK = "task"
|
||||
KNOWLEDGE = "knowledge"
|
||||
CONVERSATION = "conversation"
|
||||
TOOL = "tool"
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, value: str) -> "ContextType":
|
||||
"""
|
||||
Convert string to ContextType.
|
||||
|
||||
Args:
|
||||
value: String value
|
||||
|
||||
Returns:
|
||||
ContextType enum value
|
||||
|
||||
Raises:
|
||||
ValueError: If value is not a valid context type
|
||||
"""
|
||||
try:
|
||||
return cls(value.lower())
|
||||
except ValueError:
|
||||
valid = ", ".join(t.value for t in cls)
|
||||
raise ValueError(f"Invalid context type '{value}'. Valid types: {valid}")
|
||||
|
||||
|
||||
class ContextPriority(int, Enum):
|
||||
"""
|
||||
Priority levels for context ordering.
|
||||
|
||||
Higher values indicate higher priority.
|
||||
"""
|
||||
|
||||
LOWEST = 0
|
||||
LOW = 25
|
||||
NORMAL = 50
|
||||
HIGH = 75
|
||||
HIGHEST = 100
|
||||
CRITICAL = 150 # Never omit
|
||||
|
||||
@classmethod
|
||||
def from_int(cls, value: int) -> "ContextPriority":
|
||||
"""
|
||||
Get closest priority level for an integer.
|
||||
|
||||
Args:
|
||||
value: Integer priority value
|
||||
|
||||
Returns:
|
||||
Closest ContextPriority enum value
|
||||
"""
|
||||
priorities = sorted(cls, key=lambda p: p.value)
|
||||
for priority in reversed(priorities):
|
||||
if value >= priority.value:
|
||||
return priority
|
||||
return cls.LOWEST
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class BaseContext(ABC):
|
||||
"""
|
||||
Abstract base class for all context types.
|
||||
|
||||
Provides common fields and methods for context handling,
|
||||
scoring, and serialization.
|
||||
"""
|
||||
|
||||
# Required fields
|
||||
content: str
|
||||
source: str
|
||||
|
||||
# Optional fields with defaults
|
||||
id: str = field(default_factory=lambda: str(uuid4()))
|
||||
timestamp: datetime = field(default_factory=lambda: datetime.now(UTC))
|
||||
priority: int = field(default=ContextPriority.NORMAL.value)
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# Computed/cached fields
|
||||
_token_count: int | None = field(default=None, repr=False)
|
||||
_score: float | None = field(default=None, repr=False)
|
||||
|
||||
@property
|
||||
def token_count(self) -> int | None:
|
||||
"""Get cached token count (None if not counted yet)."""
|
||||
return self._token_count
|
||||
|
||||
@token_count.setter
|
||||
def token_count(self, value: int) -> None:
|
||||
"""Set token count."""
|
||||
self._token_count = value
|
||||
|
||||
@property
|
||||
def score(self) -> float | None:
|
||||
"""Get cached score (None if not scored yet)."""
|
||||
return self._score
|
||||
|
||||
@score.setter
|
||||
def score(self, value: float) -> None:
|
||||
"""Set score (clamped to 0.0-1.0)."""
|
||||
self._score = max(0.0, min(1.0, value))
|
||||
|
||||
@abstractmethod
|
||||
def get_type(self) -> ContextType:
|
||||
"""
|
||||
Get the type of this context.
|
||||
|
||||
Returns:
|
||||
ContextType enum value
|
||||
"""
|
||||
...
|
||||
|
||||
def get_age_seconds(self) -> float:
|
||||
"""
|
||||
Get age of context in seconds.
|
||||
|
||||
Returns:
|
||||
Age in seconds since creation
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
delta = now - self.timestamp
|
||||
return delta.total_seconds()
|
||||
|
||||
def get_age_hours(self) -> float:
|
||||
"""
|
||||
Get age of context in hours.
|
||||
|
||||
Returns:
|
||||
Age in hours since creation
|
||||
"""
|
||||
return self.get_age_seconds() / 3600
|
||||
|
||||
def is_stale(self, max_age_hours: float = 168.0) -> bool:
|
||||
"""
|
||||
Check if context is stale.
|
||||
|
||||
Args:
|
||||
max_age_hours: Maximum age before considered stale (default 7 days)
|
||||
|
||||
Returns:
|
||||
True if context is older than max_age_hours
|
||||
"""
|
||||
return self.get_age_hours() > max_age_hours
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""
|
||||
Convert context to dictionary for serialization.
|
||||
|
||||
Returns:
|
||||
Dictionary representation
|
||||
"""
|
||||
return {
|
||||
"id": self.id,
|
||||
"type": self.get_type().value,
|
||||
"content": self.content,
|
||||
"source": self.source,
|
||||
"timestamp": self.timestamp.isoformat(),
|
||||
"priority": self.priority,
|
||||
"metadata": self.metadata,
|
||||
"token_count": self._token_count,
|
||||
"score": self._score,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "BaseContext":
|
||||
"""
|
||||
Create context from dictionary.
|
||||
|
||||
Note: Subclasses should override this to return correct type.
|
||||
|
||||
Args:
|
||||
data: Dictionary with context data
|
||||
|
||||
Returns:
|
||||
Context instance
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement from_dict")
|
||||
|
||||
def truncate(self, max_tokens: int, suffix: str = "... [truncated]") -> str:
|
||||
"""
|
||||
Truncate content to fit within token limit.
|
||||
|
||||
This is a rough estimation based on characters.
|
||||
For accurate truncation, use the TokenCalculator.
|
||||
|
||||
Args:
|
||||
max_tokens: Maximum tokens allowed
|
||||
suffix: Suffix to append when truncated
|
||||
|
||||
Returns:
|
||||
Truncated content
|
||||
"""
|
||||
if self._token_count is None or self._token_count <= max_tokens:
|
||||
return self.content
|
||||
|
||||
# Rough estimation: 4 chars per token on average
|
||||
estimated_chars = max_tokens * 4
|
||||
suffix_chars = len(suffix)
|
||||
|
||||
if len(self.content) <= estimated_chars:
|
||||
return self.content
|
||||
|
||||
truncated = self.content[: estimated_chars - suffix_chars]
|
||||
# Try to break at word boundary
|
||||
last_space = truncated.rfind(" ")
|
||||
if last_space > estimated_chars * 0.8:
|
||||
truncated = truncated[:last_space]
|
||||
|
||||
return truncated + suffix
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""Hash based on ID for set/dict usage."""
|
||||
return hash(self.id)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
"""Equality based on ID."""
|
||||
if not isinstance(other, BaseContext):
|
||||
return False
|
||||
return self.id == other.id
|
||||
|
||||
|
||||
@dataclass
|
||||
class AssembledContext:
|
||||
"""
|
||||
Result of context assembly.
|
||||
|
||||
Contains the final formatted context ready for LLM consumption,
|
||||
along with metadata about the assembly process.
|
||||
"""
|
||||
|
||||
# Main content
|
||||
content: str
|
||||
total_tokens: int
|
||||
|
||||
# Assembly metadata
|
||||
context_count: int
|
||||
excluded_count: int = 0
|
||||
assembly_time_ms: float = 0.0
|
||||
model: str = ""
|
||||
|
||||
# Included contexts (optional - for inspection)
|
||||
contexts: list["BaseContext"] = field(default_factory=list)
|
||||
|
||||
# Additional metadata from assembly
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# Budget tracking
|
||||
budget_total: int = 0
|
||||
budget_used: int = 0
|
||||
|
||||
# Context breakdown
|
||||
by_type: dict[str, int] = field(default_factory=dict)
|
||||
|
||||
# Cache info
|
||||
cache_hit: bool = False
|
||||
cache_key: str | None = None
|
||||
|
||||
# Aliases for backward compatibility
|
||||
@property
|
||||
def token_count(self) -> int:
|
||||
"""Alias for total_tokens."""
|
||||
return self.total_tokens
|
||||
|
||||
@property
|
||||
def contexts_included(self) -> int:
|
||||
"""Alias for context_count."""
|
||||
return self.context_count
|
||||
|
||||
@property
|
||||
def contexts_excluded(self) -> int:
|
||||
"""Alias for excluded_count."""
|
||||
return self.excluded_count
|
||||
|
||||
@property
|
||||
def budget_utilization(self) -> float:
|
||||
"""Get budget utilization percentage."""
|
||||
if self.budget_total == 0:
|
||||
return 0.0
|
||||
return self.budget_used / self.budget_total
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"content": self.content,
|
||||
"total_tokens": self.total_tokens,
|
||||
"context_count": self.context_count,
|
||||
"excluded_count": self.excluded_count,
|
||||
"assembly_time_ms": round(self.assembly_time_ms, 2),
|
||||
"model": self.model,
|
||||
"metadata": self.metadata,
|
||||
"budget_total": self.budget_total,
|
||||
"budget_used": self.budget_used,
|
||||
"budget_utilization": round(self.budget_utilization, 3),
|
||||
"by_type": self.by_type,
|
||||
"cache_hit": self.cache_hit,
|
||||
"cache_key": self.cache_key,
|
||||
}
|
||||
|
||||
def to_json(self) -> str:
|
||||
"""Convert to JSON string."""
|
||||
import json
|
||||
|
||||
return json.dumps(self.to_dict())
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json_str: str) -> "AssembledContext":
|
||||
"""Create from JSON string."""
|
||||
import json
|
||||
|
||||
data = json.loads(json_str)
|
||||
return cls(
|
||||
content=data["content"],
|
||||
total_tokens=data["total_tokens"],
|
||||
context_count=data["context_count"],
|
||||
excluded_count=data.get("excluded_count", 0),
|
||||
assembly_time_ms=data.get("assembly_time_ms", 0.0),
|
||||
model=data.get("model", ""),
|
||||
metadata=data.get("metadata", {}),
|
||||
budget_total=data.get("budget_total", 0),
|
||||
budget_used=data.get("budget_used", 0),
|
||||
by_type=data.get("by_type", {}),
|
||||
cache_hit=data.get("cache_hit", False),
|
||||
cache_key=data.get("cache_key"),
|
||||
)
|
||||
182
backend/app/services/context/types/conversation.py
Normal file
182
backend/app/services/context/types/conversation.py
Normal file
@@ -0,0 +1,182 @@
|
||||
"""
|
||||
Conversation Context Type.
|
||||
|
||||
Represents conversation history for context continuity.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from .base import BaseContext, ContextPriority, ContextType
|
||||
|
||||
|
||||
class MessageRole(str, Enum):
|
||||
"""Roles for conversation messages."""
|
||||
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
SYSTEM = "system"
|
||||
TOOL = "tool"
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, value: str) -> "MessageRole":
|
||||
"""Convert string to MessageRole."""
|
||||
try:
|
||||
return cls(value.lower())
|
||||
except ValueError:
|
||||
# Default to user for unknown roles
|
||||
return cls.USER
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class ConversationContext(BaseContext):
|
||||
"""
|
||||
Context from conversation history.
|
||||
|
||||
Represents a single turn in the conversation,
|
||||
including user messages, assistant responses,
|
||||
and tool results.
|
||||
"""
|
||||
|
||||
# Conversation-specific fields
|
||||
role: MessageRole = field(default=MessageRole.USER)
|
||||
turn_index: int = field(default=0)
|
||||
session_id: str | None = field(default=None)
|
||||
parent_message_id: str | None = field(default=None)
|
||||
|
||||
def get_type(self) -> ContextType:
|
||||
"""Return CONVERSATION context type."""
|
||||
return ContextType.CONVERSATION
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary with conversation-specific fields."""
|
||||
base = super().to_dict()
|
||||
base.update(
|
||||
{
|
||||
"role": self.role.value,
|
||||
"turn_index": self.turn_index,
|
||||
"session_id": self.session_id,
|
||||
"parent_message_id": self.parent_message_id,
|
||||
}
|
||||
)
|
||||
return base
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "ConversationContext":
|
||||
"""Create ConversationContext from dictionary."""
|
||||
role = data.get("role", "user")
|
||||
if isinstance(role, str):
|
||||
role = MessageRole.from_string(role)
|
||||
|
||||
return cls(
|
||||
id=data.get("id", ""),
|
||||
content=data["content"],
|
||||
source=data.get("source", "conversation"),
|
||||
timestamp=datetime.fromisoformat(data["timestamp"])
|
||||
if isinstance(data.get("timestamp"), str)
|
||||
else data.get("timestamp", datetime.now(UTC)),
|
||||
priority=data.get("priority", ContextPriority.NORMAL.value),
|
||||
metadata=data.get("metadata", {}),
|
||||
role=role,
|
||||
turn_index=data.get("turn_index", 0),
|
||||
session_id=data.get("session_id"),
|
||||
parent_message_id=data.get("parent_message_id"),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_message(
|
||||
cls,
|
||||
content: str,
|
||||
role: str | MessageRole,
|
||||
turn_index: int = 0,
|
||||
session_id: str | None = None,
|
||||
timestamp: datetime | None = None,
|
||||
) -> "ConversationContext":
|
||||
"""
|
||||
Create ConversationContext from a message.
|
||||
|
||||
Args:
|
||||
content: Message content
|
||||
role: Message role (user, assistant, system, tool)
|
||||
turn_index: Position in conversation
|
||||
session_id: Session identifier
|
||||
timestamp: Message timestamp
|
||||
|
||||
Returns:
|
||||
ConversationContext instance
|
||||
"""
|
||||
if isinstance(role, str):
|
||||
role = MessageRole.from_string(role)
|
||||
|
||||
# Recent messages have higher priority
|
||||
priority = ContextPriority.NORMAL.value
|
||||
|
||||
return cls(
|
||||
content=content,
|
||||
source="conversation",
|
||||
role=role,
|
||||
turn_index=turn_index,
|
||||
session_id=session_id,
|
||||
timestamp=timestamp or datetime.now(UTC),
|
||||
priority=priority,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_history(
|
||||
cls,
|
||||
messages: list[dict[str, Any]],
|
||||
session_id: str | None = None,
|
||||
) -> list["ConversationContext"]:
|
||||
"""
|
||||
Create multiple ConversationContexts from message history.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content'
|
||||
session_id: Session identifier
|
||||
|
||||
Returns:
|
||||
List of ConversationContext instances
|
||||
"""
|
||||
contexts = []
|
||||
for i, msg in enumerate(messages):
|
||||
ctx = cls.from_message(
|
||||
content=msg.get("content", ""),
|
||||
role=msg.get("role", "user"),
|
||||
turn_index=i,
|
||||
session_id=session_id,
|
||||
timestamp=datetime.fromisoformat(msg["timestamp"])
|
||||
if "timestamp" in msg
|
||||
else None,
|
||||
)
|
||||
contexts.append(ctx)
|
||||
return contexts
|
||||
|
||||
def is_user_message(self) -> bool:
|
||||
"""Check if this is a user message."""
|
||||
return self.role == MessageRole.USER
|
||||
|
||||
def is_assistant_message(self) -> bool:
|
||||
"""Check if this is an assistant message."""
|
||||
return self.role == MessageRole.ASSISTANT
|
||||
|
||||
def is_tool_result(self) -> bool:
|
||||
"""Check if this is a tool result."""
|
||||
return self.role == MessageRole.TOOL
|
||||
|
||||
def format_for_prompt(self) -> str:
|
||||
"""
|
||||
Format message for inclusion in prompt.
|
||||
|
||||
Returns:
|
||||
Formatted message string
|
||||
"""
|
||||
role_labels = {
|
||||
MessageRole.USER: "User",
|
||||
MessageRole.ASSISTANT: "Assistant",
|
||||
MessageRole.SYSTEM: "System",
|
||||
MessageRole.TOOL: "Tool Result",
|
||||
}
|
||||
label = role_labels.get(self.role, "Unknown")
|
||||
return f"{label}: {self.content}"
|
||||
152
backend/app/services/context/types/knowledge.py
Normal file
152
backend/app/services/context/types/knowledge.py
Normal file
@@ -0,0 +1,152 @@
|
||||
"""
|
||||
Knowledge Context Type.
|
||||
|
||||
Represents RAG results from the Knowledge Base MCP server.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from .base import BaseContext, ContextPriority, ContextType
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class KnowledgeContext(BaseContext):
|
||||
"""
|
||||
Context from knowledge base / RAG retrieval.
|
||||
|
||||
Knowledge context represents chunks retrieved from the
|
||||
Knowledge Base MCP server, including:
|
||||
- Code snippets
|
||||
- Documentation
|
||||
- Previous conversations
|
||||
- External knowledge
|
||||
|
||||
Each chunk includes relevance scoring from the search.
|
||||
"""
|
||||
|
||||
# Knowledge-specific fields
|
||||
collection: str = field(default="default")
|
||||
file_type: str | None = field(default=None)
|
||||
chunk_index: int = field(default=0)
|
||||
relevance_score: float = field(default=0.0)
|
||||
search_query: str = field(default="")
|
||||
|
||||
def get_type(self) -> ContextType:
|
||||
"""Return KNOWLEDGE context type."""
|
||||
return ContextType.KNOWLEDGE
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary with knowledge-specific fields."""
|
||||
base = super().to_dict()
|
||||
base.update(
|
||||
{
|
||||
"collection": self.collection,
|
||||
"file_type": self.file_type,
|
||||
"chunk_index": self.chunk_index,
|
||||
"relevance_score": self.relevance_score,
|
||||
"search_query": self.search_query,
|
||||
}
|
||||
)
|
||||
return base
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "KnowledgeContext":
|
||||
"""Create KnowledgeContext from dictionary."""
|
||||
return cls(
|
||||
id=data.get("id", ""),
|
||||
content=data["content"],
|
||||
source=data["source"],
|
||||
timestamp=datetime.fromisoformat(data["timestamp"])
|
||||
if isinstance(data.get("timestamp"), str)
|
||||
else data.get("timestamp", datetime.now(UTC)),
|
||||
priority=data.get("priority", ContextPriority.NORMAL.value),
|
||||
metadata=data.get("metadata", {}),
|
||||
collection=data.get("collection", "default"),
|
||||
file_type=data.get("file_type"),
|
||||
chunk_index=data.get("chunk_index", 0),
|
||||
relevance_score=data.get("relevance_score", 0.0),
|
||||
search_query=data.get("search_query", ""),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_search_result(
|
||||
cls,
|
||||
result: dict[str, Any],
|
||||
query: str,
|
||||
) -> "KnowledgeContext":
|
||||
"""
|
||||
Create KnowledgeContext from a Knowledge Base search result.
|
||||
|
||||
Args:
|
||||
result: Search result from Knowledge Base MCP
|
||||
query: Search query used
|
||||
|
||||
Returns:
|
||||
KnowledgeContext instance
|
||||
"""
|
||||
return cls(
|
||||
content=result.get("content", ""),
|
||||
source=result.get("source_path", "unknown"),
|
||||
collection=result.get("collection", "default"),
|
||||
file_type=result.get("file_type"),
|
||||
chunk_index=result.get("chunk_index", 0),
|
||||
relevance_score=result.get("score", 0.0),
|
||||
search_query=query,
|
||||
metadata={
|
||||
"chunk_id": result.get("id"),
|
||||
"content_hash": result.get("content_hash"),
|
||||
},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_search_results(
|
||||
cls,
|
||||
results: list[dict[str, Any]],
|
||||
query: str,
|
||||
) -> list["KnowledgeContext"]:
|
||||
"""
|
||||
Create multiple KnowledgeContexts from search results.
|
||||
|
||||
Args:
|
||||
results: List of search results
|
||||
query: Search query used
|
||||
|
||||
Returns:
|
||||
List of KnowledgeContext instances
|
||||
"""
|
||||
return [cls.from_search_result(r, query) for r in results]
|
||||
|
||||
def is_code(self) -> bool:
|
||||
"""Check if this is code content."""
|
||||
code_types = {
|
||||
"python",
|
||||
"javascript",
|
||||
"typescript",
|
||||
"go",
|
||||
"rust",
|
||||
"java",
|
||||
"c",
|
||||
"cpp",
|
||||
}
|
||||
return self.file_type is not None and self.file_type.lower() in code_types
|
||||
|
||||
def is_documentation(self) -> bool:
|
||||
"""Check if this is documentation content."""
|
||||
doc_types = {"markdown", "rst", "txt", "md"}
|
||||
return self.file_type is not None and self.file_type.lower() in doc_types
|
||||
|
||||
def get_formatted_source(self) -> str:
|
||||
"""
|
||||
Get a formatted source string for display.
|
||||
|
||||
Returns:
|
||||
Formatted source string
|
||||
"""
|
||||
parts = [self.source]
|
||||
if self.file_type:
|
||||
parts.append(f"({self.file_type})")
|
||||
if self.collection != "default":
|
||||
parts.insert(0, f"[{self.collection}]")
|
||||
return " ".join(parts)
|
||||
138
backend/app/services/context/types/system.py
Normal file
138
backend/app/services/context/types/system.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""
|
||||
System Context Type.
|
||||
|
||||
Represents system prompts, instructions, and agent personas.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from .base import BaseContext, ContextPriority, ContextType
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class SystemContext(BaseContext):
|
||||
"""
|
||||
Context for system prompts and instructions.
|
||||
|
||||
System context typically includes:
|
||||
- Agent persona and role definitions
|
||||
- Behavioral instructions
|
||||
- Safety guidelines
|
||||
- Output format requirements
|
||||
|
||||
System context is usually high priority and should
|
||||
rarely be truncated or omitted.
|
||||
"""
|
||||
|
||||
# System context specific fields
|
||||
role: str = field(default="assistant")
|
||||
instructions_type: str = field(default="general")
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Set high priority for system context."""
|
||||
# System context defaults to high priority
|
||||
if self.priority == ContextPriority.NORMAL.value:
|
||||
self.priority = ContextPriority.HIGH.value
|
||||
|
||||
def get_type(self) -> ContextType:
|
||||
"""Return SYSTEM context type."""
|
||||
return ContextType.SYSTEM
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary with system-specific fields."""
|
||||
base = super().to_dict()
|
||||
base.update(
|
||||
{
|
||||
"role": self.role,
|
||||
"instructions_type": self.instructions_type,
|
||||
}
|
||||
)
|
||||
return base
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "SystemContext":
|
||||
"""Create SystemContext from dictionary."""
|
||||
return cls(
|
||||
id=data.get("id", ""),
|
||||
content=data["content"],
|
||||
source=data["source"],
|
||||
timestamp=datetime.fromisoformat(data["timestamp"])
|
||||
if isinstance(data.get("timestamp"), str)
|
||||
else data.get("timestamp", datetime.now(UTC)),
|
||||
priority=data.get("priority", ContextPriority.HIGH.value),
|
||||
metadata=data.get("metadata", {}),
|
||||
role=data.get("role", "assistant"),
|
||||
instructions_type=data.get("instructions_type", "general"),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create_persona(
|
||||
cls,
|
||||
name: str,
|
||||
description: str,
|
||||
capabilities: list[str] | None = None,
|
||||
constraints: list[str] | None = None,
|
||||
) -> "SystemContext":
|
||||
"""
|
||||
Create a persona system context.
|
||||
|
||||
Args:
|
||||
name: Agent name/role
|
||||
description: Role description
|
||||
capabilities: List of things the agent can do
|
||||
constraints: List of limitations
|
||||
|
||||
Returns:
|
||||
SystemContext with formatted persona
|
||||
"""
|
||||
parts = [f"You are {name}.", "", description]
|
||||
|
||||
if capabilities:
|
||||
parts.append("")
|
||||
parts.append("You can:")
|
||||
for cap in capabilities:
|
||||
parts.append(f"- {cap}")
|
||||
|
||||
if constraints:
|
||||
parts.append("")
|
||||
parts.append("You must not:")
|
||||
for constraint in constraints:
|
||||
parts.append(f"- {constraint}")
|
||||
|
||||
return cls(
|
||||
content="\n".join(parts),
|
||||
source="persona_builder",
|
||||
role=name.lower().replace(" ", "_"),
|
||||
instructions_type="persona",
|
||||
priority=ContextPriority.HIGHEST.value,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create_instructions(
|
||||
cls,
|
||||
instructions: str | list[str],
|
||||
source: str = "instructions",
|
||||
) -> "SystemContext":
|
||||
"""
|
||||
Create an instructions system context.
|
||||
|
||||
Args:
|
||||
instructions: Instructions string or list of instruction strings
|
||||
source: Source identifier
|
||||
|
||||
Returns:
|
||||
SystemContext with instructions
|
||||
"""
|
||||
if isinstance(instructions, list):
|
||||
content = "\n".join(f"- {inst}" for inst in instructions)
|
||||
else:
|
||||
content = instructions
|
||||
|
||||
return cls(
|
||||
content=content,
|
||||
source=source,
|
||||
instructions_type="instructions",
|
||||
priority=ContextPriority.HIGH.value,
|
||||
)
|
||||
193
backend/app/services/context/types/task.py
Normal file
193
backend/app/services/context/types/task.py
Normal file
@@ -0,0 +1,193 @@
|
||||
"""
|
||||
Task Context Type.
|
||||
|
||||
Represents the current task or objective for the agent.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from .base import BaseContext, ContextPriority, ContextType
|
||||
|
||||
|
||||
class TaskStatus(str, Enum):
|
||||
"""Status of a task."""
|
||||
|
||||
PENDING = "pending"
|
||||
IN_PROGRESS = "in_progress"
|
||||
BLOCKED = "blocked"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class TaskComplexity(str, Enum):
|
||||
"""Complexity level of a task."""
|
||||
|
||||
TRIVIAL = "trivial"
|
||||
SIMPLE = "simple"
|
||||
MODERATE = "moderate"
|
||||
COMPLEX = "complex"
|
||||
VERY_COMPLEX = "very_complex"
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class TaskContext(BaseContext):
|
||||
"""
|
||||
Context for the current task or objective.
|
||||
|
||||
Task context provides information about what the agent
|
||||
should accomplish, including:
|
||||
- Task description and goals
|
||||
- Acceptance criteria
|
||||
- Constraints and requirements
|
||||
- Related issue/ticket information
|
||||
"""
|
||||
|
||||
# Task-specific fields
|
||||
title: str = field(default="")
|
||||
status: TaskStatus = field(default=TaskStatus.PENDING)
|
||||
complexity: TaskComplexity = field(default=TaskComplexity.MODERATE)
|
||||
issue_id: str | None = field(default=None)
|
||||
project_id: str | None = field(default=None)
|
||||
acceptance_criteria: list[str] = field(default_factory=list)
|
||||
constraints: list[str] = field(default_factory=list)
|
||||
parent_task_id: str | None = field(default=None)
|
||||
|
||||
# Note: TaskContext should typically have HIGH priority,
|
||||
# but we don't auto-promote to allow explicit priority setting.
|
||||
# Use TaskContext.create() for default HIGH priority behavior.
|
||||
|
||||
def get_type(self) -> ContextType:
|
||||
"""Return TASK context type."""
|
||||
return ContextType.TASK
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary with task-specific fields."""
|
||||
base = super().to_dict()
|
||||
base.update(
|
||||
{
|
||||
"title": self.title,
|
||||
"status": self.status.value,
|
||||
"complexity": self.complexity.value,
|
||||
"issue_id": self.issue_id,
|
||||
"project_id": self.project_id,
|
||||
"acceptance_criteria": self.acceptance_criteria,
|
||||
"constraints": self.constraints,
|
||||
"parent_task_id": self.parent_task_id,
|
||||
}
|
||||
)
|
||||
return base
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "TaskContext":
|
||||
"""Create TaskContext from dictionary."""
|
||||
status = data.get("status", "pending")
|
||||
if isinstance(status, str):
|
||||
status = TaskStatus(status)
|
||||
|
||||
complexity = data.get("complexity", "moderate")
|
||||
if isinstance(complexity, str):
|
||||
complexity = TaskComplexity(complexity)
|
||||
|
||||
return cls(
|
||||
id=data.get("id", ""),
|
||||
content=data["content"],
|
||||
source=data.get("source", "task"),
|
||||
timestamp=datetime.fromisoformat(data["timestamp"])
|
||||
if isinstance(data.get("timestamp"), str)
|
||||
else data.get("timestamp", datetime.now(UTC)),
|
||||
priority=data.get("priority", ContextPriority.HIGH.value),
|
||||
metadata=data.get("metadata", {}),
|
||||
title=data.get("title", ""),
|
||||
status=status,
|
||||
complexity=complexity,
|
||||
issue_id=data.get("issue_id"),
|
||||
project_id=data.get("project_id"),
|
||||
acceptance_criteria=data.get("acceptance_criteria", []),
|
||||
constraints=data.get("constraints", []),
|
||||
parent_task_id=data.get("parent_task_id"),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
title: str,
|
||||
description: str,
|
||||
acceptance_criteria: list[str] | None = None,
|
||||
constraints: list[str] | None = None,
|
||||
issue_id: str | None = None,
|
||||
project_id: str | None = None,
|
||||
complexity: TaskComplexity | str = TaskComplexity.MODERATE,
|
||||
) -> "TaskContext":
|
||||
"""
|
||||
Create a task context.
|
||||
|
||||
Args:
|
||||
title: Task title
|
||||
description: Task description
|
||||
acceptance_criteria: List of acceptance criteria
|
||||
constraints: List of constraints
|
||||
issue_id: Related issue ID
|
||||
project_id: Project ID
|
||||
complexity: Task complexity
|
||||
|
||||
Returns:
|
||||
TaskContext instance
|
||||
"""
|
||||
if isinstance(complexity, str):
|
||||
complexity = TaskComplexity(complexity)
|
||||
|
||||
return cls(
|
||||
content=description,
|
||||
source=f"task:{issue_id}" if issue_id else "task",
|
||||
title=title,
|
||||
status=TaskStatus.IN_PROGRESS,
|
||||
complexity=complexity,
|
||||
issue_id=issue_id,
|
||||
project_id=project_id,
|
||||
acceptance_criteria=acceptance_criteria or [],
|
||||
constraints=constraints or [],
|
||||
)
|
||||
|
||||
def format_for_prompt(self) -> str:
|
||||
"""
|
||||
Format task for inclusion in prompt.
|
||||
|
||||
Returns:
|
||||
Formatted task string
|
||||
"""
|
||||
parts = []
|
||||
|
||||
if self.title:
|
||||
parts.append(f"Task: {self.title}")
|
||||
parts.append("")
|
||||
|
||||
parts.append(self.content)
|
||||
|
||||
if self.acceptance_criteria:
|
||||
parts.append("")
|
||||
parts.append("Acceptance Criteria:")
|
||||
for criterion in self.acceptance_criteria:
|
||||
parts.append(f"- {criterion}")
|
||||
|
||||
if self.constraints:
|
||||
parts.append("")
|
||||
parts.append("Constraints:")
|
||||
for constraint in self.constraints:
|
||||
parts.append(f"- {constraint}")
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
def is_active(self) -> bool:
|
||||
"""Check if task is currently active."""
|
||||
return self.status in (TaskStatus.PENDING, TaskStatus.IN_PROGRESS)
|
||||
|
||||
def is_complete(self) -> bool:
|
||||
"""Check if task is complete."""
|
||||
return self.status == TaskStatus.COMPLETED
|
||||
|
||||
def is_blocked(self) -> bool:
|
||||
"""Check if task is blocked."""
|
||||
return self.status == TaskStatus.BLOCKED
|
||||
211
backend/app/services/context/types/tool.py
Normal file
211
backend/app/services/context/types/tool.py
Normal file
@@ -0,0 +1,211 @@
|
||||
"""
|
||||
Tool Context Type.
|
||||
|
||||
Represents available tools and recent tool execution results.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from .base import BaseContext, ContextPriority, ContextType
|
||||
|
||||
|
||||
class ToolResultStatus(str, Enum):
|
||||
"""Status of a tool execution result."""
|
||||
|
||||
SUCCESS = "success"
|
||||
ERROR = "error"
|
||||
TIMEOUT = "timeout"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class ToolContext(BaseContext):
|
||||
"""
|
||||
Context for tools and tool execution results.
|
||||
|
||||
Tool context includes:
|
||||
- Tool descriptions and parameters
|
||||
- Recent tool execution results
|
||||
- Tool availability information
|
||||
|
||||
This helps the LLM understand what tools are available
|
||||
and what results previous tool calls produced.
|
||||
"""
|
||||
|
||||
# Tool-specific fields
|
||||
tool_name: str = field(default="")
|
||||
tool_description: str = field(default="")
|
||||
is_result: bool = field(default=False)
|
||||
result_status: ToolResultStatus | None = field(default=None)
|
||||
execution_time_ms: float | None = field(default=None)
|
||||
parameters: dict[str, Any] = field(default_factory=dict)
|
||||
server_name: str | None = field(default=None)
|
||||
|
||||
def get_type(self) -> ContextType:
|
||||
"""Return TOOL context type."""
|
||||
return ContextType.TOOL
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary with tool-specific fields."""
|
||||
base = super().to_dict()
|
||||
base.update(
|
||||
{
|
||||
"tool_name": self.tool_name,
|
||||
"tool_description": self.tool_description,
|
||||
"is_result": self.is_result,
|
||||
"result_status": self.result_status.value
|
||||
if self.result_status
|
||||
else None,
|
||||
"execution_time_ms": self.execution_time_ms,
|
||||
"parameters": self.parameters,
|
||||
"server_name": self.server_name,
|
||||
}
|
||||
)
|
||||
return base
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "ToolContext":
|
||||
"""Create ToolContext from dictionary."""
|
||||
result_status = data.get("result_status")
|
||||
if isinstance(result_status, str):
|
||||
result_status = ToolResultStatus(result_status)
|
||||
|
||||
return cls(
|
||||
id=data.get("id", ""),
|
||||
content=data["content"],
|
||||
source=data.get("source", "tool"),
|
||||
timestamp=datetime.fromisoformat(data["timestamp"])
|
||||
if isinstance(data.get("timestamp"), str)
|
||||
else data.get("timestamp", datetime.now(UTC)),
|
||||
priority=data.get("priority", ContextPriority.NORMAL.value),
|
||||
metadata=data.get("metadata", {}),
|
||||
tool_name=data.get("tool_name", ""),
|
||||
tool_description=data.get("tool_description", ""),
|
||||
is_result=data.get("is_result", False),
|
||||
result_status=result_status,
|
||||
execution_time_ms=data.get("execution_time_ms"),
|
||||
parameters=data.get("parameters", {}),
|
||||
server_name=data.get("server_name"),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_tool_definition(
|
||||
cls,
|
||||
name: str,
|
||||
description: str,
|
||||
parameters: dict[str, Any] | None = None,
|
||||
server_name: str | None = None,
|
||||
) -> "ToolContext":
|
||||
"""
|
||||
Create a ToolContext from a tool definition.
|
||||
|
||||
Args:
|
||||
name: Tool name
|
||||
description: Tool description
|
||||
parameters: Tool parameter schema
|
||||
server_name: MCP server name
|
||||
|
||||
Returns:
|
||||
ToolContext instance
|
||||
"""
|
||||
# Format content as tool documentation
|
||||
content_parts = [f"Tool: {name}", "", description]
|
||||
|
||||
if parameters:
|
||||
content_parts.append("")
|
||||
content_parts.append("Parameters:")
|
||||
for param_name, param_info in parameters.items():
|
||||
param_type = param_info.get("type", "any")
|
||||
param_desc = param_info.get("description", "")
|
||||
required = param_info.get("required", False)
|
||||
req_marker = " (required)" if required else ""
|
||||
content_parts.append(f" - {param_name}: {param_type}{req_marker}")
|
||||
if param_desc:
|
||||
content_parts.append(f" {param_desc}")
|
||||
|
||||
return cls(
|
||||
content="\n".join(content_parts),
|
||||
source=f"tool:{server_name}:{name}" if server_name else f"tool:{name}",
|
||||
tool_name=name,
|
||||
tool_description=description,
|
||||
is_result=False,
|
||||
parameters=parameters or {},
|
||||
server_name=server_name,
|
||||
priority=ContextPriority.LOW.value,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_tool_result(
|
||||
cls,
|
||||
tool_name: str,
|
||||
result: Any,
|
||||
status: ToolResultStatus = ToolResultStatus.SUCCESS,
|
||||
execution_time_ms: float | None = None,
|
||||
parameters: dict[str, Any] | None = None,
|
||||
server_name: str | None = None,
|
||||
) -> "ToolContext":
|
||||
"""
|
||||
Create a ToolContext from a tool execution result.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool that was executed
|
||||
result: Result content (will be converted to string)
|
||||
status: Execution status
|
||||
execution_time_ms: Execution time in milliseconds
|
||||
parameters: Parameters that were passed to the tool
|
||||
server_name: MCP server name
|
||||
|
||||
Returns:
|
||||
ToolContext instance
|
||||
"""
|
||||
# Convert result to string content
|
||||
if isinstance(result, str):
|
||||
content = result
|
||||
elif isinstance(result, dict):
|
||||
import json
|
||||
|
||||
try:
|
||||
content = json.dumps(result, indent=2)
|
||||
except (TypeError, ValueError):
|
||||
content = str(result)
|
||||
else:
|
||||
content = str(result)
|
||||
|
||||
return cls(
|
||||
content=content,
|
||||
source=f"tool_result:{server_name}:{tool_name}"
|
||||
if server_name
|
||||
else f"tool_result:{tool_name}",
|
||||
tool_name=tool_name,
|
||||
is_result=True,
|
||||
result_status=status,
|
||||
execution_time_ms=execution_time_ms,
|
||||
parameters=parameters or {},
|
||||
server_name=server_name,
|
||||
priority=ContextPriority.HIGH.value, # Recent results are high priority
|
||||
)
|
||||
|
||||
def is_successful(self) -> bool:
|
||||
"""Check if this is a successful tool result."""
|
||||
return self.is_result and self.result_status == ToolResultStatus.SUCCESS
|
||||
|
||||
def is_error(self) -> bool:
|
||||
"""Check if this is an error result."""
|
||||
return self.is_result and self.result_status == ToolResultStatus.ERROR
|
||||
|
||||
def format_for_prompt(self) -> str:
|
||||
"""
|
||||
Format tool context for inclusion in prompt.
|
||||
|
||||
Returns:
|
||||
Formatted tool string
|
||||
"""
|
||||
if self.is_result:
|
||||
status_str = self.result_status.value if self.result_status else "unknown"
|
||||
header = f"Tool Result ({self.tool_name}, {status_str}):"
|
||||
return f"{header}\n{self.content}"
|
||||
else:
|
||||
return self.content
|
||||
611
backend/app/services/event_bus.py
Normal file
611
backend/app/services/event_bus.py
Normal file
@@ -0,0 +1,611 @@
|
||||
"""
|
||||
EventBus service for Redis Pub/Sub communication.
|
||||
|
||||
This module provides a centralized event bus for publishing and subscribing to
|
||||
events across the Syndarix platform. It uses Redis Pub/Sub for real-time
|
||||
message delivery between services, agents, and the frontend.
|
||||
|
||||
Architecture:
|
||||
- Publishers emit events to project/agent-specific Redis channels
|
||||
- SSE endpoints subscribe to channels and stream events to clients
|
||||
- Events include metadata for reconnection support (Last-Event-ID)
|
||||
- Events are typed with the EventType enum for consistency
|
||||
|
||||
Usage:
|
||||
# Publishing events
|
||||
event_bus = EventBus()
|
||||
await event_bus.connect()
|
||||
|
||||
event = event_bus.create_event(
|
||||
event_type=EventType.AGENT_MESSAGE,
|
||||
project_id=project_id,
|
||||
actor_type="agent",
|
||||
payload={"message": "Processing..."}
|
||||
)
|
||||
await event_bus.publish(event_bus.get_project_channel(project_id), event)
|
||||
|
||||
# Subscribing to events
|
||||
async for event in event_bus.subscribe(["project:123", "agent:456"]):
|
||||
handle_event(event)
|
||||
|
||||
# Cleanup
|
||||
await event_bus.disconnect()
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import redis.asyncio as redis
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.core.config import settings
|
||||
from app.schemas.events import ActorType, Event, EventType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EventBusError(Exception):
|
||||
"""Base exception for EventBus errors."""
|
||||
|
||||
|
||||
class EventBusConnectionError(EventBusError):
|
||||
"""Raised when connection to Redis fails."""
|
||||
|
||||
|
||||
class EventBusPublishError(EventBusError):
|
||||
"""Raised when publishing an event fails."""
|
||||
|
||||
|
||||
class EventBusSubscriptionError(EventBusError):
|
||||
"""Raised when subscribing to channels fails."""
|
||||
|
||||
|
||||
class EventBus:
|
||||
"""
|
||||
EventBus for Redis Pub/Sub communication.
|
||||
|
||||
Provides methods to publish events to channels and subscribe to events
|
||||
from multiple channels. Handles connection management, serialization,
|
||||
and error recovery.
|
||||
|
||||
This class provides:
|
||||
- Event publishing to project/agent-specific channels
|
||||
- Subscription management for SSE endpoints
|
||||
- Reconnection support via event IDs (Last-Event-ID)
|
||||
- Keepalive messages for connection health
|
||||
- Type-safe event creation with the Event schema
|
||||
|
||||
Attributes:
|
||||
redis_url: Redis connection URL
|
||||
redis_client: Async Redis client instance
|
||||
pubsub: Redis PubSub instance for subscriptions
|
||||
"""
|
||||
|
||||
# Channel prefixes for different entity types
|
||||
PROJECT_CHANNEL_PREFIX = "project"
|
||||
AGENT_CHANNEL_PREFIX = "agent"
|
||||
USER_CHANNEL_PREFIX = "user"
|
||||
GLOBAL_CHANNEL = "syndarix:global"
|
||||
|
||||
def __init__(self, redis_url: str | None = None) -> None:
|
||||
"""
|
||||
Initialize the EventBus.
|
||||
|
||||
Args:
|
||||
redis_url: Redis connection URL. Defaults to settings.REDIS_URL.
|
||||
"""
|
||||
self.redis_url = redis_url or settings.REDIS_URL
|
||||
self._redis_client: redis.Redis | None = None
|
||||
self._pubsub: redis.client.PubSub | None = None
|
||||
self._connected = False
|
||||
|
||||
@property
|
||||
def redis_client(self) -> redis.Redis:
|
||||
"""Get the Redis client, raising if not connected."""
|
||||
if self._redis_client is None:
|
||||
raise EventBusConnectionError(
|
||||
"EventBus not connected. Call connect() first."
|
||||
)
|
||||
return self._redis_client
|
||||
|
||||
@property
|
||||
def pubsub(self) -> redis.client.PubSub:
|
||||
"""Get the PubSub instance, raising if not connected."""
|
||||
if self._pubsub is None:
|
||||
raise EventBusConnectionError(
|
||||
"EventBus not connected. Call connect() first."
|
||||
)
|
||||
return self._pubsub
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if the EventBus is connected to Redis."""
|
||||
return self._connected and self._redis_client is not None
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""
|
||||
Connect to Redis and initialize the PubSub client.
|
||||
|
||||
Raises:
|
||||
EventBusConnectionError: If connection to Redis fails.
|
||||
"""
|
||||
if self._connected:
|
||||
logger.debug("EventBus already connected")
|
||||
return
|
||||
|
||||
try:
|
||||
self._redis_client = redis.from_url(
|
||||
self.redis_url,
|
||||
encoding="utf-8",
|
||||
decode_responses=True,
|
||||
)
|
||||
# Test connection - ping() returns a coroutine for async Redis
|
||||
ping_result = self._redis_client.ping()
|
||||
if hasattr(ping_result, "__await__"):
|
||||
await ping_result
|
||||
self._pubsub = self._redis_client.pubsub()
|
||||
self._connected = True
|
||||
logger.info("EventBus connected to Redis")
|
||||
except redis.ConnectionError as e:
|
||||
logger.error(f"Failed to connect to Redis: {e}", exc_info=True)
|
||||
raise EventBusConnectionError(f"Failed to connect to Redis: {e}") from e
|
||||
except redis.RedisError as e:
|
||||
logger.error(f"Redis error during connection: {e}", exc_info=True)
|
||||
raise EventBusConnectionError(f"Redis error: {e}") from e
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""
|
||||
Disconnect from Redis and cleanup resources.
|
||||
"""
|
||||
if self._pubsub:
|
||||
try:
|
||||
await self._pubsub.unsubscribe()
|
||||
await self._pubsub.close()
|
||||
except redis.RedisError as e:
|
||||
logger.warning(f"Error closing PubSub: {e}")
|
||||
finally:
|
||||
self._pubsub = None
|
||||
|
||||
if self._redis_client:
|
||||
try:
|
||||
await self._redis_client.aclose()
|
||||
except redis.RedisError as e:
|
||||
logger.warning(f"Error closing Redis client: {e}")
|
||||
finally:
|
||||
self._redis_client = None
|
||||
|
||||
self._connected = False
|
||||
logger.info("EventBus disconnected from Redis")
|
||||
|
||||
@asynccontextmanager
|
||||
async def connection(self) -> AsyncIterator["EventBus"]:
|
||||
"""
|
||||
Context manager for automatic connection handling.
|
||||
|
||||
Usage:
|
||||
async with event_bus.connection() as bus:
|
||||
await bus.publish(channel, event)
|
||||
"""
|
||||
await self.connect()
|
||||
try:
|
||||
yield self
|
||||
finally:
|
||||
await self.disconnect()
|
||||
|
||||
def get_project_channel(self, project_id: UUID | str) -> str:
|
||||
"""
|
||||
Get the channel name for a project.
|
||||
|
||||
Args:
|
||||
project_id: The project UUID or string
|
||||
|
||||
Returns:
|
||||
Channel name string in format "project:{uuid}"
|
||||
"""
|
||||
return f"{self.PROJECT_CHANNEL_PREFIX}:{project_id}"
|
||||
|
||||
def get_agent_channel(self, agent_id: UUID | str) -> str:
|
||||
"""
|
||||
Get the channel name for an agent instance.
|
||||
|
||||
Args:
|
||||
agent_id: The agent instance UUID or string
|
||||
|
||||
Returns:
|
||||
Channel name string in format "agent:{uuid}"
|
||||
"""
|
||||
return f"{self.AGENT_CHANNEL_PREFIX}:{agent_id}"
|
||||
|
||||
def get_user_channel(self, user_id: UUID | str) -> str:
|
||||
"""
|
||||
Get the channel name for a user (personal notifications).
|
||||
|
||||
Args:
|
||||
user_id: The user UUID or string
|
||||
|
||||
Returns:
|
||||
Channel name string in format "user:{uuid}"
|
||||
"""
|
||||
return f"{self.USER_CHANNEL_PREFIX}:{user_id}"
|
||||
|
||||
@staticmethod
|
||||
def create_event(
|
||||
event_type: EventType,
|
||||
project_id: UUID,
|
||||
actor_type: ActorType,
|
||||
payload: dict | None = None,
|
||||
actor_id: UUID | None = None,
|
||||
event_id: str | None = None,
|
||||
timestamp: datetime | None = None,
|
||||
) -> Event:
|
||||
"""
|
||||
Factory method to create a new Event.
|
||||
|
||||
Args:
|
||||
event_type: The type of event
|
||||
project_id: The project this event belongs to
|
||||
actor_type: Type of actor ('agent', 'user', or 'system')
|
||||
payload: Event-specific payload data
|
||||
actor_id: ID of the agent or user who triggered the event
|
||||
event_id: Optional custom event ID (UUID string)
|
||||
timestamp: Optional custom timestamp (defaults to now UTC)
|
||||
|
||||
Returns:
|
||||
A new Event instance
|
||||
"""
|
||||
return Event(
|
||||
id=event_id or str(uuid4()),
|
||||
type=event_type,
|
||||
timestamp=timestamp or datetime.now(UTC),
|
||||
project_id=project_id,
|
||||
actor_id=actor_id,
|
||||
actor_type=actor_type,
|
||||
payload=payload or {},
|
||||
)
|
||||
|
||||
def _serialize_event(self, event: Event) -> str:
|
||||
"""
|
||||
Serialize an event to JSON string.
|
||||
|
||||
Args:
|
||||
event: The Event to serialize
|
||||
|
||||
Returns:
|
||||
JSON string representation of the event
|
||||
"""
|
||||
return event.model_dump_json()
|
||||
|
||||
def _deserialize_event(self, data: str) -> Event:
|
||||
"""
|
||||
Deserialize a JSON string to an Event.
|
||||
|
||||
Args:
|
||||
data: JSON string to deserialize
|
||||
|
||||
Returns:
|
||||
Deserialized Event instance
|
||||
|
||||
Raises:
|
||||
ValidationError: If the data doesn't match the Event schema
|
||||
"""
|
||||
return Event.model_validate_json(data)
|
||||
|
||||
async def publish(self, channel: str, event: Event) -> int:
|
||||
"""
|
||||
Publish an event to a channel.
|
||||
|
||||
Args:
|
||||
channel: The channel name to publish to
|
||||
event: The Event to publish
|
||||
|
||||
Returns:
|
||||
Number of subscribers that received the message
|
||||
|
||||
Raises:
|
||||
EventBusConnectionError: If not connected to Redis
|
||||
EventBusPublishError: If publishing fails
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise EventBusConnectionError("EventBus not connected")
|
||||
|
||||
try:
|
||||
message = self._serialize_event(event)
|
||||
subscriber_count = await self.redis_client.publish(channel, message)
|
||||
logger.debug(
|
||||
f"Published event {event.type} to {channel} "
|
||||
f"(received by {subscriber_count} subscribers)"
|
||||
)
|
||||
return subscriber_count
|
||||
except redis.RedisError as e:
|
||||
logger.error(f"Failed to publish event to {channel}: {e}", exc_info=True)
|
||||
raise EventBusPublishError(f"Failed to publish event: {e}") from e
|
||||
|
||||
async def publish_to_project(self, event: Event) -> int:
|
||||
"""
|
||||
Publish an event to the project's channel.
|
||||
|
||||
Convenience method that publishes to the project channel based on
|
||||
the event's project_id.
|
||||
|
||||
Args:
|
||||
event: The Event to publish (must have project_id set)
|
||||
|
||||
Returns:
|
||||
Number of subscribers that received the message
|
||||
"""
|
||||
channel = self.get_project_channel(event.project_id)
|
||||
return await self.publish(channel, event)
|
||||
|
||||
async def publish_multi(self, channels: list[str], event: Event) -> dict[str, int]:
|
||||
"""
|
||||
Publish an event to multiple channels.
|
||||
|
||||
Args:
|
||||
channels: List of channel names to publish to
|
||||
event: The Event to publish
|
||||
|
||||
Returns:
|
||||
Dictionary mapping channel names to subscriber counts
|
||||
"""
|
||||
results = {}
|
||||
for channel in channels:
|
||||
try:
|
||||
results[channel] = await self.publish(channel, event)
|
||||
except EventBusPublishError as e:
|
||||
logger.warning(f"Failed to publish to {channel}: {e}")
|
||||
results[channel] = 0
|
||||
return results
|
||||
|
||||
async def subscribe(
|
||||
self, channels: list[str], *, max_wait: float | None = None
|
||||
) -> AsyncIterator[Event]:
|
||||
"""
|
||||
Subscribe to one or more channels and yield events.
|
||||
|
||||
This is an async generator that yields Event objects as they arrive.
|
||||
Use max_wait to limit how long to wait for messages.
|
||||
|
||||
Args:
|
||||
channels: List of channel names to subscribe to
|
||||
max_wait: Optional maximum wait time in seconds for each message.
|
||||
If None, waits indefinitely.
|
||||
|
||||
Yields:
|
||||
Event objects received from subscribed channels
|
||||
|
||||
Raises:
|
||||
EventBusConnectionError: If not connected to Redis
|
||||
EventBusSubscriptionError: If subscription fails
|
||||
|
||||
Example:
|
||||
async for event in event_bus.subscribe(["project:123"], max_wait=30):
|
||||
print(f"Received: {event.type}")
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise EventBusConnectionError("EventBus not connected")
|
||||
|
||||
# Create a new pubsub for this subscription
|
||||
subscription_pubsub = self.redis_client.pubsub()
|
||||
|
||||
try:
|
||||
await subscription_pubsub.subscribe(*channels)
|
||||
logger.info(f"Subscribed to channels: {channels}")
|
||||
except redis.RedisError as e:
|
||||
logger.error(f"Failed to subscribe to channels: {e}", exc_info=True)
|
||||
await subscription_pubsub.close()
|
||||
raise EventBusSubscriptionError(f"Failed to subscribe: {e}") from e
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
if max_wait is not None:
|
||||
async with asyncio.timeout(max_wait):
|
||||
message = await subscription_pubsub.get_message(
|
||||
ignore_subscribe_messages=True, timeout=1.0
|
||||
)
|
||||
else:
|
||||
message = await subscription_pubsub.get_message(
|
||||
ignore_subscribe_messages=True, timeout=1.0
|
||||
)
|
||||
except TimeoutError:
|
||||
# Timeout reached, stop iteration
|
||||
return
|
||||
|
||||
if message is None:
|
||||
continue
|
||||
|
||||
if message["type"] == "message":
|
||||
try:
|
||||
event = self._deserialize_event(message["data"])
|
||||
yield event
|
||||
except ValidationError as e:
|
||||
logger.warning(
|
||||
f"Invalid event data received: {e}",
|
||||
extra={"channel": message.get("channel")},
|
||||
)
|
||||
continue
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(
|
||||
f"Failed to decode event JSON: {e}",
|
||||
extra={"channel": message.get("channel")},
|
||||
)
|
||||
continue
|
||||
finally:
|
||||
try:
|
||||
await subscription_pubsub.unsubscribe(*channels)
|
||||
await subscription_pubsub.close()
|
||||
logger.debug(f"Unsubscribed from channels: {channels}")
|
||||
except redis.RedisError as e:
|
||||
logger.warning(f"Error unsubscribing from channels: {e}")
|
||||
|
||||
async def subscribe_sse(
|
||||
self,
|
||||
project_id: str | UUID,
|
||||
last_event_id: str | None = None,
|
||||
keepalive_interval: int = 30,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Subscribe to events for a project in SSE format.
|
||||
|
||||
This is an async generator that yields SSE-formatted event strings.
|
||||
It includes keepalive messages at the specified interval.
|
||||
|
||||
Args:
|
||||
project_id: The project to subscribe to
|
||||
last_event_id: Optional last received event ID for reconnection
|
||||
keepalive_interval: Seconds between keepalive messages (default 30)
|
||||
|
||||
Yields:
|
||||
SSE-formatted event strings (ready to send to client)
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise EventBusConnectionError("EventBus not connected")
|
||||
|
||||
project_id_str = str(project_id)
|
||||
channel = self.get_project_channel(project_id_str)
|
||||
|
||||
subscription_pubsub = self.redis_client.pubsub()
|
||||
await subscription_pubsub.subscribe(channel)
|
||||
|
||||
logger.info(
|
||||
f"Subscribed to SSE events for project {project_id_str} "
|
||||
f"(last_event_id={last_event_id})"
|
||||
)
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
# Wait for messages with a timeout for keepalive
|
||||
message = await asyncio.wait_for(
|
||||
subscription_pubsub.get_message(ignore_subscribe_messages=True),
|
||||
timeout=keepalive_interval,
|
||||
)
|
||||
|
||||
if message is not None and message["type"] == "message":
|
||||
event_data = message["data"]
|
||||
|
||||
# If reconnecting, check if we should skip this event
|
||||
if last_event_id:
|
||||
try:
|
||||
event_dict = json.loads(event_data)
|
||||
if event_dict.get("id") == last_event_id:
|
||||
# Found the last event, start yielding from next
|
||||
last_event_id = None
|
||||
continue
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
yield event_data
|
||||
|
||||
except TimeoutError:
|
||||
# Send keepalive comment
|
||||
yield "" # Empty string signals keepalive
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"SSE subscription cancelled for project {project_id_str}")
|
||||
raise
|
||||
finally:
|
||||
await subscription_pubsub.unsubscribe(channel)
|
||||
await subscription_pubsub.close()
|
||||
logger.info(f"Unsubscribed SSE from project {project_id_str}")
|
||||
|
||||
async def subscribe_with_callback(
|
||||
self,
|
||||
channels: list[str],
|
||||
callback: Any, # Callable[[Event], Awaitable[None]]
|
||||
stop_event: asyncio.Event | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Subscribe to channels and process events with a callback.
|
||||
|
||||
This method runs until stop_event is set or an unrecoverable error occurs.
|
||||
|
||||
Args:
|
||||
channels: List of channel names to subscribe to
|
||||
callback: Async function to call for each event
|
||||
stop_event: Optional asyncio.Event to signal stop
|
||||
|
||||
Example:
|
||||
async def handle_event(event: Event):
|
||||
print(f"Handling: {event.type}")
|
||||
|
||||
stop = asyncio.Event()
|
||||
asyncio.create_task(
|
||||
event_bus.subscribe_with_callback(["project:123"], handle_event, stop)
|
||||
)
|
||||
# Later...
|
||||
stop.set()
|
||||
"""
|
||||
if stop_event is None:
|
||||
stop_event = asyncio.Event()
|
||||
|
||||
try:
|
||||
async for event in self.subscribe(channels):
|
||||
if stop_event.is_set():
|
||||
break
|
||||
try:
|
||||
await callback(event)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in event callback: {e}", exc_info=True)
|
||||
except EventBusSubscriptionError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in subscription loop: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
# Singleton instance for application-wide use
|
||||
_event_bus: EventBus | None = None
|
||||
|
||||
|
||||
def get_event_bus() -> EventBus:
|
||||
"""
|
||||
Get the singleton EventBus instance.
|
||||
|
||||
Creates a new instance if one doesn't exist. Note that you still need
|
||||
to call connect() before using the EventBus.
|
||||
|
||||
Returns:
|
||||
The singleton EventBus instance
|
||||
"""
|
||||
global _event_bus
|
||||
if _event_bus is None:
|
||||
_event_bus = EventBus()
|
||||
return _event_bus
|
||||
|
||||
|
||||
async def get_connected_event_bus() -> EventBus:
|
||||
"""
|
||||
Get a connected EventBus instance.
|
||||
|
||||
Ensures the EventBus is connected before returning. For use in
|
||||
FastAPI dependency injection.
|
||||
|
||||
Returns:
|
||||
A connected EventBus instance
|
||||
|
||||
Raises:
|
||||
EventBusConnectionError: If connection fails
|
||||
"""
|
||||
event_bus = get_event_bus()
|
||||
if not event_bus.is_connected:
|
||||
await event_bus.connect()
|
||||
return event_bus
|
||||
|
||||
|
||||
async def close_event_bus() -> None:
|
||||
"""
|
||||
Close the global EventBus instance.
|
||||
|
||||
Should be called during application shutdown.
|
||||
"""
|
||||
global _event_bus
|
||||
if _event_bus is not None:
|
||||
await _event_bus.disconnect()
|
||||
_event_bus = None
|
||||
85
backend/app/services/mcp/__init__.py
Normal file
85
backend/app/services/mcp/__init__.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""
|
||||
MCP Client Service Package
|
||||
|
||||
Provides infrastructure for communicating with MCP (Model Context Protocol)
|
||||
servers. This is the foundation for AI agent tool integration.
|
||||
|
||||
Usage:
|
||||
from app.services.mcp import get_mcp_client, MCPClientManager
|
||||
|
||||
# In FastAPI route
|
||||
async def my_route(mcp: MCPClientManager = Depends(get_mcp_client)):
|
||||
result = await mcp.call_tool("llm-gateway", "chat", {"prompt": "Hello"})
|
||||
|
||||
# Direct usage
|
||||
manager = MCPClientManager()
|
||||
await manager.initialize()
|
||||
result = await manager.call_tool("issues", "create_issue", {...})
|
||||
await manager.shutdown()
|
||||
"""
|
||||
|
||||
from .client_manager import (
|
||||
MCPClientManager,
|
||||
ServerHealth,
|
||||
get_mcp_client,
|
||||
reset_mcp_client,
|
||||
shutdown_mcp_client,
|
||||
)
|
||||
from .config import (
|
||||
MCPConfig,
|
||||
MCPServerConfig,
|
||||
TransportType,
|
||||
create_default_config,
|
||||
load_mcp_config,
|
||||
)
|
||||
from .connection import ConnectionPool, ConnectionState, MCPConnection
|
||||
from .exceptions import (
|
||||
MCPCircuitOpenError,
|
||||
MCPConnectionError,
|
||||
MCPError,
|
||||
MCPServerNotFoundError,
|
||||
MCPTimeoutError,
|
||||
MCPToolError,
|
||||
MCPToolNotFoundError,
|
||||
MCPValidationError,
|
||||
)
|
||||
from .registry import MCPServerRegistry, ServerCapabilities, get_registry
|
||||
from .routing import AsyncCircuitBreaker, CircuitState, ToolInfo, ToolResult, ToolRouter
|
||||
|
||||
__all__ = [
|
||||
# Main facade
|
||||
"MCPClientManager",
|
||||
"get_mcp_client",
|
||||
"shutdown_mcp_client",
|
||||
"reset_mcp_client",
|
||||
"ServerHealth",
|
||||
# Configuration
|
||||
"MCPConfig",
|
||||
"MCPServerConfig",
|
||||
"TransportType",
|
||||
"load_mcp_config",
|
||||
"create_default_config",
|
||||
# Registry
|
||||
"MCPServerRegistry",
|
||||
"ServerCapabilities",
|
||||
"get_registry",
|
||||
# Connection
|
||||
"ConnectionPool",
|
||||
"ConnectionState",
|
||||
"MCPConnection",
|
||||
# Routing
|
||||
"ToolRouter",
|
||||
"ToolInfo",
|
||||
"ToolResult",
|
||||
"AsyncCircuitBreaker",
|
||||
"CircuitState",
|
||||
# Exceptions
|
||||
"MCPError",
|
||||
"MCPConnectionError",
|
||||
"MCPTimeoutError",
|
||||
"MCPToolError",
|
||||
"MCPServerNotFoundError",
|
||||
"MCPToolNotFoundError",
|
||||
"MCPCircuitOpenError",
|
||||
"MCPValidationError",
|
||||
]
|
||||
430
backend/app/services/mcp/client_manager.py
Normal file
430
backend/app/services/mcp/client_manager.py
Normal file
@@ -0,0 +1,430 @@
|
||||
"""
|
||||
MCP Client Manager
|
||||
|
||||
Main facade for all MCP operations. Manages server connections,
|
||||
tool discovery, and provides a unified interface for tool calls.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from .config import MCPConfig, MCPServerConfig, load_mcp_config
|
||||
from .connection import ConnectionPool, ConnectionState
|
||||
from .exceptions import MCPServerNotFoundError
|
||||
from .registry import MCPServerRegistry, get_registry
|
||||
from .routing import ToolInfo, ToolResult, ToolRouter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServerHealth:
|
||||
"""Health status for an MCP server."""
|
||||
|
||||
name: str
|
||||
healthy: bool
|
||||
state: str
|
||||
url: str
|
||||
error: str | None = None
|
||||
tools_count: int = 0
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"name": self.name,
|
||||
"healthy": self.healthy,
|
||||
"state": self.state,
|
||||
"url": self.url,
|
||||
"error": self.error,
|
||||
"tools_count": self.tools_count,
|
||||
}
|
||||
|
||||
|
||||
class MCPClientManager:
|
||||
"""
|
||||
Central manager for all MCP client operations.
|
||||
|
||||
Provides a unified interface for:
|
||||
- Connecting to MCP servers
|
||||
- Discovering and calling tools
|
||||
- Health monitoring
|
||||
- Connection lifecycle management
|
||||
|
||||
This is the main entry point for MCP operations in the application.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: MCPConfig | None = None,
|
||||
registry: MCPServerRegistry | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the MCP client manager.
|
||||
|
||||
Args:
|
||||
config: Optional MCP configuration. If None, loads from default.
|
||||
registry: Optional registry instance. If None, uses singleton.
|
||||
"""
|
||||
self._registry = registry or get_registry()
|
||||
self._pool = ConnectionPool()
|
||||
self._router: ToolRouter | None = None
|
||||
self._initialized = False
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# Load configuration if provided
|
||||
if config is not None:
|
||||
self._registry.load_config(config)
|
||||
|
||||
@property
|
||||
def is_initialized(self) -> bool:
|
||||
"""Check if the manager is initialized."""
|
||||
return self._initialized
|
||||
|
||||
async def initialize(self, config: MCPConfig | None = None) -> None:
|
||||
"""
|
||||
Initialize the MCP client manager.
|
||||
|
||||
Loads configuration, creates connections, and discovers tools.
|
||||
|
||||
Args:
|
||||
config: Optional configuration to load
|
||||
"""
|
||||
async with self._lock:
|
||||
if self._initialized:
|
||||
logger.warning("MCPClientManager already initialized")
|
||||
return
|
||||
|
||||
logger.info("Initializing MCP Client Manager")
|
||||
|
||||
# Load configuration
|
||||
if config is not None:
|
||||
self._registry.load_config(config)
|
||||
elif len(self._registry.list_servers()) == 0:
|
||||
# Try to load from default location
|
||||
self._registry.load_config(load_mcp_config())
|
||||
|
||||
# Create router
|
||||
self._router = ToolRouter(self._registry, self._pool)
|
||||
|
||||
# Connect to all enabled servers
|
||||
await self._connect_all_servers()
|
||||
|
||||
# Discover tools from all servers
|
||||
if self._router:
|
||||
await self._router.discover_tools()
|
||||
|
||||
self._initialized = True
|
||||
logger.info(
|
||||
"MCP Client Manager initialized with %d servers",
|
||||
len(self._registry.list_enabled_servers()),
|
||||
)
|
||||
|
||||
async def _connect_all_servers(self) -> None:
|
||||
"""Connect to all enabled MCP servers."""
|
||||
enabled_servers = self._registry.get_enabled_configs()
|
||||
|
||||
for name, config in enabled_servers.items():
|
||||
try:
|
||||
await self._pool.get_connection(name, config)
|
||||
logger.info("Connected to MCP server: %s", name)
|
||||
except Exception as e:
|
||||
logger.error("Failed to connect to MCP server %s: %s", name, e)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""
|
||||
Shutdown the MCP client manager.
|
||||
|
||||
Closes all connections and cleans up resources.
|
||||
"""
|
||||
async with self._lock:
|
||||
if not self._initialized:
|
||||
return
|
||||
|
||||
logger.info("Shutting down MCP Client Manager")
|
||||
|
||||
await self._pool.close_all()
|
||||
self._initialized = False
|
||||
|
||||
logger.info("MCP Client Manager shutdown complete")
|
||||
|
||||
async def connect(self, server_name: str) -> None:
|
||||
"""
|
||||
Connect to a specific MCP server.
|
||||
|
||||
Args:
|
||||
server_name: Name of the server to connect to
|
||||
|
||||
Raises:
|
||||
MCPServerNotFoundError: If server is not registered
|
||||
"""
|
||||
config = self._registry.get(server_name)
|
||||
await self._pool.get_connection(server_name, config)
|
||||
logger.info("Connected to MCP server: %s", server_name)
|
||||
|
||||
async def disconnect(self, server_name: str) -> None:
|
||||
"""
|
||||
Disconnect from a specific MCP server.
|
||||
|
||||
Args:
|
||||
server_name: Name of the server to disconnect from
|
||||
"""
|
||||
await self._pool.close_connection(server_name)
|
||||
logger.info("Disconnected from MCP server: %s", server_name)
|
||||
|
||||
async def disconnect_all(self) -> None:
|
||||
"""Disconnect from all MCP servers."""
|
||||
await self._pool.close_all()
|
||||
|
||||
async def call_tool(
|
||||
self,
|
||||
server: str,
|
||||
tool: str,
|
||||
args: dict[str, Any] | None = None,
|
||||
timeout: float | None = None,
|
||||
) -> ToolResult:
|
||||
"""
|
||||
Call a tool on a specific MCP server.
|
||||
|
||||
Args:
|
||||
server: Name of the MCP server
|
||||
tool: Name of the tool to call
|
||||
args: Tool arguments
|
||||
timeout: Optional timeout override
|
||||
|
||||
Returns:
|
||||
Tool execution result
|
||||
"""
|
||||
if not self._initialized or self._router is None:
|
||||
await self.initialize()
|
||||
|
||||
assert self._router is not None # Guaranteed after initialize()
|
||||
return await self._router.call_tool(
|
||||
server_name=server,
|
||||
tool_name=tool,
|
||||
arguments=args,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
async def route_tool(
|
||||
self,
|
||||
tool: str,
|
||||
args: dict[str, Any] | None = None,
|
||||
timeout: float | None = None,
|
||||
) -> ToolResult:
|
||||
"""
|
||||
Route a tool call to the appropriate server automatically.
|
||||
|
||||
Args:
|
||||
tool: Name of the tool to call
|
||||
args: Tool arguments
|
||||
timeout: Optional timeout override
|
||||
|
||||
Returns:
|
||||
Tool execution result
|
||||
"""
|
||||
if not self._initialized or self._router is None:
|
||||
await self.initialize()
|
||||
|
||||
assert self._router is not None # Guaranteed after initialize()
|
||||
return await self._router.route_tool(
|
||||
tool_name=tool,
|
||||
arguments=args,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
async def list_tools(self, server: str) -> list[ToolInfo]:
|
||||
"""
|
||||
List all tools available on a specific server.
|
||||
|
||||
Args:
|
||||
server: Name of the MCP server
|
||||
|
||||
Returns:
|
||||
List of tool information
|
||||
"""
|
||||
capabilities = await self._registry.get_capabilities(server)
|
||||
return [
|
||||
ToolInfo(
|
||||
name=t.get("name", ""),
|
||||
description=t.get("description"),
|
||||
server_name=server,
|
||||
input_schema=t.get("input_schema"),
|
||||
)
|
||||
for t in capabilities.tools
|
||||
]
|
||||
|
||||
async def list_all_tools(self) -> list[ToolInfo]:
|
||||
"""
|
||||
List all tools from all servers.
|
||||
|
||||
Returns:
|
||||
List of tool information
|
||||
"""
|
||||
if not self._initialized or self._router is None:
|
||||
await self.initialize()
|
||||
|
||||
assert self._router is not None # Guaranteed after initialize()
|
||||
return await self._router.list_all_tools()
|
||||
|
||||
async def health_check(self) -> dict[str, ServerHealth]:
|
||||
"""
|
||||
Perform health check on all MCP servers.
|
||||
|
||||
Returns:
|
||||
Dict mapping server names to health status
|
||||
"""
|
||||
results: dict[str, ServerHealth] = {}
|
||||
pool_status = self._pool.get_status()
|
||||
pool_health = await self._pool.health_check_all()
|
||||
|
||||
for server_name in self._registry.list_servers():
|
||||
try:
|
||||
config = self._registry.get(server_name)
|
||||
status = pool_status.get(server_name, {})
|
||||
healthy = pool_health.get(server_name, False)
|
||||
|
||||
capabilities = self._registry.get_cached_capabilities(server_name)
|
||||
|
||||
results[server_name] = ServerHealth(
|
||||
name=server_name,
|
||||
healthy=healthy,
|
||||
state=status.get("state", ConnectionState.DISCONNECTED.value),
|
||||
url=config.url,
|
||||
tools_count=len(capabilities.tools),
|
||||
)
|
||||
except MCPServerNotFoundError:
|
||||
pass
|
||||
except Exception as e:
|
||||
results[server_name] = ServerHealth(
|
||||
name=server_name,
|
||||
healthy=False,
|
||||
state=ConnectionState.ERROR.value,
|
||||
url="unknown",
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
def list_servers(self) -> list[str]:
|
||||
"""Get list of all registered server names."""
|
||||
return self._registry.list_servers()
|
||||
|
||||
def list_enabled_servers(self) -> list[str]:
|
||||
"""Get list of enabled server names."""
|
||||
return self._registry.list_enabled_servers()
|
||||
|
||||
def get_server_config(self, server_name: str) -> MCPServerConfig:
|
||||
"""
|
||||
Get configuration for a specific server.
|
||||
|
||||
Args:
|
||||
server_name: Name of the server
|
||||
|
||||
Returns:
|
||||
Server configuration
|
||||
|
||||
Raises:
|
||||
MCPServerNotFoundError: If server is not registered
|
||||
"""
|
||||
return self._registry.get(server_name)
|
||||
|
||||
def register_server(
|
||||
self,
|
||||
name: str,
|
||||
config: MCPServerConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Register a new MCP server at runtime.
|
||||
|
||||
Args:
|
||||
name: Unique server name
|
||||
config: Server configuration
|
||||
"""
|
||||
self._registry.register(name, config)
|
||||
|
||||
def unregister_server(self, name: str) -> bool:
|
||||
"""
|
||||
Unregister an MCP server.
|
||||
|
||||
Args:
|
||||
name: Server name to unregister
|
||||
|
||||
Returns:
|
||||
True if server was found and removed
|
||||
"""
|
||||
return self._registry.unregister(name)
|
||||
|
||||
def get_circuit_breaker_status(self) -> dict[str, dict[str, Any]]:
|
||||
"""Get status of all circuit breakers."""
|
||||
if self._router is None:
|
||||
return {}
|
||||
return self._router.get_circuit_breaker_status()
|
||||
|
||||
async def reset_circuit_breaker(self, server_name: str) -> bool:
|
||||
"""
|
||||
Reset a circuit breaker for a server.
|
||||
|
||||
Args:
|
||||
server_name: Name of the server
|
||||
|
||||
Returns:
|
||||
True if circuit breaker was reset
|
||||
"""
|
||||
if self._router is None:
|
||||
return False
|
||||
return await self._router.reset_circuit_breaker(server_name)
|
||||
|
||||
|
||||
# Singleton instance
|
||||
_manager_instance: MCPClientManager | None = None
|
||||
_manager_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def get_mcp_client() -> MCPClientManager:
|
||||
"""
|
||||
Get the global MCP client manager instance.
|
||||
|
||||
This is the main dependency injection point for FastAPI.
|
||||
Uses proper locking to avoid race conditions in async contexts.
|
||||
"""
|
||||
global _manager_instance
|
||||
|
||||
# Use lock for the entire check-and-create operation to avoid race conditions
|
||||
async with _manager_lock:
|
||||
if _manager_instance is None:
|
||||
_manager_instance = MCPClientManager()
|
||||
await _manager_instance.initialize()
|
||||
|
||||
return _manager_instance
|
||||
|
||||
|
||||
async def shutdown_mcp_client() -> None:
|
||||
"""Shutdown the global MCP client manager."""
|
||||
global _manager_instance
|
||||
|
||||
# Use lock to prevent race with get_mcp_client()
|
||||
async with _manager_lock:
|
||||
if _manager_instance is not None:
|
||||
await _manager_instance.shutdown()
|
||||
_manager_instance = None
|
||||
|
||||
|
||||
async def reset_mcp_client() -> None:
|
||||
"""
|
||||
Reset the global MCP client manager (for testing).
|
||||
|
||||
This is an async function to properly acquire the manager lock
|
||||
and avoid race conditions with get_mcp_client().
|
||||
"""
|
||||
global _manager_instance
|
||||
|
||||
async with _manager_lock:
|
||||
if _manager_instance is not None:
|
||||
# Shutdown gracefully before resetting
|
||||
try:
|
||||
await _manager_instance.shutdown()
|
||||
except Exception: # noqa: S110
|
||||
pass # Ignore errors during test cleanup
|
||||
_manager_instance = None
|
||||
232
backend/app/services/mcp/config.py
Normal file
232
backend/app/services/mcp/config.py
Normal file
@@ -0,0 +1,232 @@
|
||||
"""
|
||||
MCP Configuration System
|
||||
|
||||
Pydantic models for MCP server configuration with YAML file loading
|
||||
and environment variable overrides.
|
||||
"""
|
||||
|
||||
import os
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class TransportType(str, Enum):
|
||||
"""Supported MCP transport types."""
|
||||
|
||||
HTTP = "http"
|
||||
STDIO = "stdio"
|
||||
SSE = "sse"
|
||||
|
||||
|
||||
class MCPServerConfig(BaseModel):
|
||||
"""Configuration for a single MCP server."""
|
||||
|
||||
url: str = Field(..., description="Server URL (supports ${ENV_VAR} syntax)")
|
||||
transport: TransportType = Field(
|
||||
default=TransportType.HTTP,
|
||||
description="Transport protocol to use",
|
||||
)
|
||||
timeout: int = Field(
|
||||
default=30,
|
||||
ge=1,
|
||||
le=600,
|
||||
description="Request timeout in seconds",
|
||||
)
|
||||
retry_attempts: int = Field(
|
||||
default=3,
|
||||
ge=0,
|
||||
le=10,
|
||||
description="Number of retry attempts on failure",
|
||||
)
|
||||
retry_delay: float = Field(
|
||||
default=1.0,
|
||||
ge=0.1,
|
||||
le=60.0,
|
||||
description="Initial delay between retries in seconds",
|
||||
)
|
||||
retry_max_delay: float = Field(
|
||||
default=30.0,
|
||||
ge=1.0,
|
||||
le=300.0,
|
||||
description="Maximum delay between retries in seconds",
|
||||
)
|
||||
circuit_breaker_threshold: int = Field(
|
||||
default=5,
|
||||
ge=1,
|
||||
le=50,
|
||||
description="Number of failures before opening circuit",
|
||||
)
|
||||
circuit_breaker_timeout: float = Field(
|
||||
default=30.0,
|
||||
ge=5.0,
|
||||
le=300.0,
|
||||
description="Seconds to wait before attempting to close circuit",
|
||||
)
|
||||
enabled: bool = Field(
|
||||
default=True,
|
||||
description="Whether this server is enabled",
|
||||
)
|
||||
description: str | None = Field(
|
||||
default=None,
|
||||
description="Human-readable description of the server",
|
||||
)
|
||||
|
||||
@field_validator("url", mode="before")
|
||||
@classmethod
|
||||
def expand_env_vars(cls, v: str) -> str:
|
||||
"""Expand environment variables in URL using ${VAR:-default} syntax."""
|
||||
if not isinstance(v, str):
|
||||
return v
|
||||
|
||||
result = v
|
||||
# Find all ${VAR} or ${VAR:-default} patterns
|
||||
import re
|
||||
|
||||
pattern = r"\$\{([^}]+)\}"
|
||||
matches = re.findall(pattern, v)
|
||||
|
||||
for match in matches:
|
||||
if ":-" in match:
|
||||
var_name, default = match.split(":-", 1)
|
||||
else:
|
||||
var_name, default = match, ""
|
||||
|
||||
env_value = os.environ.get(var_name.strip(), default)
|
||||
result = result.replace(f"${{{match}}}", env_value)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class MCPConfig(BaseModel):
|
||||
"""Root configuration for all MCP servers."""
|
||||
|
||||
mcp_servers: dict[str, MCPServerConfig] = Field(
|
||||
default_factory=dict,
|
||||
description="Map of server names to their configurations",
|
||||
)
|
||||
|
||||
# Global defaults
|
||||
default_timeout: int = Field(
|
||||
default=30,
|
||||
description="Default timeout for all servers",
|
||||
)
|
||||
default_retry_attempts: int = Field(
|
||||
default=3,
|
||||
description="Default retry attempts for all servers",
|
||||
)
|
||||
connection_pool_size: int = Field(
|
||||
default=10,
|
||||
ge=1,
|
||||
le=100,
|
||||
description="Maximum connections per server",
|
||||
)
|
||||
health_check_interval: int = Field(
|
||||
default=30,
|
||||
ge=5,
|
||||
le=300,
|
||||
description="Seconds between health checks",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_yaml(cls, path: str | Path) -> "MCPConfig":
|
||||
"""Load configuration from a YAML file."""
|
||||
path = Path(path)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"MCP config file not found: {path}")
|
||||
|
||||
with path.open("r") as f:
|
||||
data = yaml.safe_load(f)
|
||||
|
||||
if data is None:
|
||||
data = {}
|
||||
|
||||
return cls.model_validate(data)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "MCPConfig":
|
||||
"""Load configuration from a dictionary."""
|
||||
return cls.model_validate(data)
|
||||
|
||||
def get_server(self, name: str) -> MCPServerConfig | None:
|
||||
"""Get a server configuration by name."""
|
||||
return self.mcp_servers.get(name)
|
||||
|
||||
def get_enabled_servers(self) -> dict[str, MCPServerConfig]:
|
||||
"""Get all enabled server configurations."""
|
||||
return {
|
||||
name: config for name, config in self.mcp_servers.items() if config.enabled
|
||||
}
|
||||
|
||||
def list_server_names(self) -> list[str]:
|
||||
"""Get list of all configured server names."""
|
||||
return list(self.mcp_servers.keys())
|
||||
|
||||
|
||||
# Default configuration path
|
||||
DEFAULT_CONFIG_PATH = Path(__file__).parent.parent.parent.parent / "mcp_servers.yaml"
|
||||
|
||||
|
||||
def load_mcp_config(path: str | Path | None = None) -> MCPConfig:
|
||||
"""
|
||||
Load MCP configuration from file or environment.
|
||||
|
||||
Priority:
|
||||
1. Explicit path parameter
|
||||
2. MCP_CONFIG_PATH environment variable
|
||||
3. Default path (backend/mcp_servers.yaml)
|
||||
4. Empty config if no file exists
|
||||
"""
|
||||
if path is None:
|
||||
path = os.environ.get("MCP_CONFIG_PATH", str(DEFAULT_CONFIG_PATH))
|
||||
|
||||
path = Path(path)
|
||||
|
||||
if not path.exists():
|
||||
# Return empty config if no file exists (allows runtime registration)
|
||||
return MCPConfig()
|
||||
|
||||
return MCPConfig.from_yaml(path)
|
||||
|
||||
|
||||
def create_default_config() -> MCPConfig:
|
||||
"""
|
||||
Create a default MCP configuration with standard servers.
|
||||
|
||||
This is useful for development and as a template.
|
||||
"""
|
||||
return MCPConfig(
|
||||
mcp_servers={
|
||||
"llm-gateway": MCPServerConfig(
|
||||
url="${LLM_GATEWAY_URL:-http://localhost:8001}",
|
||||
transport=TransportType.HTTP,
|
||||
timeout=60,
|
||||
description="LLM Gateway for multi-provider AI interactions",
|
||||
),
|
||||
"knowledge-base": MCPServerConfig(
|
||||
url="${KNOWLEDGE_BASE_URL:-http://localhost:8002}",
|
||||
transport=TransportType.HTTP,
|
||||
timeout=30,
|
||||
description="Knowledge Base for RAG and document retrieval",
|
||||
),
|
||||
"git-ops": MCPServerConfig(
|
||||
url="${GIT_OPS_URL:-http://localhost:8003}",
|
||||
transport=TransportType.HTTP,
|
||||
timeout=120,
|
||||
description="Git Operations for repository management",
|
||||
),
|
||||
"issues": MCPServerConfig(
|
||||
url="${ISSUES_URL:-http://localhost:8004}",
|
||||
transport=TransportType.HTTP,
|
||||
timeout=30,
|
||||
description="Issue Tracker for Gitea/GitHub/GitLab",
|
||||
),
|
||||
},
|
||||
default_timeout=30,
|
||||
default_retry_attempts=3,
|
||||
connection_pool_size=10,
|
||||
health_check_interval=30,
|
||||
)
|
||||
473
backend/app/services/mcp/connection.py
Normal file
473
backend/app/services/mcp/connection.py
Normal file
@@ -0,0 +1,473 @@
|
||||
"""
|
||||
MCP Connection Management
|
||||
|
||||
Handles connection lifecycle, pooling, and automatic reconnection
|
||||
for MCP servers.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from .config import MCPServerConfig, TransportType
|
||||
from .exceptions import MCPConnectionError, MCPTimeoutError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConnectionState(str, Enum):
|
||||
"""Connection state enumeration."""
|
||||
|
||||
DISCONNECTED = "disconnected"
|
||||
CONNECTING = "connecting"
|
||||
CONNECTED = "connected"
|
||||
RECONNECTING = "reconnecting"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class MCPConnection:
|
||||
"""
|
||||
Manages a single connection to an MCP server.
|
||||
|
||||
Handles connection lifecycle, health checking, and automatic reconnection.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_name: str,
|
||||
config: MCPServerConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize connection.
|
||||
|
||||
Args:
|
||||
server_name: Name of the MCP server
|
||||
config: Server configuration
|
||||
"""
|
||||
self.server_name = server_name
|
||||
self.config = config
|
||||
self._state = ConnectionState.DISCONNECTED
|
||||
self._client: httpx.AsyncClient | None = None
|
||||
self._lock = asyncio.Lock()
|
||||
self._last_activity: float | None = None
|
||||
self._connection_attempts = 0
|
||||
self._last_error: Exception | None = None
|
||||
|
||||
# Reconnection settings
|
||||
self._base_delay = config.retry_delay
|
||||
self._max_delay = config.retry_max_delay
|
||||
self._max_attempts = config.retry_attempts
|
||||
|
||||
@property
|
||||
def state(self) -> ConnectionState:
|
||||
"""Get current connection state."""
|
||||
return self._state
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if connection is established."""
|
||||
return self._state == ConnectionState.CONNECTED
|
||||
|
||||
@property
|
||||
def last_error(self) -> Exception | None:
|
||||
"""Get the last error that occurred."""
|
||||
return self._last_error
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""
|
||||
Establish connection to the MCP server.
|
||||
|
||||
Raises:
|
||||
MCPConnectionError: If connection fails after all retries
|
||||
"""
|
||||
async with self._lock:
|
||||
if self._state == ConnectionState.CONNECTED:
|
||||
return
|
||||
|
||||
self._state = ConnectionState.CONNECTING
|
||||
self._connection_attempts = 0
|
||||
self._last_error = None
|
||||
|
||||
while self._connection_attempts < self._max_attempts:
|
||||
try:
|
||||
await self._do_connect()
|
||||
self._state = ConnectionState.CONNECTED
|
||||
self._last_activity = time.time()
|
||||
logger.info(
|
||||
"Connected to MCP server: %s at %s",
|
||||
self.server_name,
|
||||
self.config.url,
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
self._connection_attempts += 1
|
||||
self._last_error = e
|
||||
logger.warning(
|
||||
"Connection attempt %d/%d failed for %s: %s",
|
||||
self._connection_attempts,
|
||||
self._max_attempts,
|
||||
self.server_name,
|
||||
e,
|
||||
)
|
||||
|
||||
if self._connection_attempts < self._max_attempts:
|
||||
delay = self._calculate_backoff_delay()
|
||||
logger.debug(
|
||||
"Retrying connection to %s in %.1fs",
|
||||
self.server_name,
|
||||
delay,
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
# All attempts failed
|
||||
self._state = ConnectionState.ERROR
|
||||
raise MCPConnectionError(
|
||||
f"Failed to connect after {self._max_attempts} attempts",
|
||||
server_name=self.server_name,
|
||||
url=self.config.url,
|
||||
cause=self._last_error,
|
||||
)
|
||||
|
||||
async def _do_connect(self) -> None:
|
||||
"""Perform the actual connection (transport-specific)."""
|
||||
if self.config.transport == TransportType.HTTP:
|
||||
self._client = httpx.AsyncClient(
|
||||
base_url=self.config.url,
|
||||
timeout=httpx.Timeout(self.config.timeout),
|
||||
headers={
|
||||
"User-Agent": "Syndarix-MCP-Client/1.0",
|
||||
"Accept": "application/json",
|
||||
},
|
||||
)
|
||||
# Verify connectivity with a simple request
|
||||
try:
|
||||
# Try to hit the MCP capabilities endpoint
|
||||
response = await self._client.get("/mcp/capabilities")
|
||||
if response.status_code not in (200, 404):
|
||||
# 404 is acceptable - server might not have capabilities endpoint
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code != 404:
|
||||
raise
|
||||
except httpx.ConnectError as e:
|
||||
raise MCPConnectionError(
|
||||
"Failed to connect to server",
|
||||
server_name=self.server_name,
|
||||
url=self.config.url,
|
||||
cause=e,
|
||||
) from e
|
||||
else:
|
||||
# For STDIO and SSE transports, we'll implement later
|
||||
raise NotImplementedError(
|
||||
f"Transport {self.config.transport} not yet implemented"
|
||||
)
|
||||
|
||||
def _calculate_backoff_delay(self) -> float:
|
||||
"""Calculate exponential backoff delay with jitter."""
|
||||
import random
|
||||
|
||||
delay = self._base_delay * (2 ** (self._connection_attempts - 1))
|
||||
delay = min(delay, self._max_delay)
|
||||
# Add jitter (±25%)
|
||||
jitter = delay * 0.25 * (random.random() * 2 - 1)
|
||||
return delay + jitter
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Disconnect from the MCP server."""
|
||||
async with self._lock:
|
||||
if self._client is not None:
|
||||
try:
|
||||
await self._client.aclose()
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Error closing connection to %s: %s",
|
||||
self.server_name,
|
||||
e,
|
||||
)
|
||||
finally:
|
||||
self._client = None
|
||||
|
||||
self._state = ConnectionState.DISCONNECTED
|
||||
logger.info("Disconnected from MCP server: %s", self.server_name)
|
||||
|
||||
async def reconnect(self) -> None:
|
||||
"""Reconnect to the MCP server."""
|
||||
async with self._lock:
|
||||
self._state = ConnectionState.RECONNECTING
|
||||
await self.disconnect()
|
||||
await self.connect()
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""
|
||||
Perform a health check on the connection.
|
||||
|
||||
Returns:
|
||||
True if connection is healthy
|
||||
"""
|
||||
if not self.is_connected or self._client is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
if self.config.transport == TransportType.HTTP:
|
||||
response = await self._client.get(
|
||||
"/health",
|
||||
timeout=5.0,
|
||||
)
|
||||
return response.status_code == 200
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Health check failed for %s: %s",
|
||||
self.server_name,
|
||||
e,
|
||||
)
|
||||
return False
|
||||
|
||||
async def execute_request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
data: dict[str, Any] | None = None,
|
||||
timeout: float | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Execute an HTTP request to the MCP server.
|
||||
|
||||
Args:
|
||||
method: HTTP method (GET, POST, etc.)
|
||||
path: Request path
|
||||
data: Optional request body
|
||||
timeout: Optional timeout override
|
||||
|
||||
Returns:
|
||||
Response data
|
||||
|
||||
Raises:
|
||||
MCPConnectionError: If not connected
|
||||
MCPTimeoutError: If request times out
|
||||
"""
|
||||
if not self.is_connected or self._client is None:
|
||||
raise MCPConnectionError(
|
||||
"Not connected to server",
|
||||
server_name=self.server_name,
|
||||
)
|
||||
|
||||
effective_timeout = timeout or self.config.timeout
|
||||
|
||||
try:
|
||||
if method.upper() == "GET":
|
||||
response = await self._client.get(
|
||||
path,
|
||||
timeout=effective_timeout,
|
||||
)
|
||||
elif method.upper() == "POST":
|
||||
response = await self._client.post(
|
||||
path,
|
||||
json=data,
|
||||
timeout=effective_timeout,
|
||||
)
|
||||
else:
|
||||
response = await self._client.request(
|
||||
method.upper(),
|
||||
path,
|
||||
json=data,
|
||||
timeout=effective_timeout,
|
||||
)
|
||||
|
||||
self._last_activity = time.time()
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
except httpx.TimeoutException as e:
|
||||
raise MCPTimeoutError(
|
||||
"Request timed out",
|
||||
server_name=self.server_name,
|
||||
timeout_seconds=effective_timeout,
|
||||
operation=f"{method} {path}",
|
||||
) from e
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise MCPConnectionError(
|
||||
f"HTTP error: {e.response.status_code}",
|
||||
server_name=self.server_name,
|
||||
url=f"{self.config.url}{path}",
|
||||
cause=e,
|
||||
) from e
|
||||
except Exception as e:
|
||||
raise MCPConnectionError(
|
||||
f"Request failed: {e}",
|
||||
server_name=self.server_name,
|
||||
cause=e,
|
||||
) from e
|
||||
|
||||
|
||||
class ConnectionPool:
|
||||
"""
|
||||
Pool of connections to MCP servers.
|
||||
|
||||
Manages connection lifecycle and provides connection reuse.
|
||||
"""
|
||||
|
||||
def __init__(self, max_connections_per_server: int = 10) -> None:
|
||||
"""
|
||||
Initialize connection pool.
|
||||
|
||||
Args:
|
||||
max_connections_per_server: Maximum connections per server
|
||||
"""
|
||||
self._connections: dict[str, MCPConnection] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
self._per_server_locks: dict[str, asyncio.Lock] = {}
|
||||
self._max_per_server = max_connections_per_server
|
||||
|
||||
def _get_server_lock(self, server_name: str) -> asyncio.Lock:
|
||||
"""Get or create a lock for a specific server.
|
||||
|
||||
Uses setdefault for atomic dict access to prevent race conditions
|
||||
where two coroutines could create different locks for the same server.
|
||||
"""
|
||||
# setdefault is atomic - if key exists, returns existing value
|
||||
# if key doesn't exist, inserts new value and returns it
|
||||
return self._per_server_locks.setdefault(server_name, asyncio.Lock())
|
||||
|
||||
async def get_connection(
|
||||
self,
|
||||
server_name: str,
|
||||
config: MCPServerConfig,
|
||||
) -> MCPConnection:
|
||||
"""
|
||||
Get or create a connection to a server.
|
||||
|
||||
Uses per-server locking to avoid blocking all connections
|
||||
when establishing a new connection.
|
||||
|
||||
Args:
|
||||
server_name: Name of the server
|
||||
config: Server configuration
|
||||
|
||||
Returns:
|
||||
Active connection
|
||||
"""
|
||||
# Quick check without lock - if connection exists and is connected, return it
|
||||
if server_name in self._connections:
|
||||
connection = self._connections[server_name]
|
||||
if connection.is_connected:
|
||||
return connection
|
||||
|
||||
# Need to create or reconnect - use per-server lock to avoid blocking others
|
||||
async with self._lock:
|
||||
server_lock = self._get_server_lock(server_name)
|
||||
|
||||
async with server_lock:
|
||||
# Double-check after acquiring per-server lock
|
||||
if server_name in self._connections:
|
||||
connection = self._connections[server_name]
|
||||
if connection.is_connected:
|
||||
return connection
|
||||
# Connection exists but not connected - reconnect
|
||||
await connection.connect()
|
||||
return connection
|
||||
|
||||
# Create new connection (outside global lock, under per-server lock)
|
||||
connection = MCPConnection(server_name, config)
|
||||
await connection.connect()
|
||||
|
||||
# Store connection under global lock
|
||||
async with self._lock:
|
||||
self._connections[server_name] = connection
|
||||
|
||||
return connection
|
||||
|
||||
async def release_connection(self, server_name: str) -> None:
|
||||
"""
|
||||
Release a connection (currently just tracks usage).
|
||||
|
||||
Args:
|
||||
server_name: Name of the server
|
||||
"""
|
||||
# For now, we keep connections alive
|
||||
# Future: implement connection reaping for idle connections
|
||||
|
||||
async def close_connection(self, server_name: str) -> None:
|
||||
"""
|
||||
Close and remove a connection.
|
||||
|
||||
Args:
|
||||
server_name: Name of the server
|
||||
"""
|
||||
async with self._lock:
|
||||
if server_name in self._connections:
|
||||
await self._connections[server_name].disconnect()
|
||||
del self._connections[server_name]
|
||||
# Clean up per-server lock
|
||||
if server_name in self._per_server_locks:
|
||||
del self._per_server_locks[server_name]
|
||||
|
||||
async def close_all(self) -> None:
|
||||
"""Close all connections in the pool."""
|
||||
async with self._lock:
|
||||
for connection in self._connections.values():
|
||||
try:
|
||||
await connection.disconnect()
|
||||
except Exception as e:
|
||||
logger.warning("Error closing connection: %s", e)
|
||||
|
||||
self._connections.clear()
|
||||
self._per_server_locks.clear()
|
||||
logger.info("Closed all MCP connections")
|
||||
|
||||
async def health_check_all(self) -> dict[str, bool]:
|
||||
"""
|
||||
Perform health check on all connections.
|
||||
|
||||
Returns:
|
||||
Dict mapping server names to health status
|
||||
"""
|
||||
# Copy connections under lock to prevent modification during iteration
|
||||
async with self._lock:
|
||||
connections_snapshot = dict(self._connections)
|
||||
|
||||
results = {}
|
||||
for name, connection in connections_snapshot.items():
|
||||
results[name] = await connection.health_check()
|
||||
return results
|
||||
|
||||
def get_status(self) -> dict[str, dict[str, Any]]:
|
||||
"""
|
||||
Get status of all connections.
|
||||
|
||||
Returns:
|
||||
Dict mapping server names to status info
|
||||
"""
|
||||
return {
|
||||
name: {
|
||||
"state": conn.state.value,
|
||||
"is_connected": conn.is_connected,
|
||||
"url": conn.config.url,
|
||||
}
|
||||
for name, conn in self._connections.items()
|
||||
}
|
||||
|
||||
@asynccontextmanager
|
||||
async def connection(
|
||||
self,
|
||||
server_name: str,
|
||||
config: MCPServerConfig,
|
||||
) -> AsyncGenerator[MCPConnection, None]:
|
||||
"""
|
||||
Context manager for getting a connection.
|
||||
|
||||
Usage:
|
||||
async with pool.connection("server", config) as conn:
|
||||
result = await conn.execute_request("POST", "/tool", data)
|
||||
"""
|
||||
conn = await self.get_connection(server_name, config)
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
await self.release_connection(server_name)
|
||||
201
backend/app/services/mcp/exceptions.py
Normal file
201
backend/app/services/mcp/exceptions.py
Normal file
@@ -0,0 +1,201 @@
|
||||
"""
|
||||
MCP Exception Classes
|
||||
|
||||
Custom exceptions for MCP client operations with detailed error context.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
class MCPError(Exception):
|
||||
"""Base exception for all MCP-related errors."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
*,
|
||||
server_name: str | None = None,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.server_name = server_name
|
||||
self.details = details or {}
|
||||
|
||||
def __str__(self) -> str:
|
||||
parts = [self.message]
|
||||
if self.server_name:
|
||||
parts.append(f"server={self.server_name}")
|
||||
if self.details:
|
||||
parts.append(f"details={self.details}")
|
||||
return " | ".join(parts)
|
||||
|
||||
|
||||
class MCPConnectionError(MCPError):
|
||||
"""Raised when connection to an MCP server fails."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
*,
|
||||
server_name: str | None = None,
|
||||
url: str | None = None,
|
||||
cause: Exception | None = None,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
super().__init__(message, server_name=server_name, details=details)
|
||||
self.url = url
|
||||
self.cause = cause
|
||||
|
||||
def __str__(self) -> str:
|
||||
base = super().__str__()
|
||||
if self.url:
|
||||
base = f"{base} | url={self.url}"
|
||||
if self.cause:
|
||||
base = f"{base} | cause={type(self.cause).__name__}: {self.cause}"
|
||||
return base
|
||||
|
||||
|
||||
class MCPTimeoutError(MCPError):
|
||||
"""Raised when an MCP operation times out."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
*,
|
||||
server_name: str | None = None,
|
||||
timeout_seconds: float | None = None,
|
||||
operation: str | None = None,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
super().__init__(message, server_name=server_name, details=details)
|
||||
self.timeout_seconds = timeout_seconds
|
||||
self.operation = operation
|
||||
|
||||
def __str__(self) -> str:
|
||||
base = super().__str__()
|
||||
if self.timeout_seconds is not None:
|
||||
base = f"{base} | timeout={self.timeout_seconds}s"
|
||||
if self.operation:
|
||||
base = f"{base} | operation={self.operation}"
|
||||
return base
|
||||
|
||||
|
||||
class MCPToolError(MCPError):
|
||||
"""Raised when a tool execution fails."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
*,
|
||||
server_name: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
tool_args: dict[str, Any] | None = None,
|
||||
error_code: str | None = None,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
super().__init__(message, server_name=server_name, details=details)
|
||||
self.tool_name = tool_name
|
||||
self.tool_args = tool_args
|
||||
self.error_code = error_code
|
||||
|
||||
def __str__(self) -> str:
|
||||
base = super().__str__()
|
||||
if self.tool_name:
|
||||
base = f"{base} | tool={self.tool_name}"
|
||||
if self.error_code:
|
||||
base = f"{base} | error_code={self.error_code}"
|
||||
return base
|
||||
|
||||
|
||||
class MCPServerNotFoundError(MCPError):
|
||||
"""Raised when a requested MCP server is not registered."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_name: str,
|
||||
*,
|
||||
available_servers: list[str] | None = None,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
message = f"MCP server not found: {server_name}"
|
||||
super().__init__(message, server_name=server_name, details=details)
|
||||
self.available_servers = available_servers or []
|
||||
|
||||
def __str__(self) -> str:
|
||||
base = super().__str__()
|
||||
if self.available_servers:
|
||||
base = f"{base} | available={self.available_servers}"
|
||||
return base
|
||||
|
||||
|
||||
class MCPToolNotFoundError(MCPError):
|
||||
"""Raised when a requested tool is not found on any server."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tool_name: str,
|
||||
*,
|
||||
server_name: str | None = None,
|
||||
available_tools: list[str] | None = None,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
message = f"Tool not found: {tool_name}"
|
||||
super().__init__(message, server_name=server_name, details=details)
|
||||
self.tool_name = tool_name
|
||||
self.available_tools = available_tools or []
|
||||
|
||||
def __str__(self) -> str:
|
||||
base = super().__str__()
|
||||
if self.available_tools:
|
||||
base = f"{base} | available_tools={self.available_tools[:5]}..."
|
||||
return base
|
||||
|
||||
|
||||
class MCPCircuitOpenError(MCPError):
|
||||
"""Raised when a circuit breaker is open (server temporarily unavailable)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_name: str,
|
||||
*,
|
||||
failure_count: int | None = None,
|
||||
reset_timeout: float | None = None,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
message = f"Circuit breaker open for server: {server_name}"
|
||||
super().__init__(message, server_name=server_name, details=details)
|
||||
self.failure_count = failure_count
|
||||
self.reset_timeout = reset_timeout
|
||||
|
||||
def __str__(self) -> str:
|
||||
base = super().__str__()
|
||||
if self.failure_count is not None:
|
||||
base = f"{base} | failures={self.failure_count}"
|
||||
if self.reset_timeout is not None:
|
||||
base = f"{base} | reset_in={self.reset_timeout}s"
|
||||
return base
|
||||
|
||||
|
||||
class MCPValidationError(MCPError):
|
||||
"""Raised when tool arguments fail validation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
*,
|
||||
tool_name: str | None = None,
|
||||
field_errors: dict[str, str] | None = None,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
super().__init__(message, details=details)
|
||||
self.tool_name = tool_name
|
||||
self.field_errors = field_errors or {}
|
||||
|
||||
def __str__(self) -> str:
|
||||
base = super().__str__()
|
||||
if self.tool_name:
|
||||
base = f"{base} | tool={self.tool_name}"
|
||||
if self.field_errors:
|
||||
base = f"{base} | fields={list(self.field_errors.keys())}"
|
||||
return base
|
||||
305
backend/app/services/mcp/registry.py
Normal file
305
backend/app/services/mcp/registry.py
Normal file
@@ -0,0 +1,305 @@
|
||||
"""
|
||||
MCP Server Registry
|
||||
|
||||
Thread-safe singleton registry for managing MCP server configurations
|
||||
and their capabilities.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from threading import Lock
|
||||
from typing import Any
|
||||
|
||||
from .config import MCPConfig, MCPServerConfig, load_mcp_config
|
||||
from .exceptions import MCPServerNotFoundError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ServerCapabilities:
|
||||
"""Cached capabilities for an MCP server."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
resources: list[dict[str, Any]] | None = None,
|
||||
prompts: list[dict[str, Any]] | None = None,
|
||||
) -> None:
|
||||
self.tools = tools or []
|
||||
self.resources = resources or []
|
||||
self.prompts = prompts or []
|
||||
self._loaded = False
|
||||
self._load_time: float | None = None
|
||||
|
||||
@property
|
||||
def is_loaded(self) -> bool:
|
||||
"""Check if capabilities have been loaded."""
|
||||
return self._loaded
|
||||
|
||||
@property
|
||||
def tool_names(self) -> list[str]:
|
||||
"""Get list of tool names."""
|
||||
return [t.get("name", "") for t in self.tools if t.get("name")]
|
||||
|
||||
def mark_loaded(self) -> None:
|
||||
"""Mark capabilities as loaded."""
|
||||
import time
|
||||
|
||||
self._loaded = True
|
||||
self._load_time = time.time()
|
||||
|
||||
|
||||
class MCPServerRegistry:
|
||||
"""
|
||||
Thread-safe singleton registry for MCP servers.
|
||||
|
||||
Manages server configurations and caches their capabilities.
|
||||
"""
|
||||
|
||||
_instance: "MCPServerRegistry | None" = None
|
||||
_lock = Lock()
|
||||
|
||||
def __new__(cls) -> "MCPServerRegistry":
|
||||
"""Ensure singleton pattern."""
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize registry (only runs once due to singleton)."""
|
||||
if getattr(self, "_initialized", False):
|
||||
return
|
||||
|
||||
self._config: MCPConfig = MCPConfig()
|
||||
self._capabilities: dict[str, ServerCapabilities] = {}
|
||||
self._capabilities_lock = asyncio.Lock()
|
||||
self._initialized = True
|
||||
|
||||
logger.info("MCP Server Registry initialized")
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> "MCPServerRegistry":
|
||||
"""Get the singleton registry instance."""
|
||||
return cls()
|
||||
|
||||
@classmethod
|
||||
def reset_instance(cls) -> None:
|
||||
"""Reset the singleton (for testing)."""
|
||||
with cls._lock:
|
||||
cls._instance = None
|
||||
|
||||
def load_config(self, config: MCPConfig | None = None) -> None:
|
||||
"""
|
||||
Load configuration into the registry.
|
||||
|
||||
Args:
|
||||
config: Optional config to load. If None, loads from default path.
|
||||
"""
|
||||
if config is None:
|
||||
config = load_mcp_config()
|
||||
|
||||
self._config = config
|
||||
self._capabilities.clear()
|
||||
|
||||
logger.info(
|
||||
"Loaded MCP configuration with %d servers",
|
||||
len(config.mcp_servers),
|
||||
)
|
||||
for name in config.list_server_names():
|
||||
logger.debug("Registered MCP server: %s", name)
|
||||
|
||||
def register(self, name: str, config: MCPServerConfig) -> None:
|
||||
"""
|
||||
Register a new MCP server.
|
||||
|
||||
Args:
|
||||
name: Unique server name
|
||||
config: Server configuration
|
||||
"""
|
||||
self._config.mcp_servers[name] = config
|
||||
self._capabilities.pop(name, None) # Clear any cached capabilities
|
||||
|
||||
logger.info("Registered MCP server: %s at %s", name, config.url)
|
||||
|
||||
def unregister(self, name: str) -> bool:
|
||||
"""
|
||||
Unregister an MCP server.
|
||||
|
||||
Args:
|
||||
name: Server name to unregister
|
||||
|
||||
Returns:
|
||||
True if server was found and removed
|
||||
"""
|
||||
if name in self._config.mcp_servers:
|
||||
del self._config.mcp_servers[name]
|
||||
self._capabilities.pop(name, None)
|
||||
logger.info("Unregistered MCP server: %s", name)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get(self, name: str) -> MCPServerConfig:
|
||||
"""
|
||||
Get a server configuration by name.
|
||||
|
||||
Args:
|
||||
name: Server name
|
||||
|
||||
Returns:
|
||||
Server configuration
|
||||
|
||||
Raises:
|
||||
MCPServerNotFoundError: If server is not registered
|
||||
"""
|
||||
config = self._config.get_server(name)
|
||||
if config is None:
|
||||
raise MCPServerNotFoundError(
|
||||
server_name=name,
|
||||
available_servers=self.list_servers(),
|
||||
)
|
||||
return config
|
||||
|
||||
def get_or_none(self, name: str) -> MCPServerConfig | None:
|
||||
"""
|
||||
Get a server configuration by name, or None if not found.
|
||||
|
||||
Args:
|
||||
name: Server name
|
||||
|
||||
Returns:
|
||||
Server configuration or None
|
||||
"""
|
||||
return self._config.get_server(name)
|
||||
|
||||
def list_servers(self) -> list[str]:
|
||||
"""Get list of all registered server names."""
|
||||
return self._config.list_server_names()
|
||||
|
||||
def list_enabled_servers(self) -> list[str]:
|
||||
"""Get list of enabled server names."""
|
||||
return list(self._config.get_enabled_servers().keys())
|
||||
|
||||
def get_all_configs(self) -> dict[str, MCPServerConfig]:
|
||||
"""Get all server configurations."""
|
||||
return dict(self._config.mcp_servers)
|
||||
|
||||
def get_enabled_configs(self) -> dict[str, MCPServerConfig]:
|
||||
"""Get all enabled server configurations."""
|
||||
return self._config.get_enabled_servers()
|
||||
|
||||
async def get_capabilities(
|
||||
self,
|
||||
name: str,
|
||||
force_refresh: bool = False,
|
||||
) -> ServerCapabilities:
|
||||
"""
|
||||
Get capabilities for a server (lazy-loaded and cached).
|
||||
|
||||
Args:
|
||||
name: Server name
|
||||
force_refresh: If True, refresh cached capabilities
|
||||
|
||||
Returns:
|
||||
Server capabilities
|
||||
|
||||
Raises:
|
||||
MCPServerNotFoundError: If server is not registered
|
||||
"""
|
||||
# Verify server exists
|
||||
self.get(name)
|
||||
|
||||
async with self._capabilities_lock:
|
||||
if name not in self._capabilities or force_refresh:
|
||||
# Will be populated by connection manager when connecting
|
||||
self._capabilities[name] = ServerCapabilities()
|
||||
|
||||
return self._capabilities[name]
|
||||
|
||||
def set_capabilities(
|
||||
self,
|
||||
name: str,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
resources: list[dict[str, Any]] | None = None,
|
||||
prompts: list[dict[str, Any]] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Set capabilities for a server (called by connection manager).
|
||||
|
||||
Args:
|
||||
name: Server name
|
||||
tools: List of tool definitions
|
||||
resources: List of resource definitions
|
||||
prompts: List of prompt definitions
|
||||
"""
|
||||
capabilities = ServerCapabilities(
|
||||
tools=tools,
|
||||
resources=resources,
|
||||
prompts=prompts,
|
||||
)
|
||||
capabilities.mark_loaded()
|
||||
self._capabilities[name] = capabilities
|
||||
|
||||
logger.debug(
|
||||
"Updated capabilities for %s: %d tools, %d resources, %d prompts",
|
||||
name,
|
||||
len(capabilities.tools),
|
||||
len(capabilities.resources),
|
||||
len(capabilities.prompts),
|
||||
)
|
||||
|
||||
def get_cached_capabilities(self, name: str) -> ServerCapabilities:
|
||||
"""
|
||||
Get cached capabilities without async loading.
|
||||
|
||||
Use this for synchronous access when you only need
|
||||
cached values (e.g., for health check responses).
|
||||
|
||||
Args:
|
||||
name: Server name
|
||||
|
||||
Returns:
|
||||
Cached capabilities or empty ServerCapabilities
|
||||
"""
|
||||
return self._capabilities.get(name, ServerCapabilities())
|
||||
|
||||
def find_server_for_tool(self, tool_name: str) -> str | None:
|
||||
"""
|
||||
Find which server provides a specific tool.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool to find
|
||||
|
||||
Returns:
|
||||
Server name or None if not found
|
||||
"""
|
||||
for name, caps in self._capabilities.items():
|
||||
if tool_name in caps.tool_names:
|
||||
return name
|
||||
return None
|
||||
|
||||
def get_all_tools(self) -> dict[str, list[dict[str, Any]]]:
|
||||
"""
|
||||
Get all tools from all servers.
|
||||
|
||||
Returns:
|
||||
Dict mapping server name to list of tool definitions
|
||||
"""
|
||||
return {
|
||||
name: caps.tools
|
||||
for name, caps in self._capabilities.items()
|
||||
if caps.is_loaded
|
||||
}
|
||||
|
||||
@property
|
||||
def global_config(self) -> MCPConfig:
|
||||
"""Get the global MCP configuration."""
|
||||
return self._config
|
||||
|
||||
|
||||
# Module-level convenience function
|
||||
def get_registry() -> MCPServerRegistry:
|
||||
"""Get the global MCP server registry instance."""
|
||||
return MCPServerRegistry.get_instance()
|
||||
619
backend/app/services/mcp/routing.py
Normal file
619
backend/app/services/mcp/routing.py
Normal file
@@ -0,0 +1,619 @@
|
||||
"""
|
||||
MCP Tool Call Routing
|
||||
|
||||
Routes tool calls to appropriate servers with retry logic,
|
||||
circuit breakers, and request/response serialization.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from .config import MCPServerConfig
|
||||
from .connection import ConnectionPool, MCPConnection
|
||||
from .exceptions import (
|
||||
MCPCircuitOpenError,
|
||||
MCPError,
|
||||
MCPTimeoutError,
|
||||
MCPToolError,
|
||||
MCPToolNotFoundError,
|
||||
)
|
||||
from .registry import MCPServerRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CircuitState(Enum):
|
||||
"""Circuit breaker states."""
|
||||
|
||||
CLOSED = "closed"
|
||||
OPEN = "open"
|
||||
HALF_OPEN = "half-open"
|
||||
|
||||
|
||||
class AsyncCircuitBreaker:
|
||||
"""
|
||||
Async-compatible circuit breaker implementation.
|
||||
|
||||
Unlike pybreaker which wraps sync functions, this implementation
|
||||
provides explicit success/failure tracking for async code.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fail_max: int = 5,
|
||||
reset_timeout: float = 30.0,
|
||||
name: str = "",
|
||||
) -> None:
|
||||
"""
|
||||
Initialize circuit breaker.
|
||||
|
||||
Args:
|
||||
fail_max: Maximum failures before opening circuit
|
||||
reset_timeout: Seconds to wait before trying again
|
||||
name: Name for logging
|
||||
"""
|
||||
self.fail_max = fail_max
|
||||
self.reset_timeout = reset_timeout
|
||||
self.name = name
|
||||
self._state = CircuitState.CLOSED
|
||||
self._fail_counter = 0
|
||||
self._last_failure_time: float | None = None
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
@property
|
||||
def current_state(self) -> str:
|
||||
"""Get current state as string."""
|
||||
# Check if we should transition from OPEN to HALF_OPEN
|
||||
if self._state == CircuitState.OPEN:
|
||||
if self._should_try_reset():
|
||||
return CircuitState.HALF_OPEN.value
|
||||
return self._state.value
|
||||
|
||||
@property
|
||||
def fail_counter(self) -> int:
|
||||
"""Get current failure count."""
|
||||
return self._fail_counter
|
||||
|
||||
def _should_try_reset(self) -> bool:
|
||||
"""Check if enough time has passed to try resetting."""
|
||||
if self._last_failure_time is None:
|
||||
return True
|
||||
return (time.time() - self._last_failure_time) >= self.reset_timeout
|
||||
|
||||
async def success(self) -> None:
|
||||
"""Record a successful call."""
|
||||
async with self._lock:
|
||||
self._fail_counter = 0
|
||||
self._state = CircuitState.CLOSED
|
||||
self._last_failure_time = None
|
||||
|
||||
async def failure(self) -> None:
|
||||
"""Record a failed call."""
|
||||
async with self._lock:
|
||||
self._fail_counter += 1
|
||||
self._last_failure_time = time.time()
|
||||
|
||||
if self._fail_counter >= self.fail_max:
|
||||
self._state = CircuitState.OPEN
|
||||
logger.warning(
|
||||
"Circuit breaker %s opened after %d failures",
|
||||
self.name,
|
||||
self._fail_counter,
|
||||
)
|
||||
|
||||
def is_open(self) -> bool:
|
||||
"""Check if circuit is open (not allowing calls)."""
|
||||
if self._state == CircuitState.OPEN:
|
||||
return not self._should_try_reset()
|
||||
return False
|
||||
|
||||
async def reset(self) -> None:
|
||||
"""Manually reset the circuit breaker."""
|
||||
async with self._lock:
|
||||
self._state = CircuitState.CLOSED
|
||||
self._fail_counter = 0
|
||||
self._last_failure_time = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolInfo:
|
||||
"""Information about an available tool."""
|
||||
|
||||
name: str
|
||||
description: str | None = None
|
||||
server_name: str | None = None
|
||||
input_schema: dict[str, Any] | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"server_name": self.server_name,
|
||||
"input_schema": self.input_schema,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolResult:
|
||||
"""Result of a tool execution."""
|
||||
|
||||
success: bool
|
||||
data: Any = None
|
||||
error: str | None = None
|
||||
error_code: str | None = None
|
||||
tool_name: str | None = None
|
||||
server_name: str | None = None
|
||||
execution_time_ms: float = 0.0
|
||||
request_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"success": self.success,
|
||||
"data": self.data,
|
||||
"error": self.error,
|
||||
"error_code": self.error_code,
|
||||
"tool_name": self.tool_name,
|
||||
"server_name": self.server_name,
|
||||
"execution_time_ms": self.execution_time_ms,
|
||||
"request_id": self.request_id,
|
||||
}
|
||||
|
||||
|
||||
class ToolRouter:
|
||||
"""
|
||||
Routes tool calls to the appropriate MCP server.
|
||||
|
||||
Features:
|
||||
- Tool name to server mapping
|
||||
- Retry logic with exponential backoff
|
||||
- Circuit breaker pattern for fault tolerance
|
||||
- Request/response serialization
|
||||
- Execution timing and metrics
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
registry: MCPServerRegistry,
|
||||
connection_pool: ConnectionPool,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the tool router.
|
||||
|
||||
Args:
|
||||
registry: MCP server registry
|
||||
connection_pool: Connection pool for servers
|
||||
"""
|
||||
self._registry = registry
|
||||
self._pool = connection_pool
|
||||
self._circuit_breakers: dict[str, AsyncCircuitBreaker] = {}
|
||||
self._tool_to_server: dict[str, str] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
def _get_circuit_breaker(
|
||||
self,
|
||||
server_name: str,
|
||||
config: MCPServerConfig,
|
||||
) -> AsyncCircuitBreaker:
|
||||
"""Get or create a circuit breaker for a server."""
|
||||
if server_name not in self._circuit_breakers:
|
||||
self._circuit_breakers[server_name] = AsyncCircuitBreaker(
|
||||
fail_max=config.circuit_breaker_threshold,
|
||||
reset_timeout=config.circuit_breaker_timeout,
|
||||
name=f"mcp-{server_name}",
|
||||
)
|
||||
return self._circuit_breakers[server_name]
|
||||
|
||||
async def register_tool_mapping(
|
||||
self,
|
||||
tool_name: str,
|
||||
server_name: str,
|
||||
) -> None:
|
||||
"""
|
||||
Register a mapping from tool name to server.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool
|
||||
server_name: Name of the server providing the tool
|
||||
"""
|
||||
async with self._lock:
|
||||
self._tool_to_server[tool_name] = server_name
|
||||
logger.debug("Registered tool %s -> server %s", tool_name, server_name)
|
||||
|
||||
async def discover_tools(self) -> None:
|
||||
"""
|
||||
Discover all tools from registered servers and build mappings.
|
||||
"""
|
||||
for server_name in self._registry.list_enabled_servers():
|
||||
try:
|
||||
config = self._registry.get(server_name)
|
||||
connection = await self._pool.get_connection(server_name, config)
|
||||
|
||||
# Fetch tools from server
|
||||
tools = await self._fetch_tools_from_server(connection)
|
||||
|
||||
# Update registry with capabilities
|
||||
self._registry.set_capabilities(
|
||||
server_name,
|
||||
tools=[t.to_dict() for t in tools],
|
||||
)
|
||||
|
||||
# Update tool mappings
|
||||
for tool in tools:
|
||||
await self.register_tool_mapping(tool.name, server_name)
|
||||
|
||||
logger.info(
|
||||
"Discovered %d tools from server %s",
|
||||
len(tools),
|
||||
server_name,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to discover tools from %s: %s",
|
||||
server_name,
|
||||
e,
|
||||
)
|
||||
|
||||
async def _fetch_tools_from_server(
|
||||
self,
|
||||
connection: MCPConnection,
|
||||
) -> list[ToolInfo]:
|
||||
"""Fetch available tools from an MCP server."""
|
||||
try:
|
||||
response = await connection.execute_request(
|
||||
"GET",
|
||||
"/mcp/tools",
|
||||
)
|
||||
|
||||
tools = []
|
||||
for tool_data in response.get("tools", []):
|
||||
tools.append(
|
||||
ToolInfo(
|
||||
name=tool_data.get("name", ""),
|
||||
description=tool_data.get("description"),
|
||||
server_name=connection.server_name,
|
||||
input_schema=tool_data.get("inputSchema"),
|
||||
)
|
||||
)
|
||||
return tools
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Error fetching tools from %s: %s",
|
||||
connection.server_name,
|
||||
e,
|
||||
)
|
||||
return []
|
||||
|
||||
def find_server_for_tool(self, tool_name: str) -> str | None:
|
||||
"""
|
||||
Find which server provides a specific tool.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool
|
||||
|
||||
Returns:
|
||||
Server name or None if not found
|
||||
"""
|
||||
return self._tool_to_server.get(tool_name)
|
||||
|
||||
async def call_tool(
|
||||
self,
|
||||
server_name: str,
|
||||
tool_name: str,
|
||||
arguments: dict[str, Any] | None = None,
|
||||
timeout: float | None = None,
|
||||
) -> ToolResult:
|
||||
"""
|
||||
Call a tool on a specific server.
|
||||
|
||||
Args:
|
||||
server_name: Name of the MCP server
|
||||
tool_name: Name of the tool to call
|
||||
arguments: Tool arguments
|
||||
timeout: Optional timeout override
|
||||
|
||||
Returns:
|
||||
Tool execution result
|
||||
"""
|
||||
start_time = time.time()
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
logger.debug(
|
||||
"Tool call [%s]: %s.%s with args %s",
|
||||
request_id,
|
||||
server_name,
|
||||
tool_name,
|
||||
arguments,
|
||||
)
|
||||
|
||||
try:
|
||||
config = self._registry.get(server_name)
|
||||
circuit_breaker = self._get_circuit_breaker(server_name, config)
|
||||
|
||||
# Check circuit breaker state
|
||||
if circuit_breaker.is_open():
|
||||
raise MCPCircuitOpenError(
|
||||
server_name=server_name,
|
||||
failure_count=circuit_breaker.fail_counter,
|
||||
reset_timeout=config.circuit_breaker_timeout,
|
||||
)
|
||||
|
||||
# Execute with retry logic
|
||||
result = await self._execute_with_retry(
|
||||
server_name=server_name,
|
||||
config=config,
|
||||
tool_name=tool_name,
|
||||
arguments=arguments or {},
|
||||
timeout=timeout,
|
||||
circuit_breaker=circuit_breaker,
|
||||
)
|
||||
|
||||
execution_time = (time.time() - start_time) * 1000
|
||||
|
||||
return ToolResult(
|
||||
success=True,
|
||||
data=result,
|
||||
tool_name=tool_name,
|
||||
server_name=server_name,
|
||||
execution_time_ms=execution_time,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
except MCPCircuitOpenError:
|
||||
raise
|
||||
except MCPError as e:
|
||||
execution_time = (time.time() - start_time) * 1000
|
||||
logger.error(
|
||||
"Tool call failed [%s]: %s.%s - %s",
|
||||
request_id,
|
||||
server_name,
|
||||
tool_name,
|
||||
e,
|
||||
)
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
error_code=type(e).__name__,
|
||||
tool_name=tool_name,
|
||||
server_name=server_name,
|
||||
execution_time_ms=execution_time,
|
||||
request_id=request_id,
|
||||
)
|
||||
except Exception as e:
|
||||
execution_time = (time.time() - start_time) * 1000
|
||||
logger.exception(
|
||||
"Unexpected error in tool call [%s]: %s.%s",
|
||||
request_id,
|
||||
server_name,
|
||||
tool_name,
|
||||
)
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
error_code="UnexpectedError",
|
||||
tool_name=tool_name,
|
||||
server_name=server_name,
|
||||
execution_time_ms=execution_time,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
async def _execute_with_retry(
|
||||
self,
|
||||
server_name: str,
|
||||
config: MCPServerConfig,
|
||||
tool_name: str,
|
||||
arguments: dict[str, Any],
|
||||
timeout: float | None,
|
||||
circuit_breaker: AsyncCircuitBreaker,
|
||||
) -> Any:
|
||||
"""Execute tool call with retry logic."""
|
||||
last_error: Exception | None = None
|
||||
attempts = 0
|
||||
max_attempts = config.retry_attempts + 1 # +1 for initial attempt
|
||||
|
||||
while attempts < max_attempts:
|
||||
attempts += 1
|
||||
|
||||
try:
|
||||
# Use circuit breaker to track failures
|
||||
result = await self._execute_tool_call(
|
||||
server_name=server_name,
|
||||
config=config,
|
||||
tool_name=tool_name,
|
||||
arguments=arguments,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
# Success - record it
|
||||
await circuit_breaker.success()
|
||||
return result
|
||||
|
||||
except MCPCircuitOpenError:
|
||||
raise
|
||||
except MCPTimeoutError:
|
||||
# Timeout - don't retry
|
||||
await circuit_breaker.failure()
|
||||
raise
|
||||
except MCPToolError:
|
||||
# Tool-level error - don't retry (user error)
|
||||
raise
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
await circuit_breaker.failure()
|
||||
|
||||
if attempts < max_attempts:
|
||||
delay = self._calculate_retry_delay(attempts, config)
|
||||
logger.warning(
|
||||
"Tool call attempt %d/%d failed for %s.%s: %s. "
|
||||
"Retrying in %.1fs",
|
||||
attempts,
|
||||
max_attempts,
|
||||
server_name,
|
||||
tool_name,
|
||||
e,
|
||||
delay,
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
# All attempts failed
|
||||
raise MCPToolError(
|
||||
f"Tool call failed after {max_attempts} attempts",
|
||||
server_name=server_name,
|
||||
tool_name=tool_name,
|
||||
tool_args=arguments,
|
||||
details={"last_error": str(last_error)},
|
||||
)
|
||||
|
||||
def _calculate_retry_delay(
|
||||
self,
|
||||
attempt: int,
|
||||
config: MCPServerConfig,
|
||||
) -> float:
|
||||
"""Calculate exponential backoff delay with jitter."""
|
||||
import random
|
||||
|
||||
delay = config.retry_delay * (2 ** (attempt - 1))
|
||||
delay = min(delay, config.retry_max_delay)
|
||||
# Add jitter (±25%)
|
||||
jitter = delay * 0.25 * (random.random() * 2 - 1)
|
||||
return max(0.1, delay + jitter)
|
||||
|
||||
async def _execute_tool_call(
|
||||
self,
|
||||
server_name: str,
|
||||
config: MCPServerConfig,
|
||||
tool_name: str,
|
||||
arguments: dict[str, Any],
|
||||
timeout: float | None,
|
||||
) -> Any:
|
||||
"""Execute a single tool call."""
|
||||
connection = await self._pool.get_connection(server_name, config)
|
||||
|
||||
# Build MCP tool call request
|
||||
request_body = {
|
||||
"jsonrpc": "2.0",
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": tool_name,
|
||||
"arguments": arguments,
|
||||
},
|
||||
"id": str(uuid.uuid4()),
|
||||
}
|
||||
|
||||
response = await connection.execute_request(
|
||||
method="POST",
|
||||
path="/mcp",
|
||||
data=request_body,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
# Handle JSON-RPC response
|
||||
if "error" in response:
|
||||
error = response["error"]
|
||||
raise MCPToolError(
|
||||
error.get("message", "Tool execution failed"),
|
||||
server_name=server_name,
|
||||
tool_name=tool_name,
|
||||
tool_args=arguments,
|
||||
error_code=str(error.get("code", "UNKNOWN")),
|
||||
)
|
||||
|
||||
return response.get("result")
|
||||
|
||||
async def route_tool(
|
||||
self,
|
||||
tool_name: str,
|
||||
arguments: dict[str, Any] | None = None,
|
||||
timeout: float | None = None,
|
||||
) -> ToolResult:
|
||||
"""
|
||||
Route a tool call to the appropriate server.
|
||||
|
||||
Automatically discovers which server provides the tool.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool to call
|
||||
arguments: Tool arguments
|
||||
timeout: Optional timeout override
|
||||
|
||||
Returns:
|
||||
Tool execution result
|
||||
|
||||
Raises:
|
||||
MCPToolNotFoundError: If no server provides the tool
|
||||
"""
|
||||
server_name = self.find_server_for_tool(tool_name)
|
||||
|
||||
if server_name is None:
|
||||
# Try to find from registry
|
||||
server_name = self._registry.find_server_for_tool(tool_name)
|
||||
|
||||
if server_name is None:
|
||||
raise MCPToolNotFoundError(
|
||||
tool_name=tool_name,
|
||||
available_tools=list(self._tool_to_server.keys()),
|
||||
)
|
||||
|
||||
return await self.call_tool(
|
||||
server_name=server_name,
|
||||
tool_name=tool_name,
|
||||
arguments=arguments,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
async def list_all_tools(self) -> list[ToolInfo]:
|
||||
"""
|
||||
Get all available tools from all servers.
|
||||
|
||||
Returns:
|
||||
List of tool information
|
||||
"""
|
||||
tools = []
|
||||
all_server_tools = self._registry.get_all_tools()
|
||||
|
||||
for server_name, server_tools in all_server_tools.items():
|
||||
for tool_data in server_tools:
|
||||
tools.append(
|
||||
ToolInfo(
|
||||
name=tool_data.get("name", ""),
|
||||
description=tool_data.get("description"),
|
||||
server_name=server_name,
|
||||
input_schema=tool_data.get("input_schema"),
|
||||
)
|
||||
)
|
||||
|
||||
return tools
|
||||
|
||||
def get_circuit_breaker_status(self) -> dict[str, dict[str, Any]]:
|
||||
"""Get status of all circuit breakers."""
|
||||
return {
|
||||
name: {
|
||||
"state": cb.current_state,
|
||||
"failure_count": cb.fail_counter,
|
||||
}
|
||||
for name, cb in self._circuit_breakers.items()
|
||||
}
|
||||
|
||||
async def reset_circuit_breaker(self, server_name: str) -> bool:
|
||||
"""
|
||||
Manually reset a circuit breaker.
|
||||
|
||||
Args:
|
||||
server_name: Name of the server
|
||||
|
||||
Returns:
|
||||
True if circuit breaker was reset
|
||||
"""
|
||||
async with self._lock:
|
||||
if server_name in self._circuit_breakers:
|
||||
# Reset by removing (will be recreated on next call)
|
||||
del self._circuit_breakers[server_name]
|
||||
logger.info("Reset circuit breaker for %s", server_name)
|
||||
return True
|
||||
return False
|
||||
@@ -343,7 +343,9 @@ class OAuthService:
|
||||
await oauth_account.update_tokens(
|
||||
db,
|
||||
account=existing_oauth,
|
||||
access_token_encrypted=token.get("access_token"), refresh_token_encrypted=token.get("refresh_token"), token_expires_at=datetime.now(UTC)
|
||||
access_token_encrypted=token.get("access_token"),
|
||||
refresh_token_encrypted=token.get("refresh_token"),
|
||||
token_expires_at=datetime.now(UTC)
|
||||
+ timedelta(seconds=token.get("expires_in", 3600)),
|
||||
)
|
||||
|
||||
@@ -375,7 +377,9 @@ class OAuthService:
|
||||
provider=provider,
|
||||
provider_user_id=provider_user_id,
|
||||
provider_email=provider_email,
|
||||
access_token_encrypted=token.get("access_token"), refresh_token_encrypted=token.get("refresh_token"), token_expires_at=datetime.now(UTC)
|
||||
access_token_encrypted=token.get("access_token"),
|
||||
refresh_token_encrypted=token.get("refresh_token"),
|
||||
token_expires_at=datetime.now(UTC)
|
||||
+ timedelta(seconds=token.get("expires_in", 3600))
|
||||
if token.get("expires_in")
|
||||
else None,
|
||||
@@ -644,7 +648,9 @@ class OAuthService:
|
||||
provider=provider,
|
||||
provider_user_id=provider_user_id,
|
||||
provider_email=email,
|
||||
access_token_encrypted=token.get("access_token"), refresh_token_encrypted=token.get("refresh_token"), token_expires_at=datetime.now(UTC)
|
||||
access_token_encrypted=token.get("access_token"),
|
||||
refresh_token_encrypted=token.get("refresh_token"),
|
||||
token_expires_at=datetime.now(UTC)
|
||||
+ timedelta(seconds=token.get("expires_in", 3600))
|
||||
if token.get("expires_in")
|
||||
else None,
|
||||
|
||||
170
backend/app/services/safety/__init__.py
Normal file
170
backend/app/services/safety/__init__.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""
|
||||
Safety and Guardrails Framework
|
||||
|
||||
Comprehensive safety framework for autonomous agent operation.
|
||||
Provides multi-layered protection including:
|
||||
- Pre-execution validation
|
||||
- Cost and budget controls
|
||||
- Rate limiting
|
||||
- Loop detection and prevention
|
||||
- Human-in-the-loop approval
|
||||
- Rollback and checkpointing
|
||||
- Content filtering
|
||||
- Sandboxed execution
|
||||
- Emergency controls
|
||||
- Complete audit trail
|
||||
|
||||
Usage:
|
||||
from app.services.safety import get_safety_guardian, SafetyGuardian
|
||||
|
||||
guardian = await get_safety_guardian()
|
||||
result = await guardian.validate(action_request)
|
||||
|
||||
if result.allowed:
|
||||
# Execute action
|
||||
pass
|
||||
else:
|
||||
# Handle denial
|
||||
print(f"Action denied: {result.reasons}")
|
||||
"""
|
||||
|
||||
# Exceptions
|
||||
# Audit
|
||||
from .audit import (
|
||||
AuditLogger,
|
||||
get_audit_logger,
|
||||
reset_audit_logger,
|
||||
shutdown_audit_logger,
|
||||
)
|
||||
|
||||
# Configuration
|
||||
from .config import (
|
||||
AutonomyConfig,
|
||||
SafetyConfig,
|
||||
get_autonomy_config,
|
||||
get_default_policy,
|
||||
get_policy_for_autonomy_level,
|
||||
get_safety_config,
|
||||
load_policies_from_directory,
|
||||
load_policy_from_file,
|
||||
reset_config_cache,
|
||||
)
|
||||
from .exceptions import (
|
||||
ApprovalDeniedError,
|
||||
ApprovalRequiredError,
|
||||
ApprovalTimeoutError,
|
||||
BudgetExceededError,
|
||||
CheckpointError,
|
||||
ContentFilterError,
|
||||
EmergencyStopError,
|
||||
LoopDetectedError,
|
||||
PermissionDeniedError,
|
||||
PolicyViolationError,
|
||||
RateLimitExceededError,
|
||||
RollbackError,
|
||||
SafetyError,
|
||||
SandboxError,
|
||||
SandboxTimeoutError,
|
||||
ValidationError,
|
||||
)
|
||||
|
||||
# Guardian
|
||||
from .guardian import (
|
||||
SafetyGuardian,
|
||||
get_safety_guardian,
|
||||
reset_safety_guardian,
|
||||
shutdown_safety_guardian,
|
||||
)
|
||||
|
||||
# Models
|
||||
from .models import (
|
||||
ActionMetadata,
|
||||
ActionRequest,
|
||||
ActionResult,
|
||||
ActionType,
|
||||
ApprovalRequest,
|
||||
ApprovalResponse,
|
||||
ApprovalStatus,
|
||||
AuditEvent,
|
||||
AuditEventType,
|
||||
AutonomyLevel,
|
||||
BudgetScope,
|
||||
BudgetStatus,
|
||||
Checkpoint,
|
||||
CheckpointType,
|
||||
GuardianResult,
|
||||
PermissionLevel,
|
||||
RateLimitConfig,
|
||||
RateLimitStatus,
|
||||
ResourceType,
|
||||
RollbackResult,
|
||||
SafetyDecision,
|
||||
SafetyPolicy,
|
||||
ValidationResult,
|
||||
ValidationRule,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ActionMetadata",
|
||||
"ActionRequest",
|
||||
"ActionResult",
|
||||
# Models
|
||||
"ActionType",
|
||||
"ApprovalDeniedError",
|
||||
"ApprovalRequest",
|
||||
"ApprovalRequiredError",
|
||||
"ApprovalResponse",
|
||||
"ApprovalStatus",
|
||||
"ApprovalTimeoutError",
|
||||
"AuditEvent",
|
||||
"AuditEventType",
|
||||
# Audit
|
||||
"AuditLogger",
|
||||
"AutonomyConfig",
|
||||
"AutonomyLevel",
|
||||
"BudgetExceededError",
|
||||
"BudgetScope",
|
||||
"BudgetStatus",
|
||||
"Checkpoint",
|
||||
"CheckpointError",
|
||||
"CheckpointType",
|
||||
"ContentFilterError",
|
||||
"EmergencyStopError",
|
||||
"GuardianResult",
|
||||
"LoopDetectedError",
|
||||
"PermissionDeniedError",
|
||||
"PermissionLevel",
|
||||
"PolicyViolationError",
|
||||
"RateLimitConfig",
|
||||
"RateLimitExceededError",
|
||||
"RateLimitStatus",
|
||||
"ResourceType",
|
||||
"RollbackError",
|
||||
"RollbackResult",
|
||||
# Configuration
|
||||
"SafetyConfig",
|
||||
"SafetyDecision",
|
||||
# Exceptions
|
||||
"SafetyError",
|
||||
# Guardian
|
||||
"SafetyGuardian",
|
||||
"SafetyPolicy",
|
||||
"SandboxError",
|
||||
"SandboxTimeoutError",
|
||||
"ValidationError",
|
||||
"ValidationResult",
|
||||
"ValidationRule",
|
||||
"get_audit_logger",
|
||||
"get_autonomy_config",
|
||||
"get_default_policy",
|
||||
"get_policy_for_autonomy_level",
|
||||
"get_safety_config",
|
||||
"get_safety_guardian",
|
||||
"load_policies_from_directory",
|
||||
"load_policy_from_file",
|
||||
"reset_audit_logger",
|
||||
"reset_config_cache",
|
||||
"reset_safety_guardian",
|
||||
"shutdown_audit_logger",
|
||||
"shutdown_safety_guardian",
|
||||
]
|
||||
19
backend/app/services/safety/audit/__init__.py
Normal file
19
backend/app/services/safety/audit/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
Audit System
|
||||
|
||||
Comprehensive audit logging for all safety-related events.
|
||||
"""
|
||||
|
||||
from .logger import (
|
||||
AuditLogger,
|
||||
get_audit_logger,
|
||||
reset_audit_logger,
|
||||
shutdown_audit_logger,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AuditLogger",
|
||||
"get_audit_logger",
|
||||
"reset_audit_logger",
|
||||
"shutdown_audit_logger",
|
||||
]
|
||||
601
backend/app/services/safety/audit/logger.py
Normal file
601
backend/app/services/safety/audit/logger.py
Normal file
@@ -0,0 +1,601 @@
|
||||
"""
|
||||
Audit Logger
|
||||
|
||||
Comprehensive audit logging for all safety-related events.
|
||||
Provides tamper detection, structured logging, and compliance support.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from collections import deque
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from ..config import get_safety_config
|
||||
from ..models import (
|
||||
ActionRequest,
|
||||
AuditEvent,
|
||||
AuditEventType,
|
||||
SafetyDecision,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Sentinel for distinguishing "no argument passed" from "explicitly passing None"
|
||||
_UNSET = object()
|
||||
|
||||
|
||||
class AuditLogger:
|
||||
"""
|
||||
Audit logger for safety events.
|
||||
|
||||
Features:
|
||||
- Structured event logging
|
||||
- In-memory buffer with async flush
|
||||
- Tamper detection via hash chains
|
||||
- Query/search capability
|
||||
- Retention policy enforcement
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_buffer_size: int = 1000,
|
||||
flush_interval_seconds: float = 10.0,
|
||||
enable_hash_chain: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the audit logger.
|
||||
|
||||
Args:
|
||||
max_buffer_size: Maximum events to buffer before auto-flush
|
||||
flush_interval_seconds: Interval for periodic flush
|
||||
enable_hash_chain: Enable tamper detection via hash chain
|
||||
"""
|
||||
self._buffer: deque[AuditEvent] = deque(maxlen=max_buffer_size)
|
||||
self._persisted: list[AuditEvent] = []
|
||||
self._flush_interval = flush_interval_seconds
|
||||
self._enable_hash_chain = enable_hash_chain
|
||||
self._last_hash: str | None = None
|
||||
self._lock = asyncio.Lock()
|
||||
self._flush_task: asyncio.Task[None] | None = None
|
||||
self._running = False
|
||||
|
||||
# Event handlers for real-time processing
|
||||
self._handlers: list[Any] = []
|
||||
|
||||
config = get_safety_config()
|
||||
self._retention_days = config.audit_retention_days
|
||||
self._include_sensitive = config.audit_include_sensitive
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the audit logger background tasks."""
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._flush_task = asyncio.create_task(self._periodic_flush())
|
||||
logger.info("Audit logger started")
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the audit logger and flush remaining events."""
|
||||
self._running = False
|
||||
|
||||
if self._flush_task:
|
||||
self._flush_task.cancel()
|
||||
try:
|
||||
await self._flush_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Final flush
|
||||
await self.flush()
|
||||
logger.info("Audit logger stopped")
|
||||
|
||||
async def log(
|
||||
self,
|
||||
event_type: AuditEventType,
|
||||
*,
|
||||
agent_id: str | None = None,
|
||||
action_id: str | None = None,
|
||||
project_id: str | None = None,
|
||||
session_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
decision: SafetyDecision | None = None,
|
||||
details: dict[str, Any] | None = None,
|
||||
correlation_id: str | None = None,
|
||||
) -> AuditEvent:
|
||||
"""
|
||||
Log an audit event.
|
||||
|
||||
Args:
|
||||
event_type: Type of audit event
|
||||
agent_id: Agent ID if applicable
|
||||
action_id: Action ID if applicable
|
||||
project_id: Project ID if applicable
|
||||
session_id: Session ID if applicable
|
||||
user_id: User ID if applicable
|
||||
decision: Safety decision if applicable
|
||||
details: Additional event details
|
||||
correlation_id: Correlation ID for tracing
|
||||
|
||||
Returns:
|
||||
The created audit event
|
||||
"""
|
||||
# Sanitize sensitive data if needed
|
||||
sanitized_details = self._sanitize_details(details) if details else {}
|
||||
|
||||
event = AuditEvent(
|
||||
id=str(uuid4()),
|
||||
event_type=event_type,
|
||||
timestamp=datetime.utcnow(),
|
||||
agent_id=agent_id,
|
||||
action_id=action_id,
|
||||
project_id=project_id,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
decision=decision,
|
||||
details=sanitized_details,
|
||||
correlation_id=correlation_id,
|
||||
)
|
||||
|
||||
async with self._lock:
|
||||
# Add hash chain for tamper detection
|
||||
if self._enable_hash_chain:
|
||||
event_hash = self._compute_hash(event)
|
||||
# Modify event.details directly (not sanitized_details)
|
||||
# to ensure the hash is stored on the actual event
|
||||
event.details["_hash"] = event_hash
|
||||
event.details["_prev_hash"] = self._last_hash
|
||||
self._last_hash = event_hash
|
||||
|
||||
self._buffer.append(event)
|
||||
|
||||
# Notify handlers
|
||||
await self._notify_handlers(event)
|
||||
|
||||
# Log to standard logger as well
|
||||
self._log_to_logger(event)
|
||||
|
||||
return event
|
||||
|
||||
async def log_action_request(
|
||||
self,
|
||||
action: ActionRequest,
|
||||
decision: SafetyDecision,
|
||||
reasons: list[str] | None = None,
|
||||
) -> AuditEvent:
|
||||
"""Log an action request with its validation decision."""
|
||||
event_type = (
|
||||
AuditEventType.ACTION_DENIED
|
||||
if decision == SafetyDecision.DENY
|
||||
else AuditEventType.ACTION_VALIDATED
|
||||
)
|
||||
|
||||
return await self.log(
|
||||
event_type,
|
||||
agent_id=action.metadata.agent_id,
|
||||
action_id=action.id,
|
||||
project_id=action.metadata.project_id,
|
||||
session_id=action.metadata.session_id,
|
||||
user_id=action.metadata.user_id,
|
||||
decision=decision,
|
||||
details={
|
||||
"action_type": action.action_type.value,
|
||||
"tool_name": action.tool_name,
|
||||
"resource": action.resource,
|
||||
"is_destructive": action.is_destructive,
|
||||
"reasons": reasons or [],
|
||||
},
|
||||
correlation_id=action.metadata.correlation_id,
|
||||
)
|
||||
|
||||
async def log_action_executed(
|
||||
self,
|
||||
action: ActionRequest,
|
||||
success: bool,
|
||||
execution_time_ms: float,
|
||||
error: str | None = None,
|
||||
) -> AuditEvent:
|
||||
"""Log an action execution result."""
|
||||
event_type = (
|
||||
AuditEventType.ACTION_EXECUTED if success else AuditEventType.ACTION_FAILED
|
||||
)
|
||||
|
||||
return await self.log(
|
||||
event_type,
|
||||
agent_id=action.metadata.agent_id,
|
||||
action_id=action.id,
|
||||
project_id=action.metadata.project_id,
|
||||
session_id=action.metadata.session_id,
|
||||
decision=SafetyDecision.ALLOW if success else SafetyDecision.DENY,
|
||||
details={
|
||||
"action_type": action.action_type.value,
|
||||
"tool_name": action.tool_name,
|
||||
"success": success,
|
||||
"execution_time_ms": execution_time_ms,
|
||||
"error": error,
|
||||
},
|
||||
correlation_id=action.metadata.correlation_id,
|
||||
)
|
||||
|
||||
async def log_approval_event(
|
||||
self,
|
||||
event_type: AuditEventType,
|
||||
approval_id: str,
|
||||
action: ActionRequest,
|
||||
decided_by: str | None = None,
|
||||
reason: str | None = None,
|
||||
) -> AuditEvent:
|
||||
"""Log an approval-related event."""
|
||||
return await self.log(
|
||||
event_type,
|
||||
agent_id=action.metadata.agent_id,
|
||||
action_id=action.id,
|
||||
project_id=action.metadata.project_id,
|
||||
session_id=action.metadata.session_id,
|
||||
user_id=decided_by,
|
||||
details={
|
||||
"approval_id": approval_id,
|
||||
"action_type": action.action_type.value,
|
||||
"tool_name": action.tool_name,
|
||||
"decided_by": decided_by,
|
||||
"reason": reason,
|
||||
},
|
||||
correlation_id=action.metadata.correlation_id,
|
||||
)
|
||||
|
||||
async def log_budget_event(
|
||||
self,
|
||||
event_type: AuditEventType,
|
||||
agent_id: str,
|
||||
scope: str,
|
||||
current_usage: float,
|
||||
limit: float,
|
||||
unit: str = "tokens",
|
||||
) -> AuditEvent:
|
||||
"""Log a budget-related event."""
|
||||
return await self.log(
|
||||
event_type,
|
||||
agent_id=agent_id,
|
||||
details={
|
||||
"scope": scope,
|
||||
"current_usage": current_usage,
|
||||
"limit": limit,
|
||||
"unit": unit,
|
||||
"usage_percent": (current_usage / limit * 100) if limit > 0 else 0,
|
||||
},
|
||||
)
|
||||
|
||||
async def log_emergency_stop(
|
||||
self,
|
||||
stop_type: str,
|
||||
triggered_by: str,
|
||||
reason: str,
|
||||
affected_agents: list[str] | None = None,
|
||||
) -> AuditEvent:
|
||||
"""Log an emergency stop event."""
|
||||
return await self.log(
|
||||
AuditEventType.EMERGENCY_STOP,
|
||||
user_id=triggered_by,
|
||||
details={
|
||||
"stop_type": stop_type,
|
||||
"triggered_by": triggered_by,
|
||||
"reason": reason,
|
||||
"affected_agents": affected_agents or [],
|
||||
},
|
||||
)
|
||||
|
||||
async def flush(self) -> int:
|
||||
"""
|
||||
Flush buffered events to persistent storage.
|
||||
|
||||
Returns:
|
||||
Number of events flushed
|
||||
"""
|
||||
async with self._lock:
|
||||
if not self._buffer:
|
||||
return 0
|
||||
|
||||
events = list(self._buffer)
|
||||
self._buffer.clear()
|
||||
|
||||
# Persist events (in production, this would go to database/storage)
|
||||
self._persisted.extend(events)
|
||||
|
||||
# Enforce retention
|
||||
self._enforce_retention()
|
||||
|
||||
logger.debug("Flushed %d audit events", len(events))
|
||||
return len(events)
|
||||
|
||||
async def query(
|
||||
self,
|
||||
*,
|
||||
event_types: list[AuditEventType] | None = None,
|
||||
agent_id: str | None = None,
|
||||
action_id: str | None = None,
|
||||
project_id: str | None = None,
|
||||
session_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
start_time: datetime | None = None,
|
||||
end_time: datetime | None = None,
|
||||
correlation_id: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> list[AuditEvent]:
|
||||
"""
|
||||
Query audit events with filters.
|
||||
|
||||
Args:
|
||||
event_types: Filter by event types
|
||||
agent_id: Filter by agent ID
|
||||
action_id: Filter by action ID
|
||||
project_id: Filter by project ID
|
||||
session_id: Filter by session ID
|
||||
user_id: Filter by user ID
|
||||
start_time: Filter events after this time
|
||||
end_time: Filter events before this time
|
||||
correlation_id: Filter by correlation ID
|
||||
limit: Maximum results to return
|
||||
offset: Result offset for pagination
|
||||
|
||||
Returns:
|
||||
List of matching audit events
|
||||
"""
|
||||
# Combine buffer and persisted for query
|
||||
all_events = list(self._persisted) + list(self._buffer)
|
||||
|
||||
results = []
|
||||
for event in all_events:
|
||||
if event_types and event.event_type not in event_types:
|
||||
continue
|
||||
if agent_id and event.agent_id != agent_id:
|
||||
continue
|
||||
if action_id and event.action_id != action_id:
|
||||
continue
|
||||
if project_id and event.project_id != project_id:
|
||||
continue
|
||||
if session_id and event.session_id != session_id:
|
||||
continue
|
||||
if user_id and event.user_id != user_id:
|
||||
continue
|
||||
if start_time and event.timestamp < start_time:
|
||||
continue
|
||||
if end_time and event.timestamp > end_time:
|
||||
continue
|
||||
if correlation_id and event.correlation_id != correlation_id:
|
||||
continue
|
||||
|
||||
results.append(event)
|
||||
|
||||
# Sort by timestamp descending
|
||||
results.sort(key=lambda e: e.timestamp, reverse=True)
|
||||
|
||||
# Apply pagination
|
||||
return results[offset : offset + limit]
|
||||
|
||||
async def get_action_history(
|
||||
self,
|
||||
agent_id: str,
|
||||
limit: int = 100,
|
||||
) -> list[AuditEvent]:
|
||||
"""Get action history for an agent."""
|
||||
return await self.query(
|
||||
agent_id=agent_id,
|
||||
event_types=[
|
||||
AuditEventType.ACTION_REQUESTED,
|
||||
AuditEventType.ACTION_VALIDATED,
|
||||
AuditEventType.ACTION_DENIED,
|
||||
AuditEventType.ACTION_EXECUTED,
|
||||
AuditEventType.ACTION_FAILED,
|
||||
],
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
async def verify_integrity(self) -> tuple[bool, list[str]]:
|
||||
"""
|
||||
Verify audit log integrity using hash chain.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, list of issues found)
|
||||
"""
|
||||
if not self._enable_hash_chain:
|
||||
return True, []
|
||||
|
||||
issues: list[str] = []
|
||||
all_events = list(self._persisted) + list(self._buffer)
|
||||
|
||||
prev_hash: str | None = None
|
||||
for event in sorted(all_events, key=lambda e: e.timestamp):
|
||||
stored_prev = event.details.get("_prev_hash")
|
||||
stored_hash = event.details.get("_hash")
|
||||
|
||||
if stored_prev != prev_hash:
|
||||
issues.append(
|
||||
f"Hash chain broken at event {event.id}: "
|
||||
f"expected prev_hash={prev_hash}, got {stored_prev}"
|
||||
)
|
||||
|
||||
if stored_hash:
|
||||
# Pass prev_hash to compute hash with correct chain position
|
||||
computed = self._compute_hash(event, prev_hash=prev_hash)
|
||||
if computed != stored_hash:
|
||||
issues.append(
|
||||
f"Hash mismatch at event {event.id}: "
|
||||
f"expected {computed}, got {stored_hash}"
|
||||
)
|
||||
|
||||
prev_hash = stored_hash
|
||||
|
||||
return len(issues) == 0, issues
|
||||
|
||||
def add_handler(self, handler: Any) -> None:
|
||||
"""Add a real-time event handler."""
|
||||
self._handlers.append(handler)
|
||||
|
||||
def remove_handler(self, handler: Any) -> None:
|
||||
"""Remove an event handler."""
|
||||
if handler in self._handlers:
|
||||
self._handlers.remove(handler)
|
||||
|
||||
def _sanitize_details(self, details: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Sanitize sensitive data from details."""
|
||||
if self._include_sensitive:
|
||||
return details
|
||||
|
||||
sanitized: dict[str, Any] = {}
|
||||
sensitive_keys = {
|
||||
"password",
|
||||
"secret",
|
||||
"token",
|
||||
"api_key",
|
||||
"apikey",
|
||||
"auth",
|
||||
"credential",
|
||||
}
|
||||
|
||||
for key, value in details.items():
|
||||
lower_key = key.lower()
|
||||
if any(s in lower_key for s in sensitive_keys):
|
||||
sanitized[key] = "[REDACTED]"
|
||||
elif isinstance(value, dict):
|
||||
sanitized[key] = self._sanitize_details(value)
|
||||
else:
|
||||
sanitized[key] = value
|
||||
|
||||
return sanitized
|
||||
|
||||
def _compute_hash(
|
||||
self, event: AuditEvent, prev_hash: str | None | object = _UNSET
|
||||
) -> str:
|
||||
"""Compute hash for an event (excluding hash fields).
|
||||
|
||||
Args:
|
||||
event: The audit event to hash.
|
||||
prev_hash: Optional previous hash to use instead of self._last_hash.
|
||||
Pass this during verification to use the correct chain.
|
||||
Use None explicitly to indicate no previous hash.
|
||||
"""
|
||||
# Use passed prev_hash if explicitly provided, otherwise use instance state
|
||||
effective_prev: str | None = (
|
||||
self._last_hash if prev_hash is _UNSET else prev_hash # type: ignore[assignment]
|
||||
)
|
||||
|
||||
data: dict[str, str | dict[str, str] | None] = {
|
||||
"id": event.id,
|
||||
"event_type": event.event_type.value,
|
||||
"timestamp": event.timestamp.isoformat(),
|
||||
"agent_id": event.agent_id,
|
||||
"action_id": event.action_id,
|
||||
"project_id": event.project_id,
|
||||
"session_id": event.session_id,
|
||||
"user_id": event.user_id,
|
||||
"decision": event.decision.value if event.decision else None,
|
||||
"details": {
|
||||
k: v for k, v in event.details.items() if not k.startswith("_")
|
||||
},
|
||||
"correlation_id": event.correlation_id,
|
||||
}
|
||||
|
||||
if effective_prev:
|
||||
data["_prev_hash"] = effective_prev
|
||||
|
||||
serialized = json.dumps(data, sort_keys=True, default=str)
|
||||
return hashlib.sha256(serialized.encode()).hexdigest()
|
||||
|
||||
def _log_to_logger(self, event: AuditEvent) -> None:
|
||||
"""Log event to standard Python logger."""
|
||||
log_data = {
|
||||
"audit_event": event.event_type.value,
|
||||
"event_id": event.id,
|
||||
"agent_id": event.agent_id,
|
||||
"action_id": event.action_id,
|
||||
"decision": event.decision.value if event.decision else None,
|
||||
}
|
||||
|
||||
# Use appropriate log level based on event type
|
||||
if event.event_type in {
|
||||
AuditEventType.ACTION_DENIED,
|
||||
AuditEventType.POLICY_VIOLATION,
|
||||
AuditEventType.EMERGENCY_STOP,
|
||||
}:
|
||||
logger.warning("Audit: %s", log_data)
|
||||
elif event.event_type in {
|
||||
AuditEventType.ACTION_FAILED,
|
||||
AuditEventType.ROLLBACK_FAILED,
|
||||
}:
|
||||
logger.error("Audit: %s", log_data)
|
||||
else:
|
||||
logger.info("Audit: %s", log_data)
|
||||
|
||||
def _enforce_retention(self) -> None:
|
||||
"""Enforce retention policy on persisted events."""
|
||||
if not self._retention_days:
|
||||
return
|
||||
|
||||
cutoff = datetime.utcnow() - timedelta(days=self._retention_days)
|
||||
before_count = len(self._persisted)
|
||||
|
||||
self._persisted = [e for e in self._persisted if e.timestamp >= cutoff]
|
||||
|
||||
removed = before_count - len(self._persisted)
|
||||
if removed > 0:
|
||||
logger.info("Removed %d expired audit events", removed)
|
||||
|
||||
async def _periodic_flush(self) -> None:
|
||||
"""Background task for periodic flushing."""
|
||||
while self._running:
|
||||
try:
|
||||
await asyncio.sleep(self._flush_interval)
|
||||
await self.flush()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error("Error in periodic audit flush: %s", e)
|
||||
|
||||
async def _notify_handlers(self, event: AuditEvent) -> None:
|
||||
"""Notify all registered handlers of a new event."""
|
||||
for handler in self._handlers:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(handler):
|
||||
await handler(event)
|
||||
else:
|
||||
handler(event)
|
||||
except Exception as e:
|
||||
logger.error("Error in audit event handler: %s", e)
|
||||
|
||||
|
||||
# Singleton instance
|
||||
_audit_logger: AuditLogger | None = None
|
||||
_audit_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def get_audit_logger() -> AuditLogger:
|
||||
"""Get the global audit logger instance."""
|
||||
global _audit_logger
|
||||
|
||||
async with _audit_lock:
|
||||
if _audit_logger is None:
|
||||
_audit_logger = AuditLogger()
|
||||
await _audit_logger.start()
|
||||
|
||||
return _audit_logger
|
||||
|
||||
|
||||
async def shutdown_audit_logger() -> None:
|
||||
"""Shutdown the global audit logger."""
|
||||
global _audit_logger
|
||||
|
||||
async with _audit_lock:
|
||||
if _audit_logger is not None:
|
||||
await _audit_logger.stop()
|
||||
_audit_logger = None
|
||||
|
||||
|
||||
def reset_audit_logger() -> None:
|
||||
"""Reset the audit logger (for testing)."""
|
||||
global _audit_logger
|
||||
_audit_logger = None
|
||||
304
backend/app/services/safety/config.py
Normal file
304
backend/app/services/safety/config.py
Normal file
@@ -0,0 +1,304 @@
|
||||
"""
|
||||
Safety Framework Configuration
|
||||
|
||||
Pydantic settings for the safety and guardrails framework.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
from .models import AutonomyLevel, SafetyPolicy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SafetyConfig(BaseSettings):
|
||||
"""Configuration for the safety framework."""
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_prefix="SAFETY_",
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
# General settings
|
||||
enabled: bool = Field(True, description="Enable safety framework")
|
||||
strict_mode: bool = Field(True, description="Strict mode (fail closed on errors)")
|
||||
log_level: str = Field("INFO", description="Logging level")
|
||||
|
||||
# Default autonomy level
|
||||
default_autonomy_level: AutonomyLevel = Field(
|
||||
AutonomyLevel.MILESTONE,
|
||||
description="Default autonomy level for new agents",
|
||||
)
|
||||
|
||||
# Default budget limits
|
||||
default_session_token_budget: int = Field(
|
||||
100_000, description="Default tokens per session"
|
||||
)
|
||||
default_daily_token_budget: int = Field(
|
||||
1_000_000, description="Default tokens per day"
|
||||
)
|
||||
default_session_cost_limit: float = Field(
|
||||
10.0, description="Default USD per session"
|
||||
)
|
||||
default_daily_cost_limit: float = Field(100.0, description="Default USD per day")
|
||||
|
||||
# Default rate limits
|
||||
default_actions_per_minute: int = Field(60, description="Default actions per min")
|
||||
default_llm_calls_per_minute: int = Field(20, description="Default LLM calls/min")
|
||||
default_file_ops_per_minute: int = Field(100, description="Default file ops/min")
|
||||
|
||||
# Loop detection
|
||||
loop_detection_enabled: bool = Field(True, description="Enable loop detection")
|
||||
max_repeated_actions: int = Field(5, description="Max exact repetitions")
|
||||
max_similar_actions: int = Field(10, description="Max similar actions")
|
||||
loop_history_size: int = Field(100, description="Action history size for loops")
|
||||
|
||||
# HITL settings
|
||||
hitl_enabled: bool = Field(True, description="Enable human-in-the-loop")
|
||||
hitl_default_timeout: int = Field(300, description="Default approval timeout (s)")
|
||||
hitl_notification_channels: list[str] = Field(
|
||||
default_factory=list, description="Notification channels"
|
||||
)
|
||||
|
||||
# Rollback settings
|
||||
rollback_enabled: bool = Field(True, description="Enable rollback capability")
|
||||
checkpoint_dir: str = Field(
|
||||
"/tmp/syndarix_checkpoints", # noqa: S108
|
||||
description="Directory for checkpoint storage",
|
||||
)
|
||||
checkpoint_retention_hours: int = Field(24, description="Checkpoint retention")
|
||||
auto_checkpoint_destructive: bool = Field(
|
||||
True, description="Auto-checkpoint destructive actions"
|
||||
)
|
||||
|
||||
# Sandbox settings
|
||||
sandbox_enabled: bool = Field(False, description="Enable sandbox execution")
|
||||
sandbox_timeout: int = Field(300, description="Sandbox timeout (s)")
|
||||
sandbox_memory_mb: int = Field(1024, description="Sandbox memory limit (MB)")
|
||||
sandbox_cpu_limit: float = Field(1.0, description="Sandbox CPU limit")
|
||||
sandbox_network_enabled: bool = Field(False, description="Allow sandbox network")
|
||||
|
||||
# Audit settings
|
||||
audit_enabled: bool = Field(True, description="Enable audit logging")
|
||||
audit_retention_days: int = Field(90, description="Audit log retention (days)")
|
||||
audit_include_sensitive: bool = Field(
|
||||
False, description="Include sensitive data in audit"
|
||||
)
|
||||
|
||||
# Content filtering
|
||||
content_filter_enabled: bool = Field(True, description="Enable content filtering")
|
||||
filter_pii: bool = Field(True, description="Filter PII")
|
||||
filter_secrets: bool = Field(True, description="Filter secrets")
|
||||
|
||||
# Emergency controls
|
||||
emergency_stop_enabled: bool = Field(True, description="Enable emergency stop")
|
||||
emergency_webhook_url: str | None = Field(None, description="Emergency webhook")
|
||||
|
||||
# Policy file path
|
||||
policy_file: str | None = Field(None, description="Path to policy YAML file")
|
||||
|
||||
# Validation cache
|
||||
validation_cache_ttl: int = Field(60, description="Validation cache TTL (s)")
|
||||
validation_cache_size: int = Field(1000, description="Validation cache size")
|
||||
|
||||
|
||||
class AutonomyConfig(BaseSettings):
|
||||
"""Configuration for autonomy levels."""
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_prefix="AUTONOMY_",
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
# FULL_CONTROL settings
|
||||
full_control_cost_limit: float = Field(1.0, description="USD limit per session")
|
||||
full_control_require_all_approval: bool = Field(
|
||||
True, description="Require approval for all"
|
||||
)
|
||||
full_control_block_destructive: bool = Field(
|
||||
True, description="Block destructive actions"
|
||||
)
|
||||
|
||||
# MILESTONE settings
|
||||
milestone_cost_limit: float = Field(10.0, description="USD limit per session")
|
||||
milestone_require_critical_approval: bool = Field(
|
||||
True, description="Require approval for critical"
|
||||
)
|
||||
milestone_auto_checkpoint: bool = Field(
|
||||
True, description="Auto-checkpoint destructive"
|
||||
)
|
||||
|
||||
# AUTONOMOUS settings
|
||||
autonomous_cost_limit: float = Field(100.0, description="USD limit per session")
|
||||
autonomous_auto_approve_normal: bool = Field(
|
||||
True, description="Auto-approve normal actions"
|
||||
)
|
||||
autonomous_auto_checkpoint: bool = Field(True, description="Auto-checkpoint all")
|
||||
|
||||
|
||||
def _expand_env_vars(value: Any) -> Any:
|
||||
"""Recursively expand environment variables in values."""
|
||||
if isinstance(value, str):
|
||||
return os.path.expandvars(value)
|
||||
elif isinstance(value, dict):
|
||||
return {k: _expand_env_vars(v) for k, v in value.items()}
|
||||
elif isinstance(value, list):
|
||||
return [_expand_env_vars(v) for v in value]
|
||||
return value
|
||||
|
||||
|
||||
def load_policy_from_file(file_path: str | Path) -> SafetyPolicy | None:
|
||||
"""Load a safety policy from a YAML file."""
|
||||
path = Path(file_path)
|
||||
if not path.exists():
|
||||
logger.warning("Policy file not found: %s", path)
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(path) as f:
|
||||
data = yaml.safe_load(f)
|
||||
|
||||
if data is None:
|
||||
logger.warning("Empty policy file: %s", path)
|
||||
return None
|
||||
|
||||
# Expand environment variables
|
||||
data = _expand_env_vars(data)
|
||||
|
||||
return SafetyPolicy(**data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to load policy file %s: %s", path, e)
|
||||
return None
|
||||
|
||||
|
||||
def load_policies_from_directory(directory: str | Path) -> dict[str, SafetyPolicy]:
|
||||
"""Load all safety policies from a directory."""
|
||||
policies: dict[str, SafetyPolicy] = {}
|
||||
path = Path(directory)
|
||||
|
||||
if not path.exists() or not path.is_dir():
|
||||
logger.warning("Policy directory not found: %s", path)
|
||||
return policies
|
||||
|
||||
for file_path in path.glob("*.yaml"):
|
||||
policy = load_policy_from_file(file_path)
|
||||
if policy:
|
||||
policies[policy.name] = policy
|
||||
logger.info("Loaded policy: %s from %s", policy.name, file_path.name)
|
||||
|
||||
return policies
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_safety_config() -> SafetyConfig:
|
||||
"""Get the safety configuration (cached singleton)."""
|
||||
return SafetyConfig()
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_autonomy_config() -> AutonomyConfig:
|
||||
"""Get the autonomy configuration (cached singleton)."""
|
||||
return AutonomyConfig()
|
||||
|
||||
|
||||
def get_default_policy() -> SafetyPolicy:
|
||||
"""Get the default safety policy."""
|
||||
config = get_safety_config()
|
||||
|
||||
return SafetyPolicy(
|
||||
name="default",
|
||||
description="Default safety policy",
|
||||
max_tokens_per_session=config.default_session_token_budget,
|
||||
max_tokens_per_day=config.default_daily_token_budget,
|
||||
max_cost_per_session_usd=config.default_session_cost_limit,
|
||||
max_cost_per_day_usd=config.default_daily_cost_limit,
|
||||
max_actions_per_minute=config.default_actions_per_minute,
|
||||
max_llm_calls_per_minute=config.default_llm_calls_per_minute,
|
||||
max_file_operations_per_minute=config.default_file_ops_per_minute,
|
||||
max_repeated_actions=config.max_repeated_actions,
|
||||
max_similar_actions=config.max_similar_actions,
|
||||
require_sandbox=config.sandbox_enabled,
|
||||
sandbox_timeout_seconds=config.sandbox_timeout,
|
||||
sandbox_memory_mb=config.sandbox_memory_mb,
|
||||
)
|
||||
|
||||
|
||||
def get_policy_for_autonomy_level(level: AutonomyLevel) -> SafetyPolicy:
|
||||
"""Get the safety policy for a given autonomy level."""
|
||||
autonomy = get_autonomy_config()
|
||||
|
||||
base_policy = get_default_policy()
|
||||
|
||||
if level == AutonomyLevel.FULL_CONTROL:
|
||||
return SafetyPolicy(
|
||||
name="full_control",
|
||||
description="Full control mode - all actions require approval",
|
||||
max_cost_per_session_usd=autonomy.full_control_cost_limit,
|
||||
max_cost_per_day_usd=autonomy.full_control_cost_limit * 10,
|
||||
require_approval_for=["*"], # All actions
|
||||
max_tokens_per_session=base_policy.max_tokens_per_session // 10,
|
||||
max_tokens_per_day=base_policy.max_tokens_per_day // 10,
|
||||
max_actions_per_minute=base_policy.max_actions_per_minute // 2,
|
||||
max_llm_calls_per_minute=base_policy.max_llm_calls_per_minute // 2,
|
||||
max_file_operations_per_minute=base_policy.max_file_operations_per_minute
|
||||
// 2,
|
||||
denied_tools=["delete_*", "destroy_*", "drop_*"],
|
||||
)
|
||||
|
||||
elif level == AutonomyLevel.MILESTONE:
|
||||
return SafetyPolicy(
|
||||
name="milestone",
|
||||
description="Milestone mode - approval at milestones only",
|
||||
max_cost_per_session_usd=autonomy.milestone_cost_limit,
|
||||
max_cost_per_day_usd=autonomy.milestone_cost_limit * 10,
|
||||
require_approval_for=[
|
||||
"delete_file",
|
||||
"push_to_remote",
|
||||
"deploy_*",
|
||||
"modify_critical_*",
|
||||
"create_pull_request",
|
||||
],
|
||||
max_tokens_per_session=base_policy.max_tokens_per_session,
|
||||
max_tokens_per_day=base_policy.max_tokens_per_day,
|
||||
max_actions_per_minute=base_policy.max_actions_per_minute,
|
||||
max_llm_calls_per_minute=base_policy.max_llm_calls_per_minute,
|
||||
max_file_operations_per_minute=base_policy.max_file_operations_per_minute,
|
||||
)
|
||||
|
||||
else: # AUTONOMOUS
|
||||
return SafetyPolicy(
|
||||
name="autonomous",
|
||||
description="Autonomous mode - minimal intervention",
|
||||
max_cost_per_session_usd=autonomy.autonomous_cost_limit,
|
||||
max_cost_per_day_usd=autonomy.autonomous_cost_limit * 10,
|
||||
require_approval_for=[
|
||||
"deploy_to_production",
|
||||
"delete_repository",
|
||||
"modify_production_config",
|
||||
],
|
||||
max_tokens_per_session=base_policy.max_tokens_per_session * 5,
|
||||
max_tokens_per_day=base_policy.max_tokens_per_day * 5,
|
||||
max_actions_per_minute=base_policy.max_actions_per_minute * 2,
|
||||
max_llm_calls_per_minute=base_policy.max_llm_calls_per_minute * 2,
|
||||
max_file_operations_per_minute=base_policy.max_file_operations_per_minute
|
||||
* 2,
|
||||
)
|
||||
|
||||
|
||||
def reset_config_cache() -> None:
|
||||
"""Reset configuration caches (for testing)."""
|
||||
get_safety_config.cache_clear()
|
||||
get_autonomy_config.cache_clear()
|
||||
23
backend/app/services/safety/content/__init__.py
Normal file
23
backend/app/services/safety/content/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Content filtering for safety."""
|
||||
|
||||
from .filter import (
|
||||
ContentCategory,
|
||||
ContentFilter,
|
||||
FilterAction,
|
||||
FilterMatch,
|
||||
FilterPattern,
|
||||
FilterResult,
|
||||
filter_content,
|
||||
scan_for_secrets,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ContentCategory",
|
||||
"ContentFilter",
|
||||
"FilterAction",
|
||||
"FilterMatch",
|
||||
"FilterPattern",
|
||||
"FilterResult",
|
||||
"filter_content",
|
||||
"scan_for_secrets",
|
||||
]
|
||||
550
backend/app/services/safety/content/filter.py
Normal file
550
backend/app/services/safety/content/filter.py
Normal file
@@ -0,0 +1,550 @@
|
||||
"""
|
||||
Content Filter
|
||||
|
||||
Filters and sanitizes content for safety, including PII detection and secret scanning.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field, replace
|
||||
from enum import Enum
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from ..exceptions import ContentFilterError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ContentCategory(str, Enum):
|
||||
"""Categories of sensitive content."""
|
||||
|
||||
PII = "pii"
|
||||
SECRETS = "secrets"
|
||||
CREDENTIALS = "credentials"
|
||||
FINANCIAL = "financial"
|
||||
HEALTH = "health"
|
||||
PROFANITY = "profanity"
|
||||
INJECTION = "injection"
|
||||
CUSTOM = "custom"
|
||||
|
||||
|
||||
class FilterAction(str, Enum):
|
||||
"""Actions to take on detected content."""
|
||||
|
||||
ALLOW = "allow"
|
||||
REDACT = "redact"
|
||||
BLOCK = "block"
|
||||
WARN = "warn"
|
||||
|
||||
|
||||
@dataclass
|
||||
class FilterMatch:
|
||||
"""A match found by a filter."""
|
||||
|
||||
category: ContentCategory
|
||||
pattern_name: str
|
||||
matched_text: str
|
||||
start_pos: int
|
||||
end_pos: int
|
||||
confidence: float = 1.0
|
||||
redacted_text: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class FilterResult:
|
||||
"""Result of content filtering."""
|
||||
|
||||
original_content: str
|
||||
filtered_content: str
|
||||
matches: list[FilterMatch] = field(default_factory=list)
|
||||
blocked: bool = False
|
||||
block_reason: str | None = None
|
||||
warnings: list[str] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def has_sensitive_content(self) -> bool:
|
||||
"""Check if any sensitive content was found."""
|
||||
return len(self.matches) > 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class FilterPattern:
|
||||
"""A pattern for detecting sensitive content."""
|
||||
|
||||
name: str
|
||||
category: ContentCategory
|
||||
pattern: str # Regex pattern
|
||||
action: FilterAction = FilterAction.REDACT
|
||||
replacement: str = "[REDACTED]"
|
||||
confidence: float = 1.0
|
||||
enabled: bool = True
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Compile the regex pattern."""
|
||||
self._compiled = re.compile(self.pattern, re.IGNORECASE | re.MULTILINE)
|
||||
|
||||
def find_matches(self, content: str) -> list[FilterMatch]:
|
||||
"""Find all matches in content."""
|
||||
matches = []
|
||||
for match in self._compiled.finditer(content):
|
||||
matches.append(
|
||||
FilterMatch(
|
||||
category=self.category,
|
||||
pattern_name=self.name,
|
||||
matched_text=match.group(),
|
||||
start_pos=match.start(),
|
||||
end_pos=match.end(),
|
||||
confidence=self.confidence,
|
||||
redacted_text=self.replacement,
|
||||
)
|
||||
)
|
||||
return matches
|
||||
|
||||
|
||||
class ContentFilter:
|
||||
"""
|
||||
Filters content for sensitive information.
|
||||
|
||||
Features:
|
||||
- PII detection (emails, phones, SSN, etc.)
|
||||
- Secret scanning (API keys, tokens, passwords)
|
||||
- Credential detection
|
||||
- Injection attack prevention
|
||||
- Custom pattern support
|
||||
- Configurable actions (allow, redact, block, warn)
|
||||
"""
|
||||
|
||||
# Default patterns for common sensitive data
|
||||
DEFAULT_PATTERNS: ClassVar[list[FilterPattern]] = [
|
||||
# PII Patterns
|
||||
FilterPattern(
|
||||
name="email",
|
||||
category=ContentCategory.PII,
|
||||
pattern=r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b",
|
||||
action=FilterAction.REDACT,
|
||||
replacement="[EMAIL]",
|
||||
),
|
||||
FilterPattern(
|
||||
name="phone_us",
|
||||
category=ContentCategory.PII,
|
||||
pattern=r"\b(?:\+1[-.\s]?)?(?:\(?\d{3}\)?[-.\s]?)?\d{3}[-.\s]?\d{4}\b",
|
||||
action=FilterAction.REDACT,
|
||||
replacement="[PHONE]",
|
||||
),
|
||||
FilterPattern(
|
||||
name="ssn",
|
||||
category=ContentCategory.PII,
|
||||
pattern=r"\b\d{3}[-\s]?\d{2}[-\s]?\d{4}\b",
|
||||
action=FilterAction.REDACT,
|
||||
replacement="[SSN]",
|
||||
),
|
||||
FilterPattern(
|
||||
name="credit_card",
|
||||
category=ContentCategory.FINANCIAL,
|
||||
pattern=r"\b(?:\d{4}[-\s]?){3}\d{4}\b",
|
||||
action=FilterAction.REDACT,
|
||||
replacement="[CREDIT_CARD]",
|
||||
),
|
||||
FilterPattern(
|
||||
name="ip_address",
|
||||
category=ContentCategory.PII,
|
||||
pattern=r"\b(?:\d{1,3}\.){3}\d{1,3}\b",
|
||||
action=FilterAction.WARN,
|
||||
replacement="[IP]",
|
||||
confidence=0.8,
|
||||
),
|
||||
# Secret Patterns
|
||||
FilterPattern(
|
||||
name="api_key_generic",
|
||||
category=ContentCategory.SECRETS,
|
||||
pattern=r"\b(?:api[_-]?key|apikey)\s*[:=]\s*['\"]?([A-Za-z0-9_-]{20,})['\"]?",
|
||||
action=FilterAction.BLOCK,
|
||||
replacement="[API_KEY]",
|
||||
),
|
||||
FilterPattern(
|
||||
name="aws_access_key",
|
||||
category=ContentCategory.SECRETS,
|
||||
pattern=r"\bAKIA[0-9A-Z]{16}\b",
|
||||
action=FilterAction.BLOCK,
|
||||
replacement="[AWS_KEY]",
|
||||
),
|
||||
FilterPattern(
|
||||
name="aws_secret_key",
|
||||
category=ContentCategory.SECRETS,
|
||||
pattern=r"\b[A-Za-z0-9/+=]{40}\b",
|
||||
action=FilterAction.WARN,
|
||||
replacement="[AWS_SECRET]",
|
||||
confidence=0.6, # Lower confidence - might be false positive
|
||||
),
|
||||
FilterPattern(
|
||||
name="github_token",
|
||||
category=ContentCategory.SECRETS,
|
||||
pattern=r"\b(ghp|gho|ghu|ghs|ghr)_[A-Za-z0-9]{36,}\b",
|
||||
action=FilterAction.BLOCK,
|
||||
replacement="[GITHUB_TOKEN]",
|
||||
),
|
||||
FilterPattern(
|
||||
name="jwt_token",
|
||||
category=ContentCategory.SECRETS,
|
||||
pattern=r"\beyJ[A-Za-z0-9_-]*\.eyJ[A-Za-z0-9_-]*\.[A-Za-z0-9_-]*\b",
|
||||
action=FilterAction.BLOCK,
|
||||
replacement="[JWT]",
|
||||
),
|
||||
# Credential Patterns
|
||||
FilterPattern(
|
||||
name="password_in_url",
|
||||
category=ContentCategory.CREDENTIALS,
|
||||
pattern=r"://[^:]+:([^@]+)@",
|
||||
action=FilterAction.BLOCK,
|
||||
replacement="://[REDACTED]@",
|
||||
),
|
||||
FilterPattern(
|
||||
name="password_assignment",
|
||||
category=ContentCategory.CREDENTIALS,
|
||||
pattern=r"\b(?:password|passwd|pwd)\s*[:=]\s*['\"]?([^\s'\"]+)['\"]?",
|
||||
action=FilterAction.REDACT,
|
||||
replacement="[PASSWORD]",
|
||||
),
|
||||
FilterPattern(
|
||||
name="private_key",
|
||||
category=ContentCategory.SECRETS,
|
||||
pattern=r"-----BEGIN (?:RSA |DSA |EC |OPENSSH )?PRIVATE KEY-----",
|
||||
action=FilterAction.BLOCK,
|
||||
replacement="[PRIVATE_KEY]",
|
||||
),
|
||||
# Injection Patterns
|
||||
FilterPattern(
|
||||
name="sql_injection",
|
||||
category=ContentCategory.INJECTION,
|
||||
pattern=r"(?:'\s*(?:OR|AND)\s*')|(?:--\s*$)|(?:;\s*(?:DROP|DELETE|UPDATE|INSERT))",
|
||||
action=FilterAction.BLOCK,
|
||||
replacement="[BLOCKED]",
|
||||
),
|
||||
FilterPattern(
|
||||
name="command_injection",
|
||||
category=ContentCategory.INJECTION,
|
||||
pattern=r"[;&|`$]|\$\(|\$\{",
|
||||
action=FilterAction.WARN,
|
||||
replacement="[CMD]",
|
||||
confidence=0.5, # Low confidence - common in code
|
||||
),
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
enable_pii_filter: bool = True,
|
||||
enable_secret_filter: bool = True,
|
||||
enable_injection_filter: bool = True,
|
||||
custom_patterns: list[FilterPattern] | None = None,
|
||||
default_action: FilterAction = FilterAction.REDACT,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the ContentFilter.
|
||||
|
||||
Args:
|
||||
enable_pii_filter: Enable PII detection
|
||||
enable_secret_filter: Enable secret scanning
|
||||
enable_injection_filter: Enable injection detection
|
||||
custom_patterns: Additional custom patterns
|
||||
default_action: Default action for matches
|
||||
"""
|
||||
self._patterns: list[FilterPattern] = []
|
||||
self._default_action = default_action
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# Load default patterns based on configuration
|
||||
# Use replace() to create a copy of each pattern to avoid mutating shared defaults
|
||||
for pattern in self.DEFAULT_PATTERNS:
|
||||
if pattern.category == ContentCategory.PII and not enable_pii_filter:
|
||||
continue
|
||||
if pattern.category == ContentCategory.SECRETS and not enable_secret_filter:
|
||||
continue
|
||||
if (
|
||||
pattern.category == ContentCategory.CREDENTIALS
|
||||
and not enable_secret_filter
|
||||
):
|
||||
continue
|
||||
if (
|
||||
pattern.category == ContentCategory.INJECTION
|
||||
and not enable_injection_filter
|
||||
):
|
||||
continue
|
||||
self._patterns.append(replace(pattern))
|
||||
|
||||
# Add custom patterns
|
||||
if custom_patterns:
|
||||
self._patterns.extend(custom_patterns)
|
||||
|
||||
logger.info("ContentFilter initialized with %d patterns", len(self._patterns))
|
||||
|
||||
def add_pattern(self, pattern: FilterPattern) -> None:
|
||||
"""Add a custom pattern."""
|
||||
self._patterns.append(pattern)
|
||||
logger.debug("Added pattern: %s", pattern.name)
|
||||
|
||||
def remove_pattern(self, pattern_name: str) -> bool:
|
||||
"""Remove a pattern by name."""
|
||||
for i, pattern in enumerate(self._patterns):
|
||||
if pattern.name == pattern_name:
|
||||
del self._patterns[i]
|
||||
logger.debug("Removed pattern: %s", pattern_name)
|
||||
return True
|
||||
return False
|
||||
|
||||
def enable_pattern(self, pattern_name: str, enabled: bool = True) -> bool:
|
||||
"""Enable or disable a pattern."""
|
||||
for pattern in self._patterns:
|
||||
if pattern.name == pattern_name:
|
||||
pattern.enabled = enabled
|
||||
return True
|
||||
return False
|
||||
|
||||
async def filter(
|
||||
self,
|
||||
content: str,
|
||||
context: dict[str, Any] | None = None,
|
||||
raise_on_block: bool = False,
|
||||
) -> FilterResult:
|
||||
"""
|
||||
Filter content for sensitive information.
|
||||
|
||||
Args:
|
||||
content: Content to filter
|
||||
context: Optional context for filtering decisions
|
||||
raise_on_block: Raise exception if content is blocked
|
||||
|
||||
Returns:
|
||||
FilterResult with filtered content and match details
|
||||
|
||||
Raises:
|
||||
ContentFilterError: If content is blocked and raise_on_block=True
|
||||
"""
|
||||
all_matches: list[FilterMatch] = []
|
||||
blocked = False
|
||||
block_reason: str | None = None
|
||||
warnings: list[str] = []
|
||||
|
||||
# Find all matches
|
||||
for pattern in self._patterns:
|
||||
if not pattern.enabled:
|
||||
continue
|
||||
|
||||
matches = pattern.find_matches(content)
|
||||
for match in matches:
|
||||
all_matches.append(match)
|
||||
|
||||
if pattern.action == FilterAction.BLOCK:
|
||||
blocked = True
|
||||
block_reason = f"Blocked by pattern: {pattern.name}"
|
||||
elif pattern.action == FilterAction.WARN:
|
||||
warnings.append(
|
||||
f"Warning: {pattern.name} detected at position {match.start_pos}"
|
||||
)
|
||||
|
||||
# Sort matches by position (reverse for replacement)
|
||||
all_matches.sort(key=lambda m: m.start_pos, reverse=True)
|
||||
|
||||
# Apply redactions
|
||||
filtered_content = content
|
||||
for match in all_matches:
|
||||
matched_pattern = self._get_pattern(match.pattern_name)
|
||||
if matched_pattern and matched_pattern.action in (
|
||||
FilterAction.REDACT,
|
||||
FilterAction.BLOCK,
|
||||
):
|
||||
filtered_content = (
|
||||
filtered_content[: match.start_pos]
|
||||
+ (match.redacted_text or "[REDACTED]")
|
||||
+ filtered_content[match.end_pos :]
|
||||
)
|
||||
|
||||
# Re-sort for result
|
||||
all_matches.sort(key=lambda m: m.start_pos)
|
||||
|
||||
result = FilterResult(
|
||||
original_content=content,
|
||||
filtered_content=filtered_content if not blocked else "",
|
||||
matches=all_matches,
|
||||
blocked=blocked,
|
||||
block_reason=block_reason,
|
||||
warnings=warnings,
|
||||
)
|
||||
|
||||
if blocked:
|
||||
logger.warning(
|
||||
"Content blocked: %s (%d matches)",
|
||||
block_reason,
|
||||
len(all_matches),
|
||||
)
|
||||
if raise_on_block:
|
||||
raise ContentFilterError(
|
||||
block_reason or "Content blocked",
|
||||
filter_type=all_matches[0].category.value
|
||||
if all_matches
|
||||
else "unknown",
|
||||
detected_patterns=[m.pattern_name for m in all_matches]
|
||||
if all_matches
|
||||
else [],
|
||||
)
|
||||
elif all_matches:
|
||||
logger.debug(
|
||||
"Content filtered: %d matches, %d warnings",
|
||||
len(all_matches),
|
||||
len(warnings),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def filter_dict(
|
||||
self,
|
||||
data: dict[str, Any],
|
||||
keys_to_filter: list[str] | None = None,
|
||||
recursive: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Filter string values in a dictionary.
|
||||
|
||||
Args:
|
||||
data: Dictionary to filter
|
||||
keys_to_filter: Specific keys to filter (None = all)
|
||||
recursive: Filter nested dictionaries
|
||||
|
||||
Returns:
|
||||
Filtered dictionary
|
||||
"""
|
||||
result: dict[str, Any] = {}
|
||||
|
||||
for key, value in data.items():
|
||||
if isinstance(value, str):
|
||||
if keys_to_filter is None or key in keys_to_filter:
|
||||
filter_result = await self.filter(value)
|
||||
result[key] = filter_result.filtered_content
|
||||
else:
|
||||
result[key] = value
|
||||
elif isinstance(value, dict) and recursive:
|
||||
result[key] = await self.filter_dict(value, keys_to_filter, recursive)
|
||||
elif isinstance(value, list):
|
||||
result[key] = [
|
||||
(await self.filter(item)).filtered_content
|
||||
if isinstance(item, str)
|
||||
else item
|
||||
for item in value
|
||||
]
|
||||
else:
|
||||
result[key] = value
|
||||
|
||||
return result
|
||||
|
||||
async def scan(
|
||||
self,
|
||||
content: str,
|
||||
categories: list[ContentCategory] | None = None,
|
||||
) -> list[FilterMatch]:
|
||||
"""
|
||||
Scan content without filtering (detection only).
|
||||
|
||||
Args:
|
||||
content: Content to scan
|
||||
categories: Limit to specific categories
|
||||
|
||||
Returns:
|
||||
List of matches found
|
||||
"""
|
||||
all_matches: list[FilterMatch] = []
|
||||
|
||||
for pattern in self._patterns:
|
||||
if not pattern.enabled:
|
||||
continue
|
||||
if categories and pattern.category not in categories:
|
||||
continue
|
||||
|
||||
matches = pattern.find_matches(content)
|
||||
all_matches.extend(matches)
|
||||
|
||||
all_matches.sort(key=lambda m: m.start_pos)
|
||||
return all_matches
|
||||
|
||||
async def validate_safe(
|
||||
self,
|
||||
content: str,
|
||||
categories: list[ContentCategory] | None = None,
|
||||
allow_warnings: bool = True,
|
||||
) -> tuple[bool, list[str]]:
|
||||
"""
|
||||
Validate that content is safe (no blocked patterns).
|
||||
|
||||
Args:
|
||||
content: Content to validate
|
||||
categories: Limit to specific categories
|
||||
allow_warnings: Allow content with warnings
|
||||
|
||||
Returns:
|
||||
Tuple of (is_safe, list of issues)
|
||||
"""
|
||||
issues: list[str] = []
|
||||
|
||||
for pattern in self._patterns:
|
||||
if not pattern.enabled:
|
||||
continue
|
||||
if categories and pattern.category not in categories:
|
||||
continue
|
||||
|
||||
matches = pattern.find_matches(content)
|
||||
for match in matches:
|
||||
if pattern.action == FilterAction.BLOCK:
|
||||
issues.append(
|
||||
f"Blocked: {pattern.name} at position {match.start_pos}"
|
||||
)
|
||||
elif pattern.action == FilterAction.WARN and not allow_warnings:
|
||||
issues.append(
|
||||
f"Warning: {pattern.name} at position {match.start_pos}"
|
||||
)
|
||||
|
||||
return len(issues) == 0, issues
|
||||
|
||||
def _get_pattern(self, name: str) -> FilterPattern | None:
|
||||
"""Get a pattern by name."""
|
||||
for pattern in self._patterns:
|
||||
if pattern.name == name:
|
||||
return pattern
|
||||
return None
|
||||
|
||||
def get_pattern_stats(self) -> dict[str, Any]:
|
||||
"""Get statistics about configured patterns."""
|
||||
by_category: dict[str, int] = {}
|
||||
by_action: dict[str, int] = {}
|
||||
|
||||
for pattern in self._patterns:
|
||||
cat = pattern.category.value
|
||||
by_category[cat] = by_category.get(cat, 0) + 1
|
||||
|
||||
act = pattern.action.value
|
||||
by_action[act] = by_action.get(act, 0) + 1
|
||||
|
||||
return {
|
||||
"total_patterns": len(self._patterns),
|
||||
"enabled_patterns": sum(1 for p in self._patterns if p.enabled),
|
||||
"by_category": by_category,
|
||||
"by_action": by_action,
|
||||
}
|
||||
|
||||
|
||||
# Convenience function for quick filtering
|
||||
async def filter_content(content: str) -> str:
|
||||
"""Quick filter content with default settings."""
|
||||
filter_instance = ContentFilter()
|
||||
result = await filter_instance.filter(content)
|
||||
return result.filtered_content
|
||||
|
||||
|
||||
async def scan_for_secrets(content: str) -> list[FilterMatch]:
|
||||
"""Quick scan for secrets only."""
|
||||
filter_instance = ContentFilter(
|
||||
enable_pii_filter=False,
|
||||
enable_injection_filter=False,
|
||||
)
|
||||
return await filter_instance.scan(
|
||||
content,
|
||||
categories=[ContentCategory.SECRETS, ContentCategory.CREDENTIALS],
|
||||
)
|
||||
15
backend/app/services/safety/costs/__init__.py
Normal file
15
backend/app/services/safety/costs/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""
|
||||
Cost Control Module
|
||||
|
||||
Budget management and cost tracking.
|
||||
"""
|
||||
|
||||
from .controller import (
|
||||
BudgetTracker,
|
||||
CostController,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BudgetTracker",
|
||||
"CostController",
|
||||
]
|
||||
498
backend/app/services/safety/costs/controller.py
Normal file
498
backend/app/services/safety/costs/controller.py
Normal file
@@ -0,0 +1,498 @@
|
||||
"""
|
||||
Cost Controller
|
||||
|
||||
Budget management and cost tracking for agent operations.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from ..config import get_safety_config
|
||||
from ..exceptions import BudgetExceededError
|
||||
from ..models import (
|
||||
ActionRequest,
|
||||
BudgetScope,
|
||||
BudgetStatus,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BudgetTracker:
|
||||
"""Tracks usage against a budget limit."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scope: BudgetScope,
|
||||
scope_id: str,
|
||||
tokens_limit: int,
|
||||
cost_limit_usd: float,
|
||||
reset_interval: timedelta | None = None,
|
||||
warning_threshold: float = 0.8,
|
||||
) -> None:
|
||||
self.scope = scope
|
||||
self.scope_id = scope_id
|
||||
self.tokens_limit = tokens_limit
|
||||
self.cost_limit_usd = cost_limit_usd
|
||||
self.warning_threshold = warning_threshold
|
||||
self._reset_interval = reset_interval
|
||||
|
||||
self._tokens_used = 0
|
||||
self._cost_used_usd = 0.0
|
||||
self._created_at = datetime.utcnow()
|
||||
self._last_reset = datetime.utcnow()
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def add_usage(self, tokens: int, cost_usd: float) -> None:
|
||||
"""Add usage to the tracker."""
|
||||
async with self._lock:
|
||||
self._check_reset()
|
||||
self._tokens_used += tokens
|
||||
self._cost_used_usd += cost_usd
|
||||
|
||||
async def get_status(self) -> BudgetStatus:
|
||||
"""Get current budget status."""
|
||||
async with self._lock:
|
||||
self._check_reset()
|
||||
|
||||
tokens_remaining = max(0, self.tokens_limit - self._tokens_used)
|
||||
cost_remaining = max(0, self.cost_limit_usd - self._cost_used_usd)
|
||||
|
||||
token_usage_ratio = (
|
||||
self._tokens_used / self.tokens_limit if self.tokens_limit > 0 else 0
|
||||
)
|
||||
cost_usage_ratio = (
|
||||
self._cost_used_usd / self.cost_limit_usd
|
||||
if self.cost_limit_usd > 0
|
||||
else 0
|
||||
)
|
||||
|
||||
is_warning = (
|
||||
max(token_usage_ratio, cost_usage_ratio) >= self.warning_threshold
|
||||
)
|
||||
is_exceeded = (
|
||||
self._tokens_used >= self.tokens_limit
|
||||
or self._cost_used_usd >= self.cost_limit_usd
|
||||
)
|
||||
|
||||
reset_at = None
|
||||
if self._reset_interval:
|
||||
reset_at = self._last_reset + self._reset_interval
|
||||
|
||||
return BudgetStatus(
|
||||
scope=self.scope,
|
||||
scope_id=self.scope_id,
|
||||
tokens_used=self._tokens_used,
|
||||
tokens_limit=self.tokens_limit,
|
||||
cost_used_usd=self._cost_used_usd,
|
||||
cost_limit_usd=self.cost_limit_usd,
|
||||
tokens_remaining=tokens_remaining,
|
||||
cost_remaining_usd=cost_remaining,
|
||||
warning_threshold=self.warning_threshold,
|
||||
is_warning=is_warning,
|
||||
is_exceeded=is_exceeded,
|
||||
reset_at=reset_at,
|
||||
)
|
||||
|
||||
async def check_budget(
|
||||
self, estimated_tokens: int, estimated_cost_usd: float
|
||||
) -> bool:
|
||||
"""Check if there's enough budget for an operation."""
|
||||
async with self._lock:
|
||||
self._check_reset()
|
||||
|
||||
would_exceed_tokens = (
|
||||
self._tokens_used + estimated_tokens
|
||||
) > self.tokens_limit
|
||||
would_exceed_cost = (
|
||||
self._cost_used_usd + estimated_cost_usd
|
||||
) > self.cost_limit_usd
|
||||
|
||||
return not (would_exceed_tokens or would_exceed_cost)
|
||||
|
||||
def _check_reset(self) -> None:
|
||||
"""Check if budget should reset."""
|
||||
if self._reset_interval is None:
|
||||
return
|
||||
|
||||
now = datetime.utcnow()
|
||||
if now >= self._last_reset + self._reset_interval:
|
||||
logger.info(
|
||||
"Resetting budget for %s:%s",
|
||||
self.scope.value,
|
||||
self.scope_id,
|
||||
)
|
||||
self._tokens_used = 0
|
||||
self._cost_used_usd = 0.0
|
||||
self._last_reset = now
|
||||
|
||||
async def reset(self) -> None:
|
||||
"""Manually reset the budget."""
|
||||
async with self._lock:
|
||||
self._tokens_used = 0
|
||||
self._cost_used_usd = 0.0
|
||||
self._last_reset = datetime.utcnow()
|
||||
|
||||
|
||||
class CostController:
|
||||
"""
|
||||
Controls costs and budgets for agent operations.
|
||||
|
||||
Features:
|
||||
- Per-agent, per-project, per-session budgets
|
||||
- Real-time cost tracking
|
||||
- Budget alerts at configurable thresholds
|
||||
- Cost prediction for planned actions
|
||||
- Budget rollover policies
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
default_session_tokens: int | None = None,
|
||||
default_session_cost_usd: float | None = None,
|
||||
default_daily_tokens: int | None = None,
|
||||
default_daily_cost_usd: float | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the CostController.
|
||||
|
||||
Args:
|
||||
default_session_tokens: Default token budget per session
|
||||
default_session_cost_usd: Default USD budget per session
|
||||
default_daily_tokens: Default token budget per day
|
||||
default_daily_cost_usd: Default USD budget per day
|
||||
"""
|
||||
config = get_safety_config()
|
||||
|
||||
self._default_session_tokens = (
|
||||
default_session_tokens or config.default_session_token_budget
|
||||
)
|
||||
self._default_session_cost = (
|
||||
default_session_cost_usd or config.default_session_cost_limit
|
||||
)
|
||||
self._default_daily_tokens = (
|
||||
default_daily_tokens or config.default_daily_token_budget
|
||||
)
|
||||
self._default_daily_cost = (
|
||||
default_daily_cost_usd or config.default_daily_cost_limit
|
||||
)
|
||||
|
||||
self._trackers: dict[str, BudgetTracker] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# Alert handlers
|
||||
self._alert_handlers: list[Any] = []
|
||||
|
||||
# Track which budgets have had warning alerts sent (to avoid spam)
|
||||
self._warned_budgets: set[str] = set()
|
||||
|
||||
async def get_or_create_tracker(
|
||||
self,
|
||||
scope: BudgetScope,
|
||||
scope_id: str,
|
||||
) -> BudgetTracker:
|
||||
"""Get or create a budget tracker."""
|
||||
key = f"{scope.value}:{scope_id}"
|
||||
|
||||
async with self._lock:
|
||||
if key not in self._trackers:
|
||||
if scope == BudgetScope.SESSION:
|
||||
tracker = BudgetTracker(
|
||||
scope=scope,
|
||||
scope_id=scope_id,
|
||||
tokens_limit=self._default_session_tokens,
|
||||
cost_limit_usd=self._default_session_cost,
|
||||
)
|
||||
elif scope == BudgetScope.DAILY:
|
||||
tracker = BudgetTracker(
|
||||
scope=scope,
|
||||
scope_id=scope_id,
|
||||
tokens_limit=self._default_daily_tokens,
|
||||
cost_limit_usd=self._default_daily_cost,
|
||||
reset_interval=timedelta(days=1),
|
||||
)
|
||||
else:
|
||||
# Default
|
||||
tracker = BudgetTracker(
|
||||
scope=scope,
|
||||
scope_id=scope_id,
|
||||
tokens_limit=self._default_session_tokens,
|
||||
cost_limit_usd=self._default_session_cost,
|
||||
)
|
||||
|
||||
self._trackers[key] = tracker
|
||||
|
||||
return self._trackers[key]
|
||||
|
||||
async def check_budget(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_id: str | None,
|
||||
estimated_tokens: int,
|
||||
estimated_cost_usd: float,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if there's enough budget for an operation.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent
|
||||
session_id: Optional session ID
|
||||
estimated_tokens: Estimated token usage
|
||||
estimated_cost_usd: Estimated USD cost
|
||||
|
||||
Returns:
|
||||
True if budget is available
|
||||
"""
|
||||
# Check session budget
|
||||
if session_id:
|
||||
session_tracker = await self.get_or_create_tracker(
|
||||
BudgetScope.SESSION, session_id
|
||||
)
|
||||
if not await session_tracker.check_budget(
|
||||
estimated_tokens, estimated_cost_usd
|
||||
):
|
||||
return False
|
||||
|
||||
# Check agent daily budget
|
||||
agent_tracker = await self.get_or_create_tracker(BudgetScope.DAILY, agent_id)
|
||||
if not await agent_tracker.check_budget(estimated_tokens, estimated_cost_usd):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def check_action(self, action: ActionRequest) -> bool:
|
||||
"""
|
||||
Check if an action is within budget.
|
||||
|
||||
Args:
|
||||
action: The action to check
|
||||
|
||||
Returns:
|
||||
True if within budget
|
||||
"""
|
||||
return await self.check_budget(
|
||||
agent_id=action.metadata.agent_id,
|
||||
session_id=action.metadata.session_id,
|
||||
estimated_tokens=action.estimated_cost_tokens,
|
||||
estimated_cost_usd=action.estimated_cost_usd,
|
||||
)
|
||||
|
||||
async def require_budget(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_id: str | None,
|
||||
estimated_tokens: int,
|
||||
estimated_cost_usd: float,
|
||||
) -> None:
|
||||
"""
|
||||
Require budget or raise exception.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent
|
||||
session_id: Optional session ID
|
||||
estimated_tokens: Estimated token usage
|
||||
estimated_cost_usd: Estimated USD cost
|
||||
|
||||
Raises:
|
||||
BudgetExceededError: If budget is exceeded
|
||||
"""
|
||||
if not await self.check_budget(
|
||||
agent_id, session_id, estimated_tokens, estimated_cost_usd
|
||||
):
|
||||
# Determine which budget was exceeded
|
||||
if session_id:
|
||||
session_tracker = await self.get_or_create_tracker(
|
||||
BudgetScope.SESSION, session_id
|
||||
)
|
||||
session_status = await session_tracker.get_status()
|
||||
if session_status.is_exceeded:
|
||||
raise BudgetExceededError(
|
||||
"Session budget exceeded",
|
||||
budget_type="session",
|
||||
current_usage=session_status.tokens_used,
|
||||
budget_limit=session_status.tokens_limit,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
agent_tracker = await self.get_or_create_tracker(
|
||||
BudgetScope.DAILY, agent_id
|
||||
)
|
||||
agent_status = await agent_tracker.get_status()
|
||||
raise BudgetExceededError(
|
||||
"Daily budget exceeded",
|
||||
budget_type="daily",
|
||||
current_usage=agent_status.tokens_used,
|
||||
budget_limit=agent_status.tokens_limit,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
async def record_usage(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_id: str | None,
|
||||
tokens: int,
|
||||
cost_usd: float,
|
||||
) -> None:
|
||||
"""
|
||||
Record actual usage.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent
|
||||
session_id: Optional session ID
|
||||
tokens: Actual token usage
|
||||
cost_usd: Actual USD cost
|
||||
"""
|
||||
# Update session budget
|
||||
if session_id:
|
||||
session_key = f"session:{session_id}"
|
||||
session_tracker = await self.get_or_create_tracker(
|
||||
BudgetScope.SESSION, session_id
|
||||
)
|
||||
await session_tracker.add_usage(tokens, cost_usd)
|
||||
|
||||
# Check for warning (only alert once per budget to avoid spam)
|
||||
status = await session_tracker.get_status()
|
||||
if status.is_warning and not status.is_exceeded:
|
||||
if session_key not in self._warned_budgets:
|
||||
self._warned_budgets.add(session_key)
|
||||
await self._send_alert(
|
||||
"warning",
|
||||
f"Session {session_id} at {status.tokens_used}/{status.tokens_limit} tokens",
|
||||
status,
|
||||
)
|
||||
elif not status.is_warning:
|
||||
# Clear warning flag if usage dropped below threshold (e.g., after reset)
|
||||
self._warned_budgets.discard(session_key)
|
||||
|
||||
# Update agent daily budget
|
||||
daily_key = f"daily:{agent_id}"
|
||||
agent_tracker = await self.get_or_create_tracker(BudgetScope.DAILY, agent_id)
|
||||
await agent_tracker.add_usage(tokens, cost_usd)
|
||||
|
||||
# Check for warning (only alert once per budget to avoid spam)
|
||||
status = await agent_tracker.get_status()
|
||||
if status.is_warning and not status.is_exceeded:
|
||||
if daily_key not in self._warned_budgets:
|
||||
self._warned_budgets.add(daily_key)
|
||||
await self._send_alert(
|
||||
"warning",
|
||||
f"Agent {agent_id} at {status.tokens_used}/{status.tokens_limit} daily tokens",
|
||||
status,
|
||||
)
|
||||
elif not status.is_warning:
|
||||
# Clear warning flag if usage dropped below threshold (e.g., after reset)
|
||||
self._warned_budgets.discard(daily_key)
|
||||
|
||||
async def get_status(
|
||||
self,
|
||||
scope: BudgetScope,
|
||||
scope_id: str,
|
||||
) -> BudgetStatus | None:
|
||||
"""
|
||||
Get budget status.
|
||||
|
||||
Args:
|
||||
scope: Budget scope
|
||||
scope_id: ID within scope
|
||||
|
||||
Returns:
|
||||
Budget status or None if not tracked
|
||||
"""
|
||||
key = f"{scope.value}:{scope_id}"
|
||||
async with self._lock:
|
||||
tracker = self._trackers.get(key)
|
||||
# Get status while holding lock to prevent TOCTOU race
|
||||
if tracker:
|
||||
return await tracker.get_status()
|
||||
return None
|
||||
|
||||
async def get_all_statuses(self) -> list[BudgetStatus]:
|
||||
"""Get status of all tracked budgets."""
|
||||
statuses = []
|
||||
async with self._lock:
|
||||
# Get all statuses while holding lock to prevent TOCTOU race
|
||||
for tracker in self._trackers.values():
|
||||
statuses.append(await tracker.get_status())
|
||||
return statuses
|
||||
|
||||
async def set_budget(
|
||||
self,
|
||||
scope: BudgetScope,
|
||||
scope_id: str,
|
||||
tokens_limit: int,
|
||||
cost_limit_usd: float,
|
||||
) -> None:
|
||||
"""
|
||||
Set a custom budget limit.
|
||||
|
||||
Args:
|
||||
scope: Budget scope
|
||||
scope_id: ID within scope
|
||||
tokens_limit: Token limit
|
||||
cost_limit_usd: USD limit
|
||||
"""
|
||||
key = f"{scope.value}:{scope_id}"
|
||||
|
||||
reset_interval = None
|
||||
if scope == BudgetScope.DAILY:
|
||||
reset_interval = timedelta(days=1)
|
||||
elif scope == BudgetScope.WEEKLY:
|
||||
reset_interval = timedelta(weeks=1)
|
||||
elif scope == BudgetScope.MONTHLY:
|
||||
reset_interval = timedelta(days=30)
|
||||
|
||||
async with self._lock:
|
||||
self._trackers[key] = BudgetTracker(
|
||||
scope=scope,
|
||||
scope_id=scope_id,
|
||||
tokens_limit=tokens_limit,
|
||||
cost_limit_usd=cost_limit_usd,
|
||||
reset_interval=reset_interval,
|
||||
)
|
||||
|
||||
async def reset_budget(self, scope: BudgetScope, scope_id: str) -> bool:
|
||||
"""
|
||||
Reset a budget tracker.
|
||||
|
||||
Args:
|
||||
scope: Budget scope
|
||||
scope_id: ID within scope
|
||||
|
||||
Returns:
|
||||
True if tracker was found and reset
|
||||
"""
|
||||
key = f"{scope.value}:{scope_id}"
|
||||
async with self._lock:
|
||||
tracker = self._trackers.get(key)
|
||||
# Reset while holding lock to prevent TOCTOU race
|
||||
if tracker:
|
||||
await tracker.reset()
|
||||
return True
|
||||
return False
|
||||
|
||||
def add_alert_handler(self, handler: Any) -> None:
|
||||
"""Add an alert handler."""
|
||||
self._alert_handlers.append(handler)
|
||||
|
||||
def remove_alert_handler(self, handler: Any) -> None:
|
||||
"""Remove an alert handler."""
|
||||
if handler in self._alert_handlers:
|
||||
self._alert_handlers.remove(handler)
|
||||
|
||||
async def _send_alert(
|
||||
self,
|
||||
alert_type: str,
|
||||
message: str,
|
||||
status: BudgetStatus,
|
||||
) -> None:
|
||||
"""Send alert to all handlers."""
|
||||
for handler in self._alert_handlers:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(handler):
|
||||
await handler(alert_type, message, status)
|
||||
else:
|
||||
handler(alert_type, message, status)
|
||||
except Exception as e:
|
||||
logger.error("Error in alert handler: %s", e)
|
||||
23
backend/app/services/safety/emergency/__init__.py
Normal file
23
backend/app/services/safety/emergency/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Emergency controls for agent safety."""
|
||||
|
||||
from .controls import (
|
||||
EmergencyControls,
|
||||
EmergencyEvent,
|
||||
EmergencyReason,
|
||||
EmergencyState,
|
||||
EmergencyTrigger,
|
||||
check_emergency_allowed,
|
||||
emergency_stop_global,
|
||||
get_emergency_controls,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"EmergencyControls",
|
||||
"EmergencyEvent",
|
||||
"EmergencyReason",
|
||||
"EmergencyState",
|
||||
"EmergencyTrigger",
|
||||
"check_emergency_allowed",
|
||||
"emergency_stop_global",
|
||||
"get_emergency_controls",
|
||||
]
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user