Compare commits
91 Commits
main
...
e5975fa5d0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e5975fa5d0 | ||
|
|
731a188a76 | ||
|
|
fe2104822e | ||
|
|
664415111a | ||
|
|
acd18ff694 | ||
|
|
da5affd613 | ||
|
|
a79d923dc1 | ||
|
|
c72f6aa2f9 | ||
|
|
4f24cebf11 | ||
|
|
e0739a786c | ||
|
|
64576da7dc | ||
|
|
4a55bd63a3 | ||
|
|
a78b903f5a | ||
|
|
c7b2c82700 | ||
|
|
50b865b23b | ||
|
|
6f5dd58b54 | ||
|
|
0ceee8545e | ||
|
|
62aea06e0d | ||
|
|
24f1cc637e | ||
|
|
8b6cca5d4d | ||
|
|
c9700f760e | ||
|
|
6f509e71ce | ||
|
|
f5a86953c6 | ||
|
|
246d2a6752 | ||
|
|
36ab7069cf | ||
|
|
a4c91cb8c3 | ||
|
|
a7ba0f9bd8 | ||
|
|
f3fb4ecbeb | ||
|
|
5c35702caf | ||
|
|
7280b182bd | ||
|
|
06b2491c1f | ||
|
|
b8265783f3 | ||
|
|
63066c50ba | ||
|
|
ddf9b5fe25 | ||
|
|
c3b66cccfc | ||
|
|
896f0d92e5 | ||
|
|
2ccaeb23f2 | ||
|
|
04c939d4c2 | ||
|
|
71c94c3b5a | ||
|
|
d71891ac4e | ||
|
|
3492941aec | ||
|
|
81e8d7e73d | ||
|
|
f0b04d53af | ||
|
|
35af7daf90 | ||
|
|
5fab15a11e | ||
|
|
ab913575e1 | ||
|
|
82cb6386a6 | ||
|
|
2d05035c1d | ||
|
|
15d747eb28 | ||
|
|
3d6fa6b791 | ||
|
|
3ea1874638 | ||
|
|
e1657d5ad8 | ||
|
|
83fa51fd4a | ||
|
|
db868c53c6 | ||
|
|
68f1865a1e | ||
|
|
5b1e2852ea | ||
|
|
d0a88d1fd1 | ||
|
|
e85788f79f | ||
|
|
25d42ee2a6 | ||
|
|
e41ceafaef | ||
|
|
43fa69db7d | ||
|
|
29309e5cfd | ||
|
|
cea97afe25 | ||
|
|
b43fa8ace2 | ||
|
|
742ce4c9c8 | ||
|
|
6ea9edf3d1 | ||
|
|
25b8f1723e | ||
|
|
73d10f364c | ||
|
|
2310c8cdfd | ||
|
|
2f7124959d | ||
|
|
2104ae38ec | ||
|
|
2055320058 | ||
|
|
11da0d57a8 | ||
|
|
acfda1e9a9 | ||
|
|
3c24a8c522 | ||
|
|
ec111f9ce6 | ||
|
|
520a4d60fb | ||
|
|
6e645835dc | ||
|
|
fcda8f0f96 | ||
|
|
d6db6af964 | ||
|
|
88cf4e0abc | ||
|
|
f138417486 | ||
|
|
de47d9ee43 | ||
|
|
406b25cda0 | ||
|
|
bd702734c2 | ||
|
|
5594655fba | ||
|
|
ebd307cab4 | ||
|
|
6e3cdebbfb | ||
|
|
a6a336b66e | ||
|
|
9901dc7f51 | ||
|
|
ac64d9505e |
@@ -1,15 +1,22 @@
|
|||||||
# Common settings
|
# Common settings
|
||||||
PROJECT_NAME=App
|
PROJECT_NAME=Syndarix
|
||||||
VERSION=1.0.0
|
VERSION=1.0.0
|
||||||
|
|
||||||
# Database settings
|
# Database settings
|
||||||
POSTGRES_USER=postgres
|
POSTGRES_USER=postgres
|
||||||
POSTGRES_PASSWORD=postgres
|
POSTGRES_PASSWORD=postgres
|
||||||
POSTGRES_DB=app
|
POSTGRES_DB=syndarix
|
||||||
POSTGRES_HOST=db
|
POSTGRES_HOST=db
|
||||||
POSTGRES_PORT=5432
|
POSTGRES_PORT=5432
|
||||||
DATABASE_URL=postgresql://${POSTGRES_USER}:${POSTGRES_PASSWORD}@${POSTGRES_HOST}:${POSTGRES_PORT}/${POSTGRES_DB}
|
DATABASE_URL=postgresql://${POSTGRES_USER}:${POSTGRES_PASSWORD}@${POSTGRES_HOST}:${POSTGRES_PORT}/${POSTGRES_DB}
|
||||||
|
|
||||||
|
# Redis settings (cache, pub/sub, Celery broker)
|
||||||
|
REDIS_URL=redis://redis:6379/0
|
||||||
|
|
||||||
|
# Celery settings (optional - defaults to REDIS_URL if not set)
|
||||||
|
# CELERY_BROKER_URL=redis://redis:6379/0
|
||||||
|
# CELERY_RESULT_BACKEND=redis://redis:6379/0
|
||||||
|
|
||||||
# Backend settings
|
# Backend settings
|
||||||
BACKEND_PORT=8000
|
BACKEND_PORT=8000
|
||||||
# CRITICAL: Generate a secure SECRET_KEY for production!
|
# CRITICAL: Generate a secure SECRET_KEY for production!
|
||||||
|
|||||||
460
.gitea/workflows/ci.yaml
Normal file
460
.gitea/workflows/ci.yaml
Normal file
@@ -0,0 +1,460 @@
|
|||||||
|
# Syndarix CI/CD Pipeline
|
||||||
|
# Gitea Actions workflow for continuous integration and deployment
|
||||||
|
#
|
||||||
|
# Pipeline Structure:
|
||||||
|
# - lint: Fast feedback (linting and type checking)
|
||||||
|
# - test: Run test suites (depends on lint)
|
||||||
|
# - build: Build Docker images (depends on test)
|
||||||
|
# - deploy: Deploy to production (depends on build, only on main)
|
||||||
|
|
||||||
|
name: CI/CD Pipeline
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
- dev
|
||||||
|
- 'feature/**'
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
- dev
|
||||||
|
|
||||||
|
env:
|
||||||
|
PYTHON_VERSION: "3.12"
|
||||||
|
NODE_VERSION: "20"
|
||||||
|
UV_VERSION: "0.4.x"
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
# ===========================================================================
|
||||||
|
# LINT JOB - Fast feedback first
|
||||||
|
# ===========================================================================
|
||||||
|
lint:
|
||||||
|
name: Lint & Type Check
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
component: [backend, frontend]
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
# ----- Backend Linting -----
|
||||||
|
- name: Set up Python
|
||||||
|
if: matrix.component == 'backend'
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: ${{ env.PYTHON_VERSION }}
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
if: matrix.component == 'backend'
|
||||||
|
uses: astral-sh/setup-uv@v4
|
||||||
|
with:
|
||||||
|
version: ${{ env.UV_VERSION }}
|
||||||
|
|
||||||
|
- name: Cache uv dependencies
|
||||||
|
if: matrix.component == 'backend'
|
||||||
|
uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
~/.cache/uv
|
||||||
|
backend/.venv
|
||||||
|
key: uv-${{ runner.os }}-${{ hashFiles('backend/uv.lock') }}
|
||||||
|
restore-keys: |
|
||||||
|
uv-${{ runner.os }}-
|
||||||
|
|
||||||
|
- name: Install backend dependencies
|
||||||
|
if: matrix.component == 'backend'
|
||||||
|
working-directory: backend
|
||||||
|
run: uv sync --extra dev --frozen
|
||||||
|
|
||||||
|
- name: Run ruff linting
|
||||||
|
if: matrix.component == 'backend'
|
||||||
|
working-directory: backend
|
||||||
|
run: uv run ruff check app
|
||||||
|
|
||||||
|
- name: Run ruff format check
|
||||||
|
if: matrix.component == 'backend'
|
||||||
|
working-directory: backend
|
||||||
|
run: uv run ruff format --check app
|
||||||
|
|
||||||
|
- name: Run mypy type checking
|
||||||
|
if: matrix.component == 'backend'
|
||||||
|
working-directory: backend
|
||||||
|
run: uv run mypy app
|
||||||
|
|
||||||
|
# ----- Frontend Linting -----
|
||||||
|
- name: Set up Node.js
|
||||||
|
if: matrix.component == 'frontend'
|
||||||
|
uses: actions/setup-node@v4
|
||||||
|
with:
|
||||||
|
node-version: ${{ env.NODE_VERSION }}
|
||||||
|
|
||||||
|
- name: Cache npm dependencies
|
||||||
|
if: matrix.component == 'frontend'
|
||||||
|
uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
~/.npm
|
||||||
|
frontend/node_modules
|
||||||
|
key: npm-${{ runner.os }}-${{ hashFiles('frontend/package-lock.json') }}
|
||||||
|
restore-keys: |
|
||||||
|
npm-${{ runner.os }}-
|
||||||
|
|
||||||
|
- name: Install frontend dependencies
|
||||||
|
if: matrix.component == 'frontend'
|
||||||
|
working-directory: frontend
|
||||||
|
run: npm ci
|
||||||
|
|
||||||
|
- name: Run ESLint
|
||||||
|
if: matrix.component == 'frontend'
|
||||||
|
working-directory: frontend
|
||||||
|
run: npm run lint
|
||||||
|
|
||||||
|
- name: Run TypeScript type check
|
||||||
|
if: matrix.component == 'frontend'
|
||||||
|
working-directory: frontend
|
||||||
|
run: npm run type-check
|
||||||
|
|
||||||
|
- name: Run Prettier format check
|
||||||
|
if: matrix.component == 'frontend'
|
||||||
|
working-directory: frontend
|
||||||
|
run: npm run format:check
|
||||||
|
|
||||||
|
# ===========================================================================
|
||||||
|
# TEST JOB - Run test suites
|
||||||
|
# ===========================================================================
|
||||||
|
test:
|
||||||
|
name: Test
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
needs: lint
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
component: [backend, frontend]
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
# ----- Backend Tests -----
|
||||||
|
- name: Set up Python
|
||||||
|
if: matrix.component == 'backend'
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: ${{ env.PYTHON_VERSION }}
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
if: matrix.component == 'backend'
|
||||||
|
uses: astral-sh/setup-uv@v4
|
||||||
|
with:
|
||||||
|
version: ${{ env.UV_VERSION }}
|
||||||
|
|
||||||
|
- name: Cache uv dependencies
|
||||||
|
if: matrix.component == 'backend'
|
||||||
|
uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
~/.cache/uv
|
||||||
|
backend/.venv
|
||||||
|
key: uv-${{ runner.os }}-${{ hashFiles('backend/uv.lock') }}
|
||||||
|
restore-keys: |
|
||||||
|
uv-${{ runner.os }}-
|
||||||
|
|
||||||
|
- name: Install backend dependencies
|
||||||
|
if: matrix.component == 'backend'
|
||||||
|
working-directory: backend
|
||||||
|
run: uv sync --extra dev --frozen
|
||||||
|
|
||||||
|
- name: Run pytest with coverage
|
||||||
|
if: matrix.component == 'backend'
|
||||||
|
working-directory: backend
|
||||||
|
env:
|
||||||
|
IS_TEST: "True"
|
||||||
|
run: |
|
||||||
|
uv run pytest --cov=app --cov-report=xml --cov-report=term-missing --cov-fail-under=90
|
||||||
|
|
||||||
|
- name: Upload backend coverage report
|
||||||
|
if: matrix.component == 'backend'
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: backend-coverage
|
||||||
|
path: backend/coverage.xml
|
||||||
|
retention-days: 7
|
||||||
|
|
||||||
|
# ----- Frontend Tests -----
|
||||||
|
- name: Set up Node.js
|
||||||
|
if: matrix.component == 'frontend'
|
||||||
|
uses: actions/setup-node@v4
|
||||||
|
with:
|
||||||
|
node-version: ${{ env.NODE_VERSION }}
|
||||||
|
|
||||||
|
- name: Cache npm dependencies
|
||||||
|
if: matrix.component == 'frontend'
|
||||||
|
uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
~/.npm
|
||||||
|
frontend/node_modules
|
||||||
|
key: npm-${{ runner.os }}-${{ hashFiles('frontend/package-lock.json') }}
|
||||||
|
restore-keys: |
|
||||||
|
npm-${{ runner.os }}-
|
||||||
|
|
||||||
|
- name: Install frontend dependencies
|
||||||
|
if: matrix.component == 'frontend'
|
||||||
|
working-directory: frontend
|
||||||
|
run: npm ci
|
||||||
|
|
||||||
|
- name: Run Jest unit tests
|
||||||
|
if: matrix.component == 'frontend'
|
||||||
|
working-directory: frontend
|
||||||
|
run: npm test -- --coverage --passWithNoTests
|
||||||
|
|
||||||
|
- name: Upload frontend coverage report
|
||||||
|
if: matrix.component == 'frontend'
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: frontend-coverage
|
||||||
|
path: frontend/coverage/
|
||||||
|
retention-days: 7
|
||||||
|
|
||||||
|
# ===========================================================================
|
||||||
|
# BUILD JOB - Build Docker images
|
||||||
|
# ===========================================================================
|
||||||
|
build:
|
||||||
|
name: Build
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
needs: test
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
component: [backend, frontend]
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Docker Buildx
|
||||||
|
uses: docker/setup-buildx-action@v3
|
||||||
|
|
||||||
|
- name: Cache Docker layers
|
||||||
|
uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: /tmp/.buildx-cache
|
||||||
|
key: docker-${{ matrix.component }}-${{ github.sha }}
|
||||||
|
restore-keys: |
|
||||||
|
docker-${{ matrix.component }}-
|
||||||
|
|
||||||
|
- name: Build backend Docker image
|
||||||
|
if: matrix.component == 'backend'
|
||||||
|
uses: docker/build-push-action@v5
|
||||||
|
with:
|
||||||
|
context: ./backend
|
||||||
|
file: ./backend/Dockerfile
|
||||||
|
target: production
|
||||||
|
push: false
|
||||||
|
tags: syndarix-backend:${{ github.sha }}
|
||||||
|
cache-from: type=local,src=/tmp/.buildx-cache
|
||||||
|
cache-to: type=local,dest=/tmp/.buildx-cache-new,mode=max
|
||||||
|
|
||||||
|
- name: Build frontend Docker image
|
||||||
|
if: matrix.component == 'frontend'
|
||||||
|
uses: docker/build-push-action@v5
|
||||||
|
with:
|
||||||
|
context: ./frontend
|
||||||
|
file: ./frontend/Dockerfile
|
||||||
|
target: runner
|
||||||
|
push: false
|
||||||
|
tags: syndarix-frontend:${{ github.sha }}
|
||||||
|
build-args: |
|
||||||
|
NEXT_PUBLIC_API_URL=http://localhost:8000
|
||||||
|
cache-from: type=local,src=/tmp/.buildx-cache
|
||||||
|
cache-to: type=local,dest=/tmp/.buildx-cache-new,mode=max
|
||||||
|
|
||||||
|
# Prevent cache from growing indefinitely
|
||||||
|
- name: Move cache
|
||||||
|
run: |
|
||||||
|
rm -rf /tmp/.buildx-cache
|
||||||
|
mv /tmp/.buildx-cache-new /tmp/.buildx-cache
|
||||||
|
|
||||||
|
# ===========================================================================
|
||||||
|
# DEPLOY JOB - Deploy to production (only on main branch)
|
||||||
|
# ===========================================================================
|
||||||
|
deploy:
|
||||||
|
name: Deploy
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
needs: build
|
||||||
|
if: github.ref == 'refs/heads/main' && github.event_name == 'push'
|
||||||
|
environment: production
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Deploy notification
|
||||||
|
run: |
|
||||||
|
echo "Deployment to production would happen here"
|
||||||
|
echo "Branch: ${{ github.ref }}"
|
||||||
|
echo "Commit: ${{ github.sha }}"
|
||||||
|
echo "Actor: ${{ github.actor }}"
|
||||||
|
|
||||||
|
# TODO: Add actual deployment steps when infrastructure is ready
|
||||||
|
# Options:
|
||||||
|
# - SSH to production server and run docker-compose pull && docker-compose up -d
|
||||||
|
# - Use Kubernetes deployment
|
||||||
|
# - Use cloud provider deployment (AWS ECS, GCP Cloud Run, etc.)
|
||||||
|
# - Trigger webhook to deployment orchestrator
|
||||||
|
|
||||||
|
# ===========================================================================
|
||||||
|
# SECURITY SCAN JOB - Run on main and dev branches
|
||||||
|
# ===========================================================================
|
||||||
|
security:
|
||||||
|
name: Security Scan
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
needs: lint
|
||||||
|
if: github.ref == 'refs/heads/main' || github.ref == 'refs/heads/dev'
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: ${{ env.PYTHON_VERSION }}
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@v4
|
||||||
|
with:
|
||||||
|
version: ${{ env.UV_VERSION }}
|
||||||
|
|
||||||
|
- name: Install backend dependencies
|
||||||
|
working-directory: backend
|
||||||
|
run: uv sync --extra dev --frozen
|
||||||
|
|
||||||
|
- name: Run Bandit security scan (via ruff)
|
||||||
|
working-directory: backend
|
||||||
|
run: |
|
||||||
|
# Ruff includes flake8-bandit (S rules) for security scanning
|
||||||
|
# Run with explicit security rules only
|
||||||
|
uv run ruff check app --select=S --ignore=S101,S104,S105,S106,S603,S607
|
||||||
|
|
||||||
|
- name: Run pip-audit for dependency vulnerabilities
|
||||||
|
working-directory: backend
|
||||||
|
run: |
|
||||||
|
# pip-audit checks for known vulnerabilities in Python dependencies
|
||||||
|
uv run pip-audit --require-hashes --disable-pip -r <(uv pip compile pyproject.toml) || true
|
||||||
|
# Note: Using || true temporarily while setting up proper remediation
|
||||||
|
|
||||||
|
- name: Check for secrets in code
|
||||||
|
run: |
|
||||||
|
# Basic check for common secret patterns
|
||||||
|
# In production, use tools like gitleaks or trufflehog
|
||||||
|
echo "Checking for potential hardcoded secrets..."
|
||||||
|
! grep -rn --include="*.py" --include="*.ts" --include="*.tsx" --include="*.js" \
|
||||||
|
-E "(api_key|apikey|secret_key|secretkey|password|passwd|token)\s*=\s*['\"][^'\"]{8,}['\"]" \
|
||||||
|
backend/app frontend/src || echo "No obvious secrets found"
|
||||||
|
|
||||||
|
- name: Set up Node.js
|
||||||
|
uses: actions/setup-node@v4
|
||||||
|
with:
|
||||||
|
node-version: ${{ env.NODE_VERSION }}
|
||||||
|
|
||||||
|
- name: Install frontend dependencies
|
||||||
|
working-directory: frontend
|
||||||
|
run: npm ci
|
||||||
|
|
||||||
|
- name: Run npm audit
|
||||||
|
working-directory: frontend
|
||||||
|
run: |
|
||||||
|
npm audit --audit-level=high || true
|
||||||
|
# Note: Using || true to not fail on moderate vulnerabilities
|
||||||
|
# In production, consider stricter settings
|
||||||
|
|
||||||
|
# ===========================================================================
|
||||||
|
# E2E TEST JOB - Run end-to-end tests with Playwright
|
||||||
|
# ===========================================================================
|
||||||
|
e2e-tests:
|
||||||
|
name: E2E Tests
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
needs: [lint, test]
|
||||||
|
if: github.ref == 'refs/heads/main' || github.ref == 'refs/heads/dev' || github.event_name == 'pull_request'
|
||||||
|
services:
|
||||||
|
postgres:
|
||||||
|
image: pgvector/pgvector:pg17
|
||||||
|
env:
|
||||||
|
POSTGRES_USER: postgres
|
||||||
|
POSTGRES_PASSWORD: postgres
|
||||||
|
POSTGRES_DB: syndarix_test
|
||||||
|
ports:
|
||||||
|
- 5432:5432
|
||||||
|
options: >-
|
||||||
|
--health-cmd "pg_isready -U postgres"
|
||||||
|
--health-interval 10s
|
||||||
|
--health-timeout 5s
|
||||||
|
--health-retries 5
|
||||||
|
redis:
|
||||||
|
image: redis:7-alpine
|
||||||
|
ports:
|
||||||
|
- 6379:6379
|
||||||
|
options: >-
|
||||||
|
--health-cmd "redis-cli ping"
|
||||||
|
--health-interval 10s
|
||||||
|
--health-timeout 5s
|
||||||
|
--health-retries 5
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: ${{ env.PYTHON_VERSION }}
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@v4
|
||||||
|
with:
|
||||||
|
version: ${{ env.UV_VERSION }}
|
||||||
|
|
||||||
|
- name: Install backend dependencies
|
||||||
|
working-directory: backend
|
||||||
|
run: uv sync --extra dev --frozen
|
||||||
|
|
||||||
|
- name: Set up Node.js
|
||||||
|
uses: actions/setup-node@v4
|
||||||
|
with:
|
||||||
|
node-version: ${{ env.NODE_VERSION }}
|
||||||
|
|
||||||
|
- name: Install frontend dependencies
|
||||||
|
working-directory: frontend
|
||||||
|
run: npm ci
|
||||||
|
|
||||||
|
- name: Install Playwright browsers
|
||||||
|
working-directory: frontend
|
||||||
|
run: npx playwright install --with-deps chromium
|
||||||
|
|
||||||
|
- name: Start backend server
|
||||||
|
working-directory: backend
|
||||||
|
env:
|
||||||
|
DATABASE_URL: postgresql://postgres:postgres@localhost:5432/syndarix_test
|
||||||
|
REDIS_URL: redis://localhost:6379/0
|
||||||
|
SECRET_KEY: test-secret-key-for-e2e-tests-only
|
||||||
|
ENVIRONMENT: test
|
||||||
|
IS_TEST: "True"
|
||||||
|
run: |
|
||||||
|
# Run migrations
|
||||||
|
uv run python -c "from app.database import create_tables; import asyncio; asyncio.run(create_tables())" || true
|
||||||
|
# Start backend in background
|
||||||
|
uv run uvicorn app.main:app --host 0.0.0.0 --port 8000 &
|
||||||
|
# Wait for backend to be ready
|
||||||
|
sleep 10
|
||||||
|
|
||||||
|
- name: Run Playwright E2E tests
|
||||||
|
working-directory: frontend
|
||||||
|
env:
|
||||||
|
NEXT_PUBLIC_API_URL: http://localhost:8000
|
||||||
|
run: |
|
||||||
|
npm run build
|
||||||
|
npm run test:e2e -- --project=chromium
|
||||||
|
|
||||||
|
- name: Upload Playwright report
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
if: always()
|
||||||
|
with:
|
||||||
|
name: playwright-report
|
||||||
|
path: frontend/playwright-report/
|
||||||
|
retention-days: 7
|
||||||
316
CLAUDE.md
316
CLAUDE.md
@@ -1,243 +1,173 @@
|
|||||||
# CLAUDE.md
|
# CLAUDE.md
|
||||||
|
|
||||||
Claude Code context for FastAPI + Next.js Full-Stack Template.
|
Claude Code context for **Syndarix** - AI-Powered Software Consulting Agency.
|
||||||
|
|
||||||
**See [AGENTS.md](./AGENTS.md) for project context, architecture, and development commands.**
|
**Built on PragmaStack.** See [AGENTS.md](./AGENTS.md) for base template context.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Syndarix Project Context
|
||||||
|
|
||||||
|
### Vision
|
||||||
|
|
||||||
|
Syndarix is an autonomous platform that orchestrates specialized AI agents to deliver complete software solutions with minimal human intervention. It acts as a virtual consulting agency with AI agents playing roles like Product Owner, Architect, Engineers, QA, etc.
|
||||||
|
|
||||||
|
### Repository
|
||||||
|
|
||||||
|
- **URL:** https://gitea.pragmazest.com/cardosofelipe/syndarix
|
||||||
|
- **Issue Tracker:** Gitea Issues (primary)
|
||||||
|
- **CI/CD:** Gitea Actions
|
||||||
|
|
||||||
|
### Core Concepts
|
||||||
|
|
||||||
|
**Agent Types & Instances:**
|
||||||
|
- Agent Type = Template (base model, failover, expertise, personality)
|
||||||
|
- Agent Instance = Spawned from type, assigned to project
|
||||||
|
- Multiple instances of same type can work together
|
||||||
|
|
||||||
|
**Project Workflow:**
|
||||||
|
1. Requirements discovery with Product Owner agent
|
||||||
|
2. Architecture spike (PO + BA + Architect brainstorm)
|
||||||
|
3. Implementation planning and backlog creation
|
||||||
|
4. Autonomous sprint execution with checkpoints
|
||||||
|
5. Demo and client feedback
|
||||||
|
|
||||||
|
**Autonomy Levels:**
|
||||||
|
- `FULL_CONTROL`: Approve every action
|
||||||
|
- `MILESTONE`: Approve sprint boundaries
|
||||||
|
- `AUTONOMOUS`: Only major decisions
|
||||||
|
|
||||||
|
**MCP-First Architecture:**
|
||||||
|
All integrations via Model Context Protocol servers with explicit scoping:
|
||||||
|
```python
|
||||||
|
# All tools take project_id for scoping
|
||||||
|
search_knowledge(project_id="proj-123", query="auth flow")
|
||||||
|
create_issue(project_id="proj-123", title="Add login")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Directory Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
docs/
|
||||||
|
├── development/ # Workflow and coding standards
|
||||||
|
├── requirements/ # Requirements documents
|
||||||
|
├── architecture/ # Architecture documentation
|
||||||
|
├── adrs/ # Architecture Decision Records
|
||||||
|
└── spikes/ # Spike research documents
|
||||||
|
```
|
||||||
|
|
||||||
|
### Current Phase
|
||||||
|
|
||||||
|
**Backlog Population** - Creating detailed issues for Phase 0-1 implementation.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Development Standards
|
||||||
|
|
||||||
|
**CRITICAL: These rules are mandatory. See linked docs for full details.**
|
||||||
|
|
||||||
|
### Quick Reference
|
||||||
|
|
||||||
|
| Topic | Documentation |
|
||||||
|
|-------|---------------|
|
||||||
|
| **Workflow & Branching** | [docs/development/WORKFLOW.md](./docs/development/WORKFLOW.md) |
|
||||||
|
| **Coding Standards** | [docs/development/CODING_STANDARDS.md](./docs/development/CODING_STANDARDS.md) |
|
||||||
|
| **Design System** | [frontend/docs/design-system/](./frontend/docs/design-system/) |
|
||||||
|
| **Backend E2E Testing** | [backend/docs/E2E_TESTING.md](./backend/docs/E2E_TESTING.md) |
|
||||||
|
| **Demo Mode** | [frontend/docs/DEMO_MODE.md](./frontend/docs/DEMO_MODE.md) |
|
||||||
|
|
||||||
|
### Essential Rules Summary
|
||||||
|
|
||||||
|
1. **Issue-Driven Development**: Every piece of work MUST have an issue first
|
||||||
|
2. **Branch per Feature**: `feature/<issue-number>-<description>`, single branch for design+implementation
|
||||||
|
3. **Testing Required**: All code must be tested, aim for >90% coverage
|
||||||
|
4. **Code Review**: Must pass multi-agent review before merge
|
||||||
|
5. **No Direct Commits**: Never commit directly to `main` or `dev`
|
||||||
|
|
||||||
|
### Common Commands
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Backend
|
||||||
|
IS_TEST=True uv run pytest # Run tests
|
||||||
|
uv run ruff check src/ # Lint
|
||||||
|
uv run mypy src/ # Type check
|
||||||
|
python migrate.py auto "message" # Database migration
|
||||||
|
|
||||||
|
# Frontend
|
||||||
|
npm test # Unit tests
|
||||||
|
npm run lint # Lint
|
||||||
|
npm run type-check # Type check
|
||||||
|
npm run generate:api # Regenerate API client
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## Claude Code-Specific Guidance
|
## Claude Code-Specific Guidance
|
||||||
|
|
||||||
### Critical User Preferences
|
### Critical User Preferences
|
||||||
|
|
||||||
#### File Operations - NEVER Use Heredoc/Cat Append
|
**File Operations:**
|
||||||
**ALWAYS use Read/Write/Edit tools instead of `cat >> file << EOF` commands.**
|
- ALWAYS use Read/Write/Edit tools instead of `cat >> file << EOF`
|
||||||
|
- Never use heredoc - it triggers manual approval dialogs
|
||||||
|
|
||||||
This triggers manual approval dialogs and disrupts workflow.
|
**Work Style:**
|
||||||
|
|
||||||
```bash
|
|
||||||
# WRONG ❌
|
|
||||||
cat >> file.txt << EOF
|
|
||||||
content
|
|
||||||
EOF
|
|
||||||
|
|
||||||
# CORRECT ✅ - Use Read, then Write tools
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Work Style
|
|
||||||
- User prefers autonomous operation without frequent interruptions
|
- User prefers autonomous operation without frequent interruptions
|
||||||
- Ask for batch permissions upfront for long work sessions
|
- Ask for batch permissions upfront for long work sessions
|
||||||
- Work independently, document decisions clearly
|
- Work independently, document decisions clearly
|
||||||
- Only use emojis if the user explicitly requests it
|
- Only use emojis if the user explicitly requests it
|
||||||
|
|
||||||
### When Working with This Stack
|
### Critical Pattern: Auth Store DI
|
||||||
|
|
||||||
**Dependency Management:**
|
|
||||||
- Backend uses **uv** (modern Python package manager), not pip
|
|
||||||
- Always use `uv run` prefix: `IS_TEST=True uv run pytest`
|
|
||||||
- Or use Makefile commands: `make test`, `make install-dev`
|
|
||||||
- Add dependencies: `uv add <package>` or `uv add --dev <package>`
|
|
||||||
|
|
||||||
**Database Migrations:**
|
|
||||||
- Use the `migrate.py` helper script, not Alembic directly
|
|
||||||
- Generate + apply: `python migrate.py auto "message"`
|
|
||||||
- Never commit migrations without testing them first
|
|
||||||
- Check current state: `python migrate.py current`
|
|
||||||
|
|
||||||
**Frontend API Client Generation:**
|
|
||||||
- Run `npm run generate:api` after backend schema changes
|
|
||||||
- Client is auto-generated from OpenAPI spec
|
|
||||||
- Located in `frontend/src/lib/api/generated/`
|
|
||||||
- NEVER manually edit generated files
|
|
||||||
|
|
||||||
**Testing Commands:**
|
|
||||||
- Backend unit/integration: `IS_TEST=True uv run pytest` (always prefix with `IS_TEST=True`)
|
|
||||||
- Backend E2E (requires Docker): `make test-e2e`
|
|
||||||
- Frontend unit: `npm test`
|
|
||||||
- Frontend E2E: `npm run test:e2e`
|
|
||||||
- Use `make test` or `make test-cov` in backend for convenience
|
|
||||||
|
|
||||||
**Backend E2E Testing (requires Docker):**
|
|
||||||
- Install deps: `make install-e2e`
|
|
||||||
- Run all E2E tests: `make test-e2e`
|
|
||||||
- Run schema tests only: `make test-e2e-schema`
|
|
||||||
- Run all tests: `make test-all` (unit + E2E)
|
|
||||||
- Uses Testcontainers (real PostgreSQL) + Schemathesis (OpenAPI contract testing)
|
|
||||||
- Markers: `@pytest.mark.e2e`, `@pytest.mark.postgres`, `@pytest.mark.schemathesis`
|
|
||||||
- See: `backend/docs/E2E_TESTING.md` for complete guide
|
|
||||||
|
|
||||||
### 🔴 CRITICAL: Auth Store Dependency Injection Pattern
|
|
||||||
|
|
||||||
**ALWAYS use `useAuth()` from `AuthContext`, NEVER import `useAuthStore` directly!**
|
**ALWAYS use `useAuth()` from `AuthContext`, NEVER import `useAuthStore` directly!**
|
||||||
|
|
||||||
```typescript
|
```typescript
|
||||||
// ❌ WRONG - Bypasses dependency injection
|
// ❌ WRONG
|
||||||
import { useAuthStore } from '@/lib/stores/authStore';
|
import { useAuthStore } from '@/lib/stores/authStore';
|
||||||
const { user, isAuthenticated } = useAuthStore();
|
|
||||||
|
|
||||||
// ✅ CORRECT - Uses dependency injection
|
// ✅ CORRECT
|
||||||
import { useAuth } from '@/lib/auth/AuthContext';
|
import { useAuth } from '@/lib/auth/AuthContext';
|
||||||
const { user, isAuthenticated } = useAuth();
|
|
||||||
```
|
```
|
||||||
|
|
||||||
**Why This Matters:**
|
See [CODING_STANDARDS.md](./docs/development/CODING_STANDARDS.md#auth-store-dependency-injection) for details.
|
||||||
- E2E tests inject mock stores via `window.__TEST_AUTH_STORE__`
|
|
||||||
- Unit tests inject via `<AuthProvider store={mockStore}>`
|
|
||||||
- Direct `useAuthStore` imports bypass this injection → **tests fail**
|
|
||||||
- ESLint will catch violations (added Nov 2025)
|
|
||||||
|
|
||||||
**Exceptions:**
|
|
||||||
1. `AuthContext.tsx` - DI boundary, legitimately needs real store
|
|
||||||
2. `client.ts` - Non-React context, uses dynamic import + `__TEST_AUTH_STORE__` check
|
|
||||||
|
|
||||||
### E2E Test Best Practices
|
|
||||||
|
|
||||||
When writing or fixing Playwright tests:
|
|
||||||
|
|
||||||
**Navigation Pattern:**
|
|
||||||
```typescript
|
|
||||||
// ✅ CORRECT - Use Promise.all for Next.js Link clicks
|
|
||||||
await Promise.all([
|
|
||||||
page.waitForURL('/target', { timeout: 10000 }),
|
|
||||||
link.click()
|
|
||||||
]);
|
|
||||||
```
|
|
||||||
|
|
||||||
**Selectors:**
|
|
||||||
- Use ID-based selectors for validation errors: `#email-error`
|
|
||||||
- Error IDs use dashes not underscores: `#new-password-error`
|
|
||||||
- Target `.border-destructive[role="alert"]` to avoid Next.js route announcer conflicts
|
|
||||||
- Avoid generic `[role="alert"]` which matches multiple elements
|
|
||||||
|
|
||||||
**URL Assertions:**
|
|
||||||
```typescript
|
|
||||||
// ✅ Use regex to handle query params
|
|
||||||
await expect(page).toHaveURL(/\/auth\/login/);
|
|
||||||
|
|
||||||
// ❌ Don't use exact strings (fails with query params)
|
|
||||||
await expect(page).toHaveURL('/auth/login');
|
|
||||||
```
|
|
||||||
|
|
||||||
**Configuration:**
|
|
||||||
- Uses 12 workers in non-CI mode (`playwright.config.ts`)
|
|
||||||
- Reduces to 2 workers in CI for stability
|
|
||||||
- Tests are designed to be non-flaky with proper waits
|
|
||||||
|
|
||||||
### Important Implementation Details
|
|
||||||
|
|
||||||
**Authentication Testing:**
|
|
||||||
- Backend fixtures in `tests/conftest.py`:
|
|
||||||
- `async_test_db`: Fresh SQLite per test
|
|
||||||
- `async_test_user` / `async_test_superuser`: Pre-created users
|
|
||||||
- `user_token` / `superuser_token`: Access tokens for API calls
|
|
||||||
- Always use `@pytest.mark.asyncio` for async tests
|
|
||||||
- Use `@pytest_asyncio.fixture` for async fixtures
|
|
||||||
|
|
||||||
**Database Testing:**
|
|
||||||
```python
|
|
||||||
# Mock database exceptions correctly
|
|
||||||
from unittest.mock import patch, AsyncMock
|
|
||||||
|
|
||||||
async def mock_commit():
|
|
||||||
raise OperationalError("Connection lost", {}, Exception())
|
|
||||||
|
|
||||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
|
||||||
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
|
|
||||||
with pytest.raises(OperationalError):
|
|
||||||
await crud_method(session, obj_in=data)
|
|
||||||
mock_rollback.assert_called_once()
|
|
||||||
```
|
|
||||||
|
|
||||||
**Frontend Component Development:**
|
|
||||||
- Follow design system docs in `frontend/docs/design-system/`
|
|
||||||
- Read `08-ai-guidelines.md` for AI code generation rules
|
|
||||||
- Use parent-controlled spacing (see `04-spacing-philosophy.md`)
|
|
||||||
- WCAG AA compliance required (see `07-accessibility.md`)
|
|
||||||
|
|
||||||
**Security Considerations:**
|
|
||||||
- Backend has comprehensive security tests (JWT attacks, session hijacking)
|
|
||||||
- Never skip security headers in production
|
|
||||||
- Rate limiting is configured in route decorators: `@limiter.limit("10/minute")`
|
|
||||||
- Session revocation is database-backed, not just JWT expiry
|
|
||||||
|
|
||||||
### Common Workflows Guidance
|
|
||||||
|
|
||||||
**When Adding a New Feature:**
|
|
||||||
1. Start with backend schema and CRUD
|
|
||||||
2. Implement API route with proper authorization
|
|
||||||
3. Write backend tests (aim for >90% coverage)
|
|
||||||
4. Generate frontend API client: `npm run generate:api`
|
|
||||||
5. Implement frontend components
|
|
||||||
6. Write frontend unit tests
|
|
||||||
7. Add E2E tests for critical flows
|
|
||||||
8. Update relevant documentation
|
|
||||||
|
|
||||||
**When Fixing Tests:**
|
|
||||||
- Backend: Check test database isolation and async fixture usage
|
|
||||||
- Frontend unit: Verify mocking of `useAuth()` not `useAuthStore`
|
|
||||||
- E2E: Use `Promise.all()` pattern and regex URL assertions
|
|
||||||
|
|
||||||
**When Debugging:**
|
|
||||||
- Backend: Check `IS_TEST=True` environment variable is set
|
|
||||||
- Frontend: Run `npm run type-check` first
|
|
||||||
- E2E: Use `npm run test:e2e:debug` for step-by-step debugging
|
|
||||||
- Check logs: Backend has detailed error logging
|
|
||||||
|
|
||||||
**Demo Mode (Frontend-Only Showcase):**
|
|
||||||
- Enable: `echo "NEXT_PUBLIC_DEMO_MODE=true" > frontend/.env.local`
|
|
||||||
- Uses MSW (Mock Service Worker) to intercept API calls in browser
|
|
||||||
- Zero backend required - perfect for Vercel deployments
|
|
||||||
- **Fully Automated**: MSW handlers auto-generated from OpenAPI spec
|
|
||||||
- Run `npm run generate:api` → updates both API client AND MSW handlers
|
|
||||||
- No manual synchronization needed!
|
|
||||||
- Demo credentials (any password ≥8 chars works):
|
|
||||||
- User: `demo@example.com` / `DemoPass123`
|
|
||||||
- Admin: `admin@example.com` / `AdminPass123`
|
|
||||||
- **Safe**: MSW never runs during tests (Jest or Playwright)
|
|
||||||
- **Coverage**: Mock files excluded from linting and coverage
|
|
||||||
- **Documentation**: `frontend/docs/DEMO_MODE.md` for complete guide
|
|
||||||
|
|
||||||
### Tool Usage Preferences
|
### Tool Usage Preferences
|
||||||
|
|
||||||
**Prefer specialized tools over bash:**
|
**Prefer specialized tools over bash:**
|
||||||
- Use Read/Write/Edit tools for file operations
|
- Use Read/Write/Edit tools for file operations
|
||||||
- Never use `cat`, `echo >`, or heredoc for file manipulation
|
|
||||||
- Use Task tool with `subagent_type=Explore` for codebase exploration
|
- Use Task tool with `subagent_type=Explore` for codebase exploration
|
||||||
- Use Grep tool for code search, not bash `grep`
|
- Use Grep tool for code search, not bash `grep`
|
||||||
|
|
||||||
**When to use parallel tool calls:**
|
**Parallel tool calls for:**
|
||||||
- Independent git commands: `git status`, `git diff`, `git log`
|
- Independent git commands
|
||||||
- Reading multiple unrelated files
|
- Reading multiple unrelated files
|
||||||
- Running multiple test suites simultaneously
|
- Running multiple test suites
|
||||||
- Independent validation steps
|
- Independent validation steps
|
||||||
|
|
||||||
## Custom Skills
|
---
|
||||||
|
|
||||||
No Claude Code Skills installed yet. To create one, invoke the built-in "skill-creator" skill.
|
## Key Extensions (from PragmaStack base)
|
||||||
|
|
||||||
**Potential skill ideas for this project:**
|
- Celery + Redis for agent job queue
|
||||||
- API endpoint generator workflow (schema → CRUD → route → tests → frontend client)
|
- WebSocket/SSE for real-time updates
|
||||||
- Component generator with design system compliance
|
- pgvector for RAG knowledge base
|
||||||
- Database migration troubleshooting helper
|
- MCP server integration layer
|
||||||
- Test coverage analyzer and improvement suggester
|
|
||||||
- E2E test generator for new features
|
---
|
||||||
|
|
||||||
## Additional Resources
|
## Additional Resources
|
||||||
|
|
||||||
**Comprehensive Documentation:**
|
**Documentation:**
|
||||||
- [AGENTS.md](./AGENTS.md) - Framework-agnostic AI assistant context
|
- [AGENTS.md](./AGENTS.md) - Framework-agnostic AI assistant context
|
||||||
- [README.md](./README.md) - User-facing project overview
|
- [README.md](./README.md) - User-facing project overview
|
||||||
- `backend/docs/` - Backend architecture, coding standards, common pitfalls
|
- [docs/development/](./docs/development/) - Development workflow and standards
|
||||||
- `frontend/docs/design-system/` - Complete design system guide
|
- [backend/docs/](./backend/docs/) - Backend architecture and guides
|
||||||
|
- [frontend/docs/design-system/](./frontend/docs/design-system/) - Complete design system
|
||||||
|
|
||||||
**API Documentation (when running):**
|
**API Documentation (when running):**
|
||||||
- Swagger UI: http://localhost:8000/docs
|
- Swagger UI: http://localhost:8000/docs
|
||||||
- ReDoc: http://localhost:8000/redoc
|
- ReDoc: http://localhost:8000/redoc
|
||||||
- OpenAPI JSON: http://localhost:8000/api/v1/openapi.json
|
- OpenAPI JSON: http://localhost:8000/api/v1/openapi.json
|
||||||
|
|
||||||
**Testing Documentation:**
|
|
||||||
- Backend tests: `backend/tests/` (97% coverage)
|
|
||||||
- Frontend E2E: `frontend/e2e/README.md`
|
|
||||||
- Design system: `frontend/docs/design-system/08-ai-guidelines.md`
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
**For project architecture, development commands, and general context, see [AGENTS.md](./AGENTS.md).**
|
**For project architecture, development commands, and general context, see [AGENTS.md](./AGENTS.md).**
|
||||||
|
|||||||
724
README.md
724
README.md
@@ -1,659 +1,175 @@
|
|||||||
# <img src="frontend/public/logo.svg" alt="PragmaStack" width="32" height="32" style="vertical-align: middle" /> PragmaStack
|
# Syndarix
|
||||||
|
|
||||||
> **The Pragmatic Full-Stack Template. Production-ready, security-first, and opinionated.**
|
> **Your AI-Powered Software Consulting Agency**
|
||||||
|
>
|
||||||
|
> An autonomous platform that orchestrates specialized AI agents to deliver complete software solutions with minimal human intervention.
|
||||||
|
|
||||||
[](./backend/tests)
|
[](https://gitea.pragmazest.com/cardosofelipe/fast-next-template)
|
||||||
[](./frontend/tests)
|
|
||||||
[](./frontend/e2e)
|
|
||||||
[](./LICENSE)
|
[](./LICENSE)
|
||||||
[](./CONTRIBUTING.md)
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Why PragmaStack?
|
## Vision
|
||||||
|
|
||||||
Building a modern full-stack application often leads to "analysis paralysis" or "boilerplate fatigue". You spend weeks setting up authentication, testing, and linting before writing a single line of business logic.
|
Syndarix transforms the software development lifecycle by providing a **virtual consulting team** of AI agents that collaboratively plan, design, implement, test, and deliver complete software solutions.
|
||||||
|
|
||||||
**PragmaStack cuts through the noise.**
|
**The Problem:** Even with AI coding assistants, developers spend as much time managing AI as doing the work themselves. Context switching, babysitting, and knowledge fragmentation limit productivity.
|
||||||
|
|
||||||
We provide a **pragmatic**, opinionated foundation that prioritizes:
|
**The Solution:** A structured, autonomous agency where specialized AI agents handle different roles (Product Owner, Architect, Engineers, QA, etc.) with proper workflows, reviews, and quality gates.
|
||||||
- **Speed**: Ship features, not config files.
|
|
||||||
- **Robustness**: Security and testing are not optional.
|
|
||||||
- **Clarity**: Code that is easy to read and maintain.
|
|
||||||
|
|
||||||
Whether you're building a SaaS, an internal tool, or a side project, PragmaStack gives you a solid starting point without the bloat.
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## ✨ Features
|
## Key Features
|
||||||
|
|
||||||
### 🔐 **Authentication & Security**
|
### Multi-Agent Orchestration
|
||||||
- JWT-based authentication with access + refresh tokens
|
- Configurable agent **types** with base model, failover, expertise, and personality
|
||||||
- **OAuth/Social Login** (Google, GitHub) with PKCE support
|
- Spawn multiple **instances** from the same type (e.g., Dave, Ellis, Kate as Software Developers)
|
||||||
- **OAuth 2.0 Authorization Server** (MCP-ready) for third-party integrations
|
- Agent-to-agent communication and collaboration
|
||||||
- Session management with device tracking and revocation
|
- Per-instance customization with domain-specific knowledge
|
||||||
- Password reset flow (email integration ready)
|
|
||||||
- Secure password hashing (bcrypt)
|
|
||||||
- CSRF protection, rate limiting, and security headers
|
|
||||||
- Comprehensive security tests (JWT algorithm attacks, session hijacking, privilege escalation)
|
|
||||||
|
|
||||||
### 🔌 **OAuth Provider Mode (MCP Integration)**
|
### Complete SDLC Support
|
||||||
Full OAuth 2.0 Authorization Server for Model Context Protocol (MCP) and third-party clients:
|
- **Requirements Discovery** → **Architecture Spike** → **Implementation Planning**
|
||||||
- **RFC 7636**: Authorization Code Flow with PKCE (S256 only)
|
- **Sprint Management** with automated ceremonies
|
||||||
- **RFC 8414**: Server metadata discovery at `/.well-known/oauth-authorization-server`
|
- **Issue Tracking** with Epic/Story/Task hierarchy
|
||||||
- **RFC 7662**: Token introspection endpoint
|
- **Git Integration** with proper branch/PR workflows
|
||||||
- **RFC 7009**: Token revocation endpoint
|
- **CI/CD Pipelines** with automated testing
|
||||||
- **JWT access tokens**: Self-contained, configurable lifetime
|
|
||||||
- **Opaque refresh tokens**: Secure rotation, database-backed revocation
|
|
||||||
- **Consent management**: Users can review and revoke app permissions
|
|
||||||
- **Client management**: Admin endpoints for registering OAuth clients
|
|
||||||
- **Scopes**: `openid`, `profile`, `email`, `read:users`, `write:users`, `admin`
|
|
||||||
|
|
||||||
### 👥 **Multi-Tenancy & Organizations**
|
### Configurable Autonomy
|
||||||
- Full organization system with role-based access control (Owner, Admin, Member)
|
- From `FULL_CONTROL` (approve everything) to `AUTONOMOUS` (only major milestones)
|
||||||
- Invite/remove members, manage permissions
|
- Client can intervene at any point
|
||||||
- Organization-scoped data access
|
- Transparent progress visibility
|
||||||
- User can belong to multiple organizations
|
|
||||||
|
|
||||||
### 🛠️ **Admin Panel**
|
### MCP-First Architecture
|
||||||
- Complete user management (CRUD, activate/deactivate, bulk operations)
|
- All integrations via **Model Context Protocol (MCP)** servers
|
||||||
- Organization management (create, edit, delete, member management)
|
- Unified Knowledge Base with project/agent scoping
|
||||||
- Session monitoring across all users
|
- Git providers (Gitea, GitHub, GitLab) via MCP
|
||||||
- Real-time statistics dashboard
|
- Extensible through custom MCP tools
|
||||||
- Admin-only routes with proper authorization
|
|
||||||
|
|
||||||
### 🎨 **Modern Frontend**
|
### Project Complexity Wizard
|
||||||
- Next.js 16 with App Router and React 19
|
- **Script** → Minimal process, no repo needed
|
||||||
- **PragmaStack Design System** built on shadcn/ui + TailwindCSS
|
- **Simple** → Single sprint, basic backlog
|
||||||
- Pre-configured theme with dark mode support (coming soon)
|
- **Medium/Complex** → Full AGILE workflow with multiple sprints
|
||||||
- Responsive, accessible components (WCAG AA compliant)
|
|
||||||
- Rich marketing landing page with animated components
|
|
||||||
- Live component showcase and documentation at `/dev`
|
|
||||||
|
|
||||||
### 🌍 **Internationalization (i18n)**
|
|
||||||
- Built-in multi-language support with next-intl v4
|
|
||||||
- Locale-based routing (`/en/*`, `/it/*`)
|
|
||||||
- Seamless language switching with LocaleSwitcher component
|
|
||||||
- SEO-friendly URLs and metadata per locale
|
|
||||||
- Translation files for English and Italian (easily extensible)
|
|
||||||
- Type-safe translations throughout the app
|
|
||||||
|
|
||||||
### 🎯 **Content & UX Features**
|
|
||||||
- **Toast notifications** with Sonner for elegant user feedback
|
|
||||||
- **Smooth animations** powered by Framer Motion
|
|
||||||
- **Markdown rendering** with syntax highlighting (GitHub Flavored Markdown)
|
|
||||||
- **Charts and visualizations** ready with Recharts
|
|
||||||
- **SEO optimization** with dynamic sitemap and robots.txt generation
|
|
||||||
- **Session tracking UI** with device information and revocation controls
|
|
||||||
|
|
||||||
### 🧪 **Comprehensive Testing**
|
|
||||||
- **Backend Testing**: ~97% unit test coverage
|
|
||||||
- Unit, integration, and security tests
|
|
||||||
- Async database testing with SQLAlchemy
|
|
||||||
- API endpoint testing with fixtures
|
|
||||||
- Security vulnerability tests (JWT attacks, session hijacking, privilege escalation)
|
|
||||||
- **Frontend Unit Tests**: ~97% coverage with Jest
|
|
||||||
- Component testing
|
|
||||||
- Hook testing
|
|
||||||
- Utility function testing
|
|
||||||
- **End-to-End Tests**: Playwright with zero flaky tests
|
|
||||||
- Complete user flows (auth, navigation, settings)
|
|
||||||
- Parallel execution for speed
|
|
||||||
- Visual regression testing ready
|
|
||||||
|
|
||||||
### 📚 **Developer Experience**
|
|
||||||
- Auto-generated TypeScript API client from OpenAPI spec
|
|
||||||
- Interactive API documentation (Swagger + ReDoc)
|
|
||||||
- Database migrations with Alembic helper script
|
|
||||||
- Hot reload in development for both frontend and backend
|
|
||||||
- Comprehensive code documentation and design system docs
|
|
||||||
- Live component playground at `/dev` with code examples
|
|
||||||
- Docker support for easy deployment
|
|
||||||
- VSCode workspace settings included
|
|
||||||
|
|
||||||
### 📊 **Ready for Production**
|
|
||||||
- Docker + docker-compose setup
|
|
||||||
- Environment-based configuration
|
|
||||||
- Database connection pooling
|
|
||||||
- Error handling and logging
|
|
||||||
- Health check endpoints
|
|
||||||
- Production security headers
|
|
||||||
- Rate limiting on sensitive endpoints
|
|
||||||
- SEO optimization with dynamic sitemaps and robots.txt
|
|
||||||
- Multi-language SEO with locale-specific metadata
|
|
||||||
- Performance monitoring and bundle analysis
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 📸 Screenshots
|
## Technology Stack
|
||||||
|
|
||||||
<details>
|
Built on [PragmaStack](https://gitea.pragmazest.com/cardosofelipe/fast-next-template):
|
||||||
<summary>Click to view screenshots</summary>
|
|
||||||
|
|
||||||
### Landing Page
|
| Component | Technology |
|
||||||

|
|-----------|------------|
|
||||||
|
| Backend | FastAPI 0.115+ (Python 3.11+) |
|
||||||
|
| Frontend | Next.js 16 (React 19) |
|
||||||
|
| Database | PostgreSQL 15+ with pgvector |
|
||||||
|
| ORM | SQLAlchemy 2.0 |
|
||||||
|
| State Management | Zustand + TanStack Query |
|
||||||
|
| UI | shadcn/ui + Tailwind 4 |
|
||||||
|
| Auth | JWT dual-token + OAuth 2.0 |
|
||||||
|
| Testing | pytest + Jest + Playwright |
|
||||||
|
|
||||||
|
### Syndarix Extensions
|
||||||
|
| Component | Technology |
|
||||||
### Authentication
|
|-----------|------------|
|
||||||

|
| Task Queue | Celery + Redis |
|
||||||
|
| Real-time | FastAPI WebSocket / SSE |
|
||||||
|
| Vector DB | pgvector (PostgreSQL extension) |
|
||||||
|
| MCP SDK | Anthropic MCP SDK |
|
||||||
### Admin Dashboard
|
|
||||||

|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
### Design System
|
|
||||||

|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 🎭 Demo Mode
|
## Project Status
|
||||||
|
|
||||||
**Try the frontend without a backend!** Perfect for:
|
**Phase:** Architecture & Planning
|
||||||
- **Free deployment** on Vercel (no backend costs)
|
|
||||||
- **Portfolio showcasing** with live demos
|
See [docs/requirements/](./docs/requirements/) for the comprehensive requirements document.
|
||||||
- **Client presentations** without infrastructure setup
|
|
||||||
|
### Current Milestones
|
||||||
|
- [x] Fork PragmaStack as foundation
|
||||||
|
- [x] Create requirements document
|
||||||
|
- [ ] Execute architecture spikes
|
||||||
|
- [ ] Create ADRs for key decisions
|
||||||
|
- [ ] Begin MVP implementation
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Documentation
|
||||||
|
|
||||||
|
- [Requirements Document](./docs/requirements/SYNDARIX_REQUIREMENTS.md)
|
||||||
|
- [Architecture Decisions](./docs/adrs/) (coming soon)
|
||||||
|
- [Spike Research](./docs/spikes/) (coming soon)
|
||||||
|
- [Architecture Overview](./docs/architecture/) (coming soon)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Getting Started
|
||||||
|
|
||||||
|
### Prerequisites
|
||||||
|
- Docker & Docker Compose
|
||||||
|
- Node.js 20+
|
||||||
|
- Python 3.11+
|
||||||
|
- PostgreSQL 15+ (or use Docker)
|
||||||
|
|
||||||
### Quick Start
|
### Quick Start
|
||||||
|
|
||||||
```bash
|
|
||||||
cd frontend
|
|
||||||
echo "NEXT_PUBLIC_DEMO_MODE=true" > .env.local
|
|
||||||
npm run dev
|
|
||||||
```
|
|
||||||
|
|
||||||
**Demo Credentials:**
|
|
||||||
- Regular user: `demo@example.com` / `DemoPass123`
|
|
||||||
- Admin user: `admin@example.com` / `AdminPass123`
|
|
||||||
|
|
||||||
Demo mode uses [Mock Service Worker (MSW)](https://mswjs.io/) to intercept API calls in the browser. Your code remains unchanged - the same components work with both real and mocked backends.
|
|
||||||
|
|
||||||
**Key Features:**
|
|
||||||
- ✅ Zero backend required
|
|
||||||
- ✅ All features functional (auth, admin, stats)
|
|
||||||
- ✅ Realistic network delays and errors
|
|
||||||
- ✅ Does NOT interfere with tests (97%+ coverage maintained)
|
|
||||||
- ✅ One-line toggle: `NEXT_PUBLIC_DEMO_MODE=true`
|
|
||||||
|
|
||||||
📖 **[Complete Demo Mode Documentation](./frontend/docs/DEMO_MODE.md)**
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 🚀 Tech Stack
|
|
||||||
|
|
||||||
### Backend
|
|
||||||
- **[FastAPI](https://fastapi.tiangolo.com/)** - Modern async Python web framework
|
|
||||||
- **[SQLAlchemy 2.0](https://www.sqlalchemy.org/)** - Powerful ORM with async support
|
|
||||||
- **[PostgreSQL](https://www.postgresql.org/)** - Robust relational database
|
|
||||||
- **[Alembic](https://alembic.sqlalchemy.org/)** - Database migrations
|
|
||||||
- **[Pydantic v2](https://docs.pydantic.dev/)** - Data validation with type hints
|
|
||||||
- **[pytest](https://pytest.org/)** - Testing framework with async support
|
|
||||||
|
|
||||||
### Frontend
|
|
||||||
- **[Next.js 16](https://nextjs.org/)** - React framework with App Router
|
|
||||||
- **[React 19](https://react.dev/)** - UI library
|
|
||||||
- **[TypeScript](https://www.typescriptlang.org/)** - Type-safe JavaScript
|
|
||||||
- **[TailwindCSS](https://tailwindcss.com/)** - Utility-first CSS framework
|
|
||||||
- **[shadcn/ui](https://ui.shadcn.com/)** - Beautiful, accessible component library
|
|
||||||
- **[next-intl](https://next-intl.dev/)** - Internationalization (i18n) with type safety
|
|
||||||
- **[TanStack Query](https://tanstack.com/query)** - Powerful data fetching/caching
|
|
||||||
- **[Zustand](https://zustand-demo.pmnd.rs/)** - Lightweight state management
|
|
||||||
- **[Framer Motion](https://www.framer.com/motion/)** - Production-ready animation library
|
|
||||||
- **[Sonner](https://sonner.emilkowal.ski/)** - Beautiful toast notifications
|
|
||||||
- **[Recharts](https://recharts.org/)** - Composable charting library
|
|
||||||
- **[React Markdown](https://github.com/remarkjs/react-markdown)** - Markdown rendering with GFM support
|
|
||||||
- **[Playwright](https://playwright.dev/)** - End-to-end testing
|
|
||||||
|
|
||||||
### DevOps
|
|
||||||
- **[Docker](https://www.docker.com/)** - Containerization
|
|
||||||
- **[docker-compose](https://docs.docker.com/compose/)** - Multi-container orchestration
|
|
||||||
- **GitHub Actions** (coming soon) - CI/CD pipelines
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 📋 Prerequisites
|
|
||||||
|
|
||||||
- **Docker & Docker Compose** (recommended) - [Install Docker](https://docs.docker.com/get-docker/)
|
|
||||||
- **OR manually:**
|
|
||||||
- Python 3.12+
|
|
||||||
- Node.js 18+ (Node 20+ recommended)
|
|
||||||
- PostgreSQL 15+
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 🏃 Quick Start (Docker)
|
|
||||||
|
|
||||||
The fastest way to get started is with Docker:
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Clone the repository
|
# Clone the repository
|
||||||
git clone https://github.com/cardosofelipe/pragma-stack.git
|
git clone https://gitea.pragmazest.com/cardosofelipe/syndarix.git
|
||||||
cd fast-next-template
|
cd syndarix
|
||||||
|
|
||||||
# Copy environment file
|
# Copy environment template
|
||||||
cp .env.template .env
|
cp .env.template .env
|
||||||
|
|
||||||
# Start all services (backend, frontend, database)
|
# Start development environment
|
||||||
docker-compose up
|
docker-compose -f docker-compose.dev.yml up -d
|
||||||
|
|
||||||
# In another terminal, run database migrations
|
# Run database migrations
|
||||||
docker-compose exec backend alembic upgrade head
|
make migrate
|
||||||
|
|
||||||
# Create first superuser (optional)
|
# Start the development servers
|
||||||
docker-compose exec backend python -c "from app.init_db import init_db; import asyncio; asyncio.run(init_db())"
|
make dev
|
||||||
```
|
|
||||||
|
|
||||||
**That's it! 🎉**
|
|
||||||
|
|
||||||
- Frontend: http://localhost:3000
|
|
||||||
- Backend API: http://localhost:8000
|
|
||||||
- API Docs: http://localhost:8000/docs
|
|
||||||
|
|
||||||
Default superuser credentials:
|
|
||||||
- Email: `admin@example.com`
|
|
||||||
- Password: `admin123`
|
|
||||||
|
|
||||||
**⚠️ Change these immediately in production!**
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 🛠️ Manual Setup (Development)
|
|
||||||
|
|
||||||
### Backend Setup
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cd backend
|
|
||||||
|
|
||||||
# Create virtual environment
|
|
||||||
python -m venv .venv
|
|
||||||
source .venv/bin/activate # On Windows: .venv\Scripts\activate
|
|
||||||
|
|
||||||
# Install dependencies
|
|
||||||
pip install -r requirements.txt
|
|
||||||
|
|
||||||
# Setup environment
|
|
||||||
cp .env.example .env
|
|
||||||
# Edit .env with your database credentials
|
|
||||||
|
|
||||||
# Run migrations
|
|
||||||
alembic upgrade head
|
|
||||||
|
|
||||||
# Initialize database with first superuser
|
|
||||||
python -c "from app.init_db import init_db; import asyncio; asyncio.run(init_db())"
|
|
||||||
|
|
||||||
# Start development server
|
|
||||||
uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
|
|
||||||
```
|
|
||||||
|
|
||||||
### Frontend Setup
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cd frontend
|
|
||||||
|
|
||||||
# Install dependencies
|
|
||||||
npm install
|
|
||||||
|
|
||||||
# Setup environment
|
|
||||||
cp .env.local.example .env.local
|
|
||||||
# Edit .env.local with your backend URL
|
|
||||||
|
|
||||||
# Generate API client
|
|
||||||
npm run generate:api
|
|
||||||
|
|
||||||
# Start development server
|
|
||||||
npm run dev
|
|
||||||
```
|
|
||||||
|
|
||||||
Visit http://localhost:3000 to see your app!
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 📂 Project Structure
|
|
||||||
|
|
||||||
```
|
|
||||||
├── backend/ # FastAPI backend
|
|
||||||
│ ├── app/
|
|
||||||
│ │ ├── api/ # API routes and dependencies
|
|
||||||
│ │ ├── core/ # Core functionality (auth, config, database)
|
|
||||||
│ │ ├── crud/ # Database operations
|
|
||||||
│ │ ├── models/ # SQLAlchemy models
|
|
||||||
│ │ ├── schemas/ # Pydantic schemas
|
|
||||||
│ │ ├── services/ # Business logic
|
|
||||||
│ │ └── utils/ # Utilities
|
|
||||||
│ ├── tests/ # Backend tests (97% coverage)
|
|
||||||
│ ├── alembic/ # Database migrations
|
|
||||||
│ └── docs/ # Backend documentation
|
|
||||||
│
|
|
||||||
├── frontend/ # Next.js frontend
|
|
||||||
│ ├── src/
|
|
||||||
│ │ ├── app/ # Next.js App Router pages
|
|
||||||
│ │ ├── components/ # React components
|
|
||||||
│ │ ├── lib/ # Libraries and utilities
|
|
||||||
│ │ │ ├── api/ # API client (auto-generated)
|
|
||||||
│ │ │ └── stores/ # Zustand stores
|
|
||||||
│ │ └── hooks/ # Custom React hooks
|
|
||||||
│ ├── e2e/ # Playwright E2E tests
|
|
||||||
│ ├── tests/ # Unit tests (Jest)
|
|
||||||
│ └── docs/ # Frontend documentation
|
|
||||||
│ └── design-system/ # Comprehensive design system docs
|
|
||||||
│
|
|
||||||
├── docker-compose.yml # Docker orchestration
|
|
||||||
├── docker-compose.dev.yml # Development with hot reload
|
|
||||||
└── README.md # You are here!
|
|
||||||
```
|
```
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 🧪 Testing
|
## Architecture Overview
|
||||||
|
|
||||||
This template takes testing seriously with comprehensive coverage across all layers:
|
|
||||||
|
|
||||||
### Backend Unit & Integration Tests
|
|
||||||
|
|
||||||
**High coverage (~97%)** across all critical paths including security-focused tests.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cd backend
|
|
||||||
|
|
||||||
# Run all tests
|
|
||||||
IS_TEST=True pytest
|
|
||||||
|
|
||||||
# Run with coverage report
|
|
||||||
IS_TEST=True pytest --cov=app --cov-report=term-missing
|
|
||||||
|
|
||||||
# Run specific test file
|
|
||||||
IS_TEST=True pytest tests/api/test_auth.py -v
|
|
||||||
|
|
||||||
# Generate HTML coverage report
|
|
||||||
IS_TEST=True pytest --cov=app --cov-report=html
|
|
||||||
open htmlcov/index.html
|
|
||||||
```
|
```
|
||||||
|
+====================================================================+
|
||||||
**Test types:**
|
| SYNDARIX CORE |
|
||||||
- **Unit tests**: CRUD operations, utilities, business logic
|
+====================================================================+
|
||||||
- **Integration tests**: API endpoints with database
|
| +------------------+ +------------------+ +------------------+ |
|
||||||
- **Security tests**: JWT algorithm attacks, session hijacking, privilege escalation
|
| | Agent Orchestrator| | Project Manager | | Workflow Engine | |
|
||||||
- **Error handling tests**: Database failures, validation errors
|
| +------------------+ +------------------+ +------------------+ |
|
||||||
|
+====================================================================+
|
||||||
### Frontend Unit Tests
|
|
|
||||||
|
v
|
||||||
**High coverage (~97%)** with Jest and React Testing Library.
|
+====================================================================+
|
||||||
|
| MCP ORCHESTRATION LAYER |
|
||||||
```bash
|
| All integrations via unified MCP servers with project scoping |
|
||||||
cd frontend
|
+====================================================================+
|
||||||
|
|
|
||||||
# Run unit tests
|
+------------------------+------------------------+
|
||||||
npm test
|
| | |
|
||||||
|
+----v----+ +----v----+ +----v----+ +----v----+ +----v----+
|
||||||
# Run with coverage
|
| LLM | | Git | |Knowledge| | File | | Code |
|
||||||
npm run test:coverage
|
| Providers| | MCP | |Base MCP | |Sys. MCP | |Analysis |
|
||||||
|
+---------+ +---------+ +---------+ +---------+ +---------+
|
||||||
# Watch mode
|
|
||||||
npm run test:watch
|
|
||||||
```
|
|
||||||
|
|
||||||
**Test types:**
|
|
||||||
- Component rendering and interactions
|
|
||||||
- Custom hooks behavior
|
|
||||||
- State management
|
|
||||||
- Utility functions
|
|
||||||
- API integration mocks
|
|
||||||
|
|
||||||
### End-to-End Tests
|
|
||||||
|
|
||||||
**Zero flaky tests** with Playwright covering complete user journeys.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cd frontend
|
|
||||||
|
|
||||||
# Run E2E tests
|
|
||||||
npm run test:e2e
|
|
||||||
|
|
||||||
# Run E2E tests in UI mode (recommended for development)
|
|
||||||
npm run test:e2e:ui
|
|
||||||
|
|
||||||
# Run specific test file
|
|
||||||
npx playwright test auth-login.spec.ts
|
|
||||||
|
|
||||||
# Generate test report
|
|
||||||
npx playwright show-report
|
|
||||||
```
|
|
||||||
|
|
||||||
**Test coverage:**
|
|
||||||
- Complete authentication flows
|
|
||||||
- Navigation and routing
|
|
||||||
- Form submissions and validation
|
|
||||||
- Settings and profile management
|
|
||||||
- Session management
|
|
||||||
- Admin panel workflows (in progress)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 🤖 AI-Friendly Documentation
|
|
||||||
|
|
||||||
This project includes comprehensive documentation designed for AI coding assistants:
|
|
||||||
|
|
||||||
- **[AGENTS.md](./AGENTS.md)** - Framework-agnostic AI assistant context for PragmaStack
|
|
||||||
- **[CLAUDE.md](./CLAUDE.md)** - Claude Code-specific guidance
|
|
||||||
|
|
||||||
These files provide AI assistants with the **PragmaStack** architecture, patterns, and best practices.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 🗄️ Database Migrations
|
|
||||||
|
|
||||||
The template uses Alembic for database migrations:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cd backend
|
|
||||||
|
|
||||||
# Generate migration from model changes
|
|
||||||
python migrate.py generate "description of changes"
|
|
||||||
|
|
||||||
# Apply migrations
|
|
||||||
python migrate.py apply
|
|
||||||
|
|
||||||
# Or do both in one command
|
|
||||||
python migrate.py auto "description"
|
|
||||||
|
|
||||||
# View migration history
|
|
||||||
python migrate.py list
|
|
||||||
|
|
||||||
# Check current revision
|
|
||||||
python migrate.py current
|
|
||||||
```
|
```
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 📖 Documentation
|
## Contributing
|
||||||
|
|
||||||
### AI Assistant Documentation
|
See [CONTRIBUTING.md](./CONTRIBUTING.md) for guidelines.
|
||||||
|
|
||||||
- **[AGENTS.md](./AGENTS.md)** - Framework-agnostic AI coding assistant context
|
|
||||||
- **[CLAUDE.md](./CLAUDE.md)** - Claude Code-specific guidance and preferences
|
|
||||||
|
|
||||||
### Backend Documentation
|
|
||||||
|
|
||||||
- **[ARCHITECTURE.md](./backend/docs/ARCHITECTURE.md)** - System architecture and design patterns
|
|
||||||
- **[CODING_STANDARDS.md](./backend/docs/CODING_STANDARDS.md)** - Code quality standards
|
|
||||||
- **[COMMON_PITFALLS.md](./backend/docs/COMMON_PITFALLS.md)** - Common mistakes to avoid
|
|
||||||
- **[FEATURE_EXAMPLE.md](./backend/docs/FEATURE_EXAMPLE.md)** - Step-by-step feature guide
|
|
||||||
|
|
||||||
### Frontend Documentation
|
|
||||||
|
|
||||||
- **[PragmaStack Design System](./frontend/docs/design-system/)** - Complete design system guide
|
|
||||||
- Quick start, foundations (colors, typography, spacing)
|
|
||||||
- Component library guide
|
|
||||||
- Layout patterns, spacing philosophy
|
|
||||||
- Forms, accessibility, AI guidelines
|
|
||||||
- **[E2E Testing Guide](./frontend/e2e/README.md)** - E2E testing setup and best practices
|
|
||||||
|
|
||||||
### API Documentation
|
|
||||||
|
|
||||||
When the backend is running:
|
|
||||||
- **Swagger UI**: http://localhost:8000/docs
|
|
||||||
- **ReDoc**: http://localhost:8000/redoc
|
|
||||||
- **OpenAPI JSON**: http://localhost:8000/api/v1/openapi.json
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 🚢 Deployment
|
## License
|
||||||
|
|
||||||
### Docker Production Deployment
|
MIT License - see [LICENSE](./LICENSE) for details.
|
||||||
|
|
||||||
```bash
|
|
||||||
# Build and start all services
|
|
||||||
docker-compose up -d
|
|
||||||
|
|
||||||
# Run migrations
|
|
||||||
docker-compose exec backend alembic upgrade head
|
|
||||||
|
|
||||||
# View logs
|
|
||||||
docker-compose logs -f
|
|
||||||
|
|
||||||
# Stop services
|
|
||||||
docker-compose down
|
|
||||||
```
|
|
||||||
|
|
||||||
### Production Checklist
|
|
||||||
|
|
||||||
- [ ] Change default superuser credentials
|
|
||||||
- [ ] Set strong `SECRET_KEY` in backend `.env`
|
|
||||||
- [ ] Configure production database (PostgreSQL)
|
|
||||||
- [ ] Set `ENVIRONMENT=production` in backend
|
|
||||||
- [ ] Configure CORS origins for your domain
|
|
||||||
- [ ] Setup SSL/TLS certificates
|
|
||||||
- [ ] Configure email service for password resets
|
|
||||||
- [ ] Setup monitoring and logging
|
|
||||||
- [ ] Configure backup strategy
|
|
||||||
- [ ] Review and adjust rate limits
|
|
||||||
- [ ] Test security headers
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 🛣️ Roadmap & Status
|
## Acknowledgments
|
||||||
|
|
||||||
### ✅ Completed
|
- Built on [PragmaStack](https://gitea.pragmazest.com/cardosofelipe/fast-next-template)
|
||||||
- [x] Authentication system (JWT, refresh tokens, session management, OAuth)
|
- Powered by Claude and the Anthropic API
|
||||||
- [x] User management (CRUD, profile, password change)
|
|
||||||
- [x] Organization system with RBAC (Owner, Admin, Member)
|
|
||||||
- [x] Admin panel (users, organizations, sessions, statistics)
|
|
||||||
- [x] **Internationalization (i18n)** with next-intl (English + Italian)
|
|
||||||
- [x] Backend testing infrastructure (~97% coverage)
|
|
||||||
- [x] Frontend unit testing infrastructure (~97% coverage)
|
|
||||||
- [x] Frontend E2E testing (Playwright, zero flaky tests)
|
|
||||||
- [x] Design system documentation
|
|
||||||
- [x] **Marketing landing page** with animated components
|
|
||||||
- [x] **`/dev` documentation portal** with live component examples
|
|
||||||
- [x] **Toast notifications** system (Sonner)
|
|
||||||
- [x] **Charts and visualizations** (Recharts)
|
|
||||||
- [x] **Animation system** (Framer Motion)
|
|
||||||
- [x] **Markdown rendering** with syntax highlighting
|
|
||||||
- [x] **SEO optimization** (sitemap, robots.txt, locale-aware metadata)
|
|
||||||
- [x] Database migrations with helper script
|
|
||||||
- [x] Docker deployment
|
|
||||||
- [x] API documentation (OpenAPI/Swagger)
|
|
||||||
|
|
||||||
### 🚧 In Progress
|
|
||||||
- [ ] Email integration (templates ready, SMTP pending)
|
|
||||||
|
|
||||||
### 🔮 Planned
|
|
||||||
- [ ] GitHub Actions CI/CD pipelines
|
|
||||||
- [ ] Dynamic test coverage badges from CI
|
|
||||||
- [ ] E2E test coverage reporting
|
|
||||||
- [ ] OAuth token encryption at rest (security hardening)
|
|
||||||
- [ ] Additional languages (Spanish, French, German, etc.)
|
|
||||||
- [ ] SSO/SAML authentication
|
|
||||||
- [ ] Real-time notifications with WebSockets
|
|
||||||
- [ ] Webhook system
|
|
||||||
- [ ] File upload/storage (S3-compatible)
|
|
||||||
- [ ] Audit logging system
|
|
||||||
- [ ] API versioning example
|
|
||||||
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 🤝 Contributing
|
|
||||||
|
|
||||||
Contributions are welcome! Whether you're fixing bugs, improving documentation, or proposing new features, we'd love your help.
|
|
||||||
|
|
||||||
### How to Contribute
|
|
||||||
|
|
||||||
1. **Fork the repository**
|
|
||||||
2. **Create a feature branch** (`git checkout -b feature/amazing-feature`)
|
|
||||||
3. **Make your changes**
|
|
||||||
- Follow existing code style
|
|
||||||
- Add tests for new features
|
|
||||||
- Update documentation as needed
|
|
||||||
4. **Run tests** to ensure everything works
|
|
||||||
5. **Commit your changes** (`git commit -m 'Add amazing feature'`)
|
|
||||||
6. **Push to your branch** (`git push origin feature/amazing-feature`)
|
|
||||||
7. **Open a Pull Request**
|
|
||||||
|
|
||||||
### Development Guidelines
|
|
||||||
|
|
||||||
- Write tests for new features (aim for >90% coverage)
|
|
||||||
- Follow the existing architecture patterns
|
|
||||||
- Update documentation when adding features
|
|
||||||
- Keep commits atomic and well-described
|
|
||||||
- Be respectful and constructive in discussions
|
|
||||||
|
|
||||||
### Reporting Issues
|
|
||||||
|
|
||||||
Found a bug? Have a suggestion? [Open an issue](https://github.com/cardosofelipe/pragma-stack/issues)!
|
|
||||||
|
|
||||||
Please include:
|
|
||||||
- Clear description of the issue/suggestion
|
|
||||||
- Steps to reproduce (for bugs)
|
|
||||||
- Expected vs. actual behavior
|
|
||||||
- Environment details (OS, Python/Node version, etc.)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 📄 License
|
|
||||||
|
|
||||||
This project is licensed under the **MIT License** - see the [LICENSE](./LICENSE) file for details.
|
|
||||||
|
|
||||||
**TL;DR**: You can use this template for any purpose, commercial or non-commercial. Attribution is appreciated but not required!
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 🙏 Acknowledgments
|
|
||||||
|
|
||||||
This template is built on the shoulders of giants:
|
|
||||||
|
|
||||||
- [FastAPI](https://fastapi.tiangolo.com/) by Sebastián Ramírez
|
|
||||||
- [Next.js](https://nextjs.org/) by Vercel
|
|
||||||
- [shadcn/ui](https://ui.shadcn.com/) by shadcn
|
|
||||||
- [TanStack Query](https://tanstack.com/query) by Tanner Linsley
|
|
||||||
- [Playwright](https://playwright.dev/) by Microsoft
|
|
||||||
- And countless other open-source projects that make modern development possible
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 💬 Questions?
|
|
||||||
|
|
||||||
- **Documentation**: Check the `/docs` folders in backend and frontend
|
|
||||||
- **Issues**: [GitHub Issues](https://github.com/cardosofelipe/pragma-stack/issues)
|
|
||||||
- **Discussions**: [GitHub Discussions](https://github.com/cardosofelipe/pragma-stack/discussions)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## ⭐ Star This Repo
|
|
||||||
|
|
||||||
If this template saves you time, consider giving it a star! It helps others discover the project and motivates continued development.
|
|
||||||
|
|
||||||
**Happy coding! 🚀**
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
<div align="center">
|
|
||||||
Made with ❤️ by a developer who got tired of rebuilding the same boilerplate
|
|
||||||
</div>
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
# PragmaStack Backend API
|
# Syndarix Backend API
|
||||||
|
|
||||||
> The pragmatic, production-ready FastAPI backend for PragmaStack.
|
> The pragmatic, production-ready FastAPI backend for Syndarix.
|
||||||
|
|
||||||
## Overview
|
## Overview
|
||||||
|
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ def include_object(object, name, type_, reflected, compare_to):
|
|||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
# Interpret the config file for Python logging.
|
# Interpret the config file for Python logging.
|
||||||
# This line sets up loggers basically.
|
# This line sets up loggers basically.
|
||||||
if config.config_file_name is not None:
|
if config.config_file_name is not None:
|
||||||
|
|||||||
@@ -5,258 +5,442 @@ Revises:
|
|||||||
Create Date: 2025-11-27 09:08:09.464506
|
Create Date: 2025-11-27 09:08:09.464506
|
||||||
|
|
||||||
"""
|
"""
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
from alembic import op
|
from collections.abc import Sequence
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
from sqlalchemy.dialects import postgresql
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
revision: str = '0001'
|
revision: str = "0001"
|
||||||
down_revision: Union[str, None] = None
|
down_revision: str | None = None
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
branch_labels: str | Sequence[str] | None = None
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
def upgrade() -> None:
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
op.create_table('oauth_states',
|
op.create_table(
|
||||||
sa.Column('state', sa.String(length=255), nullable=False),
|
"oauth_states",
|
||||||
sa.Column('code_verifier', sa.String(length=128), nullable=True),
|
sa.Column("state", sa.String(length=255), nullable=False),
|
||||||
sa.Column('nonce', sa.String(length=255), nullable=True),
|
sa.Column("code_verifier", sa.String(length=128), nullable=True),
|
||||||
sa.Column('provider', sa.String(length=50), nullable=False),
|
sa.Column("nonce", sa.String(length=255), nullable=True),
|
||||||
sa.Column('redirect_uri', sa.String(length=500), nullable=True),
|
sa.Column("provider", sa.String(length=50), nullable=False),
|
||||||
sa.Column('user_id', sa.UUID(), nullable=True),
|
sa.Column("redirect_uri", sa.String(length=500), nullable=True),
|
||||||
sa.Column('expires_at', sa.DateTime(timezone=True), nullable=False),
|
sa.Column("user_id", sa.UUID(), nullable=True),
|
||||||
sa.Column('id', sa.UUID(), nullable=False),
|
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
sa.Column("id", sa.UUID(), nullable=False),
|
||||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
sa.PrimaryKeyConstraint('id')
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
)
|
)
|
||||||
op.create_index(op.f('ix_oauth_states_state'), 'oauth_states', ['state'], unique=True)
|
op.create_index(
|
||||||
op.create_table('organizations',
|
op.f("ix_oauth_states_state"), "oauth_states", ["state"], unique=True
|
||||||
sa.Column('name', sa.String(length=255), nullable=False),
|
|
||||||
sa.Column('slug', sa.String(length=255), nullable=False),
|
|
||||||
sa.Column('description', sa.Text(), nullable=True),
|
|
||||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
|
||||||
sa.Column('settings', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
|
||||||
sa.Column('id', sa.UUID(), nullable=False),
|
|
||||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
|
||||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
|
||||||
sa.PrimaryKeyConstraint('id')
|
|
||||||
)
|
)
|
||||||
op.create_index(op.f('ix_organizations_is_active'), 'organizations', ['is_active'], unique=False)
|
op.create_table(
|
||||||
op.create_index(op.f('ix_organizations_name'), 'organizations', ['name'], unique=False)
|
"organizations",
|
||||||
op.create_index('ix_organizations_name_active', 'organizations', ['name', 'is_active'], unique=False)
|
sa.Column("name", sa.String(length=255), nullable=False),
|
||||||
op.create_index(op.f('ix_organizations_slug'), 'organizations', ['slug'], unique=True)
|
sa.Column("slug", sa.String(length=255), nullable=False),
|
||||||
op.create_index('ix_organizations_slug_active', 'organizations', ['slug', 'is_active'], unique=False)
|
sa.Column("description", sa.Text(), nullable=True),
|
||||||
op.create_table('users',
|
sa.Column("is_active", sa.Boolean(), nullable=False),
|
||||||
sa.Column('email', sa.String(length=255), nullable=False),
|
sa.Column("settings", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||||
sa.Column('password_hash', sa.String(length=255), nullable=True),
|
sa.Column("id", sa.UUID(), nullable=False),
|
||||||
sa.Column('first_name', sa.String(length=100), nullable=False),
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
sa.Column('last_name', sa.String(length=100), nullable=True),
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
sa.Column('phone_number', sa.String(length=20), nullable=True),
|
sa.PrimaryKeyConstraint("id"),
|
||||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
|
||||||
sa.Column('is_superuser', sa.Boolean(), nullable=False),
|
|
||||||
sa.Column('preferences', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
|
||||||
sa.Column('locale', sa.String(length=10), nullable=True),
|
|
||||||
sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True),
|
|
||||||
sa.Column('id', sa.UUID(), nullable=False),
|
|
||||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
|
||||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
|
||||||
sa.PrimaryKeyConstraint('id')
|
|
||||||
)
|
)
|
||||||
op.create_index(op.f('ix_users_deleted_at'), 'users', ['deleted_at'], unique=False)
|
op.create_index(
|
||||||
op.create_index(op.f('ix_users_email'), 'users', ['email'], unique=True)
|
op.f("ix_organizations_is_active"), "organizations", ["is_active"], unique=False
|
||||||
op.create_index(op.f('ix_users_is_active'), 'users', ['is_active'], unique=False)
|
|
||||||
op.create_index(op.f('ix_users_is_superuser'), 'users', ['is_superuser'], unique=False)
|
|
||||||
op.create_index(op.f('ix_users_locale'), 'users', ['locale'], unique=False)
|
|
||||||
op.create_table('oauth_accounts',
|
|
||||||
sa.Column('user_id', sa.UUID(), nullable=False),
|
|
||||||
sa.Column('provider', sa.String(length=50), nullable=False),
|
|
||||||
sa.Column('provider_user_id', sa.String(length=255), nullable=False),
|
|
||||||
sa.Column('provider_email', sa.String(length=255), nullable=True),
|
|
||||||
sa.Column('access_token_encrypted', sa.String(length=2048), nullable=True),
|
|
||||||
sa.Column('refresh_token_encrypted', sa.String(length=2048), nullable=True),
|
|
||||||
sa.Column('token_expires_at', sa.DateTime(timezone=True), nullable=True),
|
|
||||||
sa.Column('id', sa.UUID(), nullable=False),
|
|
||||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
|
||||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
|
||||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
|
|
||||||
sa.PrimaryKeyConstraint('id'),
|
|
||||||
sa.UniqueConstraint('provider', 'provider_user_id', name='uq_oauth_provider_user')
|
|
||||||
)
|
)
|
||||||
op.create_index(op.f('ix_oauth_accounts_provider'), 'oauth_accounts', ['provider'], unique=False)
|
op.create_index(
|
||||||
op.create_index(op.f('ix_oauth_accounts_provider_email'), 'oauth_accounts', ['provider_email'], unique=False)
|
op.f("ix_organizations_name"), "organizations", ["name"], unique=False
|
||||||
op.create_index(op.f('ix_oauth_accounts_user_id'), 'oauth_accounts', ['user_id'], unique=False)
|
|
||||||
op.create_index('ix_oauth_accounts_user_provider', 'oauth_accounts', ['user_id', 'provider'], unique=False)
|
|
||||||
op.create_table('oauth_clients',
|
|
||||||
sa.Column('client_id', sa.String(length=64), nullable=False),
|
|
||||||
sa.Column('client_secret_hash', sa.String(length=255), nullable=True),
|
|
||||||
sa.Column('client_name', sa.String(length=255), nullable=False),
|
|
||||||
sa.Column('client_description', sa.String(length=1000), nullable=True),
|
|
||||||
sa.Column('client_type', sa.String(length=20), nullable=False),
|
|
||||||
sa.Column('redirect_uris', postgresql.JSONB(astext_type=sa.Text()), nullable=False),
|
|
||||||
sa.Column('allowed_scopes', postgresql.JSONB(astext_type=sa.Text()), nullable=False),
|
|
||||||
sa.Column('access_token_lifetime', sa.String(length=10), nullable=False),
|
|
||||||
sa.Column('refresh_token_lifetime', sa.String(length=10), nullable=False),
|
|
||||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
|
||||||
sa.Column('owner_user_id', sa.UUID(), nullable=True),
|
|
||||||
sa.Column('mcp_server_url', sa.String(length=2048), nullable=True),
|
|
||||||
sa.Column('id', sa.UUID(), nullable=False),
|
|
||||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
|
||||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
|
||||||
sa.ForeignKeyConstraint(['owner_user_id'], ['users.id'], ondelete='SET NULL'),
|
|
||||||
sa.PrimaryKeyConstraint('id')
|
|
||||||
)
|
)
|
||||||
op.create_index(op.f('ix_oauth_clients_client_id'), 'oauth_clients', ['client_id'], unique=True)
|
op.create_index(
|
||||||
op.create_index(op.f('ix_oauth_clients_is_active'), 'oauth_clients', ['is_active'], unique=False)
|
"ix_organizations_name_active",
|
||||||
op.create_table('user_organizations',
|
"organizations",
|
||||||
sa.Column('user_id', sa.UUID(), nullable=False),
|
["name", "is_active"],
|
||||||
sa.Column('organization_id', sa.UUID(), nullable=False),
|
unique=False,
|
||||||
sa.Column('role', sa.Enum('OWNER', 'ADMIN', 'MEMBER', 'GUEST', name='organizationrole'), nullable=False),
|
|
||||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
|
||||||
sa.Column('custom_permissions', sa.String(length=500), nullable=True),
|
|
||||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
|
||||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
|
||||||
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ondelete='CASCADE'),
|
|
||||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
|
|
||||||
sa.PrimaryKeyConstraint('user_id', 'organization_id')
|
|
||||||
)
|
)
|
||||||
op.create_index('ix_user_org_org_active', 'user_organizations', ['organization_id', 'is_active'], unique=False)
|
op.create_index(
|
||||||
op.create_index('ix_user_org_role', 'user_organizations', ['role'], unique=False)
|
op.f("ix_organizations_slug"), "organizations", ["slug"], unique=True
|
||||||
op.create_index('ix_user_org_user_active', 'user_organizations', ['user_id', 'is_active'], unique=False)
|
|
||||||
op.create_index(op.f('ix_user_organizations_is_active'), 'user_organizations', ['is_active'], unique=False)
|
|
||||||
op.create_table('user_sessions',
|
|
||||||
sa.Column('user_id', sa.UUID(), nullable=False),
|
|
||||||
sa.Column('refresh_token_jti', sa.String(length=255), nullable=False),
|
|
||||||
sa.Column('device_name', sa.String(length=255), nullable=True),
|
|
||||||
sa.Column('device_id', sa.String(length=255), nullable=True),
|
|
||||||
sa.Column('ip_address', sa.String(length=45), nullable=True),
|
|
||||||
sa.Column('user_agent', sa.String(length=500), nullable=True),
|
|
||||||
sa.Column('last_used_at', sa.DateTime(timezone=True), nullable=False),
|
|
||||||
sa.Column('expires_at', sa.DateTime(timezone=True), nullable=False),
|
|
||||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
|
||||||
sa.Column('location_city', sa.String(length=100), nullable=True),
|
|
||||||
sa.Column('location_country', sa.String(length=100), nullable=True),
|
|
||||||
sa.Column('id', sa.UUID(), nullable=False),
|
|
||||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
|
||||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
|
||||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
|
|
||||||
sa.PrimaryKeyConstraint('id')
|
|
||||||
)
|
)
|
||||||
op.create_index(op.f('ix_user_sessions_is_active'), 'user_sessions', ['is_active'], unique=False)
|
op.create_index(
|
||||||
op.create_index('ix_user_sessions_jti_active', 'user_sessions', ['refresh_token_jti', 'is_active'], unique=False)
|
"ix_organizations_slug_active",
|
||||||
op.create_index(op.f('ix_user_sessions_refresh_token_jti'), 'user_sessions', ['refresh_token_jti'], unique=True)
|
"organizations",
|
||||||
op.create_index('ix_user_sessions_user_active', 'user_sessions', ['user_id', 'is_active'], unique=False)
|
["slug", "is_active"],
|
||||||
op.create_index(op.f('ix_user_sessions_user_id'), 'user_sessions', ['user_id'], unique=False)
|
unique=False,
|
||||||
op.create_table('oauth_authorization_codes',
|
|
||||||
sa.Column('code', sa.String(length=128), nullable=False),
|
|
||||||
sa.Column('client_id', sa.String(length=64), nullable=False),
|
|
||||||
sa.Column('user_id', sa.UUID(), nullable=False),
|
|
||||||
sa.Column('redirect_uri', sa.String(length=2048), nullable=False),
|
|
||||||
sa.Column('scope', sa.String(length=1000), nullable=False),
|
|
||||||
sa.Column('code_challenge', sa.String(length=128), nullable=True),
|
|
||||||
sa.Column('code_challenge_method', sa.String(length=10), nullable=True),
|
|
||||||
sa.Column('state', sa.String(length=256), nullable=True),
|
|
||||||
sa.Column('nonce', sa.String(length=256), nullable=True),
|
|
||||||
sa.Column('expires_at', sa.DateTime(timezone=True), nullable=False),
|
|
||||||
sa.Column('used', sa.Boolean(), nullable=False),
|
|
||||||
sa.Column('id', sa.UUID(), nullable=False),
|
|
||||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
|
||||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
|
||||||
sa.ForeignKeyConstraint(['client_id'], ['oauth_clients.client_id'], ondelete='CASCADE'),
|
|
||||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
|
|
||||||
sa.PrimaryKeyConstraint('id')
|
|
||||||
)
|
)
|
||||||
op.create_index('ix_oauth_authorization_codes_client_user', 'oauth_authorization_codes', ['client_id', 'user_id'], unique=False)
|
op.create_table(
|
||||||
op.create_index(op.f('ix_oauth_authorization_codes_code'), 'oauth_authorization_codes', ['code'], unique=True)
|
"users",
|
||||||
op.create_index('ix_oauth_authorization_codes_expires_at', 'oauth_authorization_codes', ['expires_at'], unique=False)
|
sa.Column("email", sa.String(length=255), nullable=False),
|
||||||
op.create_table('oauth_consents',
|
sa.Column("password_hash", sa.String(length=255), nullable=True),
|
||||||
sa.Column('user_id', sa.UUID(), nullable=False),
|
sa.Column("first_name", sa.String(length=100), nullable=False),
|
||||||
sa.Column('client_id', sa.String(length=64), nullable=False),
|
sa.Column("last_name", sa.String(length=100), nullable=True),
|
||||||
sa.Column('granted_scopes', sa.String(length=1000), nullable=False),
|
sa.Column("phone_number", sa.String(length=20), nullable=True),
|
||||||
sa.Column('id', sa.UUID(), nullable=False),
|
sa.Column("is_active", sa.Boolean(), nullable=False),
|
||||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
sa.Column("is_superuser", sa.Boolean(), nullable=False),
|
||||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
sa.Column(
|
||||||
sa.ForeignKeyConstraint(['client_id'], ['oauth_clients.client_id'], ondelete='CASCADE'),
|
"preferences", postgresql.JSONB(astext_type=sa.Text()), nullable=True
|
||||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
|
),
|
||||||
sa.PrimaryKeyConstraint('id')
|
sa.Column("locale", sa.String(length=10), nullable=True),
|
||||||
|
sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("id", sa.UUID(), nullable=False),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
)
|
)
|
||||||
op.create_index('ix_oauth_consents_user_client', 'oauth_consents', ['user_id', 'client_id'], unique=True)
|
op.create_index(op.f("ix_users_deleted_at"), "users", ["deleted_at"], unique=False)
|
||||||
op.create_table('oauth_provider_refresh_tokens',
|
op.create_index(op.f("ix_users_email"), "users", ["email"], unique=True)
|
||||||
sa.Column('token_hash', sa.String(length=64), nullable=False),
|
op.create_index(op.f("ix_users_is_active"), "users", ["is_active"], unique=False)
|
||||||
sa.Column('jti', sa.String(length=64), nullable=False),
|
op.create_index(
|
||||||
sa.Column('client_id', sa.String(length=64), nullable=False),
|
op.f("ix_users_is_superuser"), "users", ["is_superuser"], unique=False
|
||||||
sa.Column('user_id', sa.UUID(), nullable=False),
|
)
|
||||||
sa.Column('scope', sa.String(length=1000), nullable=False),
|
op.create_index(op.f("ix_users_locale"), "users", ["locale"], unique=False)
|
||||||
sa.Column('expires_at', sa.DateTime(timezone=True), nullable=False),
|
op.create_table(
|
||||||
sa.Column('revoked', sa.Boolean(), nullable=False),
|
"oauth_accounts",
|
||||||
sa.Column('last_used_at', sa.DateTime(timezone=True), nullable=True),
|
sa.Column("user_id", sa.UUID(), nullable=False),
|
||||||
sa.Column('device_info', sa.String(length=500), nullable=True),
|
sa.Column("provider", sa.String(length=50), nullable=False),
|
||||||
sa.Column('ip_address', sa.String(length=45), nullable=True),
|
sa.Column("provider_user_id", sa.String(length=255), nullable=False),
|
||||||
sa.Column('id', sa.UUID(), nullable=False),
|
sa.Column("provider_email", sa.String(length=255), nullable=True),
|
||||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
sa.Column("access_token_encrypted", sa.String(length=2048), nullable=True),
|
||||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
sa.Column("refresh_token_encrypted", sa.String(length=2048), nullable=True),
|
||||||
sa.ForeignKeyConstraint(['client_id'], ['oauth_clients.client_id'], ondelete='CASCADE'),
|
sa.Column("token_expires_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
|
sa.Column("id", sa.UUID(), nullable=False),
|
||||||
sa.PrimaryKeyConstraint('id')
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.UniqueConstraint(
|
||||||
|
"provider", "provider_user_id", name="uq_oauth_provider_user"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_oauth_accounts_provider"), "oauth_accounts", ["provider"], unique=False
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_oauth_accounts_provider_email"),
|
||||||
|
"oauth_accounts",
|
||||||
|
["provider_email"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_oauth_accounts_user_id"), "oauth_accounts", ["user_id"], unique=False
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_oauth_accounts_user_provider",
|
||||||
|
"oauth_accounts",
|
||||||
|
["user_id", "provider"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"oauth_clients",
|
||||||
|
sa.Column("client_id", sa.String(length=64), nullable=False),
|
||||||
|
sa.Column("client_secret_hash", sa.String(length=255), nullable=True),
|
||||||
|
sa.Column("client_name", sa.String(length=255), nullable=False),
|
||||||
|
sa.Column("client_description", sa.String(length=1000), nullable=True),
|
||||||
|
sa.Column("client_type", sa.String(length=20), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"redirect_uris", postgresql.JSONB(astext_type=sa.Text()), nullable=False
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"allowed_scopes", postgresql.JSONB(astext_type=sa.Text()), nullable=False
|
||||||
|
),
|
||||||
|
sa.Column("access_token_lifetime", sa.String(length=10), nullable=False),
|
||||||
|
sa.Column("refresh_token_lifetime", sa.String(length=10), nullable=False),
|
||||||
|
sa.Column("is_active", sa.Boolean(), nullable=False),
|
||||||
|
sa.Column("owner_user_id", sa.UUID(), nullable=True),
|
||||||
|
sa.Column("mcp_server_url", sa.String(length=2048), nullable=True),
|
||||||
|
sa.Column("id", sa.UUID(), nullable=False),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(["owner_user_id"], ["users.id"], ondelete="SET NULL"),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_oauth_clients_client_id"), "oauth_clients", ["client_id"], unique=True
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_oauth_clients_is_active"), "oauth_clients", ["is_active"], unique=False
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"user_organizations",
|
||||||
|
sa.Column("user_id", sa.UUID(), nullable=False),
|
||||||
|
sa.Column("organization_id", sa.UUID(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"role",
|
||||||
|
sa.Enum("OWNER", "ADMIN", "MEMBER", "GUEST", name="organizationrole"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column("is_active", sa.Boolean(), nullable=False),
|
||||||
|
sa.Column("custom_permissions", sa.String(length=500), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["organization_id"], ["organizations.id"], ondelete="CASCADE"
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
sa.PrimaryKeyConstraint("user_id", "organization_id"),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_user_org_org_active",
|
||||||
|
"user_organizations",
|
||||||
|
["organization_id", "is_active"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_index("ix_user_org_role", "user_organizations", ["role"], unique=False)
|
||||||
|
op.create_index(
|
||||||
|
"ix_user_org_user_active",
|
||||||
|
"user_organizations",
|
||||||
|
["user_id", "is_active"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_user_organizations_is_active"),
|
||||||
|
"user_organizations",
|
||||||
|
["is_active"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"user_sessions",
|
||||||
|
sa.Column("user_id", sa.UUID(), nullable=False),
|
||||||
|
sa.Column("refresh_token_jti", sa.String(length=255), nullable=False),
|
||||||
|
sa.Column("device_name", sa.String(length=255), nullable=True),
|
||||||
|
sa.Column("device_id", sa.String(length=255), nullable=True),
|
||||||
|
sa.Column("ip_address", sa.String(length=45), nullable=True),
|
||||||
|
sa.Column("user_agent", sa.String(length=500), nullable=True),
|
||||||
|
sa.Column("last_used_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("is_active", sa.Boolean(), nullable=False),
|
||||||
|
sa.Column("location_city", sa.String(length=100), nullable=True),
|
||||||
|
sa.Column("location_country", sa.String(length=100), nullable=True),
|
||||||
|
sa.Column("id", sa.UUID(), nullable=False),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_user_sessions_is_active"), "user_sessions", ["is_active"], unique=False
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_user_sessions_jti_active",
|
||||||
|
"user_sessions",
|
||||||
|
["refresh_token_jti", "is_active"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_user_sessions_refresh_token_jti"),
|
||||||
|
"user_sessions",
|
||||||
|
["refresh_token_jti"],
|
||||||
|
unique=True,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_user_sessions_user_active",
|
||||||
|
"user_sessions",
|
||||||
|
["user_id", "is_active"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_user_sessions_user_id"), "user_sessions", ["user_id"], unique=False
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"oauth_authorization_codes",
|
||||||
|
sa.Column("code", sa.String(length=128), nullable=False),
|
||||||
|
sa.Column("client_id", sa.String(length=64), nullable=False),
|
||||||
|
sa.Column("user_id", sa.UUID(), nullable=False),
|
||||||
|
sa.Column("redirect_uri", sa.String(length=2048), nullable=False),
|
||||||
|
sa.Column("scope", sa.String(length=1000), nullable=False),
|
||||||
|
sa.Column("code_challenge", sa.String(length=128), nullable=True),
|
||||||
|
sa.Column("code_challenge_method", sa.String(length=10), nullable=True),
|
||||||
|
sa.Column("state", sa.String(length=256), nullable=True),
|
||||||
|
sa.Column("nonce", sa.String(length=256), nullable=True),
|
||||||
|
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("used", sa.Boolean(), nullable=False),
|
||||||
|
sa.Column("id", sa.UUID(), nullable=False),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["client_id"], ["oauth_clients.client_id"], ondelete="CASCADE"
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_oauth_authorization_codes_client_user",
|
||||||
|
"oauth_authorization_codes",
|
||||||
|
["client_id", "user_id"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_oauth_authorization_codes_code"),
|
||||||
|
"oauth_authorization_codes",
|
||||||
|
["code"],
|
||||||
|
unique=True,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_oauth_authorization_codes_expires_at",
|
||||||
|
"oauth_authorization_codes",
|
||||||
|
["expires_at"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"oauth_consents",
|
||||||
|
sa.Column("user_id", sa.UUID(), nullable=False),
|
||||||
|
sa.Column("client_id", sa.String(length=64), nullable=False),
|
||||||
|
sa.Column("granted_scopes", sa.String(length=1000), nullable=False),
|
||||||
|
sa.Column("id", sa.UUID(), nullable=False),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["client_id"], ["oauth_clients.client_id"], ondelete="CASCADE"
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_oauth_consents_user_client",
|
||||||
|
"oauth_consents",
|
||||||
|
["user_id", "client_id"],
|
||||||
|
unique=True,
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"oauth_provider_refresh_tokens",
|
||||||
|
sa.Column("token_hash", sa.String(length=64), nullable=False),
|
||||||
|
sa.Column("jti", sa.String(length=64), nullable=False),
|
||||||
|
sa.Column("client_id", sa.String(length=64), nullable=False),
|
||||||
|
sa.Column("user_id", sa.UUID(), nullable=False),
|
||||||
|
sa.Column("scope", sa.String(length=1000), nullable=False),
|
||||||
|
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("revoked", sa.Boolean(), nullable=False),
|
||||||
|
sa.Column("last_used_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("device_info", sa.String(length=500), nullable=True),
|
||||||
|
sa.Column("ip_address", sa.String(length=45), nullable=True),
|
||||||
|
sa.Column("id", sa.UUID(), nullable=False),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["client_id"], ["oauth_clients.client_id"], ondelete="CASCADE"
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_oauth_provider_refresh_tokens_client_user",
|
||||||
|
"oauth_provider_refresh_tokens",
|
||||||
|
["client_id", "user_id"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_oauth_provider_refresh_tokens_expires_at",
|
||||||
|
"oauth_provider_refresh_tokens",
|
||||||
|
["expires_at"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_oauth_provider_refresh_tokens_jti"),
|
||||||
|
"oauth_provider_refresh_tokens",
|
||||||
|
["jti"],
|
||||||
|
unique=True,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_oauth_provider_refresh_tokens_revoked"),
|
||||||
|
"oauth_provider_refresh_tokens",
|
||||||
|
["revoked"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_oauth_provider_refresh_tokens_token_hash"),
|
||||||
|
"oauth_provider_refresh_tokens",
|
||||||
|
["token_hash"],
|
||||||
|
unique=True,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_oauth_provider_refresh_tokens_user_revoked",
|
||||||
|
"oauth_provider_refresh_tokens",
|
||||||
|
["user_id", "revoked"],
|
||||||
|
unique=False,
|
||||||
)
|
)
|
||||||
op.create_index('ix_oauth_provider_refresh_tokens_client_user', 'oauth_provider_refresh_tokens', ['client_id', 'user_id'], unique=False)
|
|
||||||
op.create_index('ix_oauth_provider_refresh_tokens_expires_at', 'oauth_provider_refresh_tokens', ['expires_at'], unique=False)
|
|
||||||
op.create_index(op.f('ix_oauth_provider_refresh_tokens_jti'), 'oauth_provider_refresh_tokens', ['jti'], unique=True)
|
|
||||||
op.create_index(op.f('ix_oauth_provider_refresh_tokens_revoked'), 'oauth_provider_refresh_tokens', ['revoked'], unique=False)
|
|
||||||
op.create_index(op.f('ix_oauth_provider_refresh_tokens_token_hash'), 'oauth_provider_refresh_tokens', ['token_hash'], unique=True)
|
|
||||||
op.create_index('ix_oauth_provider_refresh_tokens_user_revoked', 'oauth_provider_refresh_tokens', ['user_id', 'revoked'], unique=False)
|
|
||||||
# ### end Alembic commands ###
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
def downgrade() -> None:
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
op.drop_index('ix_oauth_provider_refresh_tokens_user_revoked', table_name='oauth_provider_refresh_tokens')
|
op.drop_index(
|
||||||
op.drop_index(op.f('ix_oauth_provider_refresh_tokens_token_hash'), table_name='oauth_provider_refresh_tokens')
|
"ix_oauth_provider_refresh_tokens_user_revoked",
|
||||||
op.drop_index(op.f('ix_oauth_provider_refresh_tokens_revoked'), table_name='oauth_provider_refresh_tokens')
|
table_name="oauth_provider_refresh_tokens",
|
||||||
op.drop_index(op.f('ix_oauth_provider_refresh_tokens_jti'), table_name='oauth_provider_refresh_tokens')
|
)
|
||||||
op.drop_index('ix_oauth_provider_refresh_tokens_expires_at', table_name='oauth_provider_refresh_tokens')
|
op.drop_index(
|
||||||
op.drop_index('ix_oauth_provider_refresh_tokens_client_user', table_name='oauth_provider_refresh_tokens')
|
op.f("ix_oauth_provider_refresh_tokens_token_hash"),
|
||||||
op.drop_table('oauth_provider_refresh_tokens')
|
table_name="oauth_provider_refresh_tokens",
|
||||||
op.drop_index('ix_oauth_consents_user_client', table_name='oauth_consents')
|
)
|
||||||
op.drop_table('oauth_consents')
|
op.drop_index(
|
||||||
op.drop_index('ix_oauth_authorization_codes_expires_at', table_name='oauth_authorization_codes')
|
op.f("ix_oauth_provider_refresh_tokens_revoked"),
|
||||||
op.drop_index(op.f('ix_oauth_authorization_codes_code'), table_name='oauth_authorization_codes')
|
table_name="oauth_provider_refresh_tokens",
|
||||||
op.drop_index('ix_oauth_authorization_codes_client_user', table_name='oauth_authorization_codes')
|
)
|
||||||
op.drop_table('oauth_authorization_codes')
|
op.drop_index(
|
||||||
op.drop_index(op.f('ix_user_sessions_user_id'), table_name='user_sessions')
|
op.f("ix_oauth_provider_refresh_tokens_jti"),
|
||||||
op.drop_index('ix_user_sessions_user_active', table_name='user_sessions')
|
table_name="oauth_provider_refresh_tokens",
|
||||||
op.drop_index(op.f('ix_user_sessions_refresh_token_jti'), table_name='user_sessions')
|
)
|
||||||
op.drop_index('ix_user_sessions_jti_active', table_name='user_sessions')
|
op.drop_index(
|
||||||
op.drop_index(op.f('ix_user_sessions_is_active'), table_name='user_sessions')
|
"ix_oauth_provider_refresh_tokens_expires_at",
|
||||||
op.drop_table('user_sessions')
|
table_name="oauth_provider_refresh_tokens",
|
||||||
op.drop_index(op.f('ix_user_organizations_is_active'), table_name='user_organizations')
|
)
|
||||||
op.drop_index('ix_user_org_user_active', table_name='user_organizations')
|
op.drop_index(
|
||||||
op.drop_index('ix_user_org_role', table_name='user_organizations')
|
"ix_oauth_provider_refresh_tokens_client_user",
|
||||||
op.drop_index('ix_user_org_org_active', table_name='user_organizations')
|
table_name="oauth_provider_refresh_tokens",
|
||||||
op.drop_table('user_organizations')
|
)
|
||||||
op.drop_index(op.f('ix_oauth_clients_is_active'), table_name='oauth_clients')
|
op.drop_table("oauth_provider_refresh_tokens")
|
||||||
op.drop_index(op.f('ix_oauth_clients_client_id'), table_name='oauth_clients')
|
op.drop_index("ix_oauth_consents_user_client", table_name="oauth_consents")
|
||||||
op.drop_table('oauth_clients')
|
op.drop_table("oauth_consents")
|
||||||
op.drop_index('ix_oauth_accounts_user_provider', table_name='oauth_accounts')
|
op.drop_index(
|
||||||
op.drop_index(op.f('ix_oauth_accounts_user_id'), table_name='oauth_accounts')
|
"ix_oauth_authorization_codes_expires_at",
|
||||||
op.drop_index(op.f('ix_oauth_accounts_provider_email'), table_name='oauth_accounts')
|
table_name="oauth_authorization_codes",
|
||||||
op.drop_index(op.f('ix_oauth_accounts_provider'), table_name='oauth_accounts')
|
)
|
||||||
op.drop_table('oauth_accounts')
|
op.drop_index(
|
||||||
op.drop_index(op.f('ix_users_locale'), table_name='users')
|
op.f("ix_oauth_authorization_codes_code"),
|
||||||
op.drop_index(op.f('ix_users_is_superuser'), table_name='users')
|
table_name="oauth_authorization_codes",
|
||||||
op.drop_index(op.f('ix_users_is_active'), table_name='users')
|
)
|
||||||
op.drop_index(op.f('ix_users_email'), table_name='users')
|
op.drop_index(
|
||||||
op.drop_index(op.f('ix_users_deleted_at'), table_name='users')
|
"ix_oauth_authorization_codes_client_user",
|
||||||
op.drop_table('users')
|
table_name="oauth_authorization_codes",
|
||||||
op.drop_index('ix_organizations_slug_active', table_name='organizations')
|
)
|
||||||
op.drop_index(op.f('ix_organizations_slug'), table_name='organizations')
|
op.drop_table("oauth_authorization_codes")
|
||||||
op.drop_index('ix_organizations_name_active', table_name='organizations')
|
op.drop_index(op.f("ix_user_sessions_user_id"), table_name="user_sessions")
|
||||||
op.drop_index(op.f('ix_organizations_name'), table_name='organizations')
|
op.drop_index("ix_user_sessions_user_active", table_name="user_sessions")
|
||||||
op.drop_index(op.f('ix_organizations_is_active'), table_name='organizations')
|
op.drop_index(
|
||||||
op.drop_table('organizations')
|
op.f("ix_user_sessions_refresh_token_jti"), table_name="user_sessions"
|
||||||
op.drop_index(op.f('ix_oauth_states_state'), table_name='oauth_states')
|
)
|
||||||
op.drop_table('oauth_states')
|
op.drop_index("ix_user_sessions_jti_active", table_name="user_sessions")
|
||||||
|
op.drop_index(op.f("ix_user_sessions_is_active"), table_name="user_sessions")
|
||||||
|
op.drop_table("user_sessions")
|
||||||
|
op.drop_index(
|
||||||
|
op.f("ix_user_organizations_is_active"), table_name="user_organizations"
|
||||||
|
)
|
||||||
|
op.drop_index("ix_user_org_user_active", table_name="user_organizations")
|
||||||
|
op.drop_index("ix_user_org_role", table_name="user_organizations")
|
||||||
|
op.drop_index("ix_user_org_org_active", table_name="user_organizations")
|
||||||
|
op.drop_table("user_organizations")
|
||||||
|
op.drop_index(op.f("ix_oauth_clients_is_active"), table_name="oauth_clients")
|
||||||
|
op.drop_index(op.f("ix_oauth_clients_client_id"), table_name="oauth_clients")
|
||||||
|
op.drop_table("oauth_clients")
|
||||||
|
op.drop_index("ix_oauth_accounts_user_provider", table_name="oauth_accounts")
|
||||||
|
op.drop_index(op.f("ix_oauth_accounts_user_id"), table_name="oauth_accounts")
|
||||||
|
op.drop_index(op.f("ix_oauth_accounts_provider_email"), table_name="oauth_accounts")
|
||||||
|
op.drop_index(op.f("ix_oauth_accounts_provider"), table_name="oauth_accounts")
|
||||||
|
op.drop_table("oauth_accounts")
|
||||||
|
op.drop_index(op.f("ix_users_locale"), table_name="users")
|
||||||
|
op.drop_index(op.f("ix_users_is_superuser"), table_name="users")
|
||||||
|
op.drop_index(op.f("ix_users_is_active"), table_name="users")
|
||||||
|
op.drop_index(op.f("ix_users_email"), table_name="users")
|
||||||
|
op.drop_index(op.f("ix_users_deleted_at"), table_name="users")
|
||||||
|
op.drop_table("users")
|
||||||
|
op.drop_index("ix_organizations_slug_active", table_name="organizations")
|
||||||
|
op.drop_index(op.f("ix_organizations_slug"), table_name="organizations")
|
||||||
|
op.drop_index("ix_organizations_name_active", table_name="organizations")
|
||||||
|
op.drop_index(op.f("ix_organizations_name"), table_name="organizations")
|
||||||
|
op.drop_index(op.f("ix_organizations_is_active"), table_name="organizations")
|
||||||
|
op.drop_table("organizations")
|
||||||
|
op.drop_index(op.f("ix_oauth_states_state"), table_name="oauth_states")
|
||||||
|
op.drop_table("oauth_states")
|
||||||
# ### end Alembic commands ###
|
# ### end Alembic commands ###
|
||||||
|
|||||||
@@ -114,8 +114,13 @@ def upgrade() -> None:
|
|||||||
|
|
||||||
def downgrade() -> None:
|
def downgrade() -> None:
|
||||||
# Drop indexes in reverse order
|
# Drop indexes in reverse order
|
||||||
op.drop_index("ix_perf_oauth_auth_codes_expires", table_name="oauth_authorization_codes")
|
op.drop_index(
|
||||||
op.drop_index("ix_perf_oauth_refresh_tokens_expires", table_name="oauth_provider_refresh_tokens")
|
"ix_perf_oauth_auth_codes_expires", table_name="oauth_authorization_codes"
|
||||||
|
)
|
||||||
|
op.drop_index(
|
||||||
|
"ix_perf_oauth_refresh_tokens_expires",
|
||||||
|
table_name="oauth_provider_refresh_tokens",
|
||||||
|
)
|
||||||
op.drop_index("ix_perf_user_sessions_expires", table_name="user_sessions")
|
op.drop_index("ix_perf_user_sessions_expires", table_name="user_sessions")
|
||||||
op.drop_index("ix_perf_organizations_slug_lower", table_name="organizations")
|
op.drop_index("ix_perf_organizations_slug_lower", table_name="organizations")
|
||||||
op.drop_index("ix_perf_users_active", table_name="users")
|
op.drop_index("ix_perf_users_active", table_name="users")
|
||||||
|
|||||||
@@ -0,0 +1,66 @@
|
|||||||
|
"""Enable pgvector extension
|
||||||
|
|
||||||
|
Revision ID: 0003
|
||||||
|
Revises: 0002
|
||||||
|
Create Date: 2025-12-30
|
||||||
|
|
||||||
|
This migration enables the pgvector extension for PostgreSQL, which provides
|
||||||
|
vector similarity search capabilities required for the RAG (Retrieval-Augmented
|
||||||
|
Generation) knowledge base system.
|
||||||
|
|
||||||
|
Vector Dimension Reference (per ADR-008 and SPIKE-006):
|
||||||
|
---------------------------------------------------------
|
||||||
|
The dimension size depends on the embedding model used:
|
||||||
|
|
||||||
|
| Model | Dimensions | Use Case |
|
||||||
|
|----------------------------|------------|------------------------------|
|
||||||
|
| text-embedding-3-small | 1536 | General docs, conversations |
|
||||||
|
| text-embedding-3-large | 256-3072 | High accuracy (configurable) |
|
||||||
|
| voyage-code-3 | 1024 | Code files (Python, JS, etc) |
|
||||||
|
| voyage-3-large | 1024 | High quality general purpose |
|
||||||
|
| nomic-embed-text (Ollama) | 768 | Local/fallback embedding |
|
||||||
|
|
||||||
|
Recommended defaults for Syndarix:
|
||||||
|
- Documentation/conversations: 1536 (text-embedding-3-small)
|
||||||
|
- Code files: 1024 (voyage-code-3)
|
||||||
|
|
||||||
|
Prerequisites:
|
||||||
|
--------------
|
||||||
|
This migration requires PostgreSQL with the pgvector extension installed.
|
||||||
|
The Docker Compose configuration uses `pgvector/pgvector:pg17` which includes
|
||||||
|
the extension pre-installed.
|
||||||
|
|
||||||
|
References:
|
||||||
|
-----------
|
||||||
|
- ADR-008: Knowledge Base and RAG Architecture
|
||||||
|
- SPIKE-006: Knowledge Base with pgvector for RAG System
|
||||||
|
- https://github.com/pgvector/pgvector
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "0003"
|
||||||
|
down_revision: str | None = "0002"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Enable the pgvector extension.
|
||||||
|
|
||||||
|
The CREATE EXTENSION IF NOT EXISTS statement is idempotent - it will
|
||||||
|
succeed whether the extension already exists or not.
|
||||||
|
"""
|
||||||
|
op.execute("CREATE EXTENSION IF NOT EXISTS vector")
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Drop the pgvector extension.
|
||||||
|
|
||||||
|
Note: This will fail if any tables with vector columns exist.
|
||||||
|
Future migrations that create vector columns should be downgraded first.
|
||||||
|
"""
|
||||||
|
op.execute("DROP EXTENSION IF EXISTS vector")
|
||||||
507
backend/app/alembic/versions/0004_add_syndarix_models.py
Normal file
507
backend/app/alembic/versions/0004_add_syndarix_models.py
Normal file
@@ -0,0 +1,507 @@
|
|||||||
|
"""Add Syndarix models
|
||||||
|
|
||||||
|
Revision ID: 0004
|
||||||
|
Revises: 0003
|
||||||
|
Create Date: 2025-12-31
|
||||||
|
|
||||||
|
This migration creates the core Syndarix domain tables:
|
||||||
|
- projects: Client engagement projects
|
||||||
|
- agent_types: Agent template configurations
|
||||||
|
- agent_instances: Spawned agent instances assigned to projects
|
||||||
|
- sprints: Sprint containers for issues
|
||||||
|
- issues: Work items (epics, stories, tasks, bugs)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "0004"
|
||||||
|
down_revision: str | None = "0003"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Create Syndarix domain tables."""
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Create projects table
|
||||||
|
# Note: ENUM types are created automatically by sa.Enum() during table creation
|
||||||
|
# =========================================================================
|
||||||
|
op.create_table(
|
||||||
|
"projects",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||||
|
sa.Column("name", sa.String(255), nullable=False),
|
||||||
|
sa.Column("slug", sa.String(255), nullable=False),
|
||||||
|
sa.Column("description", sa.Text(), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"autonomy_level",
|
||||||
|
sa.Enum(
|
||||||
|
"full_control",
|
||||||
|
"milestone",
|
||||||
|
"autonomous",
|
||||||
|
name="autonomy_level",
|
||||||
|
),
|
||||||
|
nullable=False,
|
||||||
|
server_default="milestone",
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"status",
|
||||||
|
sa.Enum(
|
||||||
|
"active",
|
||||||
|
"paused",
|
||||||
|
"completed",
|
||||||
|
"archived",
|
||||||
|
name="project_status",
|
||||||
|
),
|
||||||
|
nullable=False,
|
||||||
|
server_default="active",
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"complexity",
|
||||||
|
sa.Enum(
|
||||||
|
"script",
|
||||||
|
"simple",
|
||||||
|
"medium",
|
||||||
|
"complex",
|
||||||
|
name="project_complexity",
|
||||||
|
),
|
||||||
|
nullable=False,
|
||||||
|
server_default="medium",
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"client_mode",
|
||||||
|
sa.Enum("technical", "auto", name="client_mode"),
|
||||||
|
nullable=False,
|
||||||
|
server_default="auto",
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"settings",
|
||||||
|
postgresql.JSONB(astext_type=sa.Text()),
|
||||||
|
nullable=False,
|
||||||
|
server_default="{}",
|
||||||
|
),
|
||||||
|
sa.Column("owner_id", postgresql.UUID(as_uuid=True), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"created_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["owner_id"], ["users.id"], ondelete="SET NULL"),
|
||||||
|
sa.UniqueConstraint("slug"),
|
||||||
|
)
|
||||||
|
# Single column indexes
|
||||||
|
op.create_index("ix_projects_name", "projects", ["name"])
|
||||||
|
op.create_index("ix_projects_slug", "projects", ["slug"])
|
||||||
|
op.create_index("ix_projects_status", "projects", ["status"])
|
||||||
|
op.create_index("ix_projects_autonomy_level", "projects", ["autonomy_level"])
|
||||||
|
op.create_index("ix_projects_complexity", "projects", ["complexity"])
|
||||||
|
op.create_index("ix_projects_client_mode", "projects", ["client_mode"])
|
||||||
|
op.create_index("ix_projects_owner_id", "projects", ["owner_id"])
|
||||||
|
# Composite indexes
|
||||||
|
op.create_index("ix_projects_slug_status", "projects", ["slug", "status"])
|
||||||
|
op.create_index("ix_projects_owner_status", "projects", ["owner_id", "status"])
|
||||||
|
op.create_index(
|
||||||
|
"ix_projects_autonomy_status", "projects", ["autonomy_level", "status"]
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_projects_complexity_status", "projects", ["complexity", "status"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Create agent_types table
|
||||||
|
# =========================================================================
|
||||||
|
op.create_table(
|
||||||
|
"agent_types",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||||
|
sa.Column("name", sa.String(255), nullable=False),
|
||||||
|
sa.Column("slug", sa.String(255), nullable=False),
|
||||||
|
sa.Column("description", sa.Text(), nullable=True),
|
||||||
|
# Areas of expertise (e.g., ["python", "fastapi", "databases"])
|
||||||
|
sa.Column(
|
||||||
|
"expertise",
|
||||||
|
postgresql.JSONB(astext_type=sa.Text()),
|
||||||
|
nullable=False,
|
||||||
|
server_default="[]",
|
||||||
|
),
|
||||||
|
# System prompt defining personality and behavior (required)
|
||||||
|
sa.Column("personality_prompt", sa.Text(), nullable=False),
|
||||||
|
# LLM model configuration
|
||||||
|
sa.Column("primary_model", sa.String(100), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"fallback_models",
|
||||||
|
postgresql.JSONB(astext_type=sa.Text()),
|
||||||
|
nullable=False,
|
||||||
|
server_default="[]",
|
||||||
|
),
|
||||||
|
# Model parameters (temperature, max_tokens, etc.)
|
||||||
|
sa.Column(
|
||||||
|
"model_params",
|
||||||
|
postgresql.JSONB(astext_type=sa.Text()),
|
||||||
|
nullable=False,
|
||||||
|
server_default="{}",
|
||||||
|
),
|
||||||
|
# MCP servers this agent can connect to
|
||||||
|
sa.Column(
|
||||||
|
"mcp_servers",
|
||||||
|
postgresql.JSONB(astext_type=sa.Text()),
|
||||||
|
nullable=False,
|
||||||
|
server_default="[]",
|
||||||
|
),
|
||||||
|
# Tool permissions configuration
|
||||||
|
sa.Column(
|
||||||
|
"tool_permissions",
|
||||||
|
postgresql.JSONB(astext_type=sa.Text()),
|
||||||
|
nullable=False,
|
||||||
|
server_default="{}",
|
||||||
|
),
|
||||||
|
sa.Column("is_active", sa.Boolean(), nullable=False, server_default="true"),
|
||||||
|
sa.Column(
|
||||||
|
"created_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.UniqueConstraint("slug"),
|
||||||
|
)
|
||||||
|
# Single column indexes
|
||||||
|
op.create_index("ix_agent_types_name", "agent_types", ["name"])
|
||||||
|
op.create_index("ix_agent_types_slug", "agent_types", ["slug"])
|
||||||
|
op.create_index("ix_agent_types_is_active", "agent_types", ["is_active"])
|
||||||
|
# Composite indexes
|
||||||
|
op.create_index("ix_agent_types_slug_active", "agent_types", ["slug", "is_active"])
|
||||||
|
op.create_index("ix_agent_types_name_active", "agent_types", ["name", "is_active"])
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Create agent_instances table
|
||||||
|
# =========================================================================
|
||||||
|
op.create_table(
|
||||||
|
"agent_instances",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||||
|
sa.Column("agent_type_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||||
|
sa.Column("project_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||||
|
sa.Column("name", sa.String(100), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"status",
|
||||||
|
sa.Enum(
|
||||||
|
"idle",
|
||||||
|
"working",
|
||||||
|
"waiting",
|
||||||
|
"paused",
|
||||||
|
"terminated",
|
||||||
|
name="agent_status",
|
||||||
|
),
|
||||||
|
nullable=False,
|
||||||
|
server_default="idle",
|
||||||
|
),
|
||||||
|
sa.Column("current_task", sa.Text(), nullable=True),
|
||||||
|
# Short-term memory (conversation context, recent decisions)
|
||||||
|
sa.Column(
|
||||||
|
"short_term_memory",
|
||||||
|
postgresql.JSONB(astext_type=sa.Text()),
|
||||||
|
nullable=False,
|
||||||
|
server_default="{}",
|
||||||
|
),
|
||||||
|
# Reference to long-term memory in vector store
|
||||||
|
sa.Column("long_term_memory_ref", sa.String(500), nullable=True),
|
||||||
|
# Session ID for active MCP connections
|
||||||
|
sa.Column("session_id", sa.String(255), nullable=True),
|
||||||
|
# Activity tracking
|
||||||
|
sa.Column("last_activity_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("terminated_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
# Usage metrics
|
||||||
|
sa.Column("tasks_completed", sa.Integer(), nullable=False, server_default="0"),
|
||||||
|
sa.Column("tokens_used", sa.BigInteger(), nullable=False, server_default="0"),
|
||||||
|
sa.Column(
|
||||||
|
"cost_incurred",
|
||||||
|
sa.Numeric(precision=10, scale=4),
|
||||||
|
nullable=False,
|
||||||
|
server_default="0",
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"created_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["agent_type_id"], ["agent_types.id"], ondelete="RESTRICT"
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(["project_id"], ["projects.id"], ondelete="CASCADE"),
|
||||||
|
)
|
||||||
|
# Single column indexes
|
||||||
|
op.create_index("ix_agent_instances_name", "agent_instances", ["name"])
|
||||||
|
op.create_index("ix_agent_instances_status", "agent_instances", ["status"])
|
||||||
|
op.create_index(
|
||||||
|
"ix_agent_instances_agent_type_id", "agent_instances", ["agent_type_id"]
|
||||||
|
)
|
||||||
|
op.create_index("ix_agent_instances_project_id", "agent_instances", ["project_id"])
|
||||||
|
op.create_index("ix_agent_instances_session_id", "agent_instances", ["session_id"])
|
||||||
|
op.create_index(
|
||||||
|
"ix_agent_instances_last_activity_at", "agent_instances", ["last_activity_at"]
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_agent_instances_terminated_at", "agent_instances", ["terminated_at"]
|
||||||
|
)
|
||||||
|
# Composite indexes
|
||||||
|
op.create_index(
|
||||||
|
"ix_agent_instances_project_status",
|
||||||
|
"agent_instances",
|
||||||
|
["project_id", "status"],
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_agent_instances_type_status",
|
||||||
|
"agent_instances",
|
||||||
|
["agent_type_id", "status"],
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_agent_instances_project_type",
|
||||||
|
"agent_instances",
|
||||||
|
["project_id", "agent_type_id"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Create sprints table (before issues for FK reference)
|
||||||
|
# =========================================================================
|
||||||
|
op.create_table(
|
||||||
|
"sprints",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||||
|
sa.Column("project_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||||
|
sa.Column("name", sa.String(255), nullable=False),
|
||||||
|
sa.Column("number", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("goal", sa.Text(), nullable=True),
|
||||||
|
sa.Column("start_date", sa.Date(), nullable=False),
|
||||||
|
sa.Column("end_date", sa.Date(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"status",
|
||||||
|
sa.Enum(
|
||||||
|
"planned",
|
||||||
|
"active",
|
||||||
|
"in_review",
|
||||||
|
"completed",
|
||||||
|
"cancelled",
|
||||||
|
name="sprint_status",
|
||||||
|
),
|
||||||
|
nullable=False,
|
||||||
|
server_default="planned",
|
||||||
|
),
|
||||||
|
sa.Column("planned_points", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("velocity", sa.Integer(), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"created_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["project_id"], ["projects.id"], ondelete="CASCADE"),
|
||||||
|
sa.UniqueConstraint("project_id", "number", name="uq_sprint_project_number"),
|
||||||
|
)
|
||||||
|
# Single column indexes
|
||||||
|
op.create_index("ix_sprints_project_id", "sprints", ["project_id"])
|
||||||
|
op.create_index("ix_sprints_status", "sprints", ["status"])
|
||||||
|
op.create_index("ix_sprints_start_date", "sprints", ["start_date"])
|
||||||
|
op.create_index("ix_sprints_end_date", "sprints", ["end_date"])
|
||||||
|
# Composite indexes
|
||||||
|
op.create_index("ix_sprints_project_status", "sprints", ["project_id", "status"])
|
||||||
|
op.create_index("ix_sprints_project_number", "sprints", ["project_id", "number"])
|
||||||
|
op.create_index("ix_sprints_date_range", "sprints", ["start_date", "end_date"])
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Create issues table
|
||||||
|
# =========================================================================
|
||||||
|
op.create_table(
|
||||||
|
"issues",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||||
|
sa.Column("project_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||||
|
# Parent issue for hierarchy (Epic -> Story -> Task)
|
||||||
|
sa.Column("parent_id", postgresql.UUID(as_uuid=True), nullable=True),
|
||||||
|
# Issue type (epic, story, task, bug)
|
||||||
|
sa.Column(
|
||||||
|
"type",
|
||||||
|
sa.Enum(
|
||||||
|
"epic",
|
||||||
|
"story",
|
||||||
|
"task",
|
||||||
|
"bug",
|
||||||
|
name="issue_type",
|
||||||
|
),
|
||||||
|
nullable=False,
|
||||||
|
server_default="task",
|
||||||
|
),
|
||||||
|
# Reporter (who created this issue)
|
||||||
|
sa.Column("reporter_id", postgresql.UUID(as_uuid=True), nullable=True),
|
||||||
|
# Issue content
|
||||||
|
sa.Column("title", sa.String(500), nullable=False),
|
||||||
|
sa.Column("body", sa.Text(), nullable=False, server_default=""),
|
||||||
|
# Status and priority
|
||||||
|
sa.Column(
|
||||||
|
"status",
|
||||||
|
sa.Enum(
|
||||||
|
"open",
|
||||||
|
"in_progress",
|
||||||
|
"in_review",
|
||||||
|
"blocked",
|
||||||
|
"closed",
|
||||||
|
name="issue_status",
|
||||||
|
),
|
||||||
|
nullable=False,
|
||||||
|
server_default="open",
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"priority",
|
||||||
|
sa.Enum(
|
||||||
|
"low",
|
||||||
|
"medium",
|
||||||
|
"high",
|
||||||
|
"critical",
|
||||||
|
name="issue_priority",
|
||||||
|
),
|
||||||
|
nullable=False,
|
||||||
|
server_default="medium",
|
||||||
|
),
|
||||||
|
# Labels for categorization
|
||||||
|
sa.Column(
|
||||||
|
"labels",
|
||||||
|
postgresql.JSONB(astext_type=sa.Text()),
|
||||||
|
nullable=False,
|
||||||
|
server_default="[]",
|
||||||
|
),
|
||||||
|
# Assignment - agent or human (mutually exclusive)
|
||||||
|
sa.Column("assigned_agent_id", postgresql.UUID(as_uuid=True), nullable=True),
|
||||||
|
sa.Column("human_assignee", sa.String(255), nullable=True),
|
||||||
|
# Sprint association
|
||||||
|
sa.Column("sprint_id", postgresql.UUID(as_uuid=True), nullable=True),
|
||||||
|
# Estimation
|
||||||
|
sa.Column("story_points", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("due_date", sa.Date(), nullable=True),
|
||||||
|
# External tracker integration (String for flexibility)
|
||||||
|
sa.Column("external_tracker_type", sa.String(50), nullable=True),
|
||||||
|
sa.Column("external_issue_id", sa.String(255), nullable=True),
|
||||||
|
sa.Column("remote_url", sa.String(1000), nullable=True),
|
||||||
|
sa.Column("external_issue_number", sa.Integer(), nullable=True),
|
||||||
|
# Sync status
|
||||||
|
sa.Column(
|
||||||
|
"sync_status",
|
||||||
|
sa.Enum(
|
||||||
|
"synced",
|
||||||
|
"pending",
|
||||||
|
"conflict",
|
||||||
|
"error",
|
||||||
|
name="sync_status",
|
||||||
|
),
|
||||||
|
nullable=False,
|
||||||
|
server_default="synced",
|
||||||
|
),
|
||||||
|
sa.Column("last_synced_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("external_updated_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
# Lifecycle
|
||||||
|
sa.Column("closed_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"created_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["project_id"], ["projects.id"], ondelete="CASCADE"),
|
||||||
|
sa.ForeignKeyConstraint(["parent_id"], ["issues.id"], ondelete="CASCADE"),
|
||||||
|
sa.ForeignKeyConstraint(["sprint_id"], ["sprints.id"], ondelete="SET NULL"),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["assigned_agent_id"], ["agent_instances.id"], ondelete="SET NULL"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# Single column indexes
|
||||||
|
op.create_index("ix_issues_project_id", "issues", ["project_id"])
|
||||||
|
op.create_index("ix_issues_parent_id", "issues", ["parent_id"])
|
||||||
|
op.create_index("ix_issues_type", "issues", ["type"])
|
||||||
|
op.create_index("ix_issues_reporter_id", "issues", ["reporter_id"])
|
||||||
|
op.create_index("ix_issues_status", "issues", ["status"])
|
||||||
|
op.create_index("ix_issues_priority", "issues", ["priority"])
|
||||||
|
op.create_index("ix_issues_assigned_agent_id", "issues", ["assigned_agent_id"])
|
||||||
|
op.create_index("ix_issues_human_assignee", "issues", ["human_assignee"])
|
||||||
|
op.create_index("ix_issues_sprint_id", "issues", ["sprint_id"])
|
||||||
|
op.create_index("ix_issues_due_date", "issues", ["due_date"])
|
||||||
|
op.create_index(
|
||||||
|
"ix_issues_external_tracker_type", "issues", ["external_tracker_type"]
|
||||||
|
)
|
||||||
|
op.create_index("ix_issues_sync_status", "issues", ["sync_status"])
|
||||||
|
op.create_index("ix_issues_closed_at", "issues", ["closed_at"])
|
||||||
|
# Composite indexes
|
||||||
|
op.create_index("ix_issues_project_status", "issues", ["project_id", "status"])
|
||||||
|
op.create_index("ix_issues_project_priority", "issues", ["project_id", "priority"])
|
||||||
|
op.create_index("ix_issues_project_sprint", "issues", ["project_id", "sprint_id"])
|
||||||
|
op.create_index("ix_issues_project_type", "issues", ["project_id", "type"])
|
||||||
|
op.create_index(
|
||||||
|
"ix_issues_project_agent", "issues", ["project_id", "assigned_agent_id"]
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_issues_project_status_priority",
|
||||||
|
"issues",
|
||||||
|
["project_id", "status", "priority"],
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_issues_external_tracker_id",
|
||||||
|
"issues",
|
||||||
|
["external_tracker_type", "external_issue_id"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Drop Syndarix domain tables."""
|
||||||
|
# Drop tables in reverse order (respecting FK constraints)
|
||||||
|
op.drop_table("issues")
|
||||||
|
op.drop_table("sprints")
|
||||||
|
op.drop_table("agent_instances")
|
||||||
|
op.drop_table("agent_types")
|
||||||
|
op.drop_table("projects")
|
||||||
|
|
||||||
|
# Drop ENUM types
|
||||||
|
op.execute("DROP TYPE IF EXISTS sprint_status")
|
||||||
|
op.execute("DROP TYPE IF EXISTS sync_status")
|
||||||
|
op.execute("DROP TYPE IF EXISTS issue_priority")
|
||||||
|
op.execute("DROP TYPE IF EXISTS issue_status")
|
||||||
|
op.execute("DROP TYPE IF EXISTS issue_type")
|
||||||
|
op.execute("DROP TYPE IF EXISTS agent_status")
|
||||||
|
op.execute("DROP TYPE IF EXISTS client_mode")
|
||||||
|
op.execute("DROP TYPE IF EXISTS project_complexity")
|
||||||
|
op.execute("DROP TYPE IF EXISTS project_status")
|
||||||
|
op.execute("DROP TYPE IF EXISTS autonomy_level")
|
||||||
@@ -151,3 +151,83 @@ async def get_optional_current_user(
|
|||||||
return user
|
return user
|
||||||
except (TokenExpiredError, TokenInvalidError):
|
except (TokenExpiredError, TokenInvalidError):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def get_current_user_sse(
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
authorization: str | None = Header(None),
|
||||||
|
token: str | None = None, # Query parameter - passed directly from route
|
||||||
|
) -> User:
|
||||||
|
"""
|
||||||
|
Get the current authenticated user for SSE endpoints.
|
||||||
|
|
||||||
|
SSE (Server-Sent Events) via EventSource API doesn't support custom headers,
|
||||||
|
so this dependency accepts tokens from either:
|
||||||
|
1. Authorization header (preferred, for non-EventSource clients)
|
||||||
|
2. Query parameter 'token' (fallback for EventSource compatibility)
|
||||||
|
|
||||||
|
Security note: Query parameter tokens appear in server logs and browser history.
|
||||||
|
Consider implementing short-lived SSE-specific tokens for production if this
|
||||||
|
is a concern. The current approach is acceptable for internal/trusted networks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Database session
|
||||||
|
authorization: Authorization header (Bearer token)
|
||||||
|
token: Query parameter token (fallback for EventSource)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
User: The authenticated user
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If authentication fails
|
||||||
|
"""
|
||||||
|
# Try Authorization header first (preferred)
|
||||||
|
auth_token = None
|
||||||
|
if authorization:
|
||||||
|
scheme, param = get_authorization_scheme_param(authorization)
|
||||||
|
if scheme.lower() == "bearer" and param:
|
||||||
|
auth_token = param
|
||||||
|
|
||||||
|
# Fall back to query parameter if no header token
|
||||||
|
if not auth_token and token:
|
||||||
|
auth_token = token
|
||||||
|
|
||||||
|
if not auth_token:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Not authenticated",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Decode token and get user ID
|
||||||
|
token_data = get_token_data(auth_token)
|
||||||
|
|
||||||
|
# Get user from database
|
||||||
|
result = await db.execute(select(User).where(User.id == token_data.user_id))
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not user.is_active:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user"
|
||||||
|
)
|
||||||
|
|
||||||
|
return user
|
||||||
|
|
||||||
|
except TokenExpiredError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Token expired",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
except TokenInvalidError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Could not validate credentials",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
|||||||
36
backend/app/api/dependencies/event_bus.py
Normal file
36
backend/app/api/dependencies/event_bus.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
"""
|
||||||
|
Event bus dependency for FastAPI routes.
|
||||||
|
|
||||||
|
This module provides the FastAPI dependency for injecting the EventBus
|
||||||
|
into route handlers. The event bus is a singleton that maintains
|
||||||
|
Redis pub/sub connections for real-time event streaming.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from app.services.event_bus import (
|
||||||
|
EventBus,
|
||||||
|
get_connected_event_bus as _get_connected_event_bus,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_event_bus() -> EventBus:
|
||||||
|
"""
|
||||||
|
FastAPI dependency that provides a connected EventBus instance.
|
||||||
|
|
||||||
|
The EventBus is a singleton that maintains Redis pub/sub connections.
|
||||||
|
It's lazily initialized and connected on first access, and should be
|
||||||
|
closed during application shutdown via close_event_bus().
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
@router.get("/events/stream")
|
||||||
|
async def stream_events(
|
||||||
|
event_bus: EventBus = Depends(get_event_bus)
|
||||||
|
):
|
||||||
|
...
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
EventBus: The global connected event bus instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
EventBusConnectionError: If connection to Redis fails
|
||||||
|
"""
|
||||||
|
return await _get_connected_event_bus()
|
||||||
@@ -2,11 +2,18 @@ from fastapi import APIRouter
|
|||||||
|
|
||||||
from app.api.routes import (
|
from app.api.routes import (
|
||||||
admin,
|
admin,
|
||||||
|
agent_types,
|
||||||
|
agents,
|
||||||
auth,
|
auth,
|
||||||
|
events,
|
||||||
|
issues,
|
||||||
|
mcp,
|
||||||
oauth,
|
oauth,
|
||||||
oauth_provider,
|
oauth_provider,
|
||||||
organizations,
|
organizations,
|
||||||
|
projects,
|
||||||
sessions,
|
sessions,
|
||||||
|
sprints,
|
||||||
users,
|
users,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -22,3 +29,22 @@ api_router.include_router(admin.router, prefix="/admin", tags=["Admin"])
|
|||||||
api_router.include_router(
|
api_router.include_router(
|
||||||
organizations.router, prefix="/organizations", tags=["Organizations"]
|
organizations.router, prefix="/organizations", tags=["Organizations"]
|
||||||
)
|
)
|
||||||
|
# SSE events router - no prefix, routes define full paths
|
||||||
|
api_router.include_router(events.router, tags=["Events"])
|
||||||
|
|
||||||
|
# MCP (Model Context Protocol) router
|
||||||
|
api_router.include_router(mcp.router, prefix="/mcp", tags=["MCP"])
|
||||||
|
|
||||||
|
# Syndarix domain routers
|
||||||
|
api_router.include_router(projects.router, prefix="/projects", tags=["Projects"])
|
||||||
|
api_router.include_router(
|
||||||
|
agent_types.router, prefix="/agent-types", tags=["Agent Types"]
|
||||||
|
)
|
||||||
|
# Issues router - routes include /projects/{project_id}/issues paths
|
||||||
|
api_router.include_router(issues.router, tags=["Issues"])
|
||||||
|
# Agents router - routes include /projects/{project_id}/agents paths
|
||||||
|
api_router.include_router(agents.router, tags=["Agents"])
|
||||||
|
# Sprints router - routes need prefix as they use /projects/{project_id}/sprints paths
|
||||||
|
api_router.include_router(
|
||||||
|
sprints.router, prefix="/projects/{project_id}/sprints", tags=["Sprints"]
|
||||||
|
)
|
||||||
|
|||||||
462
backend/app/api/routes/agent_types.py
Normal file
462
backend/app/api/routes/agent_types.py
Normal file
@@ -0,0 +1,462 @@
|
|||||||
|
# app/api/routes/agent_types.py
|
||||||
|
"""
|
||||||
|
AgentType configuration API endpoints.
|
||||||
|
|
||||||
|
Provides CRUD operations for managing AI agent type templates.
|
||||||
|
Agent types define the base configuration (model, personality, expertise)
|
||||||
|
from which agent instances are spawned for projects.
|
||||||
|
|
||||||
|
Authorization:
|
||||||
|
- Read endpoints: Any authenticated user
|
||||||
|
- Write endpoints (create, update, delete): Superusers only
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Any
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, Query, Request, status
|
||||||
|
from slowapi import Limiter
|
||||||
|
from slowapi.util import get_remote_address
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.api.dependencies.auth import get_current_user
|
||||||
|
from app.api.dependencies.permissions import require_superuser
|
||||||
|
from app.core.database import get_db
|
||||||
|
from app.core.exceptions import (
|
||||||
|
DuplicateError,
|
||||||
|
ErrorCode,
|
||||||
|
NotFoundError,
|
||||||
|
)
|
||||||
|
from app.crud.syndarix.agent_type import agent_type as agent_type_crud
|
||||||
|
from app.models.user import User
|
||||||
|
from app.schemas.common import (
|
||||||
|
MessageResponse,
|
||||||
|
PaginatedResponse,
|
||||||
|
PaginationParams,
|
||||||
|
create_pagination_meta,
|
||||||
|
)
|
||||||
|
from app.schemas.syndarix import (
|
||||||
|
AgentTypeCreate,
|
||||||
|
AgentTypeResponse,
|
||||||
|
AgentTypeUpdate,
|
||||||
|
)
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Initialize limiter for this router
|
||||||
|
limiter = Limiter(key_func=get_remote_address)
|
||||||
|
|
||||||
|
# Use higher rate limits in test environment
|
||||||
|
IS_TEST = os.getenv("IS_TEST", "False") == "True"
|
||||||
|
RATE_MULTIPLIER = 100 if IS_TEST else 1
|
||||||
|
|
||||||
|
|
||||||
|
def _build_agent_type_response(
|
||||||
|
agent_type: Any,
|
||||||
|
instance_count: int = 0,
|
||||||
|
) -> AgentTypeResponse:
|
||||||
|
"""
|
||||||
|
Build an AgentTypeResponse from a database model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_type: AgentType model instance
|
||||||
|
instance_count: Number of agent instances for this type
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AgentTypeResponse schema
|
||||||
|
"""
|
||||||
|
return AgentTypeResponse(
|
||||||
|
id=agent_type.id,
|
||||||
|
name=agent_type.name,
|
||||||
|
slug=agent_type.slug,
|
||||||
|
description=agent_type.description,
|
||||||
|
expertise=agent_type.expertise,
|
||||||
|
personality_prompt=agent_type.personality_prompt,
|
||||||
|
primary_model=agent_type.primary_model,
|
||||||
|
fallback_models=agent_type.fallback_models,
|
||||||
|
model_params=agent_type.model_params,
|
||||||
|
mcp_servers=agent_type.mcp_servers,
|
||||||
|
tool_permissions=agent_type.tool_permissions,
|
||||||
|
is_active=agent_type.is_active,
|
||||||
|
created_at=agent_type.created_at,
|
||||||
|
updated_at=agent_type.updated_at,
|
||||||
|
instance_count=instance_count,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ===== Write Endpoints (Admin Only) =====
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"",
|
||||||
|
response_model=AgentTypeResponse,
|
||||||
|
status_code=status.HTTP_201_CREATED,
|
||||||
|
summary="Create Agent Type",
|
||||||
|
description="Create a new agent type configuration (admin only)",
|
||||||
|
operation_id="create_agent_type",
|
||||||
|
)
|
||||||
|
@limiter.limit(f"{20 * RATE_MULTIPLIER}/minute")
|
||||||
|
async def create_agent_type(
|
||||||
|
request: Request,
|
||||||
|
agent_type_in: AgentTypeCreate,
|
||||||
|
admin: User = Depends(require_superuser),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Create a new agent type configuration.
|
||||||
|
|
||||||
|
Agent types define templates for AI agents including:
|
||||||
|
- Model configuration (primary model, fallback models, parameters)
|
||||||
|
- Personality and expertise areas
|
||||||
|
- MCP server integrations and tool permissions
|
||||||
|
|
||||||
|
Requires superuser privileges.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: FastAPI request object
|
||||||
|
agent_type_in: Agent type creation data
|
||||||
|
admin: Authenticated superuser
|
||||||
|
db: Database session
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created agent type configuration
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
DuplicateError: If slug already exists
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
agent_type = await agent_type_crud.create(db, obj_in=agent_type_in)
|
||||||
|
logger.info(
|
||||||
|
f"Admin {admin.email} created agent type: {agent_type.name} "
|
||||||
|
f"(slug: {agent_type.slug})"
|
||||||
|
)
|
||||||
|
return _build_agent_type_response(agent_type, instance_count=0)
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
logger.warning(f"Failed to create agent type: {e!s}")
|
||||||
|
raise DuplicateError(
|
||||||
|
message=str(e),
|
||||||
|
error_code=ErrorCode.ALREADY_EXISTS,
|
||||||
|
field="slug",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error creating agent type: {e!s}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@router.patch(
|
||||||
|
"/{agent_type_id}",
|
||||||
|
response_model=AgentTypeResponse,
|
||||||
|
summary="Update Agent Type",
|
||||||
|
description="Update an existing agent type configuration (admin only)",
|
||||||
|
operation_id="update_agent_type",
|
||||||
|
)
|
||||||
|
@limiter.limit(f"{30 * RATE_MULTIPLIER}/minute")
|
||||||
|
async def update_agent_type(
|
||||||
|
request: Request,
|
||||||
|
agent_type_id: UUID,
|
||||||
|
agent_type_in: AgentTypeUpdate,
|
||||||
|
admin: User = Depends(require_superuser),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Update an existing agent type configuration.
|
||||||
|
|
||||||
|
Partial updates are supported - only provided fields will be updated.
|
||||||
|
|
||||||
|
Requires superuser privileges.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: FastAPI request object
|
||||||
|
agent_type_id: UUID of the agent type to update
|
||||||
|
agent_type_in: Agent type update data
|
||||||
|
admin: Authenticated superuser
|
||||||
|
db: Database session
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The updated agent type configuration
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If agent type not found
|
||||||
|
DuplicateError: If new slug already exists
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Verify agent type exists
|
||||||
|
result = await agent_type_crud.get_with_instance_count(
|
||||||
|
db, agent_type_id=agent_type_id
|
||||||
|
)
|
||||||
|
if not result:
|
||||||
|
raise NotFoundError(
|
||||||
|
message=f"Agent type {agent_type_id} not found",
|
||||||
|
error_code=ErrorCode.NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
existing_type = result["agent_type"]
|
||||||
|
instance_count = result["instance_count"]
|
||||||
|
|
||||||
|
# Perform update
|
||||||
|
updated_type = await agent_type_crud.update(
|
||||||
|
db, db_obj=existing_type, obj_in=agent_type_in
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Admin {admin.email} updated agent type: {updated_type.name} "
|
||||||
|
f"(id: {agent_type_id})"
|
||||||
|
)
|
||||||
|
|
||||||
|
return _build_agent_type_response(updated_type, instance_count=instance_count)
|
||||||
|
|
||||||
|
except NotFoundError:
|
||||||
|
raise
|
||||||
|
except ValueError as e:
|
||||||
|
logger.warning(f"Failed to update agent type {agent_type_id}: {e!s}")
|
||||||
|
raise DuplicateError(
|
||||||
|
message=str(e),
|
||||||
|
error_code=ErrorCode.ALREADY_EXISTS,
|
||||||
|
field="slug",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error updating agent type {agent_type_id}: {e!s}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete(
|
||||||
|
"/{agent_type_id}",
|
||||||
|
response_model=MessageResponse,
|
||||||
|
summary="Deactivate Agent Type",
|
||||||
|
description="Deactivate an agent type (soft delete, admin only)",
|
||||||
|
operation_id="deactivate_agent_type",
|
||||||
|
)
|
||||||
|
@limiter.limit(f"{10 * RATE_MULTIPLIER}/minute")
|
||||||
|
async def deactivate_agent_type(
|
||||||
|
request: Request,
|
||||||
|
agent_type_id: UUID,
|
||||||
|
admin: User = Depends(require_superuser),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Deactivate an agent type (soft delete).
|
||||||
|
|
||||||
|
This sets is_active=False rather than deleting the record,
|
||||||
|
preserving referential integrity with existing agent instances.
|
||||||
|
|
||||||
|
Requires superuser privileges.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: FastAPI request object
|
||||||
|
agent_type_id: UUID of the agent type to deactivate
|
||||||
|
admin: Authenticated superuser
|
||||||
|
db: Database session
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Success message
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If agent type not found
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
deactivated = await agent_type_crud.deactivate(db, agent_type_id=agent_type_id)
|
||||||
|
|
||||||
|
if not deactivated:
|
||||||
|
raise NotFoundError(
|
||||||
|
message=f"Agent type {agent_type_id} not found",
|
||||||
|
error_code=ErrorCode.NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Admin {admin.email} deactivated agent type: {deactivated.name} "
|
||||||
|
f"(id: {agent_type_id})"
|
||||||
|
)
|
||||||
|
|
||||||
|
return MessageResponse(
|
||||||
|
success=True,
|
||||||
|
message=f"Agent type '{deactivated.name}' has been deactivated",
|
||||||
|
)
|
||||||
|
|
||||||
|
except NotFoundError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error deactivating agent type {agent_type_id}: {e!s}", exc_info=True
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
# ===== Read Endpoints (Authenticated Users) =====
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"",
|
||||||
|
response_model=PaginatedResponse[AgentTypeResponse],
|
||||||
|
summary="List Agent Types",
|
||||||
|
description="Get paginated list of active agent types",
|
||||||
|
operation_id="list_agent_types",
|
||||||
|
)
|
||||||
|
@limiter.limit(f"{60 * RATE_MULTIPLIER}/minute")
|
||||||
|
async def list_agent_types(
|
||||||
|
request: Request,
|
||||||
|
pagination: PaginationParams = Depends(),
|
||||||
|
is_active: bool = Query(True, description="Filter by active status"),
|
||||||
|
search: str | None = Query(None, description="Search by name, slug, description"),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
List all agent types with pagination and filtering.
|
||||||
|
|
||||||
|
By default, returns only active agent types. Set is_active=false
|
||||||
|
to include deactivated types (useful for admin views).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: FastAPI request object
|
||||||
|
pagination: Pagination parameters (page, limit)
|
||||||
|
is_active: Filter by active status (default: True)
|
||||||
|
search: Optional search term for name, slug, description
|
||||||
|
current_user: Authenticated user
|
||||||
|
db: Database session
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Paginated list of agent types with instance counts
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Get agent types with instance counts
|
||||||
|
results, total = await agent_type_crud.get_multi_with_instance_counts(
|
||||||
|
db,
|
||||||
|
skip=pagination.offset,
|
||||||
|
limit=pagination.limit,
|
||||||
|
is_active=is_active,
|
||||||
|
search=search,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build response objects
|
||||||
|
agent_types_response = [
|
||||||
|
_build_agent_type_response(
|
||||||
|
item["agent_type"],
|
||||||
|
instance_count=item["instance_count"],
|
||||||
|
)
|
||||||
|
for item in results
|
||||||
|
]
|
||||||
|
|
||||||
|
pagination_meta = create_pagination_meta(
|
||||||
|
total=total,
|
||||||
|
page=pagination.page,
|
||||||
|
limit=pagination.limit,
|
||||||
|
items_count=len(agent_types_response),
|
||||||
|
)
|
||||||
|
|
||||||
|
return PaginatedResponse(data=agent_types_response, pagination=pagination_meta)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error listing agent types: {e!s}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/{agent_type_id}",
|
||||||
|
response_model=AgentTypeResponse,
|
||||||
|
summary="Get Agent Type",
|
||||||
|
description="Get agent type details by ID",
|
||||||
|
operation_id="get_agent_type",
|
||||||
|
)
|
||||||
|
@limiter.limit(f"{100 * RATE_MULTIPLIER}/minute")
|
||||||
|
async def get_agent_type(
|
||||||
|
request: Request,
|
||||||
|
agent_type_id: UUID,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Get detailed information about a specific agent type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: FastAPI request object
|
||||||
|
agent_type_id: UUID of the agent type
|
||||||
|
current_user: Authenticated user
|
||||||
|
db: Database session
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Agent type details with instance count
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If agent type not found
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
result = await agent_type_crud.get_with_instance_count(
|
||||||
|
db, agent_type_id=agent_type_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if not result:
|
||||||
|
raise NotFoundError(
|
||||||
|
message=f"Agent type {agent_type_id} not found",
|
||||||
|
error_code=ErrorCode.NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
return _build_agent_type_response(
|
||||||
|
result["agent_type"],
|
||||||
|
instance_count=result["instance_count"],
|
||||||
|
)
|
||||||
|
|
||||||
|
except NotFoundError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting agent type {agent_type_id}: {e!s}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/slug/{slug}",
|
||||||
|
response_model=AgentTypeResponse,
|
||||||
|
summary="Get Agent Type by Slug",
|
||||||
|
description="Get agent type details by slug",
|
||||||
|
operation_id="get_agent_type_by_slug",
|
||||||
|
)
|
||||||
|
@limiter.limit(f"{100 * RATE_MULTIPLIER}/minute")
|
||||||
|
async def get_agent_type_by_slug(
|
||||||
|
request: Request,
|
||||||
|
slug: str,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Get detailed information about an agent type by its slug.
|
||||||
|
|
||||||
|
Slugs are human-readable identifiers like "product-owner" or "backend-engineer".
|
||||||
|
Useful for referencing agent types in configuration files or APIs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: FastAPI request object
|
||||||
|
slug: Slug identifier of the agent type
|
||||||
|
current_user: Authenticated user
|
||||||
|
db: Database session
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Agent type details with instance count
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If agent type not found
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
agent_type = await agent_type_crud.get_by_slug(db, slug=slug)
|
||||||
|
|
||||||
|
if not agent_type:
|
||||||
|
raise NotFoundError(
|
||||||
|
message=f"Agent type with slug '{slug}' not found",
|
||||||
|
error_code=ErrorCode.NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get instance count separately
|
||||||
|
result = await agent_type_crud.get_with_instance_count(
|
||||||
|
db, agent_type_id=agent_type.id
|
||||||
|
)
|
||||||
|
instance_count = result["instance_count"] if result else 0
|
||||||
|
|
||||||
|
return _build_agent_type_response(agent_type, instance_count=instance_count)
|
||||||
|
|
||||||
|
except NotFoundError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting agent type by slug '{slug}': {e!s}", exc_info=True)
|
||||||
|
raise
|
||||||
984
backend/app/api/routes/agents.py
Normal file
984
backend/app/api/routes/agents.py
Normal file
@@ -0,0 +1,984 @@
|
|||||||
|
# app/api/routes/agents.py
|
||||||
|
"""
|
||||||
|
Agent Instance management endpoints for Syndarix projects.
|
||||||
|
|
||||||
|
These endpoints allow project owners and superusers to manage AI agent instances
|
||||||
|
within their projects, including spawning, pausing, resuming, and terminating agents.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Any
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, Query, Request, status
|
||||||
|
from slowapi import Limiter
|
||||||
|
from slowapi.util import get_remote_address
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.api.dependencies.auth import get_current_user
|
||||||
|
from app.core.database import get_db
|
||||||
|
from app.core.exceptions import (
|
||||||
|
AuthorizationError,
|
||||||
|
NotFoundError,
|
||||||
|
ValidationException,
|
||||||
|
)
|
||||||
|
from app.crud.syndarix.agent_instance import agent_instance as agent_instance_crud
|
||||||
|
from app.crud.syndarix.agent_type import agent_type as agent_type_crud
|
||||||
|
from app.crud.syndarix.project import project as project_crud
|
||||||
|
from app.models.syndarix import AgentInstance, Project
|
||||||
|
from app.models.syndarix.enums import AgentStatus
|
||||||
|
from app.models.user import User
|
||||||
|
from app.schemas.common import (
|
||||||
|
MessageResponse,
|
||||||
|
PaginatedResponse,
|
||||||
|
PaginationParams,
|
||||||
|
create_pagination_meta,
|
||||||
|
)
|
||||||
|
from app.schemas.errors import ErrorCode
|
||||||
|
from app.schemas.syndarix.agent_instance import (
|
||||||
|
AgentInstanceCreate,
|
||||||
|
AgentInstanceMetrics,
|
||||||
|
AgentInstanceResponse,
|
||||||
|
AgentInstanceUpdate,
|
||||||
|
)
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Initialize limiter for this router
|
||||||
|
limiter = Limiter(key_func=get_remote_address)
|
||||||
|
|
||||||
|
# Use higher rate limits in test environment
|
||||||
|
IS_TEST = os.getenv("IS_TEST", "False") == "True"
|
||||||
|
RATE_MULTIPLIER = 100 if IS_TEST else 1
|
||||||
|
|
||||||
|
|
||||||
|
# Valid status transitions for agent lifecycle management
|
||||||
|
VALID_STATUS_TRANSITIONS: dict[AgentStatus, set[AgentStatus]] = {
|
||||||
|
AgentStatus.IDLE: {AgentStatus.WORKING, AgentStatus.PAUSED, AgentStatus.TERMINATED},
|
||||||
|
AgentStatus.WORKING: {
|
||||||
|
AgentStatus.IDLE,
|
||||||
|
AgentStatus.WAITING,
|
||||||
|
AgentStatus.PAUSED,
|
||||||
|
AgentStatus.TERMINATED,
|
||||||
|
},
|
||||||
|
AgentStatus.WAITING: {
|
||||||
|
AgentStatus.IDLE,
|
||||||
|
AgentStatus.WORKING,
|
||||||
|
AgentStatus.PAUSED,
|
||||||
|
AgentStatus.TERMINATED,
|
||||||
|
},
|
||||||
|
AgentStatus.PAUSED: {AgentStatus.IDLE, AgentStatus.TERMINATED},
|
||||||
|
AgentStatus.TERMINATED: set(), # Terminal state, no transitions allowed
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def verify_project_access(
|
||||||
|
db: AsyncSession,
|
||||||
|
project_id: UUID,
|
||||||
|
user: User,
|
||||||
|
) -> Project:
|
||||||
|
"""
|
||||||
|
Verify user has access to a project.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Database session
|
||||||
|
project_id: UUID of the project to verify
|
||||||
|
user: Current authenticated user
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Project: The project if access is granted
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If the project does not exist
|
||||||
|
AuthorizationError: If the user does not have access to the project
|
||||||
|
"""
|
||||||
|
project = await project_crud.get(db, id=project_id)
|
||||||
|
if not project:
|
||||||
|
raise NotFoundError(
|
||||||
|
message=f"Project {project_id} not found",
|
||||||
|
error_code=ErrorCode.NOT_FOUND,
|
||||||
|
)
|
||||||
|
if not user.is_superuser and project.owner_id != user.id:
|
||||||
|
raise AuthorizationError(
|
||||||
|
message="You do not have access to this project",
|
||||||
|
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||||
|
)
|
||||||
|
return project
|
||||||
|
|
||||||
|
|
||||||
|
def validate_status_transition(
|
||||||
|
current_status: AgentStatus,
|
||||||
|
target_status: AgentStatus,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Validate that a status transition is allowed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
current_status: The agent's current status
|
||||||
|
target_status: The desired target status
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValidationException: If the transition is not allowed
|
||||||
|
"""
|
||||||
|
valid_targets = VALID_STATUS_TRANSITIONS.get(current_status, set())
|
||||||
|
if target_status not in valid_targets:
|
||||||
|
raise ValidationException(
|
||||||
|
message=f"Cannot transition from {current_status.value} to {target_status.value}",
|
||||||
|
error_code=ErrorCode.VALIDATION_ERROR,
|
||||||
|
field="status",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_agent_response(
|
||||||
|
agent: AgentInstance,
|
||||||
|
agent_type_name: str | None = None,
|
||||||
|
agent_type_slug: str | None = None,
|
||||||
|
project_name: str | None = None,
|
||||||
|
project_slug: str | None = None,
|
||||||
|
assigned_issues_count: int = 0,
|
||||||
|
) -> AgentInstanceResponse:
|
||||||
|
"""
|
||||||
|
Build an AgentInstanceResponse from an AgentInstance model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent: The agent instance model
|
||||||
|
agent_type_name: Name of the agent type
|
||||||
|
agent_type_slug: Slug of the agent type
|
||||||
|
project_name: Name of the project
|
||||||
|
project_slug: Slug of the project
|
||||||
|
assigned_issues_count: Number of issues assigned to this agent
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AgentInstanceResponse: The response schema
|
||||||
|
"""
|
||||||
|
return AgentInstanceResponse(
|
||||||
|
id=agent.id,
|
||||||
|
agent_type_id=agent.agent_type_id,
|
||||||
|
project_id=agent.project_id,
|
||||||
|
name=agent.name,
|
||||||
|
status=agent.status,
|
||||||
|
current_task=agent.current_task,
|
||||||
|
short_term_memory=agent.short_term_memory or {},
|
||||||
|
long_term_memory_ref=agent.long_term_memory_ref,
|
||||||
|
session_id=agent.session_id,
|
||||||
|
last_activity_at=agent.last_activity_at,
|
||||||
|
terminated_at=agent.terminated_at,
|
||||||
|
tasks_completed=agent.tasks_completed,
|
||||||
|
tokens_used=agent.tokens_used,
|
||||||
|
cost_incurred=agent.cost_incurred,
|
||||||
|
created_at=agent.created_at,
|
||||||
|
updated_at=agent.updated_at,
|
||||||
|
agent_type_name=agent_type_name,
|
||||||
|
agent_type_slug=agent_type_slug,
|
||||||
|
project_name=project_name,
|
||||||
|
project_slug=project_slug,
|
||||||
|
assigned_issues_count=assigned_issues_count,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ===== Agent Instance Management Endpoints =====
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/projects/{project_id}/agents",
|
||||||
|
response_model=AgentInstanceResponse,
|
||||||
|
status_code=status.HTTP_201_CREATED,
|
||||||
|
summary="Spawn Agent Instance",
|
||||||
|
description="Spawn a new agent instance in a project. Requires project ownership or superuser.",
|
||||||
|
operation_id="spawn_agent",
|
||||||
|
)
|
||||||
|
@limiter.limit(f"{20 * RATE_MULTIPLIER}/minute")
|
||||||
|
async def spawn_agent(
|
||||||
|
request: Request,
|
||||||
|
project_id: UUID,
|
||||||
|
agent_in: AgentInstanceCreate,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Spawn a new agent instance in a project.
|
||||||
|
|
||||||
|
Creates a new agent instance from an agent type template and assigns it
|
||||||
|
to the specified project. The agent starts in IDLE status by default.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: FastAPI request object (for rate limiting)
|
||||||
|
project_id: UUID of the project to spawn the agent in
|
||||||
|
agent_in: Agent instance creation data
|
||||||
|
current_user: Current authenticated user
|
||||||
|
db: Database session
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AgentInstanceResponse: The newly created agent instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If the project is not found
|
||||||
|
AuthorizationError: If the user lacks access to the project
|
||||||
|
ValidationException: If the agent creation data is invalid
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Verify project access
|
||||||
|
project = await verify_project_access(db, project_id, current_user)
|
||||||
|
|
||||||
|
# Ensure the agent is being created for the correct project
|
||||||
|
if agent_in.project_id != project_id:
|
||||||
|
raise ValidationException(
|
||||||
|
message="Agent project_id must match the URL project_id",
|
||||||
|
error_code=ErrorCode.VALIDATION_ERROR,
|
||||||
|
field="project_id",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate that the agent type exists and is active
|
||||||
|
agent_type = await agent_type_crud.get(db, id=agent_in.agent_type_id)
|
||||||
|
if not agent_type:
|
||||||
|
raise NotFoundError(
|
||||||
|
message=f"Agent type {agent_in.agent_type_id} not found",
|
||||||
|
error_code=ErrorCode.NOT_FOUND,
|
||||||
|
)
|
||||||
|
if not agent_type.is_active:
|
||||||
|
raise ValidationException(
|
||||||
|
message=f"Agent type '{agent_type.name}' is inactive and cannot be used",
|
||||||
|
error_code=ErrorCode.VALIDATION_ERROR,
|
||||||
|
field="agent_type_id",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create the agent instance
|
||||||
|
agent = await agent_instance_crud.create(db, obj_in=agent_in)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"User {current_user.email} spawned agent '{agent.name}' "
|
||||||
|
f"(id={agent.id}) in project {project.slug}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get agent details for response
|
||||||
|
details = await agent_instance_crud.get_with_details(db, instance_id=agent.id)
|
||||||
|
if details:
|
||||||
|
return build_agent_response(
|
||||||
|
agent=details["instance"],
|
||||||
|
agent_type_name=details.get("agent_type_name"),
|
||||||
|
agent_type_slug=details.get("agent_type_slug"),
|
||||||
|
project_name=details.get("project_name"),
|
||||||
|
project_slug=details.get("project_slug"),
|
||||||
|
assigned_issues_count=details.get("assigned_issues_count", 0),
|
||||||
|
)
|
||||||
|
|
||||||
|
return build_agent_response(agent)
|
||||||
|
|
||||||
|
except (NotFoundError, AuthorizationError, ValidationException):
|
||||||
|
raise
|
||||||
|
except ValueError as e:
|
||||||
|
logger.warning(f"Failed to spawn agent: {e!s}")
|
||||||
|
raise ValidationException(
|
||||||
|
message=str(e),
|
||||||
|
error_code=ErrorCode.VALIDATION_ERROR,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error spawning agent: {e!s}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/projects/{project_id}/agents",
|
||||||
|
response_model=PaginatedResponse[AgentInstanceResponse],
|
||||||
|
summary="List Project Agents",
|
||||||
|
description="List all agent instances in a project with optional filtering.",
|
||||||
|
operation_id="list_project_agents",
|
||||||
|
)
|
||||||
|
@limiter.limit(f"{60 * RATE_MULTIPLIER}/minute")
|
||||||
|
async def list_project_agents(
|
||||||
|
request: Request,
|
||||||
|
project_id: UUID,
|
||||||
|
pagination: PaginationParams = Depends(),
|
||||||
|
status_filter: AgentStatus | None = Query(
|
||||||
|
None, alias="status", description="Filter by agent status"
|
||||||
|
),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
List all agent instances in a project.
|
||||||
|
|
||||||
|
Returns a paginated list of agents with optional status filtering.
|
||||||
|
Results are ordered by creation date (newest first).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: FastAPI request object (for rate limiting)
|
||||||
|
project_id: UUID of the project
|
||||||
|
pagination: Pagination parameters
|
||||||
|
status_filter: Optional filter by agent status
|
||||||
|
current_user: Current authenticated user
|
||||||
|
db: Database session
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PaginatedResponse[AgentInstanceResponse]: Paginated list of agents
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If the project is not found
|
||||||
|
AuthorizationError: If the user lacks access to the project
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Verify project access
|
||||||
|
project = await verify_project_access(db, project_id, current_user)
|
||||||
|
|
||||||
|
# Get agents for the project
|
||||||
|
agents, total = await agent_instance_crud.get_by_project(
|
||||||
|
db,
|
||||||
|
project_id=project_id,
|
||||||
|
status=status_filter,
|
||||||
|
skip=pagination.offset,
|
||||||
|
limit=pagination.limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build response objects
|
||||||
|
agent_responses = []
|
||||||
|
for agent in agents:
|
||||||
|
# Get details for each agent (could be optimized with bulk query)
|
||||||
|
details = await agent_instance_crud.get_with_details(
|
||||||
|
db, instance_id=agent.id
|
||||||
|
)
|
||||||
|
if details:
|
||||||
|
agent_responses.append(
|
||||||
|
build_agent_response(
|
||||||
|
agent=details["instance"],
|
||||||
|
agent_type_name=details.get("agent_type_name"),
|
||||||
|
agent_type_slug=details.get("agent_type_slug"),
|
||||||
|
project_name=details.get("project_name"),
|
||||||
|
project_slug=details.get("project_slug"),
|
||||||
|
assigned_issues_count=details.get("assigned_issues_count", 0),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
agent_responses.append(build_agent_response(agent))
|
||||||
|
|
||||||
|
pagination_meta = create_pagination_meta(
|
||||||
|
total=total,
|
||||||
|
page=pagination.page,
|
||||||
|
limit=pagination.limit,
|
||||||
|
items_count=len(agent_responses),
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"User {current_user.email} listed {len(agent_responses)} agents "
|
||||||
|
f"in project {project.slug}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return PaginatedResponse(data=agent_responses, pagination=pagination_meta)
|
||||||
|
|
||||||
|
except (NotFoundError, AuthorizationError):
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error listing project agents: {e!s}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
# ===== Project Agent Metrics Endpoint =====
|
||||||
|
# NOTE: This endpoint MUST be defined before /{agent_id} routes
|
||||||
|
# to prevent FastAPI from trying to parse "metrics" as a UUID
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/projects/{project_id}/agents/metrics",
|
||||||
|
response_model=AgentInstanceMetrics,
|
||||||
|
summary="Get Project Agent Metrics",
|
||||||
|
description="Get aggregated usage metrics for all agents in a project.",
|
||||||
|
operation_id="get_project_agent_metrics",
|
||||||
|
)
|
||||||
|
@limiter.limit(f"{60 * RATE_MULTIPLIER}/minute")
|
||||||
|
async def get_project_agent_metrics(
|
||||||
|
request: Request,
|
||||||
|
project_id: UUID,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Get aggregated usage metrics for all agents in a project.
|
||||||
|
|
||||||
|
Returns aggregated metrics across all agents including total
|
||||||
|
tasks completed, tokens used, and cost incurred.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: FastAPI request object (for rate limiting)
|
||||||
|
project_id: UUID of the project
|
||||||
|
current_user: Current authenticated user
|
||||||
|
db: Database session
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AgentInstanceMetrics: Aggregated project agent metrics
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If the project is not found
|
||||||
|
AuthorizationError: If the user lacks access to the project
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Verify project access
|
||||||
|
project = await verify_project_access(db, project_id, current_user)
|
||||||
|
|
||||||
|
# Get aggregated metrics for the project
|
||||||
|
metrics = await agent_instance_crud.get_project_metrics(
|
||||||
|
db, project_id=project_id
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"User {current_user.email} retrieved project metrics for {project.slug}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return AgentInstanceMetrics(
|
||||||
|
total_instances=metrics["total_instances"],
|
||||||
|
active_instances=metrics["active_instances"],
|
||||||
|
idle_instances=metrics["idle_instances"],
|
||||||
|
total_tasks_completed=metrics["total_tasks_completed"],
|
||||||
|
total_tokens_used=metrics["total_tokens_used"],
|
||||||
|
total_cost_incurred=metrics["total_cost_incurred"],
|
||||||
|
)
|
||||||
|
|
||||||
|
except (NotFoundError, AuthorizationError):
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting project agent metrics: {e!s}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/projects/{project_id}/agents/{agent_id}",
|
||||||
|
response_model=AgentInstanceResponse,
|
||||||
|
summary="Get Agent Details",
|
||||||
|
description="Get detailed information about a specific agent instance.",
|
||||||
|
operation_id="get_agent",
|
||||||
|
)
|
||||||
|
@limiter.limit(f"{60 * RATE_MULTIPLIER}/minute")
|
||||||
|
async def get_agent(
|
||||||
|
request: Request,
|
||||||
|
project_id: UUID,
|
||||||
|
agent_id: UUID,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Get detailed information about a specific agent instance.
|
||||||
|
|
||||||
|
Returns full agent details including related entity information
|
||||||
|
(agent type name, project name) and assigned issues count.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: FastAPI request object (for rate limiting)
|
||||||
|
project_id: UUID of the project
|
||||||
|
agent_id: UUID of the agent instance
|
||||||
|
current_user: Current authenticated user
|
||||||
|
db: Database session
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AgentInstanceResponse: The agent instance details
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If the project or agent is not found
|
||||||
|
AuthorizationError: If the user lacks access to the project
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Verify project access
|
||||||
|
await verify_project_access(db, project_id, current_user)
|
||||||
|
|
||||||
|
# Get agent with full details
|
||||||
|
details = await agent_instance_crud.get_with_details(db, instance_id=agent_id)
|
||||||
|
|
||||||
|
if not details:
|
||||||
|
raise NotFoundError(
|
||||||
|
message=f"Agent {agent_id} not found",
|
||||||
|
error_code=ErrorCode.NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
agent = details["instance"]
|
||||||
|
|
||||||
|
# Verify agent belongs to the specified project
|
||||||
|
if agent.project_id != project_id:
|
||||||
|
raise NotFoundError(
|
||||||
|
message=f"Agent {agent_id} not found in project {project_id}",
|
||||||
|
error_code=ErrorCode.NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"User {current_user.email} retrieved agent {agent.name} (id={agent_id})"
|
||||||
|
)
|
||||||
|
|
||||||
|
return build_agent_response(
|
||||||
|
agent=agent,
|
||||||
|
agent_type_name=details.get("agent_type_name"),
|
||||||
|
agent_type_slug=details.get("agent_type_slug"),
|
||||||
|
project_name=details.get("project_name"),
|
||||||
|
project_slug=details.get("project_slug"),
|
||||||
|
assigned_issues_count=details.get("assigned_issues_count", 0),
|
||||||
|
)
|
||||||
|
|
||||||
|
except (NotFoundError, AuthorizationError):
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting agent details: {e!s}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@router.patch(
|
||||||
|
"/projects/{project_id}/agents/{agent_id}",
|
||||||
|
response_model=AgentInstanceResponse,
|
||||||
|
summary="Update Agent",
|
||||||
|
description="Update an agent instance's configuration and state.",
|
||||||
|
operation_id="update_agent",
|
||||||
|
)
|
||||||
|
@limiter.limit(f"{30 * RATE_MULTIPLIER}/minute")
|
||||||
|
async def update_agent(
|
||||||
|
request: Request,
|
||||||
|
project_id: UUID,
|
||||||
|
agent_id: UUID,
|
||||||
|
agent_in: AgentInstanceUpdate,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Update an agent instance's configuration and state.
|
||||||
|
|
||||||
|
Allows updating agent status, current task, memory, and other
|
||||||
|
configurable fields. Status transitions are validated according
|
||||||
|
to the agent lifecycle state machine.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: FastAPI request object (for rate limiting)
|
||||||
|
project_id: UUID of the project
|
||||||
|
agent_id: UUID of the agent instance
|
||||||
|
agent_in: Agent update data
|
||||||
|
current_user: Current authenticated user
|
||||||
|
db: Database session
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AgentInstanceResponse: The updated agent instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If the project or agent is not found
|
||||||
|
AuthorizationError: If the user lacks access to the project
|
||||||
|
ValidationException: If the status transition is invalid
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Verify project access
|
||||||
|
await verify_project_access(db, project_id, current_user)
|
||||||
|
|
||||||
|
# Get current agent
|
||||||
|
agent = await agent_instance_crud.get(db, id=agent_id)
|
||||||
|
if not agent:
|
||||||
|
raise NotFoundError(
|
||||||
|
message=f"Agent {agent_id} not found",
|
||||||
|
error_code=ErrorCode.NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify agent belongs to the specified project
|
||||||
|
if agent.project_id != project_id:
|
||||||
|
raise NotFoundError(
|
||||||
|
message=f"Agent {agent_id} not found in project {project_id}",
|
||||||
|
error_code=ErrorCode.NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate status transition if status is being changed
|
||||||
|
if agent_in.status is not None and agent_in.status != agent.status:
|
||||||
|
validate_status_transition(agent.status, agent_in.status)
|
||||||
|
|
||||||
|
# Update the agent
|
||||||
|
updated_agent = await agent_instance_crud.update(
|
||||||
|
db, db_obj=agent, obj_in=agent_in
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"User {current_user.email} updated agent {updated_agent.name} "
|
||||||
|
f"(id={agent_id})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get updated details
|
||||||
|
details = await agent_instance_crud.get_with_details(
|
||||||
|
db, instance_id=updated_agent.id
|
||||||
|
)
|
||||||
|
if details:
|
||||||
|
return build_agent_response(
|
||||||
|
agent=details["instance"],
|
||||||
|
agent_type_name=details.get("agent_type_name"),
|
||||||
|
agent_type_slug=details.get("agent_type_slug"),
|
||||||
|
project_name=details.get("project_name"),
|
||||||
|
project_slug=details.get("project_slug"),
|
||||||
|
assigned_issues_count=details.get("assigned_issues_count", 0),
|
||||||
|
)
|
||||||
|
|
||||||
|
return build_agent_response(updated_agent)
|
||||||
|
|
||||||
|
except (NotFoundError, AuthorizationError, ValidationException):
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error updating agent: {e!s}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/projects/{project_id}/agents/{agent_id}/pause",
|
||||||
|
response_model=AgentInstanceResponse,
|
||||||
|
summary="Pause Agent",
|
||||||
|
description="Pause an agent instance, temporarily stopping its work.",
|
||||||
|
operation_id="pause_agent",
|
||||||
|
)
|
||||||
|
@limiter.limit(f"{20 * RATE_MULTIPLIER}/minute")
|
||||||
|
async def pause_agent(
|
||||||
|
request: Request,
|
||||||
|
project_id: UUID,
|
||||||
|
agent_id: UUID,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Pause an agent instance.
|
||||||
|
|
||||||
|
Transitions the agent to PAUSED status, temporarily stopping
|
||||||
|
its work. The agent can be resumed later with the resume endpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: FastAPI request object (for rate limiting)
|
||||||
|
project_id: UUID of the project
|
||||||
|
agent_id: UUID of the agent instance
|
||||||
|
current_user: Current authenticated user
|
||||||
|
db: Database session
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AgentInstanceResponse: The paused agent instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If the project or agent is not found
|
||||||
|
AuthorizationError: If the user lacks access to the project
|
||||||
|
ValidationException: If the agent cannot be paused from its current state
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Verify project access
|
||||||
|
await verify_project_access(db, project_id, current_user)
|
||||||
|
|
||||||
|
# Get current agent
|
||||||
|
agent = await agent_instance_crud.get(db, id=agent_id)
|
||||||
|
if not agent:
|
||||||
|
raise NotFoundError(
|
||||||
|
message=f"Agent {agent_id} not found",
|
||||||
|
error_code=ErrorCode.NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify agent belongs to the specified project
|
||||||
|
if agent.project_id != project_id:
|
||||||
|
raise NotFoundError(
|
||||||
|
message=f"Agent {agent_id} not found in project {project_id}",
|
||||||
|
error_code=ErrorCode.NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate the transition to PAUSED
|
||||||
|
validate_status_transition(agent.status, AgentStatus.PAUSED)
|
||||||
|
|
||||||
|
# Update status to PAUSED
|
||||||
|
paused_agent = await agent_instance_crud.update_status(
|
||||||
|
db,
|
||||||
|
instance_id=agent_id,
|
||||||
|
status=AgentStatus.PAUSED,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not paused_agent:
|
||||||
|
raise NotFoundError(
|
||||||
|
message=f"Agent {agent_id} not found",
|
||||||
|
error_code=ErrorCode.NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"User {current_user.email} paused agent {paused_agent.name} "
|
||||||
|
f"(id={agent_id})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get updated details
|
||||||
|
details = await agent_instance_crud.get_with_details(
|
||||||
|
db, instance_id=paused_agent.id
|
||||||
|
)
|
||||||
|
if details:
|
||||||
|
return build_agent_response(
|
||||||
|
agent=details["instance"],
|
||||||
|
agent_type_name=details.get("agent_type_name"),
|
||||||
|
agent_type_slug=details.get("agent_type_slug"),
|
||||||
|
project_name=details.get("project_name"),
|
||||||
|
project_slug=details.get("project_slug"),
|
||||||
|
assigned_issues_count=details.get("assigned_issues_count", 0),
|
||||||
|
)
|
||||||
|
|
||||||
|
return build_agent_response(paused_agent)
|
||||||
|
|
||||||
|
except (NotFoundError, AuthorizationError, ValidationException):
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error pausing agent: {e!s}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/projects/{project_id}/agents/{agent_id}/resume",
|
||||||
|
response_model=AgentInstanceResponse,
|
||||||
|
summary="Resume Agent",
|
||||||
|
description="Resume a paused agent instance.",
|
||||||
|
operation_id="resume_agent",
|
||||||
|
)
|
||||||
|
@limiter.limit(f"{20 * RATE_MULTIPLIER}/minute")
|
||||||
|
async def resume_agent(
|
||||||
|
request: Request,
|
||||||
|
project_id: UUID,
|
||||||
|
agent_id: UUID,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Resume a paused agent instance.
|
||||||
|
|
||||||
|
Transitions the agent from PAUSED back to IDLE status,
|
||||||
|
allowing it to accept new work.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: FastAPI request object (for rate limiting)
|
||||||
|
project_id: UUID of the project
|
||||||
|
agent_id: UUID of the agent instance
|
||||||
|
current_user: Current authenticated user
|
||||||
|
db: Database session
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AgentInstanceResponse: The resumed agent instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If the project or agent is not found
|
||||||
|
AuthorizationError: If the user lacks access to the project
|
||||||
|
ValidationException: If the agent cannot be resumed from its current state
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Verify project access
|
||||||
|
await verify_project_access(db, project_id, current_user)
|
||||||
|
|
||||||
|
# Get current agent
|
||||||
|
agent = await agent_instance_crud.get(db, id=agent_id)
|
||||||
|
if not agent:
|
||||||
|
raise NotFoundError(
|
||||||
|
message=f"Agent {agent_id} not found",
|
||||||
|
error_code=ErrorCode.NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify agent belongs to the specified project
|
||||||
|
if agent.project_id != project_id:
|
||||||
|
raise NotFoundError(
|
||||||
|
message=f"Agent {agent_id} not found in project {project_id}",
|
||||||
|
error_code=ErrorCode.NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate the transition to IDLE (resume)
|
||||||
|
validate_status_transition(agent.status, AgentStatus.IDLE)
|
||||||
|
|
||||||
|
# Update status to IDLE
|
||||||
|
resumed_agent = await agent_instance_crud.update_status(
|
||||||
|
db,
|
||||||
|
instance_id=agent_id,
|
||||||
|
status=AgentStatus.IDLE,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not resumed_agent:
|
||||||
|
raise NotFoundError(
|
||||||
|
message=f"Agent {agent_id} not found",
|
||||||
|
error_code=ErrorCode.NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"User {current_user.email} resumed agent {resumed_agent.name} "
|
||||||
|
f"(id={agent_id})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get updated details
|
||||||
|
details = await agent_instance_crud.get_with_details(
|
||||||
|
db, instance_id=resumed_agent.id
|
||||||
|
)
|
||||||
|
if details:
|
||||||
|
return build_agent_response(
|
||||||
|
agent=details["instance"],
|
||||||
|
agent_type_name=details.get("agent_type_name"),
|
||||||
|
agent_type_slug=details.get("agent_type_slug"),
|
||||||
|
project_name=details.get("project_name"),
|
||||||
|
project_slug=details.get("project_slug"),
|
||||||
|
assigned_issues_count=details.get("assigned_issues_count", 0),
|
||||||
|
)
|
||||||
|
|
||||||
|
return build_agent_response(resumed_agent)
|
||||||
|
|
||||||
|
except (NotFoundError, AuthorizationError, ValidationException):
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error resuming agent: {e!s}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete(
|
||||||
|
"/projects/{project_id}/agents/{agent_id}",
|
||||||
|
response_model=MessageResponse,
|
||||||
|
summary="Terminate Agent",
|
||||||
|
description="Terminate an agent instance, permanently stopping it.",
|
||||||
|
operation_id="terminate_agent",
|
||||||
|
)
|
||||||
|
@limiter.limit(f"{10 * RATE_MULTIPLIER}/minute")
|
||||||
|
async def terminate_agent(
|
||||||
|
request: Request,
|
||||||
|
project_id: UUID,
|
||||||
|
agent_id: UUID,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Terminate an agent instance.
|
||||||
|
|
||||||
|
Permanently terminates the agent, setting its status to TERMINATED.
|
||||||
|
This action cannot be undone - a new agent must be spawned if needed.
|
||||||
|
The agent's session and current task are cleared.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: FastAPI request object (for rate limiting)
|
||||||
|
project_id: UUID of the project
|
||||||
|
agent_id: UUID of the agent instance
|
||||||
|
current_user: Current authenticated user
|
||||||
|
db: Database session
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MessageResponse: Confirmation message
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If the project or agent is not found
|
||||||
|
AuthorizationError: If the user lacks access to the project
|
||||||
|
ValidationException: If the agent is already terminated
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Verify project access
|
||||||
|
await verify_project_access(db, project_id, current_user)
|
||||||
|
|
||||||
|
# Get current agent
|
||||||
|
agent = await agent_instance_crud.get(db, id=agent_id)
|
||||||
|
if not agent:
|
||||||
|
raise NotFoundError(
|
||||||
|
message=f"Agent {agent_id} not found",
|
||||||
|
error_code=ErrorCode.NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify agent belongs to the specified project
|
||||||
|
if agent.project_id != project_id:
|
||||||
|
raise NotFoundError(
|
||||||
|
message=f"Agent {agent_id} not found in project {project_id}",
|
||||||
|
error_code=ErrorCode.NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if already terminated
|
||||||
|
if agent.status == AgentStatus.TERMINATED:
|
||||||
|
raise ValidationException(
|
||||||
|
message="Agent is already terminated",
|
||||||
|
error_code=ErrorCode.VALIDATION_ERROR,
|
||||||
|
field="status",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate the transition to TERMINATED
|
||||||
|
validate_status_transition(agent.status, AgentStatus.TERMINATED)
|
||||||
|
|
||||||
|
agent_name = agent.name
|
||||||
|
|
||||||
|
# Terminate the agent
|
||||||
|
terminated_agent = await agent_instance_crud.terminate(db, instance_id=agent_id)
|
||||||
|
|
||||||
|
if not terminated_agent:
|
||||||
|
raise NotFoundError(
|
||||||
|
message=f"Agent {agent_id} not found",
|
||||||
|
error_code=ErrorCode.NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"User {current_user.email} terminated agent {agent_name} (id={agent_id})"
|
||||||
|
)
|
||||||
|
|
||||||
|
return MessageResponse(
|
||||||
|
success=True,
|
||||||
|
message=f"Agent '{agent_name}' has been terminated",
|
||||||
|
)
|
||||||
|
|
||||||
|
except (NotFoundError, AuthorizationError, ValidationException):
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error terminating agent: {e!s}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/projects/{project_id}/agents/{agent_id}/metrics",
|
||||||
|
response_model=AgentInstanceMetrics,
|
||||||
|
summary="Get Agent Metrics",
|
||||||
|
description="Get usage metrics for a specific agent instance.",
|
||||||
|
operation_id="get_agent_metrics",
|
||||||
|
)
|
||||||
|
@limiter.limit(f"{60 * RATE_MULTIPLIER}/minute")
|
||||||
|
async def get_agent_metrics(
|
||||||
|
request: Request,
|
||||||
|
project_id: UUID,
|
||||||
|
agent_id: UUID,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Get usage metrics for a specific agent instance.
|
||||||
|
|
||||||
|
Returns metrics including tasks completed, tokens used,
|
||||||
|
and cost incurred for the specified agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: FastAPI request object (for rate limiting)
|
||||||
|
project_id: UUID of the project
|
||||||
|
agent_id: UUID of the agent instance
|
||||||
|
current_user: Current authenticated user
|
||||||
|
db: Database session
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AgentInstanceMetrics: Agent usage metrics
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If the project or agent is not found
|
||||||
|
AuthorizationError: If the user lacks access to the project
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Verify project access
|
||||||
|
await verify_project_access(db, project_id, current_user)
|
||||||
|
|
||||||
|
# Get agent
|
||||||
|
agent = await agent_instance_crud.get(db, id=agent_id)
|
||||||
|
if not agent:
|
||||||
|
raise NotFoundError(
|
||||||
|
message=f"Agent {agent_id} not found",
|
||||||
|
error_code=ErrorCode.NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify agent belongs to the specified project
|
||||||
|
if agent.project_id != project_id:
|
||||||
|
raise NotFoundError(
|
||||||
|
message=f"Agent {agent_id} not found in project {project_id}",
|
||||||
|
error_code=ErrorCode.NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate metrics for this single agent
|
||||||
|
# For a single agent, we report its individual metrics
|
||||||
|
is_active = agent.status == AgentStatus.WORKING
|
||||||
|
is_idle = agent.status == AgentStatus.IDLE
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"User {current_user.email} retrieved metrics for agent {agent.name} "
|
||||||
|
f"(id={agent_id})"
|
||||||
|
)
|
||||||
|
|
||||||
|
return AgentInstanceMetrics(
|
||||||
|
total_instances=1,
|
||||||
|
active_instances=1 if is_active else 0,
|
||||||
|
idle_instances=1 if is_idle else 0,
|
||||||
|
total_tasks_completed=agent.tasks_completed,
|
||||||
|
total_tokens_used=agent.tokens_used,
|
||||||
|
total_cost_incurred=agent.cost_incurred,
|
||||||
|
)
|
||||||
|
|
||||||
|
except (NotFoundError, AuthorizationError):
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting agent metrics: {e!s}", exc_info=True)
|
||||||
|
raise
|
||||||
316
backend/app/api/routes/events.py
Normal file
316
backend/app/api/routes/events.py
Normal file
@@ -0,0 +1,316 @@
|
|||||||
|
"""
|
||||||
|
SSE endpoint for real-time project event streaming.
|
||||||
|
|
||||||
|
This module provides Server-Sent Events (SSE) endpoints for streaming
|
||||||
|
project events to connected clients. Events are scoped to projects,
|
||||||
|
with authorization checks to ensure clients only receive events
|
||||||
|
for projects they have access to.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Real-time event streaming via SSE
|
||||||
|
- Project-scoped authorization
|
||||||
|
- Automatic reconnection support (Last-Event-ID)
|
||||||
|
- Keepalive messages every 30 seconds
|
||||||
|
- Graceful connection cleanup
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, Header, Query, Request
|
||||||
|
from slowapi import Limiter
|
||||||
|
from slowapi.util import get_remote_address
|
||||||
|
from sse_starlette.sse import EventSourceResponse
|
||||||
|
|
||||||
|
from app.api.dependencies.auth import get_current_user, get_current_user_sse
|
||||||
|
from app.api.dependencies.event_bus import get_event_bus
|
||||||
|
from app.core.database import get_db
|
||||||
|
from app.core.exceptions import AuthorizationError
|
||||||
|
from app.models.user import User
|
||||||
|
from app.schemas.errors import ErrorCode
|
||||||
|
from app.schemas.events import EventType
|
||||||
|
from app.services.event_bus import EventBus
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
limiter = Limiter(key_func=get_remote_address)
|
||||||
|
|
||||||
|
# Keepalive interval in seconds
|
||||||
|
KEEPALIVE_INTERVAL = 30
|
||||||
|
|
||||||
|
|
||||||
|
async def check_project_access(
|
||||||
|
project_id: UUID,
|
||||||
|
user: User,
|
||||||
|
db: "AsyncSession",
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a user has access to a project's events.
|
||||||
|
|
||||||
|
Authorization rules:
|
||||||
|
- Superusers can access all projects
|
||||||
|
- Project owners can access their own projects
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: The project to check access for
|
||||||
|
user: The authenticated user
|
||||||
|
db: Database session for project lookup
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if user has access, False otherwise
|
||||||
|
"""
|
||||||
|
# Superusers can access all projects
|
||||||
|
if user.is_superuser:
|
||||||
|
logger.debug(
|
||||||
|
f"Project access granted for superuser {user.id} on project {project_id}"
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check if user owns the project
|
||||||
|
from app.crud.syndarix import project as project_crud
|
||||||
|
|
||||||
|
project = await project_crud.get(db, id=project_id)
|
||||||
|
if not project:
|
||||||
|
logger.debug(f"Project {project_id} not found for access check")
|
||||||
|
return False
|
||||||
|
|
||||||
|
has_access = bool(project.owner_id == user.id)
|
||||||
|
logger.debug(
|
||||||
|
f"Project access {'granted' if has_access else 'denied'} "
|
||||||
|
f"for user {user.id} on project {project_id} (owner: {project.owner_id})"
|
||||||
|
)
|
||||||
|
return has_access
|
||||||
|
|
||||||
|
|
||||||
|
async def event_generator(
|
||||||
|
project_id: UUID,
|
||||||
|
event_bus: EventBus,
|
||||||
|
last_event_id: str | None = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Generate SSE events for a project.
|
||||||
|
|
||||||
|
This async generator yields SSE-formatted events from the event bus,
|
||||||
|
including keepalive comments to maintain the connection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: The project to stream events for
|
||||||
|
event_bus: The EventBus instance
|
||||||
|
last_event_id: Optional last received event ID for reconnection
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
dict: SSE event data with 'event', 'data', and optional 'id' fields
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
async for event_data in event_bus.subscribe_sse(
|
||||||
|
project_id=project_id,
|
||||||
|
last_event_id=last_event_id,
|
||||||
|
keepalive_interval=KEEPALIVE_INTERVAL,
|
||||||
|
):
|
||||||
|
if event_data == "":
|
||||||
|
# Keepalive - yield SSE comment
|
||||||
|
yield {"comment": "keepalive"}
|
||||||
|
else:
|
||||||
|
# Parse event to extract type and id
|
||||||
|
try:
|
||||||
|
event_dict = json.loads(event_data)
|
||||||
|
event_type = event_dict.get("type", "message")
|
||||||
|
event_id = event_dict.get("id")
|
||||||
|
|
||||||
|
yield {
|
||||||
|
"event": event_type,
|
||||||
|
"data": event_data,
|
||||||
|
"id": event_id,
|
||||||
|
}
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# If we can't parse, send as generic message
|
||||||
|
yield {
|
||||||
|
"event": "message",
|
||||||
|
"data": event_data,
|
||||||
|
}
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.info(f"Event stream cancelled for project {project_id}")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in event stream for project {project_id}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/projects/{project_id}/events/stream",
|
||||||
|
summary="Stream Project Events",
|
||||||
|
description="""
|
||||||
|
Stream real-time events for a project via Server-Sent Events (SSE).
|
||||||
|
|
||||||
|
**Authentication**: Required (Bearer token OR query parameter)
|
||||||
|
**Authorization**: Must have access to the project
|
||||||
|
|
||||||
|
**Authentication Methods**:
|
||||||
|
- Bearer token in Authorization header (preferred)
|
||||||
|
- Query parameter `token` (for EventSource compatibility)
|
||||||
|
|
||||||
|
Note: EventSource API doesn't support custom headers, so the query parameter
|
||||||
|
option is provided for browser-based SSE clients.
|
||||||
|
|
||||||
|
**SSE Event Format**:
|
||||||
|
```
|
||||||
|
event: agent.status_changed
|
||||||
|
id: 550e8400-e29b-41d4-a716-446655440000
|
||||||
|
data: {"id": "...", "type": "agent.status_changed", "project_id": "...", ...}
|
||||||
|
|
||||||
|
: keepalive
|
||||||
|
|
||||||
|
event: issue.created
|
||||||
|
id: 550e8400-e29b-41d4-a716-446655440001
|
||||||
|
data: {...}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Reconnection**: Include the `Last-Event-ID` header with the last received
|
||||||
|
event ID to resume from where you left off.
|
||||||
|
|
||||||
|
**Keepalive**: The server sends a comment (`: keepalive`) every 30 seconds
|
||||||
|
to keep the connection alive.
|
||||||
|
|
||||||
|
**Rate Limit**: 10 connections/minute per IP
|
||||||
|
""",
|
||||||
|
response_class=EventSourceResponse,
|
||||||
|
responses={
|
||||||
|
200: {
|
||||||
|
"description": "SSE stream established",
|
||||||
|
"content": {"text/event-stream": {}},
|
||||||
|
},
|
||||||
|
401: {"description": "Not authenticated"},
|
||||||
|
403: {"description": "Not authorized to access this project"},
|
||||||
|
404: {"description": "Project not found"},
|
||||||
|
},
|
||||||
|
operation_id="stream_project_events",
|
||||||
|
)
|
||||||
|
@limiter.limit("10/minute")
|
||||||
|
async def stream_project_events(
|
||||||
|
request: Request,
|
||||||
|
project_id: UUID,
|
||||||
|
db: "AsyncSession" = Depends(get_db),
|
||||||
|
event_bus: EventBus = Depends(get_event_bus),
|
||||||
|
token: str | None = Query(
|
||||||
|
None, description="Auth token (for EventSource compatibility)"
|
||||||
|
),
|
||||||
|
authorization: str | None = Header(None, alias="Authorization"),
|
||||||
|
last_event_id: str | None = Header(None, alias="Last-Event-ID"),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Stream real-time events for a project via SSE.
|
||||||
|
|
||||||
|
This endpoint establishes a persistent SSE connection that streams
|
||||||
|
project events to the client in real-time. The connection includes:
|
||||||
|
|
||||||
|
- Event streaming: All project events (agent updates, issues, etc.)
|
||||||
|
- Keepalive: Comment every 30 seconds to maintain connection
|
||||||
|
- Reconnection: Use Last-Event-ID header to resume after disconnect
|
||||||
|
|
||||||
|
The connection is automatically cleaned up when the client disconnects.
|
||||||
|
"""
|
||||||
|
# Authenticate user (supports both header and query param tokens)
|
||||||
|
current_user = await get_current_user_sse(
|
||||||
|
db=db, authorization=authorization, token=token
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"SSE connection request for project {project_id} "
|
||||||
|
f"by user {current_user.id} "
|
||||||
|
f"(last_event_id={last_event_id})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check project access
|
||||||
|
has_access = await check_project_access(project_id, current_user, db)
|
||||||
|
if not has_access:
|
||||||
|
raise AuthorizationError(
|
||||||
|
message=f"You don't have access to project {project_id}",
|
||||||
|
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return SSE response
|
||||||
|
return EventSourceResponse(
|
||||||
|
event_generator(
|
||||||
|
project_id=project_id,
|
||||||
|
event_bus=event_bus,
|
||||||
|
last_event_id=last_event_id,
|
||||||
|
),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"X-Accel-Buffering": "no", # Disable nginx buffering
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/projects/{project_id}/events/test",
|
||||||
|
summary="Send Test Event (Development Only)",
|
||||||
|
description="""
|
||||||
|
Send a test event to a project's event stream. This endpoint is
|
||||||
|
intended for development and testing purposes.
|
||||||
|
|
||||||
|
**Authentication**: Required (Bearer token)
|
||||||
|
**Authorization**: Must have access to the project
|
||||||
|
|
||||||
|
**Note**: This endpoint should be disabled or restricted in production.
|
||||||
|
""",
|
||||||
|
response_model=dict,
|
||||||
|
responses={
|
||||||
|
200: {"description": "Test event sent"},
|
||||||
|
401: {"description": "Not authenticated"},
|
||||||
|
403: {"description": "Not authorized to access this project"},
|
||||||
|
},
|
||||||
|
operation_id="send_test_event",
|
||||||
|
)
|
||||||
|
async def send_test_event(
|
||||||
|
project_id: UUID,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
event_bus: EventBus = Depends(get_event_bus),
|
||||||
|
db: "AsyncSession" = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Send a test event to the project's event stream.
|
||||||
|
|
||||||
|
This is useful for testing SSE connections during development.
|
||||||
|
"""
|
||||||
|
# Check project access
|
||||||
|
has_access = await check_project_access(project_id, current_user, db)
|
||||||
|
if not has_access:
|
||||||
|
raise AuthorizationError(
|
||||||
|
message=f"You don't have access to project {project_id}",
|
||||||
|
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create and publish test event using the Event schema
|
||||||
|
event = EventBus.create_event(
|
||||||
|
event_type=EventType.AGENT_MESSAGE,
|
||||||
|
project_id=project_id,
|
||||||
|
actor_type="user",
|
||||||
|
actor_id=current_user.id,
|
||||||
|
payload={
|
||||||
|
"message": "Test event from SSE endpoint",
|
||||||
|
"message_type": "info",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
channel = event_bus.get_project_channel(project_id)
|
||||||
|
await event_bus.publish(channel, event)
|
||||||
|
|
||||||
|
logger.info(f"Test event sent to project {project_id}: {event.id}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"event_id": event.id,
|
||||||
|
"event_type": event.type.value,
|
||||||
|
"message": "Test event sent successfully",
|
||||||
|
}
|
||||||
968
backend/app/api/routes/issues.py
Normal file
968
backend/app/api/routes/issues.py
Normal file
@@ -0,0 +1,968 @@
|
|||||||
|
# app/api/routes/issues.py
|
||||||
|
"""
|
||||||
|
Issue CRUD API endpoints for Syndarix projects.
|
||||||
|
|
||||||
|
Provides endpoints for managing issues within projects, including:
|
||||||
|
- Create, read, update, delete operations
|
||||||
|
- Filtering by status, priority, labels, sprint, assigned agent
|
||||||
|
- Search across title and body
|
||||||
|
- Assignment to agents
|
||||||
|
- External issue tracker sync triggers
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Any
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, Query, Request, status
|
||||||
|
from slowapi import Limiter
|
||||||
|
from slowapi.util import get_remote_address
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.api.dependencies.auth import get_current_user
|
||||||
|
from app.core.database import get_db
|
||||||
|
from app.core.exceptions import (
|
||||||
|
AuthorizationError,
|
||||||
|
NotFoundError,
|
||||||
|
ValidationException,
|
||||||
|
)
|
||||||
|
from app.crud.syndarix.agent_instance import agent_instance as agent_instance_crud
|
||||||
|
from app.crud.syndarix.issue import issue as issue_crud
|
||||||
|
from app.crud.syndarix.project import project as project_crud
|
||||||
|
from app.crud.syndarix.sprint import sprint as sprint_crud
|
||||||
|
from app.models.syndarix.enums import (
|
||||||
|
AgentStatus,
|
||||||
|
IssuePriority,
|
||||||
|
IssueStatus,
|
||||||
|
SprintStatus,
|
||||||
|
SyncStatus,
|
||||||
|
)
|
||||||
|
from app.models.user import User
|
||||||
|
from app.schemas.common import (
|
||||||
|
MessageResponse,
|
||||||
|
PaginatedResponse,
|
||||||
|
PaginationParams,
|
||||||
|
SortOrder,
|
||||||
|
create_pagination_meta,
|
||||||
|
)
|
||||||
|
from app.schemas.errors import ErrorCode
|
||||||
|
from app.schemas.syndarix.issue import (
|
||||||
|
IssueAssign,
|
||||||
|
IssueCreate,
|
||||||
|
IssueResponse,
|
||||||
|
IssueStats,
|
||||||
|
IssueUpdate,
|
||||||
|
)
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Initialize limiter for this router
|
||||||
|
limiter = Limiter(key_func=get_remote_address)
|
||||||
|
|
||||||
|
# Use higher rate limits in test environment
|
||||||
|
IS_TEST = os.getenv("IS_TEST", "False") == "True"
|
||||||
|
RATE_MULTIPLIER = 100 if IS_TEST else 1
|
||||||
|
|
||||||
|
|
||||||
|
async def verify_project_ownership(
|
||||||
|
db: AsyncSession,
|
||||||
|
project_id: UUID,
|
||||||
|
user: User,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Verify that the user owns the project or is a superuser.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Database session
|
||||||
|
project_id: Project UUID to verify
|
||||||
|
user: Current authenticated user
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If project does not exist
|
||||||
|
AuthorizationError: If user does not own the project
|
||||||
|
"""
|
||||||
|
project = await project_crud.get(db, id=project_id)
|
||||||
|
if not project:
|
||||||
|
raise NotFoundError(
|
||||||
|
message=f"Project {project_id} not found",
|
||||||
|
error_code=ErrorCode.NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not user.is_superuser and project.owner_id != user.id:
|
||||||
|
raise AuthorizationError(
|
||||||
|
message="You do not have access to this project",
|
||||||
|
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_issue_response(
|
||||||
|
issue: Any,
|
||||||
|
project_name: str | None = None,
|
||||||
|
project_slug: str | None = None,
|
||||||
|
sprint_name: str | None = None,
|
||||||
|
assigned_agent_type_name: str | None = None,
|
||||||
|
) -> IssueResponse:
|
||||||
|
"""
|
||||||
|
Build an IssueResponse from an Issue model instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
issue: Issue model instance
|
||||||
|
project_name: Optional project name from relationship
|
||||||
|
project_slug: Optional project slug from relationship
|
||||||
|
sprint_name: Optional sprint name from relationship
|
||||||
|
assigned_agent_type_name: Optional agent type name from relationship
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
IssueResponse schema instance
|
||||||
|
"""
|
||||||
|
return IssueResponse(
|
||||||
|
id=issue.id,
|
||||||
|
project_id=issue.project_id,
|
||||||
|
title=issue.title,
|
||||||
|
body=issue.body,
|
||||||
|
status=issue.status,
|
||||||
|
priority=issue.priority,
|
||||||
|
labels=issue.labels or [],
|
||||||
|
assigned_agent_id=issue.assigned_agent_id,
|
||||||
|
human_assignee=issue.human_assignee,
|
||||||
|
sprint_id=issue.sprint_id,
|
||||||
|
story_points=issue.story_points,
|
||||||
|
external_tracker_type=issue.external_tracker_type,
|
||||||
|
external_issue_id=issue.external_issue_id,
|
||||||
|
remote_url=issue.remote_url,
|
||||||
|
external_issue_number=issue.external_issue_number,
|
||||||
|
sync_status=issue.sync_status,
|
||||||
|
last_synced_at=issue.last_synced_at,
|
||||||
|
external_updated_at=issue.external_updated_at,
|
||||||
|
closed_at=issue.closed_at,
|
||||||
|
created_at=issue.created_at,
|
||||||
|
updated_at=issue.updated_at,
|
||||||
|
project_name=project_name,
|
||||||
|
project_slug=project_slug,
|
||||||
|
sprint_name=sprint_name,
|
||||||
|
assigned_agent_type_name=assigned_agent_type_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ===== Issue CRUD Endpoints =====
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/projects/{project_id}/issues",
|
||||||
|
response_model=IssueResponse,
|
||||||
|
status_code=status.HTTP_201_CREATED,
|
||||||
|
summary="Create Issue",
|
||||||
|
description="Create a new issue in a project",
|
||||||
|
operation_id="create_issue",
|
||||||
|
)
|
||||||
|
@limiter.limit(f"{60 * RATE_MULTIPLIER}/minute")
|
||||||
|
async def create_issue(
|
||||||
|
request: Request,
|
||||||
|
project_id: UUID,
|
||||||
|
issue_in: IssueCreate,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Create a new issue within a project.
|
||||||
|
|
||||||
|
The user must own the project or be a superuser.
|
||||||
|
The project_id in the path takes precedence over any project_id in the body.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: FastAPI request object (for rate limiting)
|
||||||
|
project_id: UUID of the project to create the issue in
|
||||||
|
issue_in: Issue creation data
|
||||||
|
current_user: Authenticated user
|
||||||
|
db: Database session
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Created issue with full details
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If project not found
|
||||||
|
AuthorizationError: If user lacks access
|
||||||
|
ValidationException: If assigned agent not in project
|
||||||
|
"""
|
||||||
|
# Verify project access
|
||||||
|
await verify_project_ownership(db, project_id, current_user)
|
||||||
|
|
||||||
|
# Override project_id from path
|
||||||
|
issue_in.project_id = project_id
|
||||||
|
|
||||||
|
# Validate assigned agent if provided
|
||||||
|
if issue_in.assigned_agent_id:
|
||||||
|
agent = await agent_instance_crud.get(db, id=issue_in.assigned_agent_id)
|
||||||
|
if not agent:
|
||||||
|
raise NotFoundError(
|
||||||
|
message=f"Agent instance {issue_in.assigned_agent_id} not found",
|
||||||
|
error_code=ErrorCode.NOT_FOUND,
|
||||||
|
)
|
||||||
|
if agent.project_id != project_id:
|
||||||
|
raise ValidationException(
|
||||||
|
message="Agent instance does not belong to this project",
|
||||||
|
error_code=ErrorCode.VALIDATION_ERROR,
|
||||||
|
field="assigned_agent_id",
|
||||||
|
)
|
||||||
|
if agent.status == AgentStatus.TERMINATED:
|
||||||
|
raise ValidationException(
|
||||||
|
message="Cannot assign issue to a terminated agent",
|
||||||
|
error_code=ErrorCode.VALIDATION_ERROR,
|
||||||
|
field="assigned_agent_id",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate sprint if provided (IDOR prevention)
|
||||||
|
if issue_in.sprint_id:
|
||||||
|
sprint = await sprint_crud.get(db, id=issue_in.sprint_id)
|
||||||
|
if not sprint:
|
||||||
|
raise NotFoundError(
|
||||||
|
message=f"Sprint {issue_in.sprint_id} not found",
|
||||||
|
error_code=ErrorCode.NOT_FOUND,
|
||||||
|
)
|
||||||
|
if sprint.project_id != project_id:
|
||||||
|
raise ValidationException(
|
||||||
|
message="Sprint does not belong to this project",
|
||||||
|
error_code=ErrorCode.VALIDATION_ERROR,
|
||||||
|
field="sprint_id",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
issue = await issue_crud.create(db, obj_in=issue_in)
|
||||||
|
logger.info(
|
||||||
|
f"User {current_user.email} created issue '{issue.title}' "
|
||||||
|
f"in project {project_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get project details for response
|
||||||
|
project = await project_crud.get(db, id=project_id)
|
||||||
|
|
||||||
|
return _build_issue_response(
|
||||||
|
issue,
|
||||||
|
project_name=project.name if project else None,
|
||||||
|
project_slug=project.slug if project else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
logger.warning(f"Failed to create issue: {e!s}")
|
||||||
|
raise ValidationException(
|
||||||
|
message=str(e),
|
||||||
|
error_code=ErrorCode.VALIDATION_ERROR,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error creating issue: {e!s}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/projects/{project_id}/issues",
|
||||||
|
response_model=PaginatedResponse[IssueResponse],
|
||||||
|
summary="List Issues",
|
||||||
|
description="Get paginated list of issues in a project with filtering",
|
||||||
|
operation_id="list_issues",
|
||||||
|
)
|
||||||
|
@limiter.limit(f"{120 * RATE_MULTIPLIER}/minute")
|
||||||
|
async def list_issues(
|
||||||
|
request: Request,
|
||||||
|
project_id: UUID,
|
||||||
|
pagination: PaginationParams = Depends(),
|
||||||
|
status_filter: IssueStatus | None = Query(
|
||||||
|
None, alias="status", description="Filter by issue status"
|
||||||
|
),
|
||||||
|
priority: IssuePriority | None = Query(None, description="Filter by priority"),
|
||||||
|
labels: list[str] | None = Query(
|
||||||
|
None, description="Filter by labels (comma-separated)"
|
||||||
|
),
|
||||||
|
sprint_id: UUID | None = Query(None, description="Filter by sprint ID"),
|
||||||
|
assigned_agent_id: UUID | None = Query(
|
||||||
|
None, description="Filter by assigned agent ID"
|
||||||
|
),
|
||||||
|
sync_status: SyncStatus | None = Query(None, description="Filter by sync status"),
|
||||||
|
search: str | None = Query(
|
||||||
|
None, min_length=1, max_length=100, description="Search in title and body"
|
||||||
|
),
|
||||||
|
sort_by: str = Query(
|
||||||
|
"created_at",
|
||||||
|
description="Field to sort by (created_at, updated_at, priority, status, title)",
|
||||||
|
),
|
||||||
|
sort_order: SortOrder = Query(SortOrder.DESC, description="Sort order"),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
List issues in a project with comprehensive filtering options.
|
||||||
|
|
||||||
|
Supports filtering by:
|
||||||
|
- status: Issue status (open, in_progress, in_review, blocked, closed)
|
||||||
|
- priority: Issue priority (low, medium, high, critical)
|
||||||
|
- labels: Match issues containing any of the provided labels
|
||||||
|
- sprint_id: Issues in a specific sprint
|
||||||
|
- assigned_agent_id: Issues assigned to a specific agent
|
||||||
|
- sync_status: External tracker sync status
|
||||||
|
- search: Full-text search in title and body
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: FastAPI request object
|
||||||
|
project_id: Project UUID
|
||||||
|
pagination: Pagination parameters
|
||||||
|
status_filter: Optional status filter
|
||||||
|
priority: Optional priority filter
|
||||||
|
labels: Optional labels filter
|
||||||
|
sprint_id: Optional sprint filter
|
||||||
|
assigned_agent_id: Optional agent assignment filter
|
||||||
|
sync_status: Optional sync status filter
|
||||||
|
search: Optional search query
|
||||||
|
sort_by: Field to sort by
|
||||||
|
sort_order: Sort direction
|
||||||
|
current_user: Authenticated user
|
||||||
|
db: Database session
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Paginated list of issues matching filters
|
||||||
|
"""
|
||||||
|
# Verify project access
|
||||||
|
await verify_project_ownership(db, project_id, current_user)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get filtered issues
|
||||||
|
issues, total = await issue_crud.get_by_project(
|
||||||
|
db,
|
||||||
|
project_id=project_id,
|
||||||
|
status=status_filter,
|
||||||
|
priority=priority,
|
||||||
|
sprint_id=sprint_id,
|
||||||
|
assigned_agent_id=assigned_agent_id,
|
||||||
|
labels=labels,
|
||||||
|
search=search,
|
||||||
|
skip=pagination.offset,
|
||||||
|
limit=pagination.limit,
|
||||||
|
sort_by=sort_by,
|
||||||
|
sort_order=sort_order.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build response objects
|
||||||
|
issue_responses = [_build_issue_response(issue) for issue in issues]
|
||||||
|
|
||||||
|
pagination_meta = create_pagination_meta(
|
||||||
|
total=total,
|
||||||
|
page=pagination.page,
|
||||||
|
limit=pagination.limit,
|
||||||
|
items_count=len(issue_responses),
|
||||||
|
)
|
||||||
|
|
||||||
|
return PaginatedResponse(data=issue_responses, pagination=pagination_meta)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error listing issues for project {project_id}: {e!s}", exc_info=True
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
# ===== Issue Statistics Endpoint =====
|
||||||
|
# NOTE: This endpoint MUST be defined before /{issue_id} routes
|
||||||
|
# to prevent FastAPI from trying to parse "stats" as a UUID
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/projects/{project_id}/issues/stats",
|
||||||
|
response_model=IssueStats,
|
||||||
|
summary="Get Issue Statistics",
|
||||||
|
description="Get aggregated issue statistics for a project",
|
||||||
|
operation_id="get_issue_stats",
|
||||||
|
)
|
||||||
|
@limiter.limit(f"{60 * RATE_MULTIPLIER}/minute")
|
||||||
|
async def get_issue_stats(
|
||||||
|
request: Request,
|
||||||
|
project_id: UUID,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Get aggregated statistics for issues in a project.
|
||||||
|
|
||||||
|
Returns counts by status and priority, along with story point totals.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: FastAPI request object
|
||||||
|
project_id: Project UUID
|
||||||
|
current_user: Authenticated user
|
||||||
|
db: Database session
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Issue statistics including counts by status/priority and story points
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If project not found
|
||||||
|
AuthorizationError: If user lacks access
|
||||||
|
"""
|
||||||
|
# Verify project access
|
||||||
|
await verify_project_ownership(db, project_id, current_user)
|
||||||
|
|
||||||
|
try:
|
||||||
|
stats = await issue_crud.get_project_stats(db, project_id=project_id)
|
||||||
|
return IssueStats(**stats)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error getting issue stats for project {project_id}: {e!s}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/projects/{project_id}/issues/{issue_id}",
|
||||||
|
response_model=IssueResponse,
|
||||||
|
summary="Get Issue",
|
||||||
|
description="Get detailed information about a specific issue",
|
||||||
|
operation_id="get_issue",
|
||||||
|
)
|
||||||
|
@limiter.limit(f"{120 * RATE_MULTIPLIER}/minute")
|
||||||
|
async def get_issue(
|
||||||
|
request: Request,
|
||||||
|
project_id: UUID,
|
||||||
|
issue_id: UUID,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Get detailed information about a specific issue.
|
||||||
|
|
||||||
|
Returns the issue with expanded relationship data including
|
||||||
|
project name, sprint name, and assigned agent type name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: FastAPI request object
|
||||||
|
project_id: Project UUID
|
||||||
|
issue_id: Issue UUID
|
||||||
|
current_user: Authenticated user
|
||||||
|
db: Database session
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Issue details with relationship data
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If project or issue not found
|
||||||
|
AuthorizationError: If user lacks access
|
||||||
|
"""
|
||||||
|
# Verify project access
|
||||||
|
await verify_project_ownership(db, project_id, current_user)
|
||||||
|
|
||||||
|
# Get issue with details
|
||||||
|
issue_data = await issue_crud.get_with_details(db, issue_id=issue_id)
|
||||||
|
|
||||||
|
if not issue_data:
|
||||||
|
raise NotFoundError(
|
||||||
|
message=f"Issue {issue_id} not found",
|
||||||
|
error_code=ErrorCode.NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
issue = issue_data["issue"]
|
||||||
|
|
||||||
|
# Verify issue belongs to the project
|
||||||
|
if issue.project_id != project_id:
|
||||||
|
raise NotFoundError(
|
||||||
|
message=f"Issue {issue_id} not found in project {project_id}",
|
||||||
|
error_code=ErrorCode.NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
return _build_issue_response(
|
||||||
|
issue,
|
||||||
|
project_name=issue_data.get("project_name"),
|
||||||
|
project_slug=issue_data.get("project_slug"),
|
||||||
|
sprint_name=issue_data.get("sprint_name"),
|
||||||
|
assigned_agent_type_name=issue_data.get("assigned_agent_type_name"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.patch(
|
||||||
|
"/projects/{project_id}/issues/{issue_id}",
|
||||||
|
response_model=IssueResponse,
|
||||||
|
summary="Update Issue",
|
||||||
|
description="Update an existing issue",
|
||||||
|
operation_id="update_issue",
|
||||||
|
)
|
||||||
|
@limiter.limit(f"{60 * RATE_MULTIPLIER}/minute")
|
||||||
|
async def update_issue(
|
||||||
|
request: Request,
|
||||||
|
project_id: UUID,
|
||||||
|
issue_id: UUID,
|
||||||
|
issue_in: IssueUpdate,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Update an existing issue.
|
||||||
|
|
||||||
|
All fields are optional - only provided fields will be updated.
|
||||||
|
Validates that assigned agent belongs to the same project.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: FastAPI request object
|
||||||
|
project_id: Project UUID
|
||||||
|
issue_id: Issue UUID
|
||||||
|
issue_in: Fields to update
|
||||||
|
current_user: Authenticated user
|
||||||
|
db: Database session
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated issue details
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If project or issue not found
|
||||||
|
AuthorizationError: If user lacks access
|
||||||
|
ValidationException: If validation fails
|
||||||
|
"""
|
||||||
|
# Verify project access
|
||||||
|
await verify_project_ownership(db, project_id, current_user)
|
||||||
|
|
||||||
|
# Get existing issue
|
||||||
|
issue = await issue_crud.get(db, id=issue_id)
|
||||||
|
if not issue:
|
||||||
|
raise NotFoundError(
|
||||||
|
message=f"Issue {issue_id} not found",
|
||||||
|
error_code=ErrorCode.NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify issue belongs to the project
|
||||||
|
if issue.project_id != project_id:
|
||||||
|
raise NotFoundError(
|
||||||
|
message=f"Issue {issue_id} not found in project {project_id}",
|
||||||
|
error_code=ErrorCode.NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate assigned agent if being updated
|
||||||
|
if issue_in.assigned_agent_id is not None:
|
||||||
|
agent = await agent_instance_crud.get(db, id=issue_in.assigned_agent_id)
|
||||||
|
if not agent:
|
||||||
|
raise NotFoundError(
|
||||||
|
message=f"Agent instance {issue_in.assigned_agent_id} not found",
|
||||||
|
error_code=ErrorCode.NOT_FOUND,
|
||||||
|
)
|
||||||
|
if agent.project_id != project_id:
|
||||||
|
raise ValidationException(
|
||||||
|
message="Agent instance does not belong to this project",
|
||||||
|
error_code=ErrorCode.VALIDATION_ERROR,
|
||||||
|
field="assigned_agent_id",
|
||||||
|
)
|
||||||
|
if agent.status == AgentStatus.TERMINATED:
|
||||||
|
raise ValidationException(
|
||||||
|
message="Cannot assign issue to a terminated agent",
|
||||||
|
error_code=ErrorCode.VALIDATION_ERROR,
|
||||||
|
field="assigned_agent_id",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate sprint if being updated (IDOR prevention and status validation)
|
||||||
|
if issue_in.sprint_id is not None:
|
||||||
|
sprint = await sprint_crud.get(db, id=issue_in.sprint_id)
|
||||||
|
if not sprint:
|
||||||
|
raise NotFoundError(
|
||||||
|
message=f"Sprint {issue_in.sprint_id} not found",
|
||||||
|
error_code=ErrorCode.NOT_FOUND,
|
||||||
|
)
|
||||||
|
if sprint.project_id != project_id:
|
||||||
|
raise ValidationException(
|
||||||
|
message="Sprint does not belong to this project",
|
||||||
|
error_code=ErrorCode.VALIDATION_ERROR,
|
||||||
|
field="sprint_id",
|
||||||
|
)
|
||||||
|
# Cannot add issues to completed or cancelled sprints
|
||||||
|
if sprint.status in [SprintStatus.COMPLETED, SprintStatus.CANCELLED]:
|
||||||
|
raise ValidationException(
|
||||||
|
message=f"Cannot add issues to sprint with status '{sprint.status.value}'",
|
||||||
|
error_code=ErrorCode.VALIDATION_ERROR,
|
||||||
|
field="sprint_id",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
updated_issue = await issue_crud.update(db, db_obj=issue, obj_in=issue_in)
|
||||||
|
logger.info(
|
||||||
|
f"User {current_user.email} updated issue {issue_id} in project {project_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get full details for response
|
||||||
|
issue_data = await issue_crud.get_with_details(db, issue_id=issue_id)
|
||||||
|
|
||||||
|
return _build_issue_response(
|
||||||
|
updated_issue,
|
||||||
|
project_name=issue_data.get("project_name") if issue_data else None,
|
||||||
|
project_slug=issue_data.get("project_slug") if issue_data else None,
|
||||||
|
sprint_name=issue_data.get("sprint_name") if issue_data else None,
|
||||||
|
assigned_agent_type_name=issue_data.get("assigned_agent_type_name")
|
||||||
|
if issue_data
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
logger.warning(f"Failed to update issue {issue_id}: {e!s}")
|
||||||
|
raise ValidationException(
|
||||||
|
message=str(e),
|
||||||
|
error_code=ErrorCode.VALIDATION_ERROR,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error updating issue {issue_id}: {e!s}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete(
|
||||||
|
"/projects/{project_id}/issues/{issue_id}",
|
||||||
|
response_model=MessageResponse,
|
||||||
|
summary="Delete Issue",
|
||||||
|
description="Delete an issue permanently",
|
||||||
|
operation_id="delete_issue",
|
||||||
|
)
|
||||||
|
@limiter.limit(f"{30 * RATE_MULTIPLIER}/minute")
|
||||||
|
async def delete_issue(
|
||||||
|
request: Request,
|
||||||
|
project_id: UUID,
|
||||||
|
issue_id: UUID,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Delete an issue permanently.
|
||||||
|
|
||||||
|
The issue will be permanently removed from the database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: FastAPI request object
|
||||||
|
project_id: Project UUID
|
||||||
|
issue_id: Issue UUID
|
||||||
|
current_user: Authenticated user
|
||||||
|
db: Database session
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Success message
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If project or issue not found
|
||||||
|
AuthorizationError: If user lacks access
|
||||||
|
"""
|
||||||
|
# Verify project access
|
||||||
|
await verify_project_ownership(db, project_id, current_user)
|
||||||
|
|
||||||
|
# Get existing issue
|
||||||
|
issue = await issue_crud.get(db, id=issue_id)
|
||||||
|
if not issue:
|
||||||
|
raise NotFoundError(
|
||||||
|
message=f"Issue {issue_id} not found",
|
||||||
|
error_code=ErrorCode.NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify issue belongs to the project
|
||||||
|
if issue.project_id != project_id:
|
||||||
|
raise NotFoundError(
|
||||||
|
message=f"Issue {issue_id} not found in project {project_id}",
|
||||||
|
error_code=ErrorCode.NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
issue_title = issue.title
|
||||||
|
await issue_crud.remove(db, id=issue_id)
|
||||||
|
logger.info(
|
||||||
|
f"User {current_user.email} deleted issue {issue_id} "
|
||||||
|
f"('{issue_title}') from project {project_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return MessageResponse(
|
||||||
|
success=True,
|
||||||
|
message=f"Issue '{issue_title}' has been deleted",
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error deleting issue {issue_id}: {e!s}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
# ===== Issue Assignment Endpoint =====
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/projects/{project_id}/issues/{issue_id}/assign",
|
||||||
|
response_model=IssueResponse,
|
||||||
|
summary="Assign Issue",
|
||||||
|
description="Assign an issue to an agent or human",
|
||||||
|
operation_id="assign_issue",
|
||||||
|
)
|
||||||
|
@limiter.limit(f"{60 * RATE_MULTIPLIER}/minute")
|
||||||
|
async def assign_issue(
|
||||||
|
request: Request,
|
||||||
|
project_id: UUID,
|
||||||
|
issue_id: UUID,
|
||||||
|
assignment: IssueAssign,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Assign an issue to an agent or human.
|
||||||
|
|
||||||
|
Only one type of assignment is allowed at a time:
|
||||||
|
- assigned_agent_id: Assign to an AI agent instance
|
||||||
|
- human_assignee: Assign to a human (name/email string)
|
||||||
|
|
||||||
|
To unassign, pass both as null/None.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: FastAPI request object
|
||||||
|
project_id: Project UUID
|
||||||
|
issue_id: Issue UUID
|
||||||
|
assignment: Assignment data
|
||||||
|
current_user: Authenticated user
|
||||||
|
db: Database session
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated issue with assignment
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If project, issue, or agent not found
|
||||||
|
AuthorizationError: If user lacks access
|
||||||
|
ValidationException: If agent not in project
|
||||||
|
"""
|
||||||
|
# Verify project access
|
||||||
|
await verify_project_ownership(db, project_id, current_user)
|
||||||
|
|
||||||
|
# Get existing issue
|
||||||
|
issue = await issue_crud.get(db, id=issue_id)
|
||||||
|
if not issue:
|
||||||
|
raise NotFoundError(
|
||||||
|
message=f"Issue {issue_id} not found",
|
||||||
|
error_code=ErrorCode.NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify issue belongs to the project
|
||||||
|
if issue.project_id != project_id:
|
||||||
|
raise NotFoundError(
|
||||||
|
message=f"Issue {issue_id} not found in project {project_id}",
|
||||||
|
error_code=ErrorCode.NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process assignment based on type
|
||||||
|
if assignment.assigned_agent_id:
|
||||||
|
# Validate agent exists and belongs to project
|
||||||
|
agent = await agent_instance_crud.get(db, id=assignment.assigned_agent_id)
|
||||||
|
if not agent:
|
||||||
|
raise NotFoundError(
|
||||||
|
message=f"Agent instance {assignment.assigned_agent_id} not found",
|
||||||
|
error_code=ErrorCode.NOT_FOUND,
|
||||||
|
)
|
||||||
|
if agent.project_id != project_id:
|
||||||
|
raise ValidationException(
|
||||||
|
message="Agent instance does not belong to this project",
|
||||||
|
error_code=ErrorCode.VALIDATION_ERROR,
|
||||||
|
field="assigned_agent_id",
|
||||||
|
)
|
||||||
|
if agent.status == AgentStatus.TERMINATED:
|
||||||
|
raise ValidationException(
|
||||||
|
message="Cannot assign issue to a terminated agent",
|
||||||
|
error_code=ErrorCode.VALIDATION_ERROR,
|
||||||
|
field="assigned_agent_id",
|
||||||
|
)
|
||||||
|
|
||||||
|
updated_issue = await issue_crud.assign_to_agent(
|
||||||
|
db, issue_id=issue_id, agent_id=assignment.assigned_agent_id
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"User {current_user.email} assigned issue {issue_id} to agent {agent.name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
elif assignment.human_assignee:
|
||||||
|
updated_issue = await issue_crud.assign_to_human(
|
||||||
|
db, issue_id=issue_id, human_assignee=assignment.human_assignee
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"User {current_user.email} assigned issue {issue_id} "
|
||||||
|
f"to human '{assignment.human_assignee}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Unassign - clear both agent and human
|
||||||
|
updated_issue = await issue_crud.assign_to_agent(
|
||||||
|
db, issue_id=issue_id, agent_id=None
|
||||||
|
)
|
||||||
|
logger.info(f"User {current_user.email} unassigned issue {issue_id}")
|
||||||
|
|
||||||
|
if not updated_issue:
|
||||||
|
raise NotFoundError(
|
||||||
|
message=f"Issue {issue_id} not found",
|
||||||
|
error_code=ErrorCode.NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get full details for response
|
||||||
|
issue_data = await issue_crud.get_with_details(db, issue_id=issue_id)
|
||||||
|
|
||||||
|
return _build_issue_response(
|
||||||
|
updated_issue,
|
||||||
|
project_name=issue_data.get("project_name") if issue_data else None,
|
||||||
|
project_slug=issue_data.get("project_slug") if issue_data else None,
|
||||||
|
sprint_name=issue_data.get("sprint_name") if issue_data else None,
|
||||||
|
assigned_agent_type_name=issue_data.get("assigned_agent_type_name")
|
||||||
|
if issue_data
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete(
|
||||||
|
"/projects/{project_id}/issues/{issue_id}/assignment",
|
||||||
|
response_model=IssueResponse,
|
||||||
|
summary="Unassign Issue",
|
||||||
|
description="""
|
||||||
|
Remove agent/human assignment from an issue.
|
||||||
|
|
||||||
|
**Authentication**: Required (Bearer token)
|
||||||
|
**Authorization**: Project owner or superuser
|
||||||
|
|
||||||
|
This clears both agent and human assignee fields.
|
||||||
|
|
||||||
|
**Rate Limit**: 60 requests/minute
|
||||||
|
""",
|
||||||
|
operation_id="unassign_issue",
|
||||||
|
)
|
||||||
|
@limiter.limit(f"{60 * RATE_MULTIPLIER}/minute")
|
||||||
|
async def unassign_issue(
|
||||||
|
request: Request,
|
||||||
|
project_id: UUID,
|
||||||
|
issue_id: UUID,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Remove assignment from an issue.
|
||||||
|
|
||||||
|
Clears both assigned_agent_id and human_assignee fields.
|
||||||
|
"""
|
||||||
|
# Verify project access
|
||||||
|
await verify_project_ownership(db, project_id, current_user)
|
||||||
|
|
||||||
|
# Get existing issue
|
||||||
|
issue = await issue_crud.get(db, id=issue_id)
|
||||||
|
if not issue:
|
||||||
|
raise NotFoundError(
|
||||||
|
message=f"Issue {issue_id} not found",
|
||||||
|
error_code=ErrorCode.NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify issue belongs to project (IDOR prevention)
|
||||||
|
if issue.project_id != project_id:
|
||||||
|
raise NotFoundError(
|
||||||
|
message=f"Issue {issue_id} not found in project {project_id}",
|
||||||
|
error_code=ErrorCode.NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Unassign the issue
|
||||||
|
updated_issue = await issue_crud.unassign(db, issue_id=issue_id)
|
||||||
|
|
||||||
|
if not updated_issue:
|
||||||
|
raise NotFoundError(
|
||||||
|
message=f"Issue {issue_id} not found",
|
||||||
|
error_code=ErrorCode.NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"User {current_user.email} unassigned issue {issue_id}")
|
||||||
|
|
||||||
|
# Get full details for response
|
||||||
|
issue_data = await issue_crud.get_with_details(db, issue_id=issue_id)
|
||||||
|
|
||||||
|
return _build_issue_response(
|
||||||
|
updated_issue,
|
||||||
|
project_name=issue_data.get("project_name") if issue_data else None,
|
||||||
|
project_slug=issue_data.get("project_slug") if issue_data else None,
|
||||||
|
sprint_name=issue_data.get("sprint_name") if issue_data else None,
|
||||||
|
assigned_agent_type_name=issue_data.get("assigned_agent_type_name")
|
||||||
|
if issue_data
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ===== Issue Sync Endpoint =====
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/projects/{project_id}/issues/{issue_id}/sync",
|
||||||
|
response_model=MessageResponse,
|
||||||
|
summary="Trigger Issue Sync",
|
||||||
|
description="Trigger synchronization with external issue tracker",
|
||||||
|
operation_id="sync_issue",
|
||||||
|
)
|
||||||
|
@limiter.limit(f"{30 * RATE_MULTIPLIER}/minute")
|
||||||
|
async def sync_issue(
|
||||||
|
request: Request,
|
||||||
|
project_id: UUID,
|
||||||
|
issue_id: UUID,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Trigger synchronization of an issue with its external tracker.
|
||||||
|
|
||||||
|
This endpoint queues a sync task for the issue. The actual synchronization
|
||||||
|
happens asynchronously via Celery.
|
||||||
|
|
||||||
|
Prerequisites:
|
||||||
|
- Issue must have external_tracker_type configured
|
||||||
|
- Project must have integration settings for the tracker
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: FastAPI request object
|
||||||
|
project_id: Project UUID
|
||||||
|
issue_id: Issue UUID
|
||||||
|
current_user: Authenticated user
|
||||||
|
db: Database session
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Message indicating sync has been triggered
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If project or issue not found
|
||||||
|
AuthorizationError: If user lacks access
|
||||||
|
ValidationException: If issue has no external tracker
|
||||||
|
"""
|
||||||
|
# Verify project access
|
||||||
|
await verify_project_ownership(db, project_id, current_user)
|
||||||
|
|
||||||
|
# Get existing issue
|
||||||
|
issue = await issue_crud.get(db, id=issue_id)
|
||||||
|
if not issue:
|
||||||
|
raise NotFoundError(
|
||||||
|
message=f"Issue {issue_id} not found",
|
||||||
|
error_code=ErrorCode.NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify issue belongs to the project
|
||||||
|
if issue.project_id != project_id:
|
||||||
|
raise NotFoundError(
|
||||||
|
message=f"Issue {issue_id} not found in project {project_id}",
|
||||||
|
error_code=ErrorCode.NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if issue has external tracker configured
|
||||||
|
if not issue.external_tracker_type:
|
||||||
|
raise ValidationException(
|
||||||
|
message="Issue does not have an external tracker configured",
|
||||||
|
error_code=ErrorCode.VALIDATION_ERROR,
|
||||||
|
field="external_tracker_type",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update sync status to pending
|
||||||
|
await issue_crud.update_sync_status(
|
||||||
|
db,
|
||||||
|
issue_id=issue_id,
|
||||||
|
sync_status=SyncStatus.PENDING,
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: Queue Celery task for actual sync
|
||||||
|
# When Celery is set up, this will be:
|
||||||
|
# from app.tasks.sync import sync_issue_task
|
||||||
|
# sync_issue_task.delay(str(issue_id))
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"User {current_user.email} triggered sync for issue {issue_id} "
|
||||||
|
f"(tracker: {issue.external_tracker_type})"
|
||||||
|
)
|
||||||
|
|
||||||
|
return MessageResponse(
|
||||||
|
success=True,
|
||||||
|
message=f"Sync triggered for issue '{issue.title}'. "
|
||||||
|
f"Status will update when complete.",
|
||||||
|
)
|
||||||
444
backend/app/api/routes/mcp.py
Normal file
444
backend/app/api/routes/mcp.py
Normal file
@@ -0,0 +1,444 @@
|
|||||||
|
"""
|
||||||
|
MCP (Model Context Protocol) API Endpoints
|
||||||
|
|
||||||
|
Provides REST endpoints for managing MCP server connections
|
||||||
|
and executing tool calls.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from typing import Annotated, Any
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Path, status
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from app.api.dependencies.permissions import require_superuser
|
||||||
|
from app.models.user import User
|
||||||
|
from app.services.mcp import (
|
||||||
|
MCPCircuitOpenError,
|
||||||
|
MCPClientManager,
|
||||||
|
MCPConnectionError,
|
||||||
|
MCPError,
|
||||||
|
MCPServerNotFoundError,
|
||||||
|
MCPTimeoutError,
|
||||||
|
MCPToolError,
|
||||||
|
MCPToolNotFoundError,
|
||||||
|
get_mcp_client,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
# Server name validation pattern: alphanumeric, hyphens, underscores, 1-64 chars
|
||||||
|
SERVER_NAME_PATTERN = re.compile(r"^[a-zA-Z0-9_-]{1,64}$")
|
||||||
|
|
||||||
|
# Type alias for validated server name path parameter
|
||||||
|
ServerNamePath = Annotated[
|
||||||
|
str,
|
||||||
|
Path(
|
||||||
|
description="MCP server name",
|
||||||
|
min_length=1,
|
||||||
|
max_length=64,
|
||||||
|
pattern=r"^[a-zA-Z0-9_-]+$",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Request/Response Schemas
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class ServerInfo(BaseModel):
|
||||||
|
"""Information about an MCP server."""
|
||||||
|
|
||||||
|
name: str = Field(..., description="Server name")
|
||||||
|
url: str = Field(..., description="Server URL")
|
||||||
|
enabled: bool = Field(..., description="Whether server is enabled")
|
||||||
|
timeout: int = Field(..., description="Request timeout in seconds")
|
||||||
|
transport: str = Field(..., description="Transport type (http, stdio, sse)")
|
||||||
|
description: str | None = Field(None, description="Server description")
|
||||||
|
|
||||||
|
|
||||||
|
class ServerListResponse(BaseModel):
|
||||||
|
"""Response containing list of MCP servers."""
|
||||||
|
|
||||||
|
servers: list[ServerInfo]
|
||||||
|
total: int
|
||||||
|
|
||||||
|
|
||||||
|
class ToolInfoResponse(BaseModel):
|
||||||
|
"""Information about an MCP tool."""
|
||||||
|
|
||||||
|
name: str = Field(..., description="Tool name")
|
||||||
|
description: str | None = Field(None, description="Tool description")
|
||||||
|
server_name: str | None = Field(None, description="Server providing the tool")
|
||||||
|
input_schema: dict[str, Any] | None = Field(None, description="JSON schema for input")
|
||||||
|
|
||||||
|
|
||||||
|
class ToolListResponse(BaseModel):
|
||||||
|
"""Response containing list of tools."""
|
||||||
|
|
||||||
|
tools: list[ToolInfoResponse]
|
||||||
|
total: int
|
||||||
|
|
||||||
|
|
||||||
|
class ServerHealthStatus(BaseModel):
|
||||||
|
"""Health status for a server."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
healthy: bool
|
||||||
|
state: str
|
||||||
|
url: str
|
||||||
|
error: str | None = None
|
||||||
|
tools_count: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
class HealthCheckResponse(BaseModel):
|
||||||
|
"""Response containing health status of all servers."""
|
||||||
|
|
||||||
|
servers: dict[str, ServerHealthStatus]
|
||||||
|
healthy_count: int
|
||||||
|
unhealthy_count: int
|
||||||
|
total: int
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCallRequest(BaseModel):
|
||||||
|
"""Request to execute a tool."""
|
||||||
|
|
||||||
|
server: str = Field(..., description="MCP server name")
|
||||||
|
tool: str = Field(..., description="Tool name to execute")
|
||||||
|
arguments: dict[str, Any] = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
description="Tool arguments",
|
||||||
|
)
|
||||||
|
timeout: float | None = Field(
|
||||||
|
None,
|
||||||
|
description="Optional timeout override in seconds",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCallResponse(BaseModel):
|
||||||
|
"""Response from tool execution."""
|
||||||
|
|
||||||
|
success: bool
|
||||||
|
data: Any | None = None
|
||||||
|
error: str | None = None
|
||||||
|
error_code: str | None = None
|
||||||
|
tool_name: str | None = None
|
||||||
|
server_name: str | None = None
|
||||||
|
execution_time_ms: float = 0.0
|
||||||
|
request_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class CircuitBreakerStatus(BaseModel):
|
||||||
|
"""Status of a circuit breaker."""
|
||||||
|
|
||||||
|
server_name: str
|
||||||
|
state: str
|
||||||
|
failure_count: int
|
||||||
|
|
||||||
|
|
||||||
|
class CircuitBreakerListResponse(BaseModel):
|
||||||
|
"""Response containing circuit breaker statuses."""
|
||||||
|
|
||||||
|
circuit_breakers: list[CircuitBreakerStatus]
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Endpoints
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/servers",
|
||||||
|
response_model=ServerListResponse,
|
||||||
|
summary="List MCP Servers",
|
||||||
|
description="Get list of all registered MCP servers with their configurations.",
|
||||||
|
)
|
||||||
|
async def list_servers(
|
||||||
|
mcp: MCPClientManager = Depends(get_mcp_client),
|
||||||
|
) -> ServerListResponse:
|
||||||
|
"""List all registered MCP servers."""
|
||||||
|
servers = []
|
||||||
|
|
||||||
|
for name in mcp.list_servers():
|
||||||
|
try:
|
||||||
|
config = mcp.get_server_config(name)
|
||||||
|
servers.append(
|
||||||
|
ServerInfo(
|
||||||
|
name=name,
|
||||||
|
url=config.url,
|
||||||
|
enabled=config.enabled,
|
||||||
|
timeout=config.timeout,
|
||||||
|
transport=config.transport.value,
|
||||||
|
description=config.description,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except MCPServerNotFoundError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
return ServerListResponse(
|
||||||
|
servers=servers,
|
||||||
|
total=len(servers),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/servers/{server_name}/tools",
|
||||||
|
response_model=ToolListResponse,
|
||||||
|
summary="List Server Tools",
|
||||||
|
description="Get list of tools available on a specific MCP server.",
|
||||||
|
)
|
||||||
|
async def list_server_tools(
|
||||||
|
server_name: ServerNamePath,
|
||||||
|
mcp: MCPClientManager = Depends(get_mcp_client),
|
||||||
|
) -> ToolListResponse:
|
||||||
|
"""List all tools available on a specific server."""
|
||||||
|
try:
|
||||||
|
tools = await mcp.list_tools(server_name)
|
||||||
|
return ToolListResponse(
|
||||||
|
tools=[
|
||||||
|
ToolInfoResponse(
|
||||||
|
name=t.name,
|
||||||
|
description=t.description,
|
||||||
|
server_name=t.server_name,
|
||||||
|
input_schema=t.input_schema,
|
||||||
|
)
|
||||||
|
for t in tools
|
||||||
|
],
|
||||||
|
total=len(tools),
|
||||||
|
)
|
||||||
|
except MCPServerNotFoundError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Server not found: {server_name}",
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/tools",
|
||||||
|
response_model=ToolListResponse,
|
||||||
|
summary="List All Tools",
|
||||||
|
description="Get list of all tools from all MCP servers.",
|
||||||
|
)
|
||||||
|
async def list_all_tools(
|
||||||
|
mcp: MCPClientManager = Depends(get_mcp_client),
|
||||||
|
) -> ToolListResponse:
|
||||||
|
"""List all tools from all servers."""
|
||||||
|
tools = await mcp.list_all_tools()
|
||||||
|
return ToolListResponse(
|
||||||
|
tools=[
|
||||||
|
ToolInfoResponse(
|
||||||
|
name=t.name,
|
||||||
|
description=t.description,
|
||||||
|
server_name=t.server_name,
|
||||||
|
input_schema=t.input_schema,
|
||||||
|
)
|
||||||
|
for t in tools
|
||||||
|
],
|
||||||
|
total=len(tools),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/health",
|
||||||
|
response_model=HealthCheckResponse,
|
||||||
|
summary="Health Check",
|
||||||
|
description="Check health status of all MCP servers.",
|
||||||
|
)
|
||||||
|
async def health_check(
|
||||||
|
mcp: MCPClientManager = Depends(get_mcp_client),
|
||||||
|
) -> HealthCheckResponse:
|
||||||
|
"""Perform health check on all MCP servers."""
|
||||||
|
health_results = await mcp.health_check()
|
||||||
|
|
||||||
|
servers = {
|
||||||
|
name: ServerHealthStatus(
|
||||||
|
name=status.name,
|
||||||
|
healthy=status.healthy,
|
||||||
|
state=status.state,
|
||||||
|
url=status.url,
|
||||||
|
error=status.error,
|
||||||
|
tools_count=status.tools_count,
|
||||||
|
)
|
||||||
|
for name, status in health_results.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
healthy_count = sum(1 for s in servers.values() if s.healthy)
|
||||||
|
unhealthy_count = len(servers) - healthy_count
|
||||||
|
|
||||||
|
return HealthCheckResponse(
|
||||||
|
servers=servers,
|
||||||
|
healthy_count=healthy_count,
|
||||||
|
unhealthy_count=unhealthy_count,
|
||||||
|
total=len(servers),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/call",
|
||||||
|
response_model=ToolCallResponse,
|
||||||
|
summary="Execute Tool (Admin Only)",
|
||||||
|
description="Execute a tool on an MCP server. Requires superuser privileges.",
|
||||||
|
)
|
||||||
|
async def call_tool(
|
||||||
|
request: ToolCallRequest,
|
||||||
|
current_user: User = Depends(require_superuser),
|
||||||
|
mcp: MCPClientManager = Depends(get_mcp_client),
|
||||||
|
) -> ToolCallResponse:
|
||||||
|
"""
|
||||||
|
Execute a tool on an MCP server.
|
||||||
|
|
||||||
|
This endpoint is restricted to superusers for direct tool execution.
|
||||||
|
Normal tool execution should go through agent workflows.
|
||||||
|
"""
|
||||||
|
logger.info(
|
||||||
|
"Tool call by user %s: %s.%s",
|
||||||
|
current_user.id,
|
||||||
|
request.server,
|
||||||
|
request.tool,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await mcp.call_tool(
|
||||||
|
server=request.server,
|
||||||
|
tool=request.tool,
|
||||||
|
args=request.arguments,
|
||||||
|
timeout=request.timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ToolCallResponse(
|
||||||
|
success=result.success,
|
||||||
|
data=result.data,
|
||||||
|
error=result.error,
|
||||||
|
error_code=result.error_code,
|
||||||
|
tool_name=result.tool_name,
|
||||||
|
server_name=result.server_name,
|
||||||
|
execution_time_ms=result.execution_time_ms,
|
||||||
|
request_id=result.request_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
except MCPCircuitOpenError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
detail=f"Server temporarily unavailable: {e.server_name}",
|
||||||
|
) from e
|
||||||
|
except MCPToolNotFoundError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Tool not found: {e.tool_name}",
|
||||||
|
) from e
|
||||||
|
except MCPServerNotFoundError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Server not found: {e.server_name}",
|
||||||
|
) from e
|
||||||
|
except MCPTimeoutError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_504_GATEWAY_TIMEOUT,
|
||||||
|
detail=str(e),
|
||||||
|
) from e
|
||||||
|
except MCPConnectionError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||||
|
detail=str(e),
|
||||||
|
) from e
|
||||||
|
except MCPToolError as e:
|
||||||
|
# Tool errors are returned in the response, not as HTTP errors
|
||||||
|
return ToolCallResponse(
|
||||||
|
success=False,
|
||||||
|
error=str(e),
|
||||||
|
error_code=e.error_code,
|
||||||
|
tool_name=e.tool_name,
|
||||||
|
server_name=e.server_name,
|
||||||
|
)
|
||||||
|
except MCPError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=str(e),
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/circuit-breakers",
|
||||||
|
response_model=CircuitBreakerListResponse,
|
||||||
|
summary="List Circuit Breakers",
|
||||||
|
description="Get status of all circuit breakers.",
|
||||||
|
)
|
||||||
|
async def list_circuit_breakers(
|
||||||
|
mcp: MCPClientManager = Depends(get_mcp_client),
|
||||||
|
) -> CircuitBreakerListResponse:
|
||||||
|
"""Get status of all circuit breakers."""
|
||||||
|
status_dict = mcp.get_circuit_breaker_status()
|
||||||
|
|
||||||
|
return CircuitBreakerListResponse(
|
||||||
|
circuit_breakers=[
|
||||||
|
CircuitBreakerStatus(
|
||||||
|
server_name=name,
|
||||||
|
state=info.get("state", "unknown"),
|
||||||
|
failure_count=info.get("failure_count", 0),
|
||||||
|
)
|
||||||
|
for name, info in status_dict.items()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/circuit-breakers/{server_name}/reset",
|
||||||
|
status_code=status.HTTP_204_NO_CONTENT,
|
||||||
|
summary="Reset Circuit Breaker (Admin Only)",
|
||||||
|
description="Manually reset a circuit breaker for a server.",
|
||||||
|
)
|
||||||
|
async def reset_circuit_breaker(
|
||||||
|
server_name: ServerNamePath,
|
||||||
|
current_user: User = Depends(require_superuser),
|
||||||
|
mcp: MCPClientManager = Depends(get_mcp_client),
|
||||||
|
) -> None:
|
||||||
|
"""Manually reset a circuit breaker."""
|
||||||
|
logger.info(
|
||||||
|
"Circuit breaker reset by user %s for server %s",
|
||||||
|
current_user.id,
|
||||||
|
server_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
success = await mcp.reset_circuit_breaker(server_name)
|
||||||
|
if not success:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"No circuit breaker found for server: {server_name}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/servers/{server_name}/reconnect",
|
||||||
|
status_code=status.HTTP_204_NO_CONTENT,
|
||||||
|
summary="Reconnect to Server (Admin Only)",
|
||||||
|
description="Force reconnection to an MCP server.",
|
||||||
|
)
|
||||||
|
async def reconnect_server(
|
||||||
|
server_name: ServerNamePath,
|
||||||
|
current_user: User = Depends(require_superuser),
|
||||||
|
mcp: MCPClientManager = Depends(get_mcp_client),
|
||||||
|
) -> None:
|
||||||
|
"""Force reconnection to an MCP server."""
|
||||||
|
logger.info(
|
||||||
|
"Reconnect requested by user %s for server %s",
|
||||||
|
current_user.id,
|
||||||
|
server_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await mcp.disconnect(server_name)
|
||||||
|
await mcp.connect(server_name)
|
||||||
|
except MCPServerNotFoundError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Server not found: {server_name}",
|
||||||
|
) from e
|
||||||
|
except MCPConnectionError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||||
|
detail=f"Failed to reconnect: {e}",
|
||||||
|
) from e
|
||||||
659
backend/app/api/routes/projects.py
Normal file
659
backend/app/api/routes/projects.py
Normal file
@@ -0,0 +1,659 @@
|
|||||||
|
# app/api/routes/projects.py
|
||||||
|
"""
|
||||||
|
Project management API endpoints for Syndarix.
|
||||||
|
|
||||||
|
These endpoints allow users to manage their AI-powered software consulting projects.
|
||||||
|
Users can create, read, update, and manage the lifecycle of their projects.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Any
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, Query, Request, status
|
||||||
|
from slowapi import Limiter
|
||||||
|
from slowapi.util import get_remote_address
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.api.dependencies.auth import get_current_user
|
||||||
|
from app.core.database import get_db
|
||||||
|
from app.core.exceptions import (
|
||||||
|
AuthorizationError,
|
||||||
|
DuplicateError,
|
||||||
|
ErrorCode,
|
||||||
|
NotFoundError,
|
||||||
|
ValidationException,
|
||||||
|
)
|
||||||
|
from app.crud.syndarix.project import project as project_crud
|
||||||
|
from app.models.syndarix.enums import ProjectStatus
|
||||||
|
from app.models.user import User
|
||||||
|
from app.schemas.common import (
|
||||||
|
MessageResponse,
|
||||||
|
PaginatedResponse,
|
||||||
|
PaginationParams,
|
||||||
|
create_pagination_meta,
|
||||||
|
)
|
||||||
|
from app.schemas.syndarix.project import (
|
||||||
|
ProjectCreate,
|
||||||
|
ProjectResponse,
|
||||||
|
ProjectUpdate,
|
||||||
|
)
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Initialize rate limiter
|
||||||
|
limiter = Limiter(key_func=get_remote_address)
|
||||||
|
|
||||||
|
# Use higher rate limits in test environment
|
||||||
|
IS_TEST = os.getenv("IS_TEST", "False") == "True"
|
||||||
|
RATE_MULTIPLIER = 100 if IS_TEST else 1
|
||||||
|
|
||||||
|
|
||||||
|
def _build_project_response(project_data: dict[str, Any]) -> ProjectResponse:
|
||||||
|
"""
|
||||||
|
Build a ProjectResponse from project data dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_data: Dictionary containing project and related counts
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ProjectResponse with all fields populated
|
||||||
|
"""
|
||||||
|
project = project_data["project"]
|
||||||
|
return ProjectResponse(
|
||||||
|
id=project.id,
|
||||||
|
name=project.name,
|
||||||
|
slug=project.slug,
|
||||||
|
description=project.description,
|
||||||
|
autonomy_level=project.autonomy_level,
|
||||||
|
status=project.status,
|
||||||
|
settings=project.settings,
|
||||||
|
owner_id=project.owner_id,
|
||||||
|
created_at=project.created_at,
|
||||||
|
updated_at=project.updated_at,
|
||||||
|
agent_count=project_data.get("agent_count", 0),
|
||||||
|
issue_count=project_data.get("issue_count", 0),
|
||||||
|
active_sprint_name=project_data.get("active_sprint_name"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _check_project_ownership(project: Any, current_user: User) -> None:
|
||||||
|
"""
|
||||||
|
Check if the current user owns the project or is a superuser.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project: The project to check ownership of
|
||||||
|
current_user: The authenticated user
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AuthorizationError: If user doesn't own the project and isn't a superuser
|
||||||
|
"""
|
||||||
|
if not current_user.is_superuser and project.owner_id != current_user.id:
|
||||||
|
raise AuthorizationError(
|
||||||
|
message="You do not have permission to access this project",
|
||||||
|
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Project CRUD Endpoints
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"",
|
||||||
|
response_model=ProjectResponse,
|
||||||
|
status_code=status.HTTP_201_CREATED,
|
||||||
|
summary="Create Project",
|
||||||
|
description="""
|
||||||
|
Create a new project for the current user.
|
||||||
|
|
||||||
|
The project will be owned by the authenticated user.
|
||||||
|
A unique slug is required for URL-friendly project identification.
|
||||||
|
|
||||||
|
**Rate Limit**: 10 requests/minute
|
||||||
|
""",
|
||||||
|
operation_id="create_project",
|
||||||
|
)
|
||||||
|
@limiter.limit(f"{10 * RATE_MULTIPLIER}/minute")
|
||||||
|
async def create_project(
|
||||||
|
request: Request,
|
||||||
|
project_in: ProjectCreate,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Create a new project.
|
||||||
|
|
||||||
|
The authenticated user becomes the owner of the project.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Set the owner to the current user
|
||||||
|
project_data = ProjectCreate(
|
||||||
|
name=project_in.name,
|
||||||
|
slug=project_in.slug,
|
||||||
|
description=project_in.description,
|
||||||
|
autonomy_level=project_in.autonomy_level,
|
||||||
|
status=project_in.status,
|
||||||
|
settings=project_in.settings,
|
||||||
|
owner_id=current_user.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
project = await project_crud.create(db, obj_in=project_data)
|
||||||
|
logger.info(f"User {current_user.email} created project {project.slug}")
|
||||||
|
|
||||||
|
return ProjectResponse(
|
||||||
|
id=project.id,
|
||||||
|
name=project.name,
|
||||||
|
slug=project.slug,
|
||||||
|
description=project.description,
|
||||||
|
autonomy_level=project.autonomy_level,
|
||||||
|
status=project.status,
|
||||||
|
settings=project.settings,
|
||||||
|
owner_id=project.owner_id,
|
||||||
|
created_at=project.created_at,
|
||||||
|
updated_at=project.updated_at,
|
||||||
|
agent_count=0,
|
||||||
|
issue_count=0,
|
||||||
|
active_sprint_name=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
error_msg = str(e)
|
||||||
|
if "already exists" in error_msg.lower():
|
||||||
|
logger.warning(f"Duplicate project slug attempted: {project_in.slug}")
|
||||||
|
raise DuplicateError(
|
||||||
|
message=error_msg,
|
||||||
|
error_code=ErrorCode.DUPLICATE_ENTRY,
|
||||||
|
field="slug",
|
||||||
|
)
|
||||||
|
logger.error(f"Error creating project: {error_msg}", exc_info=True)
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error creating project: {e!s}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"",
|
||||||
|
response_model=PaginatedResponse[ProjectResponse],
|
||||||
|
summary="List Projects",
|
||||||
|
description="""
|
||||||
|
List projects for the current user with filtering and pagination.
|
||||||
|
|
||||||
|
Regular users see only their own projects.
|
||||||
|
Superusers can see all projects by setting `all_projects=true`.
|
||||||
|
|
||||||
|
**Rate Limit**: 30 requests/minute
|
||||||
|
""",
|
||||||
|
operation_id="list_projects",
|
||||||
|
)
|
||||||
|
@limiter.limit(f"{30 * RATE_MULTIPLIER}/minute")
|
||||||
|
async def list_projects(
|
||||||
|
request: Request,
|
||||||
|
pagination: PaginationParams = Depends(),
|
||||||
|
status_filter: ProjectStatus | None = Query(
|
||||||
|
None, alias="status", description="Filter by project status"
|
||||||
|
),
|
||||||
|
search: str | None = Query(
|
||||||
|
None, description="Search by name, slug, or description"
|
||||||
|
),
|
||||||
|
all_projects: bool = Query(False, description="Show all projects (superuser only)"),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
List projects with filtering, search, and pagination.
|
||||||
|
|
||||||
|
Regular users only see their own projects.
|
||||||
|
Superusers can view all projects if all_projects is true.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Determine owner filter based on user role and request
|
||||||
|
owner_id = (
|
||||||
|
None if (current_user.is_superuser and all_projects) else current_user.id
|
||||||
|
)
|
||||||
|
|
||||||
|
projects_data, total = await project_crud.get_multi_with_counts(
|
||||||
|
db,
|
||||||
|
skip=pagination.offset,
|
||||||
|
limit=pagination.limit,
|
||||||
|
status=status_filter,
|
||||||
|
owner_id=owner_id,
|
||||||
|
search=search,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build response objects
|
||||||
|
project_responses = [_build_project_response(data) for data in projects_data]
|
||||||
|
|
||||||
|
pagination_meta = create_pagination_meta(
|
||||||
|
total=total,
|
||||||
|
page=pagination.page,
|
||||||
|
limit=pagination.limit,
|
||||||
|
items_count=len(project_responses),
|
||||||
|
)
|
||||||
|
|
||||||
|
return PaginatedResponse(data=project_responses, pagination=pagination_meta)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error listing projects: {e!s}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/{project_id}",
|
||||||
|
response_model=ProjectResponse,
|
||||||
|
summary="Get Project",
|
||||||
|
description="""
|
||||||
|
Get detailed information about a specific project.
|
||||||
|
|
||||||
|
Users can only access their own projects unless they are superusers.
|
||||||
|
|
||||||
|
**Rate Limit**: 60 requests/minute
|
||||||
|
""",
|
||||||
|
operation_id="get_project",
|
||||||
|
)
|
||||||
|
@limiter.limit(f"{60 * RATE_MULTIPLIER}/minute")
|
||||||
|
async def get_project(
|
||||||
|
request: Request,
|
||||||
|
project_id: UUID,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Get detailed information about a project by ID.
|
||||||
|
|
||||||
|
Includes agent count, issue count, and active sprint name.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
project_data = await project_crud.get_with_counts(db, project_id=project_id)
|
||||||
|
|
||||||
|
if not project_data:
|
||||||
|
raise NotFoundError(
|
||||||
|
message=f"Project {project_id} not found",
|
||||||
|
error_code=ErrorCode.NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
project = project_data["project"]
|
||||||
|
_check_project_ownership(project, current_user)
|
||||||
|
|
||||||
|
return _build_project_response(project_data)
|
||||||
|
|
||||||
|
except (NotFoundError, AuthorizationError):
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting project {project_id}: {e!s}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/slug/{slug}",
|
||||||
|
response_model=ProjectResponse,
|
||||||
|
summary="Get Project by Slug",
|
||||||
|
description="""
|
||||||
|
Get detailed information about a project by its slug.
|
||||||
|
|
||||||
|
Users can only access their own projects unless they are superusers.
|
||||||
|
|
||||||
|
**Rate Limit**: 60 requests/minute
|
||||||
|
""",
|
||||||
|
operation_id="get_project_by_slug",
|
||||||
|
)
|
||||||
|
@limiter.limit(f"{60 * RATE_MULTIPLIER}/minute")
|
||||||
|
async def get_project_by_slug(
|
||||||
|
request: Request,
|
||||||
|
slug: str,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Get detailed information about a project by slug.
|
||||||
|
|
||||||
|
Includes agent count, issue count, and active sprint name.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
project = await project_crud.get_by_slug(db, slug=slug)
|
||||||
|
|
||||||
|
if not project:
|
||||||
|
raise NotFoundError(
|
||||||
|
message=f"Project with slug '{slug}' not found",
|
||||||
|
error_code=ErrorCode.NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
_check_project_ownership(project, current_user)
|
||||||
|
|
||||||
|
# Get project with counts
|
||||||
|
project_data = await project_crud.get_with_counts(db, project_id=project.id)
|
||||||
|
|
||||||
|
if not project_data:
|
||||||
|
raise NotFoundError(
|
||||||
|
message=f"Project with slug '{slug}' not found",
|
||||||
|
error_code=ErrorCode.NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
return _build_project_response(project_data)
|
||||||
|
|
||||||
|
except (NotFoundError, AuthorizationError):
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting project by slug {slug}: {e!s}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@router.patch(
|
||||||
|
"/{project_id}",
|
||||||
|
response_model=ProjectResponse,
|
||||||
|
summary="Update Project",
|
||||||
|
description="""
|
||||||
|
Update an existing project.
|
||||||
|
|
||||||
|
Only the project owner or a superuser can update a project.
|
||||||
|
Only provided fields will be updated.
|
||||||
|
|
||||||
|
**Rate Limit**: 20 requests/minute
|
||||||
|
""",
|
||||||
|
operation_id="update_project",
|
||||||
|
)
|
||||||
|
@limiter.limit(f"{20 * RATE_MULTIPLIER}/minute")
|
||||||
|
async def update_project(
|
||||||
|
request: Request,
|
||||||
|
project_id: UUID,
|
||||||
|
project_in: ProjectUpdate,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Update a project's information.
|
||||||
|
|
||||||
|
Only the project owner or superusers can perform updates.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
project = await project_crud.get(db, id=project_id)
|
||||||
|
|
||||||
|
if not project:
|
||||||
|
raise NotFoundError(
|
||||||
|
message=f"Project {project_id} not found",
|
||||||
|
error_code=ErrorCode.NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
_check_project_ownership(project, current_user)
|
||||||
|
|
||||||
|
# Update the project
|
||||||
|
updated_project = await project_crud.update(
|
||||||
|
db, db_obj=project, obj_in=project_in
|
||||||
|
)
|
||||||
|
logger.info(f"User {current_user.email} updated project {updated_project.slug}")
|
||||||
|
|
||||||
|
# Get updated project with counts
|
||||||
|
project_data = await project_crud.get_with_counts(
|
||||||
|
db, project_id=updated_project.id
|
||||||
|
)
|
||||||
|
|
||||||
|
if not project_data:
|
||||||
|
# This shouldn't happen, but handle gracefully
|
||||||
|
raise NotFoundError(
|
||||||
|
message=f"Project {project_id} not found after update",
|
||||||
|
error_code=ErrorCode.NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
return _build_project_response(project_data)
|
||||||
|
|
||||||
|
except (NotFoundError, AuthorizationError):
|
||||||
|
raise
|
||||||
|
except ValueError as e:
|
||||||
|
error_msg = str(e)
|
||||||
|
if "already exists" in error_msg.lower():
|
||||||
|
logger.warning(f"Duplicate project slug attempted: {project_in.slug}")
|
||||||
|
raise DuplicateError(
|
||||||
|
message=error_msg,
|
||||||
|
error_code=ErrorCode.DUPLICATE_ENTRY,
|
||||||
|
field="slug",
|
||||||
|
)
|
||||||
|
logger.error(f"Error updating project: {error_msg}", exc_info=True)
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error updating project {project_id}: {e!s}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete(
|
||||||
|
"/{project_id}",
|
||||||
|
response_model=MessageResponse,
|
||||||
|
summary="Archive Project",
|
||||||
|
description="""
|
||||||
|
Archive a project (soft delete).
|
||||||
|
|
||||||
|
Only the project owner or a superuser can archive a project.
|
||||||
|
Archived projects are not deleted but are no longer accessible for active work.
|
||||||
|
|
||||||
|
**Rate Limit**: 10 requests/minute
|
||||||
|
""",
|
||||||
|
operation_id="archive_project",
|
||||||
|
)
|
||||||
|
@limiter.limit(f"{10 * RATE_MULTIPLIER}/minute")
|
||||||
|
async def archive_project(
|
||||||
|
request: Request,
|
||||||
|
project_id: UUID,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Archive a project by setting its status to ARCHIVED.
|
||||||
|
|
||||||
|
This is a soft delete operation. The project data is preserved.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
project = await project_crud.get(db, id=project_id)
|
||||||
|
|
||||||
|
if not project:
|
||||||
|
raise NotFoundError(
|
||||||
|
message=f"Project {project_id} not found",
|
||||||
|
error_code=ErrorCode.NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
_check_project_ownership(project, current_user)
|
||||||
|
|
||||||
|
# Check if project is already archived
|
||||||
|
if project.status == ProjectStatus.ARCHIVED:
|
||||||
|
return MessageResponse(
|
||||||
|
success=True,
|
||||||
|
message=f"Project '{project.name}' is already archived",
|
||||||
|
)
|
||||||
|
|
||||||
|
archived_project = await project_crud.archive_project(db, project_id=project_id)
|
||||||
|
|
||||||
|
if not archived_project:
|
||||||
|
raise NotFoundError(
|
||||||
|
message=f"Failed to archive project {project_id}",
|
||||||
|
error_code=ErrorCode.NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"User {current_user.email} archived project {project.slug}")
|
||||||
|
|
||||||
|
return MessageResponse(
|
||||||
|
success=True,
|
||||||
|
message=f"Project '{archived_project.name}' has been archived",
|
||||||
|
)
|
||||||
|
|
||||||
|
except (NotFoundError, AuthorizationError):
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error archiving project {project_id}: {e!s}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Project Lifecycle Endpoints
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/{project_id}/pause",
|
||||||
|
response_model=ProjectResponse,
|
||||||
|
summary="Pause Project",
|
||||||
|
description="""
|
||||||
|
Pause an active project.
|
||||||
|
|
||||||
|
Only ACTIVE projects can be paused.
|
||||||
|
Only the project owner or a superuser can pause a project.
|
||||||
|
|
||||||
|
**Rate Limit**: 10 requests/minute
|
||||||
|
""",
|
||||||
|
operation_id="pause_project",
|
||||||
|
)
|
||||||
|
@limiter.limit(f"{10 * RATE_MULTIPLIER}/minute")
|
||||||
|
async def pause_project(
|
||||||
|
request: Request,
|
||||||
|
project_id: UUID,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Pause an active project.
|
||||||
|
|
||||||
|
Sets the project status to PAUSED. Only ACTIVE projects can be paused.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
project = await project_crud.get(db, id=project_id)
|
||||||
|
|
||||||
|
if not project:
|
||||||
|
raise NotFoundError(
|
||||||
|
message=f"Project {project_id} not found",
|
||||||
|
error_code=ErrorCode.NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
_check_project_ownership(project, current_user)
|
||||||
|
|
||||||
|
# Validate current status (business logic validation, not authorization)
|
||||||
|
if project.status == ProjectStatus.PAUSED:
|
||||||
|
raise ValidationException(
|
||||||
|
message="Project is already paused",
|
||||||
|
error_code=ErrorCode.VALIDATION_ERROR,
|
||||||
|
field="status",
|
||||||
|
)
|
||||||
|
|
||||||
|
if project.status == ProjectStatus.ARCHIVED:
|
||||||
|
raise ValidationException(
|
||||||
|
message="Cannot pause an archived project",
|
||||||
|
error_code=ErrorCode.VALIDATION_ERROR,
|
||||||
|
field="status",
|
||||||
|
)
|
||||||
|
|
||||||
|
if project.status == ProjectStatus.COMPLETED:
|
||||||
|
raise ValidationException(
|
||||||
|
message="Cannot pause a completed project",
|
||||||
|
error_code=ErrorCode.VALIDATION_ERROR,
|
||||||
|
field="status",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update status to PAUSED
|
||||||
|
updated_project = await project_crud.update(
|
||||||
|
db, db_obj=project, obj_in=ProjectUpdate(status=ProjectStatus.PAUSED)
|
||||||
|
)
|
||||||
|
logger.info(f"User {current_user.email} paused project {project.slug}")
|
||||||
|
|
||||||
|
# Get project with counts
|
||||||
|
project_data = await project_crud.get_with_counts(
|
||||||
|
db, project_id=updated_project.id
|
||||||
|
)
|
||||||
|
|
||||||
|
if not project_data:
|
||||||
|
raise NotFoundError(
|
||||||
|
message=f"Project {project_id} not found after update",
|
||||||
|
error_code=ErrorCode.NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
return _build_project_response(project_data)
|
||||||
|
|
||||||
|
except (NotFoundError, AuthorizationError, ValidationException):
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error pausing project {project_id}: {e!s}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/{project_id}/resume",
|
||||||
|
response_model=ProjectResponse,
|
||||||
|
summary="Resume Project",
|
||||||
|
description="""
|
||||||
|
Resume a paused project.
|
||||||
|
|
||||||
|
Only PAUSED projects can be resumed.
|
||||||
|
Only the project owner or a superuser can resume a project.
|
||||||
|
|
||||||
|
**Rate Limit**: 10 requests/minute
|
||||||
|
""",
|
||||||
|
operation_id="resume_project",
|
||||||
|
)
|
||||||
|
@limiter.limit(f"{10 * RATE_MULTIPLIER}/minute")
|
||||||
|
async def resume_project(
|
||||||
|
request: Request,
|
||||||
|
project_id: UUID,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Resume a paused project.
|
||||||
|
|
||||||
|
Sets the project status back to ACTIVE. Only PAUSED projects can be resumed.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
project = await project_crud.get(db, id=project_id)
|
||||||
|
|
||||||
|
if not project:
|
||||||
|
raise NotFoundError(
|
||||||
|
message=f"Project {project_id} not found",
|
||||||
|
error_code=ErrorCode.NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
_check_project_ownership(project, current_user)
|
||||||
|
|
||||||
|
# Validate current status (business logic validation, not authorization)
|
||||||
|
if project.status == ProjectStatus.ACTIVE:
|
||||||
|
raise ValidationException(
|
||||||
|
message="Project is already active",
|
||||||
|
error_code=ErrorCode.VALIDATION_ERROR,
|
||||||
|
field="status",
|
||||||
|
)
|
||||||
|
|
||||||
|
if project.status == ProjectStatus.ARCHIVED:
|
||||||
|
raise ValidationException(
|
||||||
|
message="Cannot resume an archived project",
|
||||||
|
error_code=ErrorCode.VALIDATION_ERROR,
|
||||||
|
field="status",
|
||||||
|
)
|
||||||
|
|
||||||
|
if project.status == ProjectStatus.COMPLETED:
|
||||||
|
raise ValidationException(
|
||||||
|
message="Cannot resume a completed project",
|
||||||
|
error_code=ErrorCode.VALIDATION_ERROR,
|
||||||
|
field="status",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update status to ACTIVE
|
||||||
|
updated_project = await project_crud.update(
|
||||||
|
db, db_obj=project, obj_in=ProjectUpdate(status=ProjectStatus.ACTIVE)
|
||||||
|
)
|
||||||
|
logger.info(f"User {current_user.email} resumed project {project.slug}")
|
||||||
|
|
||||||
|
# Get project with counts
|
||||||
|
project_data = await project_crud.get_with_counts(
|
||||||
|
db, project_id=updated_project.id
|
||||||
|
)
|
||||||
|
|
||||||
|
if not project_data:
|
||||||
|
raise NotFoundError(
|
||||||
|
message=f"Project {project_id} not found after update",
|
||||||
|
error_code=ErrorCode.NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
return _build_project_response(project_data)
|
||||||
|
|
||||||
|
except (NotFoundError, AuthorizationError, ValidationException):
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error resuming project {project_id}: {e!s}", exc_info=True)
|
||||||
|
raise
|
||||||
1186
backend/app/api/routes/sprints.py
Normal file
1186
backend/app/api/routes/sprints.py
Normal file
File diff suppressed because it is too large
Load Diff
116
backend/app/celery_app.py
Normal file
116
backend/app/celery_app.py
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
# app/celery_app.py
|
||||||
|
"""
|
||||||
|
Celery application configuration for Syndarix.
|
||||||
|
|
||||||
|
This module configures the Celery app for background task processing:
|
||||||
|
- Agent execution tasks (LLM calls, tool execution)
|
||||||
|
- Git operations (clone, commit, push, PR creation)
|
||||||
|
- Issue synchronization with external trackers
|
||||||
|
- Workflow state management
|
||||||
|
- Cost tracking and budget monitoring
|
||||||
|
|
||||||
|
Architecture:
|
||||||
|
- Redis as message broker and result backend
|
||||||
|
- Queue routing for task isolation
|
||||||
|
- JSON serialization for cross-language compatibility
|
||||||
|
- Beat scheduler for periodic tasks
|
||||||
|
"""
|
||||||
|
|
||||||
|
from celery import Celery
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
# Create Celery application instance
|
||||||
|
celery_app = Celery(
|
||||||
|
"syndarix",
|
||||||
|
broker=settings.celery_broker_url,
|
||||||
|
backend=settings.celery_result_backend,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Define task queues with their own exchanges and routing keys
|
||||||
|
TASK_QUEUES = {
|
||||||
|
"agent": {"exchange": "agent", "routing_key": "agent"},
|
||||||
|
"git": {"exchange": "git", "routing_key": "git"},
|
||||||
|
"sync": {"exchange": "sync", "routing_key": "sync"},
|
||||||
|
"default": {"exchange": "default", "routing_key": "default"},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Configure Celery
|
||||||
|
celery_app.conf.update(
|
||||||
|
# Serialization
|
||||||
|
task_serializer="json",
|
||||||
|
accept_content=["json"],
|
||||||
|
result_serializer="json",
|
||||||
|
# Timezone
|
||||||
|
timezone="UTC",
|
||||||
|
enable_utc=True,
|
||||||
|
# Task imports for auto-discovery
|
||||||
|
imports=("app.tasks",),
|
||||||
|
# Default queue
|
||||||
|
task_default_queue="default",
|
||||||
|
# Task queues configuration
|
||||||
|
task_queues=TASK_QUEUES,
|
||||||
|
# Task routing - route tasks to appropriate queues
|
||||||
|
task_routes={
|
||||||
|
"app.tasks.agent.*": {"queue": "agent"},
|
||||||
|
"app.tasks.git.*": {"queue": "git"},
|
||||||
|
"app.tasks.sync.*": {"queue": "sync"},
|
||||||
|
"app.tasks.*": {"queue": "default"},
|
||||||
|
},
|
||||||
|
# Time limits per ADR-003
|
||||||
|
task_soft_time_limit=300, # 5 minutes soft limit
|
||||||
|
task_time_limit=600, # 10 minutes hard limit
|
||||||
|
# Result expiration - 24 hours
|
||||||
|
result_expires=86400,
|
||||||
|
# Broker connection retry
|
||||||
|
broker_connection_retry_on_startup=True,
|
||||||
|
# Retry configuration per ADR-003 (built-in retry with backoff)
|
||||||
|
task_autoretry_for=(Exception,), # Retry on all exceptions
|
||||||
|
task_retry_kwargs={"max_retries": 3, "countdown": 5}, # Initial 5s delay
|
||||||
|
task_retry_backoff=True, # Enable exponential backoff
|
||||||
|
task_retry_backoff_max=600, # Max 10 minutes between retries
|
||||||
|
task_retry_jitter=True, # Add jitter to prevent thundering herd
|
||||||
|
# Beat schedule for periodic tasks
|
||||||
|
beat_schedule={
|
||||||
|
# Cost aggregation every hour per ADR-012
|
||||||
|
"aggregate-daily-costs": {
|
||||||
|
"task": "app.tasks.cost.aggregate_daily_costs",
|
||||||
|
"schedule": 3600.0, # 1 hour in seconds
|
||||||
|
},
|
||||||
|
# Reset daily budget counters at midnight UTC
|
||||||
|
"reset-daily-budget-counters": {
|
||||||
|
"task": "app.tasks.cost.reset_daily_budget_counters",
|
||||||
|
"schedule": 86400.0, # 24 hours in seconds
|
||||||
|
},
|
||||||
|
# Check for stale workflows every 5 minutes
|
||||||
|
"recover-stale-workflows": {
|
||||||
|
"task": "app.tasks.workflow.recover_stale_workflows",
|
||||||
|
"schedule": 300.0, # 5 minutes in seconds
|
||||||
|
},
|
||||||
|
# Incremental issue sync every minute per ADR-011
|
||||||
|
"sync-issues-incremental": {
|
||||||
|
"task": "app.tasks.sync.sync_issues_incremental",
|
||||||
|
"schedule": 60.0, # 1 minute in seconds
|
||||||
|
},
|
||||||
|
# Full issue reconciliation every 15 minutes per ADR-011
|
||||||
|
"sync-issues-full": {
|
||||||
|
"task": "app.tasks.sync.sync_issues_full",
|
||||||
|
"schedule": 900.0, # 15 minutes in seconds
|
||||||
|
},
|
||||||
|
},
|
||||||
|
# Task execution settings
|
||||||
|
task_acks_late=True, # Acknowledge tasks after execution
|
||||||
|
task_reject_on_worker_lost=True, # Reject tasks if worker dies
|
||||||
|
worker_prefetch_multiplier=1, # Fair task distribution
|
||||||
|
)
|
||||||
|
|
||||||
|
# Auto-discover tasks from task modules
|
||||||
|
celery_app.autodiscover_tasks(
|
||||||
|
[
|
||||||
|
"app.tasks.agent",
|
||||||
|
"app.tasks.git",
|
||||||
|
"app.tasks.sync",
|
||||||
|
"app.tasks.workflow",
|
||||||
|
"app.tasks.cost",
|
||||||
|
]
|
||||||
|
)
|
||||||
@@ -5,7 +5,7 @@ from pydantic_settings import BaseSettings
|
|||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
PROJECT_NAME: str = "PragmaStack"
|
PROJECT_NAME: str = "Syndarix"
|
||||||
VERSION: str = "1.0.0"
|
VERSION: str = "1.0.0"
|
||||||
API_V1_STR: str = "/api/v1"
|
API_V1_STR: str = "/api/v1"
|
||||||
|
|
||||||
@@ -39,6 +39,32 @@ class Settings(BaseSettings):
|
|||||||
db_pool_timeout: int = 30 # Seconds to wait for a connection
|
db_pool_timeout: int = 30 # Seconds to wait for a connection
|
||||||
db_pool_recycle: int = 3600 # Recycle connections after 1 hour
|
db_pool_recycle: int = 3600 # Recycle connections after 1 hour
|
||||||
|
|
||||||
|
# Redis configuration (Syndarix: cache, pub/sub, Celery broker)
|
||||||
|
REDIS_URL: str = Field(
|
||||||
|
default="redis://localhost:6379/0",
|
||||||
|
description="Redis URL for cache, pub/sub, and Celery broker",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Celery configuration (Syndarix: background task processing)
|
||||||
|
CELERY_BROKER_URL: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Celery broker URL (defaults to REDIS_URL if not set)",
|
||||||
|
)
|
||||||
|
CELERY_RESULT_BACKEND: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Celery result backend URL (defaults to REDIS_URL if not set)",
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def celery_broker_url(self) -> str:
|
||||||
|
"""Get Celery broker URL, defaulting to Redis."""
|
||||||
|
return self.CELERY_BROKER_URL or self.REDIS_URL
|
||||||
|
|
||||||
|
@property
|
||||||
|
def celery_result_backend(self) -> str:
|
||||||
|
"""Get Celery result backend URL, defaulting to Redis."""
|
||||||
|
return self.CELERY_RESULT_BACKEND or self.REDIS_URL
|
||||||
|
|
||||||
# SQL debugging (disable in production)
|
# SQL debugging (disable in production)
|
||||||
sql_echo: bool = False # Log SQL statements
|
sql_echo: bool = False # Log SQL statements
|
||||||
sql_echo_pool: bool = False # Log connection pool events
|
sql_echo_pool: bool = False # Log connection pool events
|
||||||
|
|||||||
474
backend/app/core/redis.py
Normal file
474
backend/app/core/redis.py
Normal file
@@ -0,0 +1,474 @@
|
|||||||
|
# app/core/redis.py
|
||||||
|
"""
|
||||||
|
Redis client configuration for caching and pub/sub.
|
||||||
|
|
||||||
|
This module provides async Redis connectivity with connection pooling
|
||||||
|
for FastAPI endpoints and background tasks.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Connection pooling for efficient resource usage
|
||||||
|
- Cache operations (get, set, delete, expire)
|
||||||
|
- Pub/sub operations (publish, subscribe)
|
||||||
|
- Health check for monitoring
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from redis.asyncio import ConnectionPool, Redis
|
||||||
|
from redis.asyncio.client import PubSub
|
||||||
|
from redis.exceptions import ConnectionError, RedisError, TimeoutError
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Default TTL for cache entries (1 hour)
|
||||||
|
DEFAULT_CACHE_TTL = 3600
|
||||||
|
|
||||||
|
# Connection pool settings
|
||||||
|
POOL_MAX_CONNECTIONS = 50
|
||||||
|
POOL_TIMEOUT = 10 # seconds
|
||||||
|
|
||||||
|
|
||||||
|
class RedisClient:
|
||||||
|
"""
|
||||||
|
Async Redis client with connection pooling.
|
||||||
|
|
||||||
|
Provides high-level operations for caching and pub/sub
|
||||||
|
with proper error handling and connection management.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, url: str | None = None) -> None:
|
||||||
|
"""
|
||||||
|
Initialize Redis client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: Redis connection URL. Defaults to settings.REDIS_URL.
|
||||||
|
"""
|
||||||
|
self._url = url or settings.REDIS_URL
|
||||||
|
self._pool: ConnectionPool | None = None
|
||||||
|
self._client: Redis | None = None
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
async def _ensure_pool(self) -> ConnectionPool:
|
||||||
|
"""Ensure connection pool is initialized (thread-safe)."""
|
||||||
|
if self._pool is None:
|
||||||
|
async with self._lock:
|
||||||
|
# Double-check after acquiring lock
|
||||||
|
if self._pool is None:
|
||||||
|
self._pool = ConnectionPool.from_url(
|
||||||
|
self._url,
|
||||||
|
max_connections=POOL_MAX_CONNECTIONS,
|
||||||
|
socket_timeout=POOL_TIMEOUT,
|
||||||
|
socket_connect_timeout=POOL_TIMEOUT,
|
||||||
|
decode_responses=True,
|
||||||
|
health_check_interval=30,
|
||||||
|
)
|
||||||
|
logger.info("Redis connection pool initialized")
|
||||||
|
return self._pool
|
||||||
|
|
||||||
|
async def _get_client(self) -> Redis:
|
||||||
|
"""Get Redis client instance from pool."""
|
||||||
|
pool = await self._ensure_pool()
|
||||||
|
if self._client is None:
|
||||||
|
self._client = Redis(connection_pool=pool)
|
||||||
|
return self._client
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Cache Operations
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
async def cache_get(self, key: str) -> str | None:
|
||||||
|
"""
|
||||||
|
Get a value from cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Cache key.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cached value or None if not found.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
client = await self._get_client()
|
||||||
|
value = await client.get(key)
|
||||||
|
if value is not None:
|
||||||
|
logger.debug(f"Cache hit for key: {key}")
|
||||||
|
else:
|
||||||
|
logger.debug(f"Cache miss for key: {key}")
|
||||||
|
return value
|
||||||
|
except (ConnectionError, TimeoutError) as e:
|
||||||
|
logger.error(f"Redis cache_get failed for key '{key}': {e}")
|
||||||
|
return None
|
||||||
|
except RedisError as e:
|
||||||
|
logger.error(f"Redis error in cache_get for key '{key}': {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def cache_get_json(self, key: str) -> Any | None:
|
||||||
|
"""
|
||||||
|
Get a JSON-serialized value from cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Cache key.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deserialized value or None if not found.
|
||||||
|
"""
|
||||||
|
value = await self.cache_get(key)
|
||||||
|
if value is not None:
|
||||||
|
try:
|
||||||
|
return json.loads(value)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.error(f"Failed to decode JSON for key '{key}': {e}")
|
||||||
|
return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def cache_set(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
value: str,
|
||||||
|
ttl: int | None = None,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Set a value in cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Cache key.
|
||||||
|
value: Value to cache.
|
||||||
|
ttl: Time-to-live in seconds. Defaults to DEFAULT_CACHE_TTL.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if successful, False otherwise.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
client = await self._get_client()
|
||||||
|
ttl = ttl if ttl is not None else DEFAULT_CACHE_TTL
|
||||||
|
await client.set(key, value, ex=ttl)
|
||||||
|
logger.debug(f"Cache set for key: {key} (TTL: {ttl}s)")
|
||||||
|
return True
|
||||||
|
except (ConnectionError, TimeoutError) as e:
|
||||||
|
logger.error(f"Redis cache_set failed for key '{key}': {e}")
|
||||||
|
return False
|
||||||
|
except RedisError as e:
|
||||||
|
logger.error(f"Redis error in cache_set for key '{key}': {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def cache_set_json(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
value: Any,
|
||||||
|
ttl: int | None = None,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Set a JSON-serialized value in cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Cache key.
|
||||||
|
value: Value to serialize and cache.
|
||||||
|
ttl: Time-to-live in seconds.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if successful, False otherwise.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
serialized = json.dumps(value)
|
||||||
|
return await self.cache_set(key, serialized, ttl)
|
||||||
|
except (TypeError, ValueError) as e:
|
||||||
|
logger.error(f"Failed to serialize value for key '{key}': {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def cache_delete(self, key: str) -> bool:
|
||||||
|
"""
|
||||||
|
Delete a key from cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Cache key to delete.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if key was deleted, False otherwise.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
client = await self._get_client()
|
||||||
|
result = await client.delete(key)
|
||||||
|
logger.debug(f"Cache delete for key: {key} (deleted: {result > 0})")
|
||||||
|
return result > 0
|
||||||
|
except (ConnectionError, TimeoutError) as e:
|
||||||
|
logger.error(f"Redis cache_delete failed for key '{key}': {e}")
|
||||||
|
return False
|
||||||
|
except RedisError as e:
|
||||||
|
logger.error(f"Redis error in cache_delete for key '{key}': {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def cache_delete_pattern(self, pattern: str) -> int:
|
||||||
|
"""
|
||||||
|
Delete all keys matching a pattern.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pattern: Glob-style pattern (e.g., "user:*").
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of keys deleted.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
client = await self._get_client()
|
||||||
|
deleted = 0
|
||||||
|
async for key in client.scan_iter(pattern):
|
||||||
|
await client.delete(key)
|
||||||
|
deleted += 1
|
||||||
|
logger.debug(f"Cache delete pattern '{pattern}': {deleted} keys deleted")
|
||||||
|
return deleted
|
||||||
|
except (ConnectionError, TimeoutError) as e:
|
||||||
|
logger.error(f"Redis cache_delete_pattern failed for '{pattern}': {e}")
|
||||||
|
return 0
|
||||||
|
except RedisError as e:
|
||||||
|
logger.error(f"Redis error in cache_delete_pattern for '{pattern}': {e}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
async def cache_expire(self, key: str, ttl: int) -> bool:
|
||||||
|
"""
|
||||||
|
Set or update TTL for a key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Cache key.
|
||||||
|
ttl: New TTL in seconds.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if TTL was set, False if key doesn't exist.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
client = await self._get_client()
|
||||||
|
result = await client.expire(key, ttl)
|
||||||
|
logger.debug(
|
||||||
|
f"Cache expire for key: {key} (TTL: {ttl}s, success: {result})"
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
except (ConnectionError, TimeoutError) as e:
|
||||||
|
logger.error(f"Redis cache_expire failed for key '{key}': {e}")
|
||||||
|
return False
|
||||||
|
except RedisError as e:
|
||||||
|
logger.error(f"Redis error in cache_expire for key '{key}': {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def cache_exists(self, key: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a key exists in cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Cache key.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if key exists, False otherwise.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
client = await self._get_client()
|
||||||
|
result = await client.exists(key)
|
||||||
|
return result > 0
|
||||||
|
except (ConnectionError, TimeoutError) as e:
|
||||||
|
logger.error(f"Redis cache_exists failed for key '{key}': {e}")
|
||||||
|
return False
|
||||||
|
except RedisError as e:
|
||||||
|
logger.error(f"Redis error in cache_exists for key '{key}': {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def cache_ttl(self, key: str) -> int:
|
||||||
|
"""
|
||||||
|
Get remaining TTL for a key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Cache key.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TTL in seconds, -1 if no TTL, -2 if key doesn't exist.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
client = await self._get_client()
|
||||||
|
return await client.ttl(key)
|
||||||
|
except (ConnectionError, TimeoutError) as e:
|
||||||
|
logger.error(f"Redis cache_ttl failed for key '{key}': {e}")
|
||||||
|
return -2
|
||||||
|
except RedisError as e:
|
||||||
|
logger.error(f"Redis error in cache_ttl for key '{key}': {e}")
|
||||||
|
return -2
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Pub/Sub Operations
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
async def publish(self, channel: str, message: str | dict) -> int:
|
||||||
|
"""
|
||||||
|
Publish a message to a channel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channel: Channel name.
|
||||||
|
message: Message to publish (string or dict for JSON serialization).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of subscribers that received the message.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
client = await self._get_client()
|
||||||
|
if isinstance(message, dict):
|
||||||
|
message = json.dumps(message)
|
||||||
|
result = await client.publish(channel, message)
|
||||||
|
logger.debug(f"Published to channel '{channel}': {result} subscribers")
|
||||||
|
return result
|
||||||
|
except (ConnectionError, TimeoutError) as e:
|
||||||
|
logger.error(f"Redis publish failed for channel '{channel}': {e}")
|
||||||
|
return 0
|
||||||
|
except RedisError as e:
|
||||||
|
logger.error(f"Redis error in publish for channel '{channel}': {e}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def subscribe(self, *channels: str) -> AsyncGenerator[PubSub, None]:
|
||||||
|
"""
|
||||||
|
Subscribe to one or more channels.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
async with redis_client.subscribe("channel1", "channel2") as pubsub:
|
||||||
|
async for message in pubsub.listen():
|
||||||
|
if message["type"] == "message":
|
||||||
|
print(message["data"])
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channels: Channel names to subscribe to.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
PubSub instance for receiving messages.
|
||||||
|
"""
|
||||||
|
client = await self._get_client()
|
||||||
|
pubsub = client.pubsub()
|
||||||
|
try:
|
||||||
|
await pubsub.subscribe(*channels)
|
||||||
|
logger.debug(f"Subscribed to channels: {channels}")
|
||||||
|
yield pubsub
|
||||||
|
finally:
|
||||||
|
await pubsub.unsubscribe(*channels)
|
||||||
|
await pubsub.close()
|
||||||
|
logger.debug(f"Unsubscribed from channels: {channels}")
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def psubscribe(self, *patterns: str) -> AsyncGenerator[PubSub, None]:
|
||||||
|
"""
|
||||||
|
Subscribe to channels matching patterns.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
async with redis_client.psubscribe("user:*") as pubsub:
|
||||||
|
async for message in pubsub.listen():
|
||||||
|
if message["type"] == "pmessage":
|
||||||
|
print(message["pattern"], message["channel"], message["data"])
|
||||||
|
|
||||||
|
Args:
|
||||||
|
patterns: Glob-style patterns to subscribe to.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
PubSub instance for receiving messages.
|
||||||
|
"""
|
||||||
|
client = await self._get_client()
|
||||||
|
pubsub = client.pubsub()
|
||||||
|
try:
|
||||||
|
await pubsub.psubscribe(*patterns)
|
||||||
|
logger.debug(f"Pattern subscribed: {patterns}")
|
||||||
|
yield pubsub
|
||||||
|
finally:
|
||||||
|
await pubsub.punsubscribe(*patterns)
|
||||||
|
await pubsub.close()
|
||||||
|
logger.debug(f"Pattern unsubscribed: {patterns}")
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Health & Connection Management
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
async def health_check(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check if Redis connection is healthy.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if connection is successful, False otherwise.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
client = await self._get_client()
|
||||||
|
result = await client.ping()
|
||||||
|
return result is True
|
||||||
|
except (ConnectionError, TimeoutError) as e:
|
||||||
|
logger.error(f"Redis health check failed: {e}")
|
||||||
|
return False
|
||||||
|
except RedisError as e:
|
||||||
|
logger.error(f"Redis health check error: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""
|
||||||
|
Close Redis connections and cleanup resources.
|
||||||
|
|
||||||
|
Should be called during application shutdown.
|
||||||
|
"""
|
||||||
|
if self._client:
|
||||||
|
await self._client.close()
|
||||||
|
self._client = None
|
||||||
|
logger.debug("Redis client closed")
|
||||||
|
|
||||||
|
if self._pool:
|
||||||
|
await self._pool.disconnect()
|
||||||
|
self._pool = None
|
||||||
|
logger.info("Redis connection pool closed")
|
||||||
|
|
||||||
|
async def get_pool_info(self) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get connection pool statistics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with pool information.
|
||||||
|
"""
|
||||||
|
if self._pool is None:
|
||||||
|
return {"status": "not_initialized"}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "active",
|
||||||
|
"max_connections": POOL_MAX_CONNECTIONS,
|
||||||
|
"url": self._url.split("@")[-1] if "@" in self._url else self._url,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Global Redis client instance
|
||||||
|
redis_client = RedisClient()
|
||||||
|
|
||||||
|
|
||||||
|
# FastAPI dependency for Redis client
|
||||||
|
async def get_redis() -> AsyncGenerator[RedisClient, None]:
|
||||||
|
"""
|
||||||
|
FastAPI dependency that provides the Redis client.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
@router.get("/cached-data")
|
||||||
|
async def get_data(redis: RedisClient = Depends(get_redis)):
|
||||||
|
cached = await redis.cache_get("my-key")
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
yield redis_client
|
||||||
|
|
||||||
|
|
||||||
|
# Health check function for use in /health endpoint
|
||||||
|
async def check_redis_health() -> bool:
|
||||||
|
"""
|
||||||
|
Check if Redis connection is healthy.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if connection is successful, False otherwise.
|
||||||
|
"""
|
||||||
|
return await redis_client.health_check()
|
||||||
|
|
||||||
|
|
||||||
|
# Cleanup function for application shutdown
|
||||||
|
async def close_redis() -> None:
|
||||||
|
"""
|
||||||
|
Close Redis connections.
|
||||||
|
|
||||||
|
Should be called during application shutdown.
|
||||||
|
"""
|
||||||
|
await redis_client.close()
|
||||||
20
backend/app/crud/syndarix/__init__.py
Normal file
20
backend/app/crud/syndarix/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
# app/crud/syndarix/__init__.py
|
||||||
|
"""
|
||||||
|
Syndarix CRUD operations.
|
||||||
|
|
||||||
|
This package contains CRUD operations for all Syndarix domain entities.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .agent_instance import agent_instance
|
||||||
|
from .agent_type import agent_type
|
||||||
|
from .issue import issue
|
||||||
|
from .project import project
|
||||||
|
from .sprint import sprint
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"agent_instance",
|
||||||
|
"agent_type",
|
||||||
|
"issue",
|
||||||
|
"project",
|
||||||
|
"sprint",
|
||||||
|
]
|
||||||
394
backend/app/crud/syndarix/agent_instance.py
Normal file
394
backend/app/crud/syndarix/agent_instance.py
Normal file
@@ -0,0 +1,394 @@
|
|||||||
|
# app/crud/syndarix/agent_instance.py
|
||||||
|
"""Async CRUD operations for AgentInstance model using SQLAlchemy 2.0 patterns."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from decimal import Decimal
|
||||||
|
from typing import Any
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from sqlalchemy import func, select, update
|
||||||
|
from sqlalchemy.exc import IntegrityError
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.orm import joinedload
|
||||||
|
|
||||||
|
from app.crud.base import CRUDBase
|
||||||
|
from app.models.syndarix import AgentInstance, Issue
|
||||||
|
from app.models.syndarix.enums import AgentStatus
|
||||||
|
from app.schemas.syndarix import AgentInstanceCreate, AgentInstanceUpdate
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CRUDAgentInstance(
|
||||||
|
CRUDBase[AgentInstance, AgentInstanceCreate, AgentInstanceUpdate]
|
||||||
|
):
|
||||||
|
"""Async CRUD operations for AgentInstance model."""
|
||||||
|
|
||||||
|
async def create(
|
||||||
|
self, db: AsyncSession, *, obj_in: AgentInstanceCreate
|
||||||
|
) -> AgentInstance:
|
||||||
|
"""Create a new agent instance with error handling."""
|
||||||
|
try:
|
||||||
|
db_obj = AgentInstance(
|
||||||
|
agent_type_id=obj_in.agent_type_id,
|
||||||
|
project_id=obj_in.project_id,
|
||||||
|
name=obj_in.name,
|
||||||
|
status=obj_in.status,
|
||||||
|
current_task=obj_in.current_task,
|
||||||
|
short_term_memory=obj_in.short_term_memory,
|
||||||
|
long_term_memory_ref=obj_in.long_term_memory_ref,
|
||||||
|
session_id=obj_in.session_id,
|
||||||
|
)
|
||||||
|
db.add(db_obj)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(db_obj)
|
||||||
|
return db_obj
|
||||||
|
except IntegrityError as e:
|
||||||
|
await db.rollback()
|
||||||
|
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||||
|
logger.error(f"Integrity error creating agent instance: {error_msg}")
|
||||||
|
raise ValueError(f"Database integrity error: {error_msg}")
|
||||||
|
except Exception as e:
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(
|
||||||
|
f"Unexpected error creating agent instance: {e!s}", exc_info=True
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_with_details(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
instance_id: UUID,
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""
|
||||||
|
Get an agent instance with full details including related entities.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with instance and related entity details
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Get instance with joined relationships
|
||||||
|
result = await db.execute(
|
||||||
|
select(AgentInstance)
|
||||||
|
.options(
|
||||||
|
joinedload(AgentInstance.agent_type),
|
||||||
|
joinedload(AgentInstance.project),
|
||||||
|
)
|
||||||
|
.where(AgentInstance.id == instance_id)
|
||||||
|
)
|
||||||
|
instance = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not instance:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Get assigned issues count
|
||||||
|
issues_count_result = await db.execute(
|
||||||
|
select(func.count(Issue.id)).where(
|
||||||
|
Issue.assigned_agent_id == instance_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assigned_issues_count = issues_count_result.scalar_one()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"instance": instance,
|
||||||
|
"agent_type_name": instance.agent_type.name
|
||||||
|
if instance.agent_type
|
||||||
|
else None,
|
||||||
|
"agent_type_slug": instance.agent_type.slug
|
||||||
|
if instance.agent_type
|
||||||
|
else None,
|
||||||
|
"project_name": instance.project.name if instance.project else None,
|
||||||
|
"project_slug": instance.project.slug if instance.project else None,
|
||||||
|
"assigned_issues_count": assigned_issues_count,
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error getting agent instance with details {instance_id}: {e!s}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_by_project(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
project_id: UUID,
|
||||||
|
status: AgentStatus | None = None,
|
||||||
|
skip: int = 0,
|
||||||
|
limit: int = 100,
|
||||||
|
) -> tuple[list[AgentInstance], int]:
|
||||||
|
"""Get agent instances for a specific project."""
|
||||||
|
try:
|
||||||
|
query = select(AgentInstance).where(AgentInstance.project_id == project_id)
|
||||||
|
|
||||||
|
if status is not None:
|
||||||
|
query = query.where(AgentInstance.status == status)
|
||||||
|
|
||||||
|
# Get total count
|
||||||
|
count_query = select(func.count()).select_from(query.alias())
|
||||||
|
count_result = await db.execute(count_query)
|
||||||
|
total = count_result.scalar_one()
|
||||||
|
|
||||||
|
# Apply pagination
|
||||||
|
query = query.order_by(AgentInstance.created_at.desc())
|
||||||
|
query = query.offset(skip).limit(limit)
|
||||||
|
result = await db.execute(query)
|
||||||
|
instances = list(result.scalars().all())
|
||||||
|
|
||||||
|
return instances, total
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error getting instances by project {project_id}: {e!s}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_by_agent_type(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
agent_type_id: UUID,
|
||||||
|
status: AgentStatus | None = None,
|
||||||
|
) -> list[AgentInstance]:
|
||||||
|
"""Get all instances of a specific agent type."""
|
||||||
|
try:
|
||||||
|
query = select(AgentInstance).where(
|
||||||
|
AgentInstance.agent_type_id == agent_type_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if status is not None:
|
||||||
|
query = query.where(AgentInstance.status == status)
|
||||||
|
|
||||||
|
query = query.order_by(AgentInstance.created_at.desc())
|
||||||
|
result = await db.execute(query)
|
||||||
|
return list(result.scalars().all())
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error getting instances by agent type {agent_type_id}: {e!s}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def update_status(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
instance_id: UUID,
|
||||||
|
status: AgentStatus,
|
||||||
|
current_task: str | None = None,
|
||||||
|
) -> AgentInstance | None:
|
||||||
|
"""Update the status of an agent instance."""
|
||||||
|
try:
|
||||||
|
result = await db.execute(
|
||||||
|
select(AgentInstance).where(AgentInstance.id == instance_id)
|
||||||
|
)
|
||||||
|
instance = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not instance:
|
||||||
|
return None
|
||||||
|
|
||||||
|
instance.status = status
|
||||||
|
instance.last_activity_at = datetime.now(UTC)
|
||||||
|
if current_task is not None:
|
||||||
|
instance.current_task = current_task
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(instance)
|
||||||
|
return instance
|
||||||
|
except Exception as e:
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(
|
||||||
|
f"Error updating instance status {instance_id}: {e!s}", exc_info=True
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def terminate(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
instance_id: UUID,
|
||||||
|
) -> AgentInstance | None:
|
||||||
|
"""Terminate an agent instance.
|
||||||
|
|
||||||
|
Also unassigns all issues from this agent to prevent orphaned assignments.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
result = await db.execute(
|
||||||
|
select(AgentInstance).where(AgentInstance.id == instance_id)
|
||||||
|
)
|
||||||
|
instance = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not instance:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Unassign all issues from this agent before terminating
|
||||||
|
await db.execute(
|
||||||
|
update(Issue)
|
||||||
|
.where(Issue.assigned_agent_id == instance_id)
|
||||||
|
.values(assigned_agent_id=None)
|
||||||
|
)
|
||||||
|
|
||||||
|
instance.status = AgentStatus.TERMINATED
|
||||||
|
instance.terminated_at = datetime.now(UTC)
|
||||||
|
instance.current_task = None
|
||||||
|
instance.session_id = None
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(instance)
|
||||||
|
return instance
|
||||||
|
except Exception as e:
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(
|
||||||
|
f"Error terminating instance {instance_id}: {e!s}", exc_info=True
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def record_task_completion(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
instance_id: UUID,
|
||||||
|
tokens_used: int,
|
||||||
|
cost_incurred: Decimal,
|
||||||
|
) -> AgentInstance | None:
|
||||||
|
"""Record a completed task and update metrics.
|
||||||
|
|
||||||
|
Uses atomic SQL UPDATE to prevent lost updates under concurrent load.
|
||||||
|
This avoids the read-modify-write race condition that occurs when
|
||||||
|
multiple task completions happen simultaneously.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
|
||||||
|
# Use atomic SQL UPDATE to increment counters without race conditions
|
||||||
|
# This is safe for concurrent updates - no read-modify-write pattern
|
||||||
|
result = await db.execute(
|
||||||
|
update(AgentInstance)
|
||||||
|
.where(AgentInstance.id == instance_id)
|
||||||
|
.values(
|
||||||
|
tasks_completed=AgentInstance.tasks_completed + 1,
|
||||||
|
tokens_used=AgentInstance.tokens_used + tokens_used,
|
||||||
|
cost_incurred=AgentInstance.cost_incurred + cost_incurred,
|
||||||
|
last_activity_at=now,
|
||||||
|
updated_at=now,
|
||||||
|
)
|
||||||
|
.returning(AgentInstance)
|
||||||
|
)
|
||||||
|
instance = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not instance:
|
||||||
|
return None
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
return instance
|
||||||
|
except Exception as e:
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(
|
||||||
|
f"Error recording task completion {instance_id}: {e!s}", exc_info=True
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_project_metrics(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
project_id: UUID,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Get aggregated metrics for all agents in a project."""
|
||||||
|
try:
|
||||||
|
result = await db.execute(
|
||||||
|
select(
|
||||||
|
func.count(AgentInstance.id).label("total_instances"),
|
||||||
|
func.count(AgentInstance.id)
|
||||||
|
.filter(AgentInstance.status == AgentStatus.WORKING)
|
||||||
|
.label("active_instances"),
|
||||||
|
func.count(AgentInstance.id)
|
||||||
|
.filter(AgentInstance.status == AgentStatus.IDLE)
|
||||||
|
.label("idle_instances"),
|
||||||
|
func.sum(AgentInstance.tasks_completed).label("total_tasks"),
|
||||||
|
func.sum(AgentInstance.tokens_used).label("total_tokens"),
|
||||||
|
func.sum(AgentInstance.cost_incurred).label("total_cost"),
|
||||||
|
).where(AgentInstance.project_id == project_id)
|
||||||
|
)
|
||||||
|
row = result.one()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_instances": row.total_instances or 0,
|
||||||
|
"active_instances": row.active_instances or 0,
|
||||||
|
"idle_instances": row.idle_instances or 0,
|
||||||
|
"total_tasks_completed": row.total_tasks or 0,
|
||||||
|
"total_tokens_used": row.total_tokens or 0,
|
||||||
|
"total_cost_incurred": row.total_cost or Decimal("0.0000"),
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error getting project metrics {project_id}: {e!s}", exc_info=True
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def bulk_terminate_by_project(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
project_id: UUID,
|
||||||
|
) -> int:
|
||||||
|
"""Terminate all active instances in a project.
|
||||||
|
|
||||||
|
Also unassigns all issues from these agents to prevent orphaned assignments.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# First, unassign all issues from agents in this project
|
||||||
|
# Get all agent IDs that will be terminated
|
||||||
|
agents_to_terminate = await db.execute(
|
||||||
|
select(AgentInstance.id).where(
|
||||||
|
AgentInstance.project_id == project_id,
|
||||||
|
AgentInstance.status != AgentStatus.TERMINATED,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
agent_ids = [row[0] for row in agents_to_terminate.fetchall()]
|
||||||
|
|
||||||
|
# Unassign issues from these agents
|
||||||
|
if agent_ids:
|
||||||
|
await db.execute(
|
||||||
|
update(Issue)
|
||||||
|
.where(Issue.assigned_agent_id.in_(agent_ids))
|
||||||
|
.values(assigned_agent_id=None)
|
||||||
|
)
|
||||||
|
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
stmt = (
|
||||||
|
update(AgentInstance)
|
||||||
|
.where(
|
||||||
|
AgentInstance.project_id == project_id,
|
||||||
|
AgentInstance.status != AgentStatus.TERMINATED,
|
||||||
|
)
|
||||||
|
.values(
|
||||||
|
status=AgentStatus.TERMINATED,
|
||||||
|
terminated_at=now,
|
||||||
|
current_task=None,
|
||||||
|
session_id=None,
|
||||||
|
updated_at=now,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
terminated_count = result.rowcount
|
||||||
|
logger.info(
|
||||||
|
f"Bulk terminated {terminated_count} instances in project {project_id}"
|
||||||
|
)
|
||||||
|
return terminated_count
|
||||||
|
except Exception as e:
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(
|
||||||
|
f"Error bulk terminating instances for project {project_id}: {e!s}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
# Create a singleton instance for use across the application
|
||||||
|
agent_instance = CRUDAgentInstance(AgentInstance)
|
||||||
265
backend/app/crud/syndarix/agent_type.py
Normal file
265
backend/app/crud/syndarix/agent_type.py
Normal file
@@ -0,0 +1,265 @@
|
|||||||
|
# app/crud/syndarix/agent_type.py
|
||||||
|
"""Async CRUD operations for AgentType model using SQLAlchemy 2.0 patterns."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from sqlalchemy import func, or_, select
|
||||||
|
from sqlalchemy.exc import IntegrityError
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.crud.base import CRUDBase
|
||||||
|
from app.models.syndarix import AgentInstance, AgentType
|
||||||
|
from app.schemas.syndarix import AgentTypeCreate, AgentTypeUpdate
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CRUDAgentType(CRUDBase[AgentType, AgentTypeCreate, AgentTypeUpdate]):
|
||||||
|
"""Async CRUD operations for AgentType model."""
|
||||||
|
|
||||||
|
async def get_by_slug(self, db: AsyncSession, *, slug: str) -> AgentType | None:
|
||||||
|
"""Get agent type by slug."""
|
||||||
|
try:
|
||||||
|
result = await db.execute(select(AgentType).where(AgentType.slug == slug))
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting agent type by slug {slug}: {e!s}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def create(self, db: AsyncSession, *, obj_in: AgentTypeCreate) -> AgentType:
|
||||||
|
"""Create a new agent type with error handling."""
|
||||||
|
try:
|
||||||
|
db_obj = AgentType(
|
||||||
|
name=obj_in.name,
|
||||||
|
slug=obj_in.slug,
|
||||||
|
description=obj_in.description,
|
||||||
|
expertise=obj_in.expertise,
|
||||||
|
personality_prompt=obj_in.personality_prompt,
|
||||||
|
primary_model=obj_in.primary_model,
|
||||||
|
fallback_models=obj_in.fallback_models,
|
||||||
|
model_params=obj_in.model_params,
|
||||||
|
mcp_servers=obj_in.mcp_servers,
|
||||||
|
tool_permissions=obj_in.tool_permissions,
|
||||||
|
is_active=obj_in.is_active,
|
||||||
|
)
|
||||||
|
db.add(db_obj)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(db_obj)
|
||||||
|
return db_obj
|
||||||
|
except IntegrityError as e:
|
||||||
|
await db.rollback()
|
||||||
|
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||||
|
if "slug" in error_msg.lower():
|
||||||
|
logger.warning(f"Duplicate slug attempted: {obj_in.slug}")
|
||||||
|
raise ValueError(f"Agent type with slug '{obj_in.slug}' already exists")
|
||||||
|
logger.error(f"Integrity error creating agent type: {error_msg}")
|
||||||
|
raise ValueError(f"Database integrity error: {error_msg}")
|
||||||
|
except Exception as e:
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(f"Unexpected error creating agent type: {e!s}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_multi_with_filters(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
skip: int = 0,
|
||||||
|
limit: int = 100,
|
||||||
|
is_active: bool | None = None,
|
||||||
|
search: str | None = None,
|
||||||
|
sort_by: str = "created_at",
|
||||||
|
sort_order: str = "desc",
|
||||||
|
) -> tuple[list[AgentType], int]:
|
||||||
|
"""
|
||||||
|
Get multiple agent types with filtering, searching, and sorting.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (agent types list, total count)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
query = select(AgentType)
|
||||||
|
|
||||||
|
# Apply filters
|
||||||
|
if is_active is not None:
|
||||||
|
query = query.where(AgentType.is_active == is_active)
|
||||||
|
|
||||||
|
if search:
|
||||||
|
search_filter = or_(
|
||||||
|
AgentType.name.ilike(f"%{search}%"),
|
||||||
|
AgentType.slug.ilike(f"%{search}%"),
|
||||||
|
AgentType.description.ilike(f"%{search}%"),
|
||||||
|
)
|
||||||
|
query = query.where(search_filter)
|
||||||
|
|
||||||
|
# Get total count before pagination
|
||||||
|
count_query = select(func.count()).select_from(query.alias())
|
||||||
|
count_result = await db.execute(count_query)
|
||||||
|
total = count_result.scalar_one()
|
||||||
|
|
||||||
|
# Apply sorting
|
||||||
|
sort_column = getattr(AgentType, sort_by, AgentType.created_at)
|
||||||
|
if sort_order == "desc":
|
||||||
|
query = query.order_by(sort_column.desc())
|
||||||
|
else:
|
||||||
|
query = query.order_by(sort_column.asc())
|
||||||
|
|
||||||
|
# Apply pagination
|
||||||
|
query = query.offset(skip).limit(limit)
|
||||||
|
result = await db.execute(query)
|
||||||
|
agent_types = list(result.scalars().all())
|
||||||
|
|
||||||
|
return agent_types, total
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting agent types with filters: {e!s}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_with_instance_count(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
agent_type_id: UUID,
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""
|
||||||
|
Get a single agent type with its instance count.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with agent_type and instance_count
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
result = await db.execute(
|
||||||
|
select(AgentType).where(AgentType.id == agent_type_id)
|
||||||
|
)
|
||||||
|
agent_type = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not agent_type:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Get instance count
|
||||||
|
count_result = await db.execute(
|
||||||
|
select(func.count(AgentInstance.id)).where(
|
||||||
|
AgentInstance.agent_type_id == agent_type_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
instance_count = count_result.scalar_one()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"agent_type": agent_type,
|
||||||
|
"instance_count": instance_count,
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error getting agent type with count {agent_type_id}: {e!s}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_multi_with_instance_counts(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
skip: int = 0,
|
||||||
|
limit: int = 100,
|
||||||
|
is_active: bool | None = None,
|
||||||
|
search: str | None = None,
|
||||||
|
) -> tuple[list[dict[str, Any]], int]:
|
||||||
|
"""
|
||||||
|
Get agent types with instance counts in optimized queries.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (list of dicts with agent_type and instance_count, total count)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Get filtered agent types
|
||||||
|
agent_types, total = await self.get_multi_with_filters(
|
||||||
|
db,
|
||||||
|
skip=skip,
|
||||||
|
limit=limit,
|
||||||
|
is_active=is_active,
|
||||||
|
search=search,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not agent_types:
|
||||||
|
return [], 0
|
||||||
|
|
||||||
|
agent_type_ids = [at.id for at in agent_types]
|
||||||
|
|
||||||
|
# Get instance counts in bulk
|
||||||
|
counts_result = await db.execute(
|
||||||
|
select(
|
||||||
|
AgentInstance.agent_type_id,
|
||||||
|
func.count(AgentInstance.id).label("count"),
|
||||||
|
)
|
||||||
|
.where(AgentInstance.agent_type_id.in_(agent_type_ids))
|
||||||
|
.group_by(AgentInstance.agent_type_id)
|
||||||
|
)
|
||||||
|
counts = {row.agent_type_id: row.count for row in counts_result}
|
||||||
|
|
||||||
|
# Combine results
|
||||||
|
results = [
|
||||||
|
{
|
||||||
|
"agent_type": agent_type,
|
||||||
|
"instance_count": counts.get(agent_type.id, 0),
|
||||||
|
}
|
||||||
|
for agent_type in agent_types
|
||||||
|
]
|
||||||
|
|
||||||
|
return results, total
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting agent types with counts: {e!s}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_by_expertise(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
expertise: str,
|
||||||
|
is_active: bool = True,
|
||||||
|
) -> list[AgentType]:
|
||||||
|
"""Get agent types that have a specific expertise."""
|
||||||
|
try:
|
||||||
|
# Use PostgreSQL JSONB contains operator
|
||||||
|
query = select(AgentType).where(
|
||||||
|
AgentType.expertise.contains([expertise.lower()]),
|
||||||
|
AgentType.is_active == is_active,
|
||||||
|
)
|
||||||
|
result = await db.execute(query)
|
||||||
|
return list(result.scalars().all())
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error getting agent types by expertise {expertise}: {e!s}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def deactivate(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
agent_type_id: UUID,
|
||||||
|
) -> AgentType | None:
|
||||||
|
"""Deactivate an agent type (soft delete)."""
|
||||||
|
try:
|
||||||
|
result = await db.execute(
|
||||||
|
select(AgentType).where(AgentType.id == agent_type_id)
|
||||||
|
)
|
||||||
|
agent_type = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not agent_type:
|
||||||
|
return None
|
||||||
|
|
||||||
|
agent_type.is_active = False
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(agent_type)
|
||||||
|
return agent_type
|
||||||
|
except Exception as e:
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(
|
||||||
|
f"Error deactivating agent type {agent_type_id}: {e!s}", exc_info=True
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
# Create a singleton instance for use across the application
|
||||||
|
agent_type = CRUDAgentType(AgentType)
|
||||||
525
backend/app/crud/syndarix/issue.py
Normal file
525
backend/app/crud/syndarix/issue.py
Normal file
@@ -0,0 +1,525 @@
|
|||||||
|
# app/crud/syndarix/issue.py
|
||||||
|
"""Async CRUD operations for Issue model using SQLAlchemy 2.0 patterns."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from typing import Any
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from sqlalchemy import func, or_, select
|
||||||
|
from sqlalchemy.exc import IntegrityError
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.orm import joinedload
|
||||||
|
|
||||||
|
from app.crud.base import CRUDBase
|
||||||
|
from app.models.syndarix import AgentInstance, Issue
|
||||||
|
from app.models.syndarix.enums import IssuePriority, IssueStatus, SyncStatus
|
||||||
|
from app.schemas.syndarix import IssueCreate, IssueUpdate
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CRUDIssue(CRUDBase[Issue, IssueCreate, IssueUpdate]):
|
||||||
|
"""Async CRUD operations for Issue model."""
|
||||||
|
|
||||||
|
async def create(self, db: AsyncSession, *, obj_in: IssueCreate) -> Issue:
|
||||||
|
"""Create a new issue with error handling."""
|
||||||
|
try:
|
||||||
|
db_obj = Issue(
|
||||||
|
project_id=obj_in.project_id,
|
||||||
|
title=obj_in.title,
|
||||||
|
body=obj_in.body,
|
||||||
|
status=obj_in.status,
|
||||||
|
priority=obj_in.priority,
|
||||||
|
labels=obj_in.labels,
|
||||||
|
assigned_agent_id=obj_in.assigned_agent_id,
|
||||||
|
human_assignee=obj_in.human_assignee,
|
||||||
|
sprint_id=obj_in.sprint_id,
|
||||||
|
story_points=obj_in.story_points,
|
||||||
|
external_tracker_type=obj_in.external_tracker_type,
|
||||||
|
external_issue_id=obj_in.external_issue_id,
|
||||||
|
remote_url=obj_in.remote_url,
|
||||||
|
external_issue_number=obj_in.external_issue_number,
|
||||||
|
sync_status=SyncStatus.SYNCED,
|
||||||
|
)
|
||||||
|
db.add(db_obj)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(db_obj)
|
||||||
|
return db_obj
|
||||||
|
except IntegrityError as e:
|
||||||
|
await db.rollback()
|
||||||
|
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||||
|
logger.error(f"Integrity error creating issue: {error_msg}")
|
||||||
|
raise ValueError(f"Database integrity error: {error_msg}")
|
||||||
|
except Exception as e:
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(f"Unexpected error creating issue: {e!s}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_with_details(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
issue_id: UUID,
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""
|
||||||
|
Get an issue with full details including related entity names.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with issue and related entity details
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Get issue with joined relationships
|
||||||
|
result = await db.execute(
|
||||||
|
select(Issue)
|
||||||
|
.options(
|
||||||
|
joinedload(Issue.project),
|
||||||
|
joinedload(Issue.sprint),
|
||||||
|
joinedload(Issue.assigned_agent).joinedload(
|
||||||
|
AgentInstance.agent_type
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.where(Issue.id == issue_id)
|
||||||
|
)
|
||||||
|
issue = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not issue:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"issue": issue,
|
||||||
|
"project_name": issue.project.name if issue.project else None,
|
||||||
|
"project_slug": issue.project.slug if issue.project else None,
|
||||||
|
"sprint_name": issue.sprint.name if issue.sprint else None,
|
||||||
|
"assigned_agent_type_name": (
|
||||||
|
issue.assigned_agent.agent_type.name
|
||||||
|
if issue.assigned_agent and issue.assigned_agent.agent_type
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error getting issue with details {issue_id}: {e!s}", exc_info=True
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_by_project(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
project_id: UUID,
|
||||||
|
status: IssueStatus | None = None,
|
||||||
|
priority: IssuePriority | None = None,
|
||||||
|
sprint_id: UUID | None = None,
|
||||||
|
assigned_agent_id: UUID | None = None,
|
||||||
|
labels: list[str] | None = None,
|
||||||
|
search: str | None = None,
|
||||||
|
skip: int = 0,
|
||||||
|
limit: int = 100,
|
||||||
|
sort_by: str = "created_at",
|
||||||
|
sort_order: str = "desc",
|
||||||
|
) -> tuple[list[Issue], int]:
|
||||||
|
"""Get issues for a specific project with filters."""
|
||||||
|
try:
|
||||||
|
query = select(Issue).where(Issue.project_id == project_id)
|
||||||
|
|
||||||
|
# Apply filters
|
||||||
|
if status is not None:
|
||||||
|
query = query.where(Issue.status == status)
|
||||||
|
|
||||||
|
if priority is not None:
|
||||||
|
query = query.where(Issue.priority == priority)
|
||||||
|
|
||||||
|
if sprint_id is not None:
|
||||||
|
query = query.where(Issue.sprint_id == sprint_id)
|
||||||
|
|
||||||
|
if assigned_agent_id is not None:
|
||||||
|
query = query.where(Issue.assigned_agent_id == assigned_agent_id)
|
||||||
|
|
||||||
|
if labels:
|
||||||
|
# Match any of the provided labels
|
||||||
|
for label in labels:
|
||||||
|
query = query.where(Issue.labels.contains([label.lower()]))
|
||||||
|
|
||||||
|
if search:
|
||||||
|
search_filter = or_(
|
||||||
|
Issue.title.ilike(f"%{search}%"),
|
||||||
|
Issue.body.ilike(f"%{search}%"),
|
||||||
|
)
|
||||||
|
query = query.where(search_filter)
|
||||||
|
|
||||||
|
# Get total count
|
||||||
|
count_query = select(func.count()).select_from(query.alias())
|
||||||
|
count_result = await db.execute(count_query)
|
||||||
|
total = count_result.scalar_one()
|
||||||
|
|
||||||
|
# Apply sorting
|
||||||
|
sort_column = getattr(Issue, sort_by, Issue.created_at)
|
||||||
|
if sort_order == "desc":
|
||||||
|
query = query.order_by(sort_column.desc())
|
||||||
|
else:
|
||||||
|
query = query.order_by(sort_column.asc())
|
||||||
|
|
||||||
|
# Apply pagination
|
||||||
|
query = query.offset(skip).limit(limit)
|
||||||
|
result = await db.execute(query)
|
||||||
|
issues = list(result.scalars().all())
|
||||||
|
|
||||||
|
return issues, total
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error getting issues by project {project_id}: {e!s}", exc_info=True
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_by_sprint(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
sprint_id: UUID,
|
||||||
|
status: IssueStatus | None = None,
|
||||||
|
) -> list[Issue]:
|
||||||
|
"""Get all issues in a sprint."""
|
||||||
|
try:
|
||||||
|
query = select(Issue).where(Issue.sprint_id == sprint_id)
|
||||||
|
|
||||||
|
if status is not None:
|
||||||
|
query = query.where(Issue.status == status)
|
||||||
|
|
||||||
|
query = query.order_by(Issue.priority.desc(), Issue.created_at.asc())
|
||||||
|
result = await db.execute(query)
|
||||||
|
return list(result.scalars().all())
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error getting issues by sprint {sprint_id}: {e!s}", exc_info=True
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def assign_to_agent(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
issue_id: UUID,
|
||||||
|
agent_id: UUID | None,
|
||||||
|
) -> Issue | None:
|
||||||
|
"""Assign an issue to an agent (or unassign if agent_id is None)."""
|
||||||
|
try:
|
||||||
|
result = await db.execute(select(Issue).where(Issue.id == issue_id))
|
||||||
|
issue = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not issue:
|
||||||
|
return None
|
||||||
|
|
||||||
|
issue.assigned_agent_id = agent_id
|
||||||
|
issue.human_assignee = None # Clear human assignee when assigning to agent
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(issue)
|
||||||
|
return issue
|
||||||
|
except Exception as e:
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(
|
||||||
|
f"Error assigning issue {issue_id} to agent {agent_id}: {e!s}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def assign_to_human(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
issue_id: UUID,
|
||||||
|
human_assignee: str | None,
|
||||||
|
) -> Issue | None:
|
||||||
|
"""Assign an issue to a human (or unassign if human_assignee is None)."""
|
||||||
|
try:
|
||||||
|
result = await db.execute(select(Issue).where(Issue.id == issue_id))
|
||||||
|
issue = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not issue:
|
||||||
|
return None
|
||||||
|
|
||||||
|
issue.human_assignee = human_assignee
|
||||||
|
issue.assigned_agent_id = None # Clear agent when assigning to human
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(issue)
|
||||||
|
return issue
|
||||||
|
except Exception as e:
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(
|
||||||
|
f"Error assigning issue {issue_id} to human {human_assignee}: {e!s}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def close_issue(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
issue_id: UUID,
|
||||||
|
) -> Issue | None:
|
||||||
|
"""Close an issue by setting status and closed_at timestamp."""
|
||||||
|
try:
|
||||||
|
result = await db.execute(select(Issue).where(Issue.id == issue_id))
|
||||||
|
issue = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not issue:
|
||||||
|
return None
|
||||||
|
|
||||||
|
issue.status = IssueStatus.CLOSED
|
||||||
|
issue.closed_at = datetime.now(UTC)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(issue)
|
||||||
|
return issue
|
||||||
|
except Exception as e:
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(f"Error closing issue {issue_id}: {e!s}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def reopen_issue(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
issue_id: UUID,
|
||||||
|
) -> Issue | None:
|
||||||
|
"""Reopen a closed issue."""
|
||||||
|
try:
|
||||||
|
result = await db.execute(select(Issue).where(Issue.id == issue_id))
|
||||||
|
issue = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not issue:
|
||||||
|
return None
|
||||||
|
|
||||||
|
issue.status = IssueStatus.OPEN
|
||||||
|
issue.closed_at = None
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(issue)
|
||||||
|
return issue
|
||||||
|
except Exception as e:
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(f"Error reopening issue {issue_id}: {e!s}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def update_sync_status(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
issue_id: UUID,
|
||||||
|
sync_status: SyncStatus,
|
||||||
|
last_synced_at: datetime | None = None,
|
||||||
|
external_updated_at: datetime | None = None,
|
||||||
|
) -> Issue | None:
|
||||||
|
"""Update the sync status of an issue."""
|
||||||
|
try:
|
||||||
|
result = await db.execute(select(Issue).where(Issue.id == issue_id))
|
||||||
|
issue = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not issue:
|
||||||
|
return None
|
||||||
|
|
||||||
|
issue.sync_status = sync_status
|
||||||
|
if last_synced_at:
|
||||||
|
issue.last_synced_at = last_synced_at
|
||||||
|
if external_updated_at:
|
||||||
|
issue.external_updated_at = external_updated_at
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(issue)
|
||||||
|
return issue
|
||||||
|
except Exception as e:
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(
|
||||||
|
f"Error updating sync status for issue {issue_id}: {e!s}", exc_info=True
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_project_stats(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
project_id: UUID,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Get issue statistics for a project."""
|
||||||
|
try:
|
||||||
|
# Get counts by status
|
||||||
|
status_counts = await db.execute(
|
||||||
|
select(Issue.status, func.count(Issue.id).label("count"))
|
||||||
|
.where(Issue.project_id == project_id)
|
||||||
|
.group_by(Issue.status)
|
||||||
|
)
|
||||||
|
by_status = {row.status.value: row.count for row in status_counts}
|
||||||
|
|
||||||
|
# Get counts by priority
|
||||||
|
priority_counts = await db.execute(
|
||||||
|
select(Issue.priority, func.count(Issue.id).label("count"))
|
||||||
|
.where(Issue.project_id == project_id)
|
||||||
|
.group_by(Issue.priority)
|
||||||
|
)
|
||||||
|
by_priority = {row.priority.value: row.count for row in priority_counts}
|
||||||
|
|
||||||
|
# Get story points
|
||||||
|
points_result = await db.execute(
|
||||||
|
select(
|
||||||
|
func.sum(Issue.story_points).label("total"),
|
||||||
|
func.sum(Issue.story_points)
|
||||||
|
.filter(Issue.status == IssueStatus.CLOSED)
|
||||||
|
.label("completed"),
|
||||||
|
).where(Issue.project_id == project_id)
|
||||||
|
)
|
||||||
|
points_row = points_result.one()
|
||||||
|
|
||||||
|
total_issues = sum(by_status.values())
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total": total_issues,
|
||||||
|
"open": by_status.get("open", 0),
|
||||||
|
"in_progress": by_status.get("in_progress", 0),
|
||||||
|
"in_review": by_status.get("in_review", 0),
|
||||||
|
"blocked": by_status.get("blocked", 0),
|
||||||
|
"closed": by_status.get("closed", 0),
|
||||||
|
"by_priority": by_priority,
|
||||||
|
"total_story_points": points_row.total,
|
||||||
|
"completed_story_points": points_row.completed,
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error getting issue stats for project {project_id}: {e!s}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_by_external_id(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
external_tracker_type: str,
|
||||||
|
external_issue_id: str,
|
||||||
|
) -> Issue | None:
|
||||||
|
"""Get an issue by its external tracker ID."""
|
||||||
|
try:
|
||||||
|
result = await db.execute(
|
||||||
|
select(Issue).where(
|
||||||
|
Issue.external_tracker_type == external_tracker_type,
|
||||||
|
Issue.external_issue_id == external_issue_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error getting issue by external ID {external_tracker_type}:{external_issue_id}: {e!s}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_pending_sync(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
project_id: UUID | None = None,
|
||||||
|
limit: int = 100,
|
||||||
|
) -> list[Issue]:
|
||||||
|
"""Get issues that need to be synced with external tracker."""
|
||||||
|
try:
|
||||||
|
query = select(Issue).where(
|
||||||
|
Issue.external_tracker_type.isnot(None),
|
||||||
|
Issue.sync_status.in_([SyncStatus.PENDING, SyncStatus.ERROR]),
|
||||||
|
)
|
||||||
|
|
||||||
|
if project_id:
|
||||||
|
query = query.where(Issue.project_id == project_id)
|
||||||
|
|
||||||
|
query = query.order_by(Issue.updated_at.asc()).limit(limit)
|
||||||
|
result = await db.execute(query)
|
||||||
|
return list(result.scalars().all())
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting pending sync issues: {e!s}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def remove_sprint_from_issues(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
sprint_id: UUID,
|
||||||
|
) -> int:
|
||||||
|
"""Remove sprint assignment from all issues in a sprint.
|
||||||
|
|
||||||
|
Used when deleting a sprint to clean up references.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of issues updated
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from sqlalchemy import update
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
update(Issue).where(Issue.sprint_id == sprint_id).values(sprint_id=None)
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
return result.rowcount
|
||||||
|
except Exception as e:
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(
|
||||||
|
f"Error removing sprint {sprint_id} from issues: {e!s}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def unassign(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
issue_id: UUID,
|
||||||
|
) -> Issue | None:
|
||||||
|
"""Remove agent assignment from an issue.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated issue or None if not found
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
result = await db.execute(select(Issue).where(Issue.id == issue_id))
|
||||||
|
issue = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not issue:
|
||||||
|
return None
|
||||||
|
|
||||||
|
issue.assigned_agent_id = None
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(issue)
|
||||||
|
return issue
|
||||||
|
except Exception as e:
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(f"Error unassigning issue {issue_id}: {e!s}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def remove_from_sprint(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
issue_id: UUID,
|
||||||
|
) -> Issue | None:
|
||||||
|
"""Remove an issue from its current sprint.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated issue or None if not found
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
result = await db.execute(select(Issue).where(Issue.id == issue_id))
|
||||||
|
issue = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not issue:
|
||||||
|
return None
|
||||||
|
|
||||||
|
issue.sprint_id = None
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(issue)
|
||||||
|
return issue
|
||||||
|
except Exception as e:
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(
|
||||||
|
f"Error removing issue {issue_id} from sprint: {e!s}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
# Create a singleton instance for use across the application
|
||||||
|
issue = CRUDIssue(Issue)
|
||||||
362
backend/app/crud/syndarix/project.py
Normal file
362
backend/app/crud/syndarix/project.py
Normal file
@@ -0,0 +1,362 @@
|
|||||||
|
# app/crud/syndarix/project.py
|
||||||
|
"""Async CRUD operations for Project model using SQLAlchemy 2.0 patterns."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from typing import Any
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from sqlalchemy import func, or_, select, update
|
||||||
|
from sqlalchemy.exc import IntegrityError
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.crud.base import CRUDBase
|
||||||
|
from app.models.syndarix import AgentInstance, Issue, Project, Sprint
|
||||||
|
from app.models.syndarix.enums import AgentStatus, ProjectStatus, SprintStatus
|
||||||
|
from app.schemas.syndarix import ProjectCreate, ProjectUpdate
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CRUDProject(CRUDBase[Project, ProjectCreate, ProjectUpdate]):
|
||||||
|
"""Async CRUD operations for Project model."""
|
||||||
|
|
||||||
|
async def get_by_slug(self, db: AsyncSession, *, slug: str) -> Project | None:
|
||||||
|
"""Get project by slug."""
|
||||||
|
try:
|
||||||
|
result = await db.execute(select(Project).where(Project.slug == slug))
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting project by slug {slug}: {e!s}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def create(self, db: AsyncSession, *, obj_in: ProjectCreate) -> Project:
|
||||||
|
"""Create a new project with error handling."""
|
||||||
|
try:
|
||||||
|
db_obj = Project(
|
||||||
|
name=obj_in.name,
|
||||||
|
slug=obj_in.slug,
|
||||||
|
description=obj_in.description,
|
||||||
|
autonomy_level=obj_in.autonomy_level,
|
||||||
|
status=obj_in.status,
|
||||||
|
settings=obj_in.settings or {},
|
||||||
|
owner_id=obj_in.owner_id,
|
||||||
|
)
|
||||||
|
db.add(db_obj)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(db_obj)
|
||||||
|
return db_obj
|
||||||
|
except IntegrityError as e:
|
||||||
|
await db.rollback()
|
||||||
|
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||||
|
if "slug" in error_msg.lower():
|
||||||
|
logger.warning(f"Duplicate slug attempted: {obj_in.slug}")
|
||||||
|
raise ValueError(f"Project with slug '{obj_in.slug}' already exists")
|
||||||
|
logger.error(f"Integrity error creating project: {error_msg}")
|
||||||
|
raise ValueError(f"Database integrity error: {error_msg}")
|
||||||
|
except Exception as e:
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(f"Unexpected error creating project: {e!s}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_multi_with_filters(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
skip: int = 0,
|
||||||
|
limit: int = 100,
|
||||||
|
status: ProjectStatus | None = None,
|
||||||
|
owner_id: UUID | None = None,
|
||||||
|
search: str | None = None,
|
||||||
|
sort_by: str = "created_at",
|
||||||
|
sort_order: str = "desc",
|
||||||
|
) -> tuple[list[Project], int]:
|
||||||
|
"""
|
||||||
|
Get multiple projects with filtering, searching, and sorting.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (projects list, total count)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
query = select(Project)
|
||||||
|
|
||||||
|
# Apply filters
|
||||||
|
if status is not None:
|
||||||
|
query = query.where(Project.status == status)
|
||||||
|
|
||||||
|
if owner_id is not None:
|
||||||
|
query = query.where(Project.owner_id == owner_id)
|
||||||
|
|
||||||
|
if search:
|
||||||
|
search_filter = or_(
|
||||||
|
Project.name.ilike(f"%{search}%"),
|
||||||
|
Project.slug.ilike(f"%{search}%"),
|
||||||
|
Project.description.ilike(f"%{search}%"),
|
||||||
|
)
|
||||||
|
query = query.where(search_filter)
|
||||||
|
|
||||||
|
# Get total count before pagination
|
||||||
|
count_query = select(func.count()).select_from(query.alias())
|
||||||
|
count_result = await db.execute(count_query)
|
||||||
|
total = count_result.scalar_one()
|
||||||
|
|
||||||
|
# Apply sorting
|
||||||
|
sort_column = getattr(Project, sort_by, Project.created_at)
|
||||||
|
if sort_order == "desc":
|
||||||
|
query = query.order_by(sort_column.desc())
|
||||||
|
else:
|
||||||
|
query = query.order_by(sort_column.asc())
|
||||||
|
|
||||||
|
# Apply pagination
|
||||||
|
query = query.offset(skip).limit(limit)
|
||||||
|
result = await db.execute(query)
|
||||||
|
projects = list(result.scalars().all())
|
||||||
|
|
||||||
|
return projects, total
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting projects with filters: {e!s}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_with_counts(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
project_id: UUID,
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""
|
||||||
|
Get a single project with agent and issue counts.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with project, agent_count, issue_count, active_sprint_name
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Get project
|
||||||
|
result = await db.execute(select(Project).where(Project.id == project_id))
|
||||||
|
project = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not project:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Get agent count
|
||||||
|
agent_count_result = await db.execute(
|
||||||
|
select(func.count(AgentInstance.id)).where(
|
||||||
|
AgentInstance.project_id == project_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
agent_count = agent_count_result.scalar_one()
|
||||||
|
|
||||||
|
# Get issue count
|
||||||
|
issue_count_result = await db.execute(
|
||||||
|
select(func.count(Issue.id)).where(Issue.project_id == project_id)
|
||||||
|
)
|
||||||
|
issue_count = issue_count_result.scalar_one()
|
||||||
|
|
||||||
|
# Get active sprint name
|
||||||
|
active_sprint_result = await db.execute(
|
||||||
|
select(Sprint.name).where(
|
||||||
|
Sprint.project_id == project_id,
|
||||||
|
Sprint.status == SprintStatus.ACTIVE,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
active_sprint_name = active_sprint_result.scalar_one_or_none()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"project": project,
|
||||||
|
"agent_count": agent_count,
|
||||||
|
"issue_count": issue_count,
|
||||||
|
"active_sprint_name": active_sprint_name,
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error getting project with counts {project_id}: {e!s}", exc_info=True
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_multi_with_counts(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
skip: int = 0,
|
||||||
|
limit: int = 100,
|
||||||
|
status: ProjectStatus | None = None,
|
||||||
|
owner_id: UUID | None = None,
|
||||||
|
search: str | None = None,
|
||||||
|
) -> tuple[list[dict[str, Any]], int]:
|
||||||
|
"""
|
||||||
|
Get projects with agent/issue counts in optimized queries.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (list of dicts with project and counts, total count)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Get filtered projects
|
||||||
|
projects, total = await self.get_multi_with_filters(
|
||||||
|
db,
|
||||||
|
skip=skip,
|
||||||
|
limit=limit,
|
||||||
|
status=status,
|
||||||
|
owner_id=owner_id,
|
||||||
|
search=search,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not projects:
|
||||||
|
return [], 0
|
||||||
|
|
||||||
|
project_ids = [p.id for p in projects]
|
||||||
|
|
||||||
|
# Get agent counts in bulk
|
||||||
|
agent_counts_result = await db.execute(
|
||||||
|
select(
|
||||||
|
AgentInstance.project_id,
|
||||||
|
func.count(AgentInstance.id).label("count"),
|
||||||
|
)
|
||||||
|
.where(AgentInstance.project_id.in_(project_ids))
|
||||||
|
.group_by(AgentInstance.project_id)
|
||||||
|
)
|
||||||
|
agent_counts = {row.project_id: row.count for row in agent_counts_result}
|
||||||
|
|
||||||
|
# Get issue counts in bulk
|
||||||
|
issue_counts_result = await db.execute(
|
||||||
|
select(
|
||||||
|
Issue.project_id,
|
||||||
|
func.count(Issue.id).label("count"),
|
||||||
|
)
|
||||||
|
.where(Issue.project_id.in_(project_ids))
|
||||||
|
.group_by(Issue.project_id)
|
||||||
|
)
|
||||||
|
issue_counts = {row.project_id: row.count for row in issue_counts_result}
|
||||||
|
|
||||||
|
# Get active sprint names
|
||||||
|
active_sprints_result = await db.execute(
|
||||||
|
select(Sprint.project_id, Sprint.name).where(
|
||||||
|
Sprint.project_id.in_(project_ids),
|
||||||
|
Sprint.status == SprintStatus.ACTIVE,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
active_sprints = {row.project_id: row.name for row in active_sprints_result}
|
||||||
|
|
||||||
|
# Combine results
|
||||||
|
results = [
|
||||||
|
{
|
||||||
|
"project": project,
|
||||||
|
"agent_count": agent_counts.get(project.id, 0),
|
||||||
|
"issue_count": issue_counts.get(project.id, 0),
|
||||||
|
"active_sprint_name": active_sprints.get(project.id),
|
||||||
|
}
|
||||||
|
for project in projects
|
||||||
|
]
|
||||||
|
|
||||||
|
return results, total
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting projects with counts: {e!s}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_projects_by_owner(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
owner_id: UUID,
|
||||||
|
status: ProjectStatus | None = None,
|
||||||
|
) -> list[Project]:
|
||||||
|
"""Get all projects owned by a specific user."""
|
||||||
|
try:
|
||||||
|
query = select(Project).where(Project.owner_id == owner_id)
|
||||||
|
|
||||||
|
if status is not None:
|
||||||
|
query = query.where(Project.status == status)
|
||||||
|
|
||||||
|
query = query.order_by(Project.created_at.desc())
|
||||||
|
result = await db.execute(query)
|
||||||
|
return list(result.scalars().all())
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error getting projects by owner {owner_id}: {e!s}", exc_info=True
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def archive_project(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
project_id: UUID,
|
||||||
|
) -> Project | None:
|
||||||
|
"""Archive a project by setting status to ARCHIVED.
|
||||||
|
|
||||||
|
This also performs cascading cleanup:
|
||||||
|
- Terminates all active agent instances
|
||||||
|
- Cancels all planned/active sprints
|
||||||
|
- Unassigns issues from terminated agents
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
result = await db.execute(select(Project).where(Project.id == project_id))
|
||||||
|
project = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not project:
|
||||||
|
return None
|
||||||
|
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
|
||||||
|
# 1. Get all agent IDs that will be terminated
|
||||||
|
agents_to_terminate = await db.execute(
|
||||||
|
select(AgentInstance.id).where(
|
||||||
|
AgentInstance.project_id == project_id,
|
||||||
|
AgentInstance.status != AgentStatus.TERMINATED,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
agent_ids = [row[0] for row in agents_to_terminate.fetchall()]
|
||||||
|
|
||||||
|
# 2. Unassign issues from these agents to prevent orphaned assignments
|
||||||
|
if agent_ids:
|
||||||
|
await db.execute(
|
||||||
|
update(Issue)
|
||||||
|
.where(Issue.assigned_agent_id.in_(agent_ids))
|
||||||
|
.values(assigned_agent_id=None)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Terminate all active agents
|
||||||
|
await db.execute(
|
||||||
|
update(AgentInstance)
|
||||||
|
.where(
|
||||||
|
AgentInstance.project_id == project_id,
|
||||||
|
AgentInstance.status != AgentStatus.TERMINATED,
|
||||||
|
)
|
||||||
|
.values(
|
||||||
|
status=AgentStatus.TERMINATED,
|
||||||
|
terminated_at=now,
|
||||||
|
current_task=None,
|
||||||
|
session_id=None,
|
||||||
|
updated_at=now,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. Cancel all planned/active sprints
|
||||||
|
await db.execute(
|
||||||
|
update(Sprint)
|
||||||
|
.where(
|
||||||
|
Sprint.project_id == project_id,
|
||||||
|
Sprint.status.in_([SprintStatus.PLANNED, SprintStatus.ACTIVE]),
|
||||||
|
)
|
||||||
|
.values(
|
||||||
|
status=SprintStatus.CANCELLED,
|
||||||
|
updated_at=now,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 5. Archive the project
|
||||||
|
project.status = ProjectStatus.ARCHIVED
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(project)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Archived project {project_id}: terminated agents={len(agent_ids)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return project
|
||||||
|
except Exception as e:
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(f"Error archiving project {project_id}: {e!s}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
# Create a singleton instance for use across the application
|
||||||
|
project = CRUDProject(Project)
|
||||||
439
backend/app/crud/syndarix/sprint.py
Normal file
439
backend/app/crud/syndarix/sprint.py
Normal file
@@ -0,0 +1,439 @@
|
|||||||
|
# app/crud/syndarix/sprint.py
|
||||||
|
"""Async CRUD operations for Sprint model using SQLAlchemy 2.0 patterns."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import date
|
||||||
|
from typing import Any
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from sqlalchemy import func, select
|
||||||
|
from sqlalchemy.exc import IntegrityError
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.orm import joinedload
|
||||||
|
|
||||||
|
from app.crud.base import CRUDBase
|
||||||
|
from app.models.syndarix import Issue, Sprint
|
||||||
|
from app.models.syndarix.enums import IssueStatus, SprintStatus
|
||||||
|
from app.schemas.syndarix import SprintCreate, SprintUpdate
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CRUDSprint(CRUDBase[Sprint, SprintCreate, SprintUpdate]):
|
||||||
|
"""Async CRUD operations for Sprint model."""
|
||||||
|
|
||||||
|
async def create(self, db: AsyncSession, *, obj_in: SprintCreate) -> Sprint:
|
||||||
|
"""Create a new sprint with error handling."""
|
||||||
|
try:
|
||||||
|
db_obj = Sprint(
|
||||||
|
project_id=obj_in.project_id,
|
||||||
|
name=obj_in.name,
|
||||||
|
number=obj_in.number,
|
||||||
|
goal=obj_in.goal,
|
||||||
|
start_date=obj_in.start_date,
|
||||||
|
end_date=obj_in.end_date,
|
||||||
|
status=obj_in.status,
|
||||||
|
planned_points=obj_in.planned_points,
|
||||||
|
velocity=obj_in.velocity,
|
||||||
|
)
|
||||||
|
db.add(db_obj)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(db_obj)
|
||||||
|
return db_obj
|
||||||
|
except IntegrityError as e:
|
||||||
|
await db.rollback()
|
||||||
|
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||||
|
logger.error(f"Integrity error creating sprint: {error_msg}")
|
||||||
|
raise ValueError(f"Database integrity error: {error_msg}")
|
||||||
|
except Exception as e:
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(f"Unexpected error creating sprint: {e!s}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_with_details(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
sprint_id: UUID,
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""
|
||||||
|
Get a sprint with full details including issue counts.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with sprint and related details
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Get sprint with joined project
|
||||||
|
result = await db.execute(
|
||||||
|
select(Sprint)
|
||||||
|
.options(joinedload(Sprint.project))
|
||||||
|
.where(Sprint.id == sprint_id)
|
||||||
|
)
|
||||||
|
sprint = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not sprint:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Get issue counts
|
||||||
|
issue_counts = await db.execute(
|
||||||
|
select(
|
||||||
|
func.count(Issue.id).label("total"),
|
||||||
|
func.count(Issue.id)
|
||||||
|
.filter(Issue.status == IssueStatus.OPEN)
|
||||||
|
.label("open"),
|
||||||
|
func.count(Issue.id)
|
||||||
|
.filter(Issue.status == IssueStatus.CLOSED)
|
||||||
|
.label("completed"),
|
||||||
|
).where(Issue.sprint_id == sprint_id)
|
||||||
|
)
|
||||||
|
counts = issue_counts.one()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"sprint": sprint,
|
||||||
|
"project_name": sprint.project.name if sprint.project else None,
|
||||||
|
"project_slug": sprint.project.slug if sprint.project else None,
|
||||||
|
"issue_count": counts.total,
|
||||||
|
"open_issues": counts.open,
|
||||||
|
"completed_issues": counts.completed,
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error getting sprint with details {sprint_id}: {e!s}", exc_info=True
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_by_project(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
project_id: UUID,
|
||||||
|
status: SprintStatus | None = None,
|
||||||
|
skip: int = 0,
|
||||||
|
limit: int = 100,
|
||||||
|
) -> tuple[list[Sprint], int]:
|
||||||
|
"""Get sprints for a specific project."""
|
||||||
|
try:
|
||||||
|
query = select(Sprint).where(Sprint.project_id == project_id)
|
||||||
|
|
||||||
|
if status is not None:
|
||||||
|
query = query.where(Sprint.status == status)
|
||||||
|
|
||||||
|
# Get total count
|
||||||
|
count_query = select(func.count()).select_from(query.alias())
|
||||||
|
count_result = await db.execute(count_query)
|
||||||
|
total = count_result.scalar_one()
|
||||||
|
|
||||||
|
# Apply sorting (by number descending - newest first)
|
||||||
|
query = query.order_by(Sprint.number.desc())
|
||||||
|
query = query.offset(skip).limit(limit)
|
||||||
|
result = await db.execute(query)
|
||||||
|
sprints = list(result.scalars().all())
|
||||||
|
|
||||||
|
return sprints, total
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error getting sprints by project {project_id}: {e!s}", exc_info=True
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_active_sprint(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
project_id: UUID,
|
||||||
|
) -> Sprint | None:
|
||||||
|
"""Get the currently active sprint for a project."""
|
||||||
|
try:
|
||||||
|
result = await db.execute(
|
||||||
|
select(Sprint).where(
|
||||||
|
Sprint.project_id == project_id,
|
||||||
|
Sprint.status == SprintStatus.ACTIVE,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error getting active sprint for project {project_id}: {e!s}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_next_sprint_number(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
project_id: UUID,
|
||||||
|
) -> int:
|
||||||
|
"""Get the next sprint number for a project."""
|
||||||
|
try:
|
||||||
|
result = await db.execute(
|
||||||
|
select(func.max(Sprint.number)).where(Sprint.project_id == project_id)
|
||||||
|
)
|
||||||
|
max_number = result.scalar_one_or_none()
|
||||||
|
return (max_number or 0) + 1
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error getting next sprint number for project {project_id}: {e!s}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def start_sprint(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
sprint_id: UUID,
|
||||||
|
start_date: date | None = None,
|
||||||
|
) -> Sprint | None:
|
||||||
|
"""Start a planned sprint.
|
||||||
|
|
||||||
|
Uses row-level locking (SELECT FOR UPDATE) to prevent race conditions
|
||||||
|
when multiple requests try to start sprints concurrently.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Lock the sprint row to prevent concurrent modifications
|
||||||
|
result = await db.execute(
|
||||||
|
select(Sprint).where(Sprint.id == sprint_id).with_for_update()
|
||||||
|
)
|
||||||
|
sprint = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not sprint:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if sprint.status != SprintStatus.PLANNED:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot start sprint with status {sprint.status.value}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for existing active sprint with lock to prevent race condition
|
||||||
|
# Lock all sprints for this project to ensure atomic check-and-update
|
||||||
|
active_check = await db.execute(
|
||||||
|
select(Sprint)
|
||||||
|
.where(
|
||||||
|
Sprint.project_id == sprint.project_id,
|
||||||
|
Sprint.status == SprintStatus.ACTIVE,
|
||||||
|
)
|
||||||
|
.with_for_update()
|
||||||
|
)
|
||||||
|
active_sprint = active_check.scalar_one_or_none()
|
||||||
|
if active_sprint:
|
||||||
|
raise ValueError(
|
||||||
|
f"Project already has an active sprint: {active_sprint.name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
sprint.status = SprintStatus.ACTIVE
|
||||||
|
if start_date:
|
||||||
|
sprint.start_date = start_date
|
||||||
|
|
||||||
|
# Calculate planned points from issues
|
||||||
|
points_result = await db.execute(
|
||||||
|
select(func.sum(Issue.story_points)).where(Issue.sprint_id == sprint_id)
|
||||||
|
)
|
||||||
|
sprint.planned_points = points_result.scalar_one_or_none() or 0
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(sprint)
|
||||||
|
return sprint
|
||||||
|
except ValueError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(f"Error starting sprint {sprint_id}: {e!s}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def complete_sprint(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
sprint_id: UUID,
|
||||||
|
) -> Sprint | None:
|
||||||
|
"""Complete an active sprint and calculate completed points.
|
||||||
|
|
||||||
|
Uses row-level locking (SELECT FOR UPDATE) to prevent race conditions
|
||||||
|
when velocity is being calculated and other operations might modify issues.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Lock the sprint row to prevent concurrent modifications
|
||||||
|
result = await db.execute(
|
||||||
|
select(Sprint).where(Sprint.id == sprint_id).with_for_update()
|
||||||
|
)
|
||||||
|
sprint = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not sprint:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if sprint.status != SprintStatus.ACTIVE:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot complete sprint with status {sprint.status.value}"
|
||||||
|
)
|
||||||
|
|
||||||
|
sprint.status = SprintStatus.COMPLETED
|
||||||
|
|
||||||
|
# Calculate velocity (completed points) from closed issues
|
||||||
|
# Note: Issues are not locked, but sprint lock ensures this sprint's
|
||||||
|
# completion is atomic and prevents concurrent completion attempts
|
||||||
|
points_result = await db.execute(
|
||||||
|
select(func.sum(Issue.story_points)).where(
|
||||||
|
Issue.sprint_id == sprint_id,
|
||||||
|
Issue.status == IssueStatus.CLOSED,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
sprint.velocity = points_result.scalar_one_or_none() or 0
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(sprint)
|
||||||
|
return sprint
|
||||||
|
except ValueError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(f"Error completing sprint {sprint_id}: {e!s}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def cancel_sprint(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
sprint_id: UUID,
|
||||||
|
) -> Sprint | None:
|
||||||
|
"""Cancel a sprint (only PLANNED or ACTIVE sprints can be cancelled).
|
||||||
|
|
||||||
|
Uses row-level locking to prevent race conditions with concurrent
|
||||||
|
sprint status modifications.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Lock the sprint row to prevent concurrent modifications
|
||||||
|
result = await db.execute(
|
||||||
|
select(Sprint).where(Sprint.id == sprint_id).with_for_update()
|
||||||
|
)
|
||||||
|
sprint = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not sprint:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if sprint.status not in [SprintStatus.PLANNED, SprintStatus.ACTIVE]:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot cancel sprint with status {sprint.status.value}"
|
||||||
|
)
|
||||||
|
|
||||||
|
sprint.status = SprintStatus.CANCELLED
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(sprint)
|
||||||
|
return sprint
|
||||||
|
except ValueError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(f"Error cancelling sprint {sprint_id}: {e!s}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_velocity(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
project_id: UUID,
|
||||||
|
limit: int = 5,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Get velocity data for completed sprints."""
|
||||||
|
try:
|
||||||
|
result = await db.execute(
|
||||||
|
select(Sprint)
|
||||||
|
.where(
|
||||||
|
Sprint.project_id == project_id,
|
||||||
|
Sprint.status == SprintStatus.COMPLETED,
|
||||||
|
)
|
||||||
|
.order_by(Sprint.number.desc())
|
||||||
|
.limit(limit)
|
||||||
|
)
|
||||||
|
sprints = list(result.scalars().all())
|
||||||
|
|
||||||
|
velocity_data = []
|
||||||
|
for sprint in reversed(sprints): # Return in chronological order
|
||||||
|
velocity_ratio = None
|
||||||
|
if sprint.planned_points and sprint.planned_points > 0:
|
||||||
|
velocity_ratio = (sprint.velocity or 0) / sprint.planned_points
|
||||||
|
velocity_data.append(
|
||||||
|
{
|
||||||
|
"sprint_number": sprint.number,
|
||||||
|
"sprint_name": sprint.name,
|
||||||
|
"planned_points": sprint.planned_points,
|
||||||
|
"velocity": sprint.velocity,
|
||||||
|
"velocity_ratio": velocity_ratio,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return velocity_data
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error getting velocity for project {project_id}: {e!s}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_sprints_with_issue_counts(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
project_id: UUID,
|
||||||
|
skip: int = 0,
|
||||||
|
limit: int = 100,
|
||||||
|
) -> tuple[list[dict[str, Any]], int]:
|
||||||
|
"""Get sprints with issue counts in optimized queries."""
|
||||||
|
try:
|
||||||
|
# Get sprints
|
||||||
|
sprints, total = await self.get_by_project(
|
||||||
|
db, project_id=project_id, skip=skip, limit=limit
|
||||||
|
)
|
||||||
|
|
||||||
|
if not sprints:
|
||||||
|
return [], 0
|
||||||
|
|
||||||
|
sprint_ids = [s.id for s in sprints]
|
||||||
|
|
||||||
|
# Get issue counts in bulk
|
||||||
|
issue_counts = await db.execute(
|
||||||
|
select(
|
||||||
|
Issue.sprint_id,
|
||||||
|
func.count(Issue.id).label("total"),
|
||||||
|
func.count(Issue.id)
|
||||||
|
.filter(Issue.status == IssueStatus.OPEN)
|
||||||
|
.label("open"),
|
||||||
|
func.count(Issue.id)
|
||||||
|
.filter(Issue.status == IssueStatus.CLOSED)
|
||||||
|
.label("completed"),
|
||||||
|
)
|
||||||
|
.where(Issue.sprint_id.in_(sprint_ids))
|
||||||
|
.group_by(Issue.sprint_id)
|
||||||
|
)
|
||||||
|
counts_map = {
|
||||||
|
row.sprint_id: {
|
||||||
|
"issue_count": row.total,
|
||||||
|
"open_issues": row.open,
|
||||||
|
"completed_issues": row.completed,
|
||||||
|
}
|
||||||
|
for row in issue_counts
|
||||||
|
}
|
||||||
|
|
||||||
|
# Combine results
|
||||||
|
results = [
|
||||||
|
{
|
||||||
|
"sprint": sprint,
|
||||||
|
**counts_map.get(
|
||||||
|
sprint.id,
|
||||||
|
{"issue_count": 0, "open_issues": 0, "completed_issues": 0},
|
||||||
|
),
|
||||||
|
}
|
||||||
|
for sprint in sprints
|
||||||
|
]
|
||||||
|
|
||||||
|
return results, total
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error getting sprints with counts for project {project_id}: {e!s}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
# Create a singleton instance for use across the application
|
||||||
|
sprint = CRUDSprint(Sprint)
|
||||||
@@ -18,13 +18,26 @@ from .oauth_provider_token import OAuthConsent, OAuthProviderRefreshToken
|
|||||||
from .oauth_state import OAuthState
|
from .oauth_state import OAuthState
|
||||||
from .organization import Organization
|
from .organization import Organization
|
||||||
|
|
||||||
|
# Syndarix domain models
|
||||||
|
from .syndarix import (
|
||||||
|
AgentInstance,
|
||||||
|
AgentType,
|
||||||
|
Issue,
|
||||||
|
Project,
|
||||||
|
Sprint,
|
||||||
|
)
|
||||||
|
|
||||||
# Import models
|
# Import models
|
||||||
from .user import User
|
from .user import User
|
||||||
from .user_organization import OrganizationRole, UserOrganization
|
from .user_organization import OrganizationRole, UserOrganization
|
||||||
from .user_session import UserSession
|
from .user_session import UserSession
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
# Syndarix models
|
||||||
|
"AgentInstance",
|
||||||
|
"AgentType",
|
||||||
"Base",
|
"Base",
|
||||||
|
"Issue",
|
||||||
"OAuthAccount",
|
"OAuthAccount",
|
||||||
"OAuthAuthorizationCode",
|
"OAuthAuthorizationCode",
|
||||||
"OAuthClient",
|
"OAuthClient",
|
||||||
@@ -33,6 +46,8 @@ __all__ = [
|
|||||||
"OAuthState",
|
"OAuthState",
|
||||||
"Organization",
|
"Organization",
|
||||||
"OrganizationRole",
|
"OrganizationRole",
|
||||||
|
"Project",
|
||||||
|
"Sprint",
|
||||||
"TimestampMixin",
|
"TimestampMixin",
|
||||||
"UUIDMixin",
|
"UUIDMixin",
|
||||||
"User",
|
"User",
|
||||||
|
|||||||
47
backend/app/models/syndarix/__init__.py
Normal file
47
backend/app/models/syndarix/__init__.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
# app/models/syndarix/__init__.py
|
||||||
|
"""
|
||||||
|
Syndarix domain models.
|
||||||
|
|
||||||
|
This package contains all the core entities for the Syndarix AI consulting platform:
|
||||||
|
- Project: Client engagements with autonomy settings
|
||||||
|
- AgentType: Templates for AI agent capabilities
|
||||||
|
- AgentInstance: Spawned agents working on projects
|
||||||
|
- Issue: Units of work with external tracker sync
|
||||||
|
- Sprint: Time-boxed iterations for organizing work
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .agent_instance import AgentInstance
|
||||||
|
from .agent_type import AgentType
|
||||||
|
from .enums import (
|
||||||
|
AgentStatus,
|
||||||
|
AutonomyLevel,
|
||||||
|
ClientMode,
|
||||||
|
IssuePriority,
|
||||||
|
IssueStatus,
|
||||||
|
IssueType,
|
||||||
|
ProjectComplexity,
|
||||||
|
ProjectStatus,
|
||||||
|
SprintStatus,
|
||||||
|
SyncStatus,
|
||||||
|
)
|
||||||
|
from .issue import Issue
|
||||||
|
from .project import Project
|
||||||
|
from .sprint import Sprint
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AgentInstance",
|
||||||
|
"AgentStatus",
|
||||||
|
"AgentType",
|
||||||
|
"AutonomyLevel",
|
||||||
|
"ClientMode",
|
||||||
|
"Issue",
|
||||||
|
"IssuePriority",
|
||||||
|
"IssueStatus",
|
||||||
|
"IssueType",
|
||||||
|
"Project",
|
||||||
|
"ProjectComplexity",
|
||||||
|
"ProjectStatus",
|
||||||
|
"Sprint",
|
||||||
|
"SprintStatus",
|
||||||
|
"SyncStatus",
|
||||||
|
]
|
||||||
111
backend/app/models/syndarix/agent_instance.py
Normal file
111
backend/app/models/syndarix/agent_instance.py
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
# app/models/syndarix/agent_instance.py
|
||||||
|
"""
|
||||||
|
AgentInstance model for Syndarix AI consulting platform.
|
||||||
|
|
||||||
|
An AgentInstance is a spawned instance of an AgentType, assigned to a
|
||||||
|
specific project to perform work.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from sqlalchemy import (
|
||||||
|
BigInteger,
|
||||||
|
Column,
|
||||||
|
DateTime,
|
||||||
|
Enum,
|
||||||
|
ForeignKey,
|
||||||
|
Index,
|
||||||
|
Integer,
|
||||||
|
Numeric,
|
||||||
|
String,
|
||||||
|
Text,
|
||||||
|
)
|
||||||
|
from sqlalchemy.dialects.postgresql import (
|
||||||
|
JSONB,
|
||||||
|
UUID as PGUUID,
|
||||||
|
)
|
||||||
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
|
from app.models.base import Base, TimestampMixin, UUIDMixin
|
||||||
|
|
||||||
|
from .enums import AgentStatus
|
||||||
|
|
||||||
|
|
||||||
|
class AgentInstance(Base, UUIDMixin, TimestampMixin):
|
||||||
|
"""
|
||||||
|
AgentInstance model representing a spawned agent working on a project.
|
||||||
|
|
||||||
|
Tracks:
|
||||||
|
- Current status and task
|
||||||
|
- Memory (short-term in DB, long-term reference to vector store)
|
||||||
|
- Session information for MCP connections
|
||||||
|
- Usage metrics (tasks completed, tokens, cost)
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "agent_instances"
|
||||||
|
|
||||||
|
# Foreign keys
|
||||||
|
agent_type_id = Column(
|
||||||
|
PGUUID(as_uuid=True),
|
||||||
|
ForeignKey("agent_types.id", ondelete="RESTRICT"),
|
||||||
|
nullable=False,
|
||||||
|
index=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
project_id = Column(
|
||||||
|
PGUUID(as_uuid=True),
|
||||||
|
ForeignKey("projects.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
index=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Agent instance name (e.g., "Dave", "Eve") for personality
|
||||||
|
name = Column(String(100), nullable=False, index=True)
|
||||||
|
|
||||||
|
# Status tracking
|
||||||
|
status: Column[AgentStatus] = Column(
|
||||||
|
Enum(AgentStatus),
|
||||||
|
default=AgentStatus.IDLE,
|
||||||
|
nullable=False,
|
||||||
|
index=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Current task description (brief summary of what agent is doing)
|
||||||
|
current_task = Column(Text, nullable=True)
|
||||||
|
|
||||||
|
# Short-term memory stored in database (conversation context, recent decisions)
|
||||||
|
short_term_memory = Column(JSONB, default=dict, nullable=False)
|
||||||
|
|
||||||
|
# Reference to long-term memory in vector store (e.g., "project-123/agent-456")
|
||||||
|
long_term_memory_ref = Column(String(500), nullable=True)
|
||||||
|
|
||||||
|
# Session ID for active MCP connections
|
||||||
|
session_id = Column(String(255), nullable=True, index=True)
|
||||||
|
|
||||||
|
# Activity tracking
|
||||||
|
last_activity_at = Column(DateTime(timezone=True), nullable=True, index=True)
|
||||||
|
terminated_at = Column(DateTime(timezone=True), nullable=True, index=True)
|
||||||
|
|
||||||
|
# Usage metrics
|
||||||
|
tasks_completed = Column(Integer, default=0, nullable=False)
|
||||||
|
tokens_used = Column(BigInteger, default=0, nullable=False)
|
||||||
|
cost_incurred = Column(Numeric(precision=10, scale=4), default=0, nullable=False)
|
||||||
|
|
||||||
|
# Relationships
|
||||||
|
agent_type = relationship("AgentType", back_populates="instances")
|
||||||
|
project = relationship("Project", back_populates="agent_instances")
|
||||||
|
assigned_issues = relationship(
|
||||||
|
"Issue",
|
||||||
|
back_populates="assigned_agent",
|
||||||
|
foreign_keys="Issue.assigned_agent_id",
|
||||||
|
)
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("ix_agent_instances_project_status", "project_id", "status"),
|
||||||
|
Index("ix_agent_instances_type_status", "agent_type_id", "status"),
|
||||||
|
Index("ix_agent_instances_project_type", "project_id", "agent_type_id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return (
|
||||||
|
f"<AgentInstance {self.name} ({self.id}) type={self.agent_type_id} "
|
||||||
|
f"project={self.project_id} status={self.status.value}>"
|
||||||
|
)
|
||||||
72
backend/app/models/syndarix/agent_type.py
Normal file
72
backend/app/models/syndarix/agent_type.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
# app/models/syndarix/agent_type.py
|
||||||
|
"""
|
||||||
|
AgentType model for Syndarix AI consulting platform.
|
||||||
|
|
||||||
|
An AgentType is a template that defines the capabilities, personality,
|
||||||
|
and model configuration for agent instances.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from sqlalchemy import Boolean, Column, Index, String, Text
|
||||||
|
from sqlalchemy.dialects.postgresql import JSONB
|
||||||
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
|
from app.models.base import Base, TimestampMixin, UUIDMixin
|
||||||
|
|
||||||
|
|
||||||
|
class AgentType(Base, UUIDMixin, TimestampMixin):
|
||||||
|
"""
|
||||||
|
AgentType model representing a template for agent instances.
|
||||||
|
|
||||||
|
Each agent type defines:
|
||||||
|
- Expertise areas and personality prompt
|
||||||
|
- Model configuration (primary, fallback, parameters)
|
||||||
|
- MCP server access and tool permissions
|
||||||
|
|
||||||
|
Examples: ProductOwner, Architect, BackendEngineer, QAEngineer
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "agent_types"
|
||||||
|
|
||||||
|
name = Column(String(255), nullable=False, index=True)
|
||||||
|
slug = Column(String(255), unique=True, nullable=False, index=True)
|
||||||
|
description = Column(Text, nullable=True)
|
||||||
|
|
||||||
|
# Areas of expertise for this agent type (e.g., ["python", "fastapi", "databases"])
|
||||||
|
expertise = Column(JSONB, default=list, nullable=False)
|
||||||
|
|
||||||
|
# System prompt defining the agent's personality and behavior
|
||||||
|
personality_prompt = Column(Text, nullable=False)
|
||||||
|
|
||||||
|
# Primary LLM model to use (e.g., "claude-opus-4-5-20251101")
|
||||||
|
primary_model = Column(String(100), nullable=False)
|
||||||
|
|
||||||
|
# Fallback models in order of preference
|
||||||
|
fallback_models = Column(JSONB, default=list, nullable=False)
|
||||||
|
|
||||||
|
# Model parameters (temperature, max_tokens, etc.)
|
||||||
|
model_params = Column(JSONB, default=dict, nullable=False)
|
||||||
|
|
||||||
|
# List of MCP servers this agent can connect to
|
||||||
|
mcp_servers = Column(JSONB, default=list, nullable=False)
|
||||||
|
|
||||||
|
# Tool permissions configuration
|
||||||
|
# Structure: {"allowed": ["*"], "denied": [], "require_approval": ["gitea:create_pr"]}
|
||||||
|
tool_permissions = Column(JSONB, default=dict, nullable=False)
|
||||||
|
|
||||||
|
# Whether this agent type is available for new instances
|
||||||
|
is_active = Column(Boolean, default=True, nullable=False, index=True)
|
||||||
|
|
||||||
|
# Relationships
|
||||||
|
instances = relationship(
|
||||||
|
"AgentInstance",
|
||||||
|
back_populates="agent_type",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
)
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("ix_agent_types_slug_active", "slug", "is_active"),
|
||||||
|
Index("ix_agent_types_name_active", "name", "is_active"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"<AgentType {self.name} ({self.slug}) active={self.is_active}>"
|
||||||
169
backend/app/models/syndarix/enums.py
Normal file
169
backend/app/models/syndarix/enums.py
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
# app/models/syndarix/enums.py
|
||||||
|
"""
|
||||||
|
Enums for Syndarix domain models.
|
||||||
|
|
||||||
|
These enums represent the core state machines and categorizations
|
||||||
|
used throughout the Syndarix AI consulting platform.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from enum import Enum as PyEnum
|
||||||
|
|
||||||
|
|
||||||
|
class AutonomyLevel(str, PyEnum):
|
||||||
|
"""
|
||||||
|
Defines how much control the human has over agent actions.
|
||||||
|
|
||||||
|
FULL_CONTROL: Human must approve every agent action
|
||||||
|
MILESTONE: Human approves at sprint boundaries and major decisions
|
||||||
|
AUTONOMOUS: Agents work independently, only escalating critical issues
|
||||||
|
"""
|
||||||
|
|
||||||
|
FULL_CONTROL = "full_control"
|
||||||
|
MILESTONE = "milestone"
|
||||||
|
AUTONOMOUS = "autonomous"
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectComplexity(str, PyEnum):
|
||||||
|
"""
|
||||||
|
Project complexity level for estimation and planning.
|
||||||
|
|
||||||
|
SCRIPT: Simple automation or script-level work
|
||||||
|
SIMPLE: Straightforward feature or fix
|
||||||
|
MEDIUM: Standard complexity with some architectural considerations
|
||||||
|
COMPLEX: Large-scale feature requiring significant design work
|
||||||
|
"""
|
||||||
|
|
||||||
|
SCRIPT = "script"
|
||||||
|
SIMPLE = "simple"
|
||||||
|
MEDIUM = "medium"
|
||||||
|
COMPLEX = "complex"
|
||||||
|
|
||||||
|
|
||||||
|
class ClientMode(str, PyEnum):
|
||||||
|
"""
|
||||||
|
How the client prefers to interact with agents.
|
||||||
|
|
||||||
|
TECHNICAL: Client is technical and prefers detailed updates
|
||||||
|
AUTO: Agents automatically determine communication level
|
||||||
|
"""
|
||||||
|
|
||||||
|
TECHNICAL = "technical"
|
||||||
|
AUTO = "auto"
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectStatus(str, PyEnum):
|
||||||
|
"""
|
||||||
|
Project lifecycle status.
|
||||||
|
|
||||||
|
ACTIVE: Project is actively being worked on
|
||||||
|
PAUSED: Project is temporarily on hold
|
||||||
|
COMPLETED: Project has been delivered successfully
|
||||||
|
ARCHIVED: Project is no longer accessible for work
|
||||||
|
"""
|
||||||
|
|
||||||
|
ACTIVE = "active"
|
||||||
|
PAUSED = "paused"
|
||||||
|
COMPLETED = "completed"
|
||||||
|
ARCHIVED = "archived"
|
||||||
|
|
||||||
|
|
||||||
|
class AgentStatus(str, PyEnum):
|
||||||
|
"""
|
||||||
|
Current operational status of an agent instance.
|
||||||
|
|
||||||
|
IDLE: Agent is available but not currently working
|
||||||
|
WORKING: Agent is actively processing a task
|
||||||
|
WAITING: Agent is waiting for external input or approval
|
||||||
|
PAUSED: Agent has been manually paused
|
||||||
|
TERMINATED: Agent instance has been shut down
|
||||||
|
"""
|
||||||
|
|
||||||
|
IDLE = "idle"
|
||||||
|
WORKING = "working"
|
||||||
|
WAITING = "waiting"
|
||||||
|
PAUSED = "paused"
|
||||||
|
TERMINATED = "terminated"
|
||||||
|
|
||||||
|
|
||||||
|
class IssueType(str, PyEnum):
|
||||||
|
"""
|
||||||
|
Issue type for categorization and hierarchy.
|
||||||
|
|
||||||
|
EPIC: Large feature or body of work containing stories
|
||||||
|
STORY: User-facing feature or requirement
|
||||||
|
TASK: Technical work item
|
||||||
|
BUG: Defect or issue to be fixed
|
||||||
|
"""
|
||||||
|
|
||||||
|
EPIC = "epic"
|
||||||
|
STORY = "story"
|
||||||
|
TASK = "task"
|
||||||
|
BUG = "bug"
|
||||||
|
|
||||||
|
|
||||||
|
class IssueStatus(str, PyEnum):
|
||||||
|
"""
|
||||||
|
Issue workflow status.
|
||||||
|
|
||||||
|
OPEN: Issue is ready to be worked on
|
||||||
|
IN_PROGRESS: Agent or human is actively working on the issue
|
||||||
|
IN_REVIEW: Work is complete, awaiting review
|
||||||
|
BLOCKED: Issue cannot proceed due to dependencies or blockers
|
||||||
|
CLOSED: Issue has been completed or cancelled
|
||||||
|
"""
|
||||||
|
|
||||||
|
OPEN = "open"
|
||||||
|
IN_PROGRESS = "in_progress"
|
||||||
|
IN_REVIEW = "in_review"
|
||||||
|
BLOCKED = "blocked"
|
||||||
|
CLOSED = "closed"
|
||||||
|
|
||||||
|
|
||||||
|
class IssuePriority(str, PyEnum):
|
||||||
|
"""
|
||||||
|
Issue priority levels.
|
||||||
|
|
||||||
|
LOW: Nice to have, can be deferred
|
||||||
|
MEDIUM: Standard priority, should be done
|
||||||
|
HIGH: Important, should be prioritized
|
||||||
|
CRITICAL: Must be done immediately, blocking other work
|
||||||
|
"""
|
||||||
|
|
||||||
|
LOW = "low"
|
||||||
|
MEDIUM = "medium"
|
||||||
|
HIGH = "high"
|
||||||
|
CRITICAL = "critical"
|
||||||
|
|
||||||
|
|
||||||
|
class SyncStatus(str, PyEnum):
|
||||||
|
"""
|
||||||
|
External issue tracker synchronization status.
|
||||||
|
|
||||||
|
SYNCED: Local and remote are in sync
|
||||||
|
PENDING: Local changes waiting to be pushed
|
||||||
|
CONFLICT: Merge conflict between local and remote
|
||||||
|
ERROR: Synchronization failed due to an error
|
||||||
|
"""
|
||||||
|
|
||||||
|
SYNCED = "synced"
|
||||||
|
PENDING = "pending"
|
||||||
|
CONFLICT = "conflict"
|
||||||
|
ERROR = "error"
|
||||||
|
|
||||||
|
|
||||||
|
class SprintStatus(str, PyEnum):
|
||||||
|
"""
|
||||||
|
Sprint lifecycle status.
|
||||||
|
|
||||||
|
PLANNED: Sprint has been created but not started
|
||||||
|
ACTIVE: Sprint is currently in progress
|
||||||
|
IN_REVIEW: Sprint work is done, demo/review pending
|
||||||
|
COMPLETED: Sprint has been finished successfully
|
||||||
|
CANCELLED: Sprint was cancelled before completion
|
||||||
|
"""
|
||||||
|
|
||||||
|
PLANNED = "planned"
|
||||||
|
ACTIVE = "active"
|
||||||
|
IN_REVIEW = "in_review"
|
||||||
|
COMPLETED = "completed"
|
||||||
|
CANCELLED = "cancelled"
|
||||||
176
backend/app/models/syndarix/issue.py
Normal file
176
backend/app/models/syndarix/issue.py
Normal file
@@ -0,0 +1,176 @@
|
|||||||
|
# app/models/syndarix/issue.py
|
||||||
|
"""
|
||||||
|
Issue model for Syndarix AI consulting platform.
|
||||||
|
|
||||||
|
An Issue represents a unit of work that can be assigned to agents or humans,
|
||||||
|
with optional synchronization to external issue trackers (Gitea, GitHub, GitLab).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from sqlalchemy import (
|
||||||
|
Column,
|
||||||
|
Date,
|
||||||
|
DateTime,
|
||||||
|
Enum,
|
||||||
|
ForeignKey,
|
||||||
|
Index,
|
||||||
|
Integer,
|
||||||
|
String,
|
||||||
|
Text,
|
||||||
|
)
|
||||||
|
from sqlalchemy.dialects.postgresql import (
|
||||||
|
JSONB,
|
||||||
|
UUID as PGUUID,
|
||||||
|
)
|
||||||
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
|
from app.models.base import Base, TimestampMixin, UUIDMixin
|
||||||
|
|
||||||
|
from .enums import IssuePriority, IssueStatus, IssueType, SyncStatus
|
||||||
|
|
||||||
|
|
||||||
|
class Issue(Base, UUIDMixin, TimestampMixin):
|
||||||
|
"""
|
||||||
|
Issue model representing a unit of work in a project.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Standard issue fields (title, body, status, priority)
|
||||||
|
- Assignment to agent instances or human assignees
|
||||||
|
- Sprint association for backlog management
|
||||||
|
- External tracker synchronization (Gitea, GitHub, GitLab)
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "issues"
|
||||||
|
|
||||||
|
# Foreign key to project
|
||||||
|
project_id = Column(
|
||||||
|
PGUUID(as_uuid=True),
|
||||||
|
ForeignKey("projects.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
index=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parent issue for hierarchy (Epic -> Story -> Task)
|
||||||
|
parent_id = Column(
|
||||||
|
PGUUID(as_uuid=True),
|
||||||
|
ForeignKey("issues.id", ondelete="CASCADE"),
|
||||||
|
nullable=True,
|
||||||
|
index=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Issue type (Epic, Story, Task, Bug)
|
||||||
|
type: Column[IssueType] = Column(
|
||||||
|
Enum(IssueType),
|
||||||
|
default=IssueType.TASK,
|
||||||
|
nullable=False,
|
||||||
|
index=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reporter (who created this issue - can be user or agent)
|
||||||
|
reporter_id = Column(
|
||||||
|
PGUUID(as_uuid=True),
|
||||||
|
nullable=True, # System-generated issues may have no reporter
|
||||||
|
index=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Issue content
|
||||||
|
title = Column(String(500), nullable=False)
|
||||||
|
body = Column(Text, nullable=False, default="")
|
||||||
|
|
||||||
|
# Status and priority
|
||||||
|
status: Column[IssueStatus] = Column(
|
||||||
|
Enum(IssueStatus),
|
||||||
|
default=IssueStatus.OPEN,
|
||||||
|
nullable=False,
|
||||||
|
index=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
priority: Column[IssuePriority] = Column(
|
||||||
|
Enum(IssuePriority),
|
||||||
|
default=IssuePriority.MEDIUM,
|
||||||
|
nullable=False,
|
||||||
|
index=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Labels for categorization (e.g., ["bug", "frontend", "urgent"])
|
||||||
|
labels = Column(JSONB, default=list, nullable=False)
|
||||||
|
|
||||||
|
# Assignment - either to an agent or a human (mutually exclusive)
|
||||||
|
assigned_agent_id = Column(
|
||||||
|
PGUUID(as_uuid=True),
|
||||||
|
ForeignKey("agent_instances.id", ondelete="SET NULL"),
|
||||||
|
nullable=True,
|
||||||
|
index=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Human assignee (username or email, not a FK to allow external users)
|
||||||
|
human_assignee = Column(String(255), nullable=True, index=True)
|
||||||
|
|
||||||
|
# Sprint association
|
||||||
|
sprint_id = Column(
|
||||||
|
PGUUID(as_uuid=True),
|
||||||
|
ForeignKey("sprints.id", ondelete="SET NULL"),
|
||||||
|
nullable=True,
|
||||||
|
index=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Story points for estimation
|
||||||
|
story_points = Column(Integer, nullable=True)
|
||||||
|
|
||||||
|
# Due date for the issue
|
||||||
|
due_date = Column(Date, nullable=True, index=True)
|
||||||
|
|
||||||
|
# External tracker integration
|
||||||
|
external_tracker_type = Column(
|
||||||
|
String(50),
|
||||||
|
nullable=True,
|
||||||
|
index=True,
|
||||||
|
) # 'gitea', 'github', 'gitlab'
|
||||||
|
|
||||||
|
external_issue_id = Column(String(255), nullable=True) # External system's ID
|
||||||
|
remote_url = Column(String(1000), nullable=True) # Link to external issue
|
||||||
|
external_issue_number = Column(Integer, nullable=True) # Issue number (e.g., #123)
|
||||||
|
|
||||||
|
# Sync status with external tracker
|
||||||
|
sync_status: Column[SyncStatus] = Column(
|
||||||
|
Enum(SyncStatus),
|
||||||
|
default=SyncStatus.SYNCED,
|
||||||
|
nullable=False,
|
||||||
|
# Note: Index defined in __table_args__ as ix_issues_sync_status
|
||||||
|
)
|
||||||
|
|
||||||
|
last_synced_at = Column(DateTime(timezone=True), nullable=True)
|
||||||
|
external_updated_at = Column(DateTime(timezone=True), nullable=True)
|
||||||
|
|
||||||
|
# Lifecycle timestamp
|
||||||
|
closed_at = Column(DateTime(timezone=True), nullable=True, index=True)
|
||||||
|
|
||||||
|
# Relationships
|
||||||
|
project = relationship("Project", back_populates="issues")
|
||||||
|
assigned_agent = relationship(
|
||||||
|
"AgentInstance",
|
||||||
|
back_populates="assigned_issues",
|
||||||
|
foreign_keys=[assigned_agent_id],
|
||||||
|
)
|
||||||
|
sprint = relationship("Sprint", back_populates="issues")
|
||||||
|
parent = relationship("Issue", remote_side="Issue.id", backref="children")
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("ix_issues_project_status", "project_id", "status"),
|
||||||
|
Index("ix_issues_project_priority", "project_id", "priority"),
|
||||||
|
Index("ix_issues_project_sprint", "project_id", "sprint_id"),
|
||||||
|
Index(
|
||||||
|
"ix_issues_external_tracker_id",
|
||||||
|
"external_tracker_type",
|
||||||
|
"external_issue_id",
|
||||||
|
),
|
||||||
|
Index("ix_issues_sync_status", "sync_status"),
|
||||||
|
Index("ix_issues_project_agent", "project_id", "assigned_agent_id"),
|
||||||
|
Index("ix_issues_project_type", "project_id", "type"),
|
||||||
|
Index("ix_issues_project_status_priority", "project_id", "status", "priority"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return (
|
||||||
|
f"<Issue {self.id} title='{self.title[:30]}...' "
|
||||||
|
f"status={self.status.value} priority={self.priority.value}>"
|
||||||
|
)
|
||||||
103
backend/app/models/syndarix/project.py
Normal file
103
backend/app/models/syndarix/project.py
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
# app/models/syndarix/project.py
|
||||||
|
"""
|
||||||
|
Project model for Syndarix AI consulting platform.
|
||||||
|
|
||||||
|
A Project represents a client engagement where AI agents collaborate
|
||||||
|
to deliver software solutions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from sqlalchemy import Column, Enum, ForeignKey, Index, String, Text
|
||||||
|
from sqlalchemy.dialects.postgresql import (
|
||||||
|
JSONB,
|
||||||
|
UUID as PGUUID,
|
||||||
|
)
|
||||||
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
|
from app.models.base import Base, TimestampMixin, UUIDMixin
|
||||||
|
|
||||||
|
from .enums import AutonomyLevel, ClientMode, ProjectComplexity, ProjectStatus
|
||||||
|
|
||||||
|
|
||||||
|
class Project(Base, UUIDMixin, TimestampMixin):
|
||||||
|
"""
|
||||||
|
Project model representing a client engagement.
|
||||||
|
|
||||||
|
A project contains:
|
||||||
|
- Configuration for how autonomous agents should operate
|
||||||
|
- Settings for MCP server integrations
|
||||||
|
- Relationship to assigned agents, issues, and sprints
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "projects"
|
||||||
|
|
||||||
|
name = Column(String(255), nullable=False, index=True)
|
||||||
|
slug = Column(String(255), unique=True, nullable=False, index=True)
|
||||||
|
description = Column(Text, nullable=True)
|
||||||
|
|
||||||
|
autonomy_level: Column[AutonomyLevel] = Column(
|
||||||
|
Enum(AutonomyLevel),
|
||||||
|
default=AutonomyLevel.MILESTONE,
|
||||||
|
nullable=False,
|
||||||
|
index=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
status: Column[ProjectStatus] = Column(
|
||||||
|
Enum(ProjectStatus),
|
||||||
|
default=ProjectStatus.ACTIVE,
|
||||||
|
nullable=False,
|
||||||
|
index=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
complexity: Column[ProjectComplexity] = Column(
|
||||||
|
Enum(ProjectComplexity),
|
||||||
|
default=ProjectComplexity.MEDIUM,
|
||||||
|
nullable=False,
|
||||||
|
index=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
client_mode: Column[ClientMode] = Column(
|
||||||
|
Enum(ClientMode),
|
||||||
|
default=ClientMode.AUTO,
|
||||||
|
nullable=False,
|
||||||
|
index=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# JSON field for flexible project configuration
|
||||||
|
# Can include: mcp_servers, webhook_urls, notification_settings, etc.
|
||||||
|
settings = Column(JSONB, default=dict, nullable=False)
|
||||||
|
|
||||||
|
# Foreign key to the User who owns this project
|
||||||
|
owner_id = Column(
|
||||||
|
PGUUID(as_uuid=True),
|
||||||
|
ForeignKey("users.id", ondelete="SET NULL"),
|
||||||
|
nullable=True,
|
||||||
|
index=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Relationships
|
||||||
|
owner = relationship("User", foreign_keys=[owner_id])
|
||||||
|
agent_instances = relationship(
|
||||||
|
"AgentInstance",
|
||||||
|
back_populates="project",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
)
|
||||||
|
issues = relationship(
|
||||||
|
"Issue",
|
||||||
|
back_populates="project",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
)
|
||||||
|
sprints = relationship(
|
||||||
|
"Sprint",
|
||||||
|
back_populates="project",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
)
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("ix_projects_slug_status", "slug", "status"),
|
||||||
|
Index("ix_projects_owner_status", "owner_id", "status"),
|
||||||
|
Index("ix_projects_autonomy_status", "autonomy_level", "status"),
|
||||||
|
Index("ix_projects_complexity_status", "complexity", "status"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"<Project {self.name} ({self.slug}) status={self.status.value}>"
|
||||||
86
backend/app/models/syndarix/sprint.py
Normal file
86
backend/app/models/syndarix/sprint.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
# app/models/syndarix/sprint.py
|
||||||
|
"""
|
||||||
|
Sprint model for Syndarix AI consulting platform.
|
||||||
|
|
||||||
|
A Sprint represents a time-boxed iteration for organizing and delivering work.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from sqlalchemy import (
|
||||||
|
Column,
|
||||||
|
Date,
|
||||||
|
Enum,
|
||||||
|
ForeignKey,
|
||||||
|
Index,
|
||||||
|
Integer,
|
||||||
|
String,
|
||||||
|
Text,
|
||||||
|
UniqueConstraint,
|
||||||
|
)
|
||||||
|
from sqlalchemy.dialects.postgresql import UUID as PGUUID
|
||||||
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
|
from app.models.base import Base, TimestampMixin, UUIDMixin
|
||||||
|
|
||||||
|
from .enums import SprintStatus
|
||||||
|
|
||||||
|
|
||||||
|
class Sprint(Base, UUIDMixin, TimestampMixin):
|
||||||
|
"""
|
||||||
|
Sprint model representing a time-boxed iteration.
|
||||||
|
|
||||||
|
Tracks:
|
||||||
|
- Sprint metadata (name, number, goal)
|
||||||
|
- Date range (start/end)
|
||||||
|
- Progress metrics (planned vs completed points)
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "sprints"
|
||||||
|
|
||||||
|
# Foreign key to project
|
||||||
|
project_id = Column(
|
||||||
|
PGUUID(as_uuid=True),
|
||||||
|
ForeignKey("projects.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
index=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sprint identification
|
||||||
|
name = Column(String(255), nullable=False)
|
||||||
|
number = Column(Integer, nullable=False) # Sprint number within project
|
||||||
|
|
||||||
|
# Sprint goal (what we aim to achieve)
|
||||||
|
goal = Column(Text, nullable=True)
|
||||||
|
|
||||||
|
# Date range
|
||||||
|
start_date = Column(Date, nullable=False, index=True)
|
||||||
|
end_date = Column(Date, nullable=False, index=True)
|
||||||
|
|
||||||
|
# Status
|
||||||
|
status: Column[SprintStatus] = Column(
|
||||||
|
Enum(SprintStatus),
|
||||||
|
default=SprintStatus.PLANNED,
|
||||||
|
nullable=False,
|
||||||
|
index=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Progress metrics
|
||||||
|
planned_points = Column(Integer, nullable=True) # Sum of story points at start
|
||||||
|
velocity = Column(Integer, nullable=True) # Sum of completed story points
|
||||||
|
|
||||||
|
# Relationships
|
||||||
|
project = relationship("Project", back_populates="sprints")
|
||||||
|
issues = relationship("Issue", back_populates="sprint")
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("ix_sprints_project_status", "project_id", "status"),
|
||||||
|
Index("ix_sprints_project_number", "project_id", "number"),
|
||||||
|
Index("ix_sprints_date_range", "start_date", "end_date"),
|
||||||
|
# Ensure sprint numbers are unique within a project
|
||||||
|
UniqueConstraint("project_id", "number", name="uq_sprint_project_number"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return (
|
||||||
|
f"<Sprint {self.name} (#{self.number}) "
|
||||||
|
f"project={self.project_id} status={self.status.value}>"
|
||||||
|
)
|
||||||
273
backend/app/schemas/events.py
Normal file
273
backend/app/schemas/events.py
Normal file
@@ -0,0 +1,273 @@
|
|||||||
|
"""
|
||||||
|
Event schemas for the Syndarix EventBus (Redis Pub/Sub).
|
||||||
|
|
||||||
|
This module defines event types and payload schemas for real-time communication
|
||||||
|
between services, agents, and the frontend.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Literal
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class EventType(str, Enum):
|
||||||
|
"""
|
||||||
|
Event types for the EventBus.
|
||||||
|
|
||||||
|
Naming convention: {domain}.{action}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Agent Events
|
||||||
|
AGENT_SPAWNED = "agent.spawned"
|
||||||
|
AGENT_STATUS_CHANGED = "agent.status_changed"
|
||||||
|
AGENT_MESSAGE = "agent.message"
|
||||||
|
AGENT_TERMINATED = "agent.terminated"
|
||||||
|
|
||||||
|
# Issue Events
|
||||||
|
ISSUE_CREATED = "issue.created"
|
||||||
|
ISSUE_UPDATED = "issue.updated"
|
||||||
|
ISSUE_ASSIGNED = "issue.assigned"
|
||||||
|
ISSUE_CLOSED = "issue.closed"
|
||||||
|
|
||||||
|
# Sprint Events
|
||||||
|
SPRINT_STARTED = "sprint.started"
|
||||||
|
SPRINT_COMPLETED = "sprint.completed"
|
||||||
|
|
||||||
|
# Approval Events
|
||||||
|
APPROVAL_REQUESTED = "approval.requested"
|
||||||
|
APPROVAL_GRANTED = "approval.granted"
|
||||||
|
APPROVAL_DENIED = "approval.denied"
|
||||||
|
|
||||||
|
# Project Events
|
||||||
|
PROJECT_CREATED = "project.created"
|
||||||
|
PROJECT_UPDATED = "project.updated"
|
||||||
|
PROJECT_ARCHIVED = "project.archived"
|
||||||
|
|
||||||
|
# Workflow Events
|
||||||
|
WORKFLOW_STARTED = "workflow.started"
|
||||||
|
WORKFLOW_STEP_COMPLETED = "workflow.step_completed"
|
||||||
|
WORKFLOW_COMPLETED = "workflow.completed"
|
||||||
|
WORKFLOW_FAILED = "workflow.failed"
|
||||||
|
|
||||||
|
|
||||||
|
ActorType = Literal["agent", "user", "system"]
|
||||||
|
|
||||||
|
|
||||||
|
class Event(BaseModel):
|
||||||
|
"""
|
||||||
|
Base event schema for the EventBus.
|
||||||
|
|
||||||
|
All events published to the EventBus must conform to this schema.
|
||||||
|
"""
|
||||||
|
|
||||||
|
id: str = Field(
|
||||||
|
...,
|
||||||
|
description="Unique event identifier (UUID string)",
|
||||||
|
examples=["550e8400-e29b-41d4-a716-446655440000"],
|
||||||
|
)
|
||||||
|
type: EventType = Field(
|
||||||
|
...,
|
||||||
|
description="Event type enum value",
|
||||||
|
examples=[EventType.AGENT_MESSAGE],
|
||||||
|
)
|
||||||
|
timestamp: datetime = Field(
|
||||||
|
...,
|
||||||
|
description="When the event occurred (UTC)",
|
||||||
|
examples=["2024-01-15T10:30:00Z"],
|
||||||
|
)
|
||||||
|
project_id: UUID = Field(
|
||||||
|
...,
|
||||||
|
description="Project this event belongs to",
|
||||||
|
examples=["550e8400-e29b-41d4-a716-446655440001"],
|
||||||
|
)
|
||||||
|
actor_id: UUID | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="ID of the agent or user who triggered the event",
|
||||||
|
examples=["550e8400-e29b-41d4-a716-446655440002"],
|
||||||
|
)
|
||||||
|
actor_type: ActorType = Field(
|
||||||
|
...,
|
||||||
|
description="Type of actor: 'agent', 'user', or 'system'",
|
||||||
|
examples=["agent"],
|
||||||
|
)
|
||||||
|
payload: dict = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
description="Event-specific payload data",
|
||||||
|
)
|
||||||
|
|
||||||
|
model_config = {
|
||||||
|
"json_schema_extra": {
|
||||||
|
"example": {
|
||||||
|
"id": "550e8400-e29b-41d4-a716-446655440000",
|
||||||
|
"type": "agent.message",
|
||||||
|
"timestamp": "2024-01-15T10:30:00Z",
|
||||||
|
"project_id": "550e8400-e29b-41d4-a716-446655440001",
|
||||||
|
"actor_id": "550e8400-e29b-41d4-a716-446655440002",
|
||||||
|
"actor_type": "agent",
|
||||||
|
"payload": {"message": "Processing task...", "progress": 50},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Specific payload schemas for type safety
|
||||||
|
|
||||||
|
|
||||||
|
class AgentSpawnedPayload(BaseModel):
|
||||||
|
"""Payload for AGENT_SPAWNED events."""
|
||||||
|
|
||||||
|
agent_instance_id: UUID = Field(..., description="ID of the spawned agent instance")
|
||||||
|
agent_type_id: UUID = Field(..., description="ID of the agent type")
|
||||||
|
agent_name: str = Field(..., description="Human-readable name of the agent")
|
||||||
|
role: str = Field(..., description="Agent role (e.g., 'product_owner', 'engineer')")
|
||||||
|
|
||||||
|
|
||||||
|
class AgentStatusChangedPayload(BaseModel):
|
||||||
|
"""Payload for AGENT_STATUS_CHANGED events."""
|
||||||
|
|
||||||
|
agent_instance_id: UUID = Field(..., description="ID of the agent instance")
|
||||||
|
previous_status: str = Field(..., description="Previous status")
|
||||||
|
new_status: str = Field(..., description="New status")
|
||||||
|
reason: str | None = Field(default=None, description="Reason for status change")
|
||||||
|
|
||||||
|
|
||||||
|
class AgentMessagePayload(BaseModel):
|
||||||
|
"""Payload for AGENT_MESSAGE events."""
|
||||||
|
|
||||||
|
agent_instance_id: UUID = Field(..., description="ID of the agent instance")
|
||||||
|
message: str = Field(..., description="Message content")
|
||||||
|
message_type: str = Field(
|
||||||
|
default="info",
|
||||||
|
description="Message type: 'info', 'warning', 'error', 'debug'",
|
||||||
|
)
|
||||||
|
metadata: dict = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
description="Additional metadata (e.g., token usage, model info)",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentTerminatedPayload(BaseModel):
|
||||||
|
"""Payload for AGENT_TERMINATED events."""
|
||||||
|
|
||||||
|
agent_instance_id: UUID = Field(..., description="ID of the agent instance")
|
||||||
|
termination_reason: str = Field(..., description="Reason for termination")
|
||||||
|
final_status: str = Field(..., description="Final status at termination")
|
||||||
|
|
||||||
|
|
||||||
|
class IssueCreatedPayload(BaseModel):
|
||||||
|
"""Payload for ISSUE_CREATED events."""
|
||||||
|
|
||||||
|
issue_id: str = Field(..., description="Issue ID (from external tracker)")
|
||||||
|
title: str = Field(..., description="Issue title")
|
||||||
|
priority: str | None = Field(default=None, description="Issue priority")
|
||||||
|
labels: list[str] = Field(default_factory=list, description="Issue labels")
|
||||||
|
|
||||||
|
|
||||||
|
class IssueUpdatedPayload(BaseModel):
|
||||||
|
"""Payload for ISSUE_UPDATED events."""
|
||||||
|
|
||||||
|
issue_id: str = Field(..., description="Issue ID (from external tracker)")
|
||||||
|
changes: dict = Field(..., description="Dictionary of field changes")
|
||||||
|
|
||||||
|
|
||||||
|
class IssueAssignedPayload(BaseModel):
|
||||||
|
"""Payload for ISSUE_ASSIGNED events."""
|
||||||
|
|
||||||
|
issue_id: str = Field(..., description="Issue ID (from external tracker)")
|
||||||
|
assignee_id: UUID | None = Field(
|
||||||
|
default=None, description="Agent or user assigned to"
|
||||||
|
)
|
||||||
|
assignee_name: str | None = Field(default=None, description="Assignee name")
|
||||||
|
|
||||||
|
|
||||||
|
class IssueClosedPayload(BaseModel):
|
||||||
|
"""Payload for ISSUE_CLOSED events."""
|
||||||
|
|
||||||
|
issue_id: str = Field(..., description="Issue ID (from external tracker)")
|
||||||
|
resolution: str = Field(..., description="Resolution status")
|
||||||
|
|
||||||
|
|
||||||
|
class SprintStartedPayload(BaseModel):
|
||||||
|
"""Payload for SPRINT_STARTED events."""
|
||||||
|
|
||||||
|
sprint_id: UUID = Field(..., description="Sprint ID")
|
||||||
|
sprint_name: str = Field(..., description="Sprint name")
|
||||||
|
goal: str | None = Field(default=None, description="Sprint goal")
|
||||||
|
issue_count: int = Field(default=0, description="Number of issues in sprint")
|
||||||
|
|
||||||
|
|
||||||
|
class SprintCompletedPayload(BaseModel):
|
||||||
|
"""Payload for SPRINT_COMPLETED events."""
|
||||||
|
|
||||||
|
sprint_id: UUID = Field(..., description="Sprint ID")
|
||||||
|
sprint_name: str = Field(..., description="Sprint name")
|
||||||
|
completed_issues: int = Field(default=0, description="Number of completed issues")
|
||||||
|
incomplete_issues: int = Field(default=0, description="Number of incomplete issues")
|
||||||
|
|
||||||
|
|
||||||
|
class ApprovalRequestedPayload(BaseModel):
|
||||||
|
"""Payload for APPROVAL_REQUESTED events."""
|
||||||
|
|
||||||
|
approval_id: UUID = Field(..., description="Approval request ID")
|
||||||
|
approval_type: str = Field(..., description="Type of approval needed")
|
||||||
|
description: str = Field(..., description="Description of what needs approval")
|
||||||
|
requested_by: UUID | None = Field(
|
||||||
|
default=None, description="Agent/user requesting approval"
|
||||||
|
)
|
||||||
|
timeout_minutes: int | None = Field(
|
||||||
|
default=None, description="Minutes before auto-escalation"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ApprovalGrantedPayload(BaseModel):
|
||||||
|
"""Payload for APPROVAL_GRANTED events."""
|
||||||
|
|
||||||
|
approval_id: UUID = Field(..., description="Approval request ID")
|
||||||
|
approved_by: UUID = Field(..., description="User who granted approval")
|
||||||
|
comments: str | None = Field(default=None, description="Approval comments")
|
||||||
|
|
||||||
|
|
||||||
|
class ApprovalDeniedPayload(BaseModel):
|
||||||
|
"""Payload for APPROVAL_DENIED events."""
|
||||||
|
|
||||||
|
approval_id: UUID = Field(..., description="Approval request ID")
|
||||||
|
denied_by: UUID = Field(..., description="User who denied approval")
|
||||||
|
reason: str = Field(..., description="Reason for denial")
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowStartedPayload(BaseModel):
|
||||||
|
"""Payload for WORKFLOW_STARTED events."""
|
||||||
|
|
||||||
|
workflow_id: UUID = Field(..., description="Workflow execution ID")
|
||||||
|
workflow_type: str = Field(..., description="Type of workflow")
|
||||||
|
total_steps: int = Field(default=0, description="Total number of steps")
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowStepCompletedPayload(BaseModel):
|
||||||
|
"""Payload for WORKFLOW_STEP_COMPLETED events."""
|
||||||
|
|
||||||
|
workflow_id: UUID = Field(..., description="Workflow execution ID")
|
||||||
|
step_name: str = Field(..., description="Name of completed step")
|
||||||
|
step_number: int = Field(..., description="Step number (1-indexed)")
|
||||||
|
total_steps: int = Field(..., description="Total number of steps")
|
||||||
|
result: dict = Field(default_factory=dict, description="Step result data")
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowCompletedPayload(BaseModel):
|
||||||
|
"""Payload for WORKFLOW_COMPLETED events."""
|
||||||
|
|
||||||
|
workflow_id: UUID = Field(..., description="Workflow execution ID")
|
||||||
|
duration_seconds: float = Field(..., description="Total execution duration")
|
||||||
|
result: dict = Field(default_factory=dict, description="Workflow result data")
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowFailedPayload(BaseModel):
|
||||||
|
"""Payload for WORKFLOW_FAILED events."""
|
||||||
|
|
||||||
|
workflow_id: UUID = Field(..., description="Workflow execution ID")
|
||||||
|
error_message: str = Field(..., description="Error message")
|
||||||
|
failed_step: str | None = Field(default=None, description="Step that failed")
|
||||||
|
recoverable: bool = Field(default=False, description="Whether error is recoverable")
|
||||||
113
backend/app/schemas/syndarix/__init__.py
Normal file
113
backend/app/schemas/syndarix/__init__.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
# app/schemas/syndarix/__init__.py
|
||||||
|
"""
|
||||||
|
Syndarix domain schemas.
|
||||||
|
|
||||||
|
This package contains Pydantic schemas for validating and serializing
|
||||||
|
Syndarix domain entities.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .agent_instance import (
|
||||||
|
AgentInstanceCreate,
|
||||||
|
AgentInstanceInDB,
|
||||||
|
AgentInstanceListResponse,
|
||||||
|
AgentInstanceMetrics,
|
||||||
|
AgentInstanceResponse,
|
||||||
|
AgentInstanceTerminate,
|
||||||
|
AgentInstanceUpdate,
|
||||||
|
)
|
||||||
|
from .agent_type import (
|
||||||
|
AgentTypeCreate,
|
||||||
|
AgentTypeInDB,
|
||||||
|
AgentTypeListResponse,
|
||||||
|
AgentTypeResponse,
|
||||||
|
AgentTypeUpdate,
|
||||||
|
)
|
||||||
|
from .enums import (
|
||||||
|
AgentStatus,
|
||||||
|
AutonomyLevel,
|
||||||
|
IssuePriority,
|
||||||
|
IssueStatus,
|
||||||
|
ProjectStatus,
|
||||||
|
SprintStatus,
|
||||||
|
SyncStatus,
|
||||||
|
)
|
||||||
|
from .issue import (
|
||||||
|
IssueAssign,
|
||||||
|
IssueClose,
|
||||||
|
IssueCreate,
|
||||||
|
IssueInDB,
|
||||||
|
IssueListResponse,
|
||||||
|
IssueResponse,
|
||||||
|
IssueStats,
|
||||||
|
IssueSyncUpdate,
|
||||||
|
IssueUpdate,
|
||||||
|
)
|
||||||
|
from .project import (
|
||||||
|
ProjectCreate,
|
||||||
|
ProjectInDB,
|
||||||
|
ProjectListResponse,
|
||||||
|
ProjectResponse,
|
||||||
|
ProjectUpdate,
|
||||||
|
)
|
||||||
|
from .sprint import (
|
||||||
|
SprintBurndown,
|
||||||
|
SprintComplete,
|
||||||
|
SprintCreate,
|
||||||
|
SprintInDB,
|
||||||
|
SprintListResponse,
|
||||||
|
SprintResponse,
|
||||||
|
SprintStart,
|
||||||
|
SprintUpdate,
|
||||||
|
SprintVelocity,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# AgentInstance schemas
|
||||||
|
"AgentInstanceCreate",
|
||||||
|
"AgentInstanceInDB",
|
||||||
|
"AgentInstanceListResponse",
|
||||||
|
"AgentInstanceMetrics",
|
||||||
|
"AgentInstanceResponse",
|
||||||
|
"AgentInstanceTerminate",
|
||||||
|
"AgentInstanceUpdate",
|
||||||
|
# Enums
|
||||||
|
"AgentStatus",
|
||||||
|
# AgentType schemas
|
||||||
|
"AgentTypeCreate",
|
||||||
|
"AgentTypeInDB",
|
||||||
|
"AgentTypeListResponse",
|
||||||
|
"AgentTypeResponse",
|
||||||
|
"AgentTypeUpdate",
|
||||||
|
"AutonomyLevel",
|
||||||
|
# Issue schemas
|
||||||
|
"IssueAssign",
|
||||||
|
"IssueClose",
|
||||||
|
"IssueCreate",
|
||||||
|
"IssueInDB",
|
||||||
|
"IssueListResponse",
|
||||||
|
"IssuePriority",
|
||||||
|
"IssueResponse",
|
||||||
|
"IssueStats",
|
||||||
|
"IssueStatus",
|
||||||
|
"IssueSyncUpdate",
|
||||||
|
"IssueUpdate",
|
||||||
|
# Project schemas
|
||||||
|
"ProjectCreate",
|
||||||
|
"ProjectInDB",
|
||||||
|
"ProjectListResponse",
|
||||||
|
"ProjectResponse",
|
||||||
|
"ProjectStatus",
|
||||||
|
"ProjectUpdate",
|
||||||
|
# Sprint schemas
|
||||||
|
"SprintBurndown",
|
||||||
|
"SprintComplete",
|
||||||
|
"SprintCreate",
|
||||||
|
"SprintInDB",
|
||||||
|
"SprintListResponse",
|
||||||
|
"SprintResponse",
|
||||||
|
"SprintStart",
|
||||||
|
"SprintStatus",
|
||||||
|
"SprintUpdate",
|
||||||
|
"SprintVelocity",
|
||||||
|
"SyncStatus",
|
||||||
|
]
|
||||||
124
backend/app/schemas/syndarix/agent_instance.py
Normal file
124
backend/app/schemas/syndarix/agent_instance.py
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
# app/schemas/syndarix/agent_instance.py
|
||||||
|
"""
|
||||||
|
Pydantic schemas for AgentInstance entity.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from decimal import Decimal
|
||||||
|
from typing import Any
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
|
from .enums import AgentStatus
|
||||||
|
|
||||||
|
|
||||||
|
class AgentInstanceBase(BaseModel):
|
||||||
|
"""Base agent instance schema with common fields."""
|
||||||
|
|
||||||
|
agent_type_id: UUID
|
||||||
|
project_id: UUID
|
||||||
|
status: AgentStatus = AgentStatus.IDLE
|
||||||
|
current_task: str | None = None
|
||||||
|
short_term_memory: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
long_term_memory_ref: str | None = Field(None, max_length=500)
|
||||||
|
session_id: str | None = Field(None, max_length=255)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentInstanceCreate(BaseModel):
|
||||||
|
"""Schema for creating a new agent instance."""
|
||||||
|
|
||||||
|
agent_type_id: UUID
|
||||||
|
project_id: UUID
|
||||||
|
name: str = Field(..., min_length=1, max_length=100)
|
||||||
|
status: AgentStatus = AgentStatus.IDLE
|
||||||
|
current_task: str | None = None
|
||||||
|
short_term_memory: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
long_term_memory_ref: str | None = Field(None, max_length=500)
|
||||||
|
session_id: str | None = Field(None, max_length=255)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentInstanceUpdate(BaseModel):
|
||||||
|
"""Schema for updating an agent instance."""
|
||||||
|
|
||||||
|
status: AgentStatus | None = None
|
||||||
|
current_task: str | None = None
|
||||||
|
short_term_memory: dict[str, Any] | None = None
|
||||||
|
long_term_memory_ref: str | None = None
|
||||||
|
session_id: str | None = None
|
||||||
|
last_activity_at: datetime | None = None
|
||||||
|
tasks_completed: int | None = Field(None, ge=0)
|
||||||
|
tokens_used: int | None = Field(None, ge=0)
|
||||||
|
cost_incurred: Decimal | None = Field(None, ge=0)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentInstanceTerminate(BaseModel):
|
||||||
|
"""Schema for terminating an agent instance."""
|
||||||
|
|
||||||
|
reason: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class AgentInstanceInDB(AgentInstanceBase):
|
||||||
|
"""Schema for agent instance in database."""
|
||||||
|
|
||||||
|
id: UUID
|
||||||
|
last_activity_at: datetime | None = None
|
||||||
|
terminated_at: datetime | None = None
|
||||||
|
tasks_completed: int = 0
|
||||||
|
tokens_used: int = 0
|
||||||
|
cost_incurred: Decimal = Decimal("0.0000")
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentInstanceResponse(BaseModel):
|
||||||
|
"""Schema for agent instance API responses."""
|
||||||
|
|
||||||
|
id: UUID
|
||||||
|
agent_type_id: UUID
|
||||||
|
project_id: UUID
|
||||||
|
name: str
|
||||||
|
status: AgentStatus
|
||||||
|
current_task: str | None = None
|
||||||
|
short_term_memory: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
long_term_memory_ref: str | None = None
|
||||||
|
session_id: str | None = None
|
||||||
|
last_activity_at: datetime | None = None
|
||||||
|
terminated_at: datetime | None = None
|
||||||
|
tasks_completed: int = 0
|
||||||
|
tokens_used: int = 0
|
||||||
|
cost_incurred: Decimal = Decimal("0.0000")
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
|
|
||||||
|
# Expanded fields from relationships
|
||||||
|
agent_type_name: str | None = None
|
||||||
|
agent_type_slug: str | None = None
|
||||||
|
project_name: str | None = None
|
||||||
|
project_slug: str | None = None
|
||||||
|
assigned_issues_count: int | None = 0
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentInstanceListResponse(BaseModel):
|
||||||
|
"""Schema for paginated agent instance list responses."""
|
||||||
|
|
||||||
|
agent_instances: list[AgentInstanceResponse]
|
||||||
|
total: int
|
||||||
|
page: int
|
||||||
|
page_size: int
|
||||||
|
pages: int
|
||||||
|
|
||||||
|
|
||||||
|
class AgentInstanceMetrics(BaseModel):
|
||||||
|
"""Schema for agent instance metrics summary."""
|
||||||
|
|
||||||
|
total_instances: int
|
||||||
|
active_instances: int
|
||||||
|
idle_instances: int
|
||||||
|
total_tasks_completed: int
|
||||||
|
total_tokens_used: int
|
||||||
|
total_cost_incurred: Decimal
|
||||||
151
backend/app/schemas/syndarix/agent_type.py
Normal file
151
backend/app/schemas/syndarix/agent_type.py
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
# app/schemas/syndarix/agent_type.py
|
||||||
|
"""
|
||||||
|
Pydantic schemas for AgentType entity.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||||
|
|
||||||
|
|
||||||
|
class AgentTypeBase(BaseModel):
|
||||||
|
"""Base agent type schema with common fields."""
|
||||||
|
|
||||||
|
name: str = Field(..., min_length=1, max_length=255)
|
||||||
|
slug: str | None = Field(None, min_length=1, max_length=255)
|
||||||
|
description: str | None = None
|
||||||
|
expertise: list[str] = Field(default_factory=list)
|
||||||
|
personality_prompt: str = Field(..., min_length=1)
|
||||||
|
primary_model: str = Field(..., min_length=1, max_length=100)
|
||||||
|
fallback_models: list[str] = Field(default_factory=list)
|
||||||
|
model_params: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
mcp_servers: list[str] = Field(default_factory=list)
|
||||||
|
tool_permissions: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
is_active: bool = True
|
||||||
|
|
||||||
|
@field_validator("slug")
|
||||||
|
@classmethod
|
||||||
|
def validate_slug(cls, v: str | None) -> str | None:
|
||||||
|
"""Validate slug format: lowercase, alphanumeric, hyphens only."""
|
||||||
|
if v is None:
|
||||||
|
return v
|
||||||
|
if not re.match(r"^[a-z0-9-]+$", v):
|
||||||
|
raise ValueError(
|
||||||
|
"Slug must contain only lowercase letters, numbers, and hyphens"
|
||||||
|
)
|
||||||
|
if v.startswith("-") or v.endswith("-"):
|
||||||
|
raise ValueError("Slug cannot start or end with a hyphen")
|
||||||
|
if "--" in v:
|
||||||
|
raise ValueError("Slug cannot contain consecutive hyphens")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@field_validator("name")
|
||||||
|
@classmethod
|
||||||
|
def validate_name(cls, v: str) -> str:
|
||||||
|
"""Validate agent type name."""
|
||||||
|
if not v or v.strip() == "":
|
||||||
|
raise ValueError("Agent type name cannot be empty")
|
||||||
|
return v.strip()
|
||||||
|
|
||||||
|
@field_validator("expertise")
|
||||||
|
@classmethod
|
||||||
|
def validate_expertise(cls, v: list[str]) -> list[str]:
|
||||||
|
"""Validate and normalize expertise list."""
|
||||||
|
return [e.strip().lower() for e in v if e.strip()]
|
||||||
|
|
||||||
|
@field_validator("mcp_servers")
|
||||||
|
@classmethod
|
||||||
|
def validate_mcp_servers(cls, v: list[str]) -> list[str]:
|
||||||
|
"""Validate MCP server list."""
|
||||||
|
return [s.strip() for s in v if s.strip()]
|
||||||
|
|
||||||
|
|
||||||
|
class AgentTypeCreate(AgentTypeBase):
|
||||||
|
"""Schema for creating a new agent type."""
|
||||||
|
|
||||||
|
name: str = Field(..., min_length=1, max_length=255)
|
||||||
|
slug: str = Field(..., min_length=1, max_length=255)
|
||||||
|
personality_prompt: str = Field(..., min_length=1)
|
||||||
|
primary_model: str = Field(..., min_length=1, max_length=100)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentTypeUpdate(BaseModel):
|
||||||
|
"""Schema for updating an agent type."""
|
||||||
|
|
||||||
|
name: str | None = Field(None, min_length=1, max_length=255)
|
||||||
|
slug: str | None = Field(None, min_length=1, max_length=255)
|
||||||
|
description: str | None = None
|
||||||
|
expertise: list[str] | None = None
|
||||||
|
personality_prompt: str | None = None
|
||||||
|
primary_model: str | None = Field(None, min_length=1, max_length=100)
|
||||||
|
fallback_models: list[str] | None = None
|
||||||
|
model_params: dict[str, Any] | None = None
|
||||||
|
mcp_servers: list[str] | None = None
|
||||||
|
tool_permissions: dict[str, Any] | None = None
|
||||||
|
is_active: bool | None = None
|
||||||
|
|
||||||
|
@field_validator("slug")
|
||||||
|
@classmethod
|
||||||
|
def validate_slug(cls, v: str | None) -> str | None:
|
||||||
|
"""Validate slug format."""
|
||||||
|
if v is None:
|
||||||
|
return v
|
||||||
|
if not re.match(r"^[a-z0-9-]+$", v):
|
||||||
|
raise ValueError(
|
||||||
|
"Slug must contain only lowercase letters, numbers, and hyphens"
|
||||||
|
)
|
||||||
|
if v.startswith("-") or v.endswith("-"):
|
||||||
|
raise ValueError("Slug cannot start or end with a hyphen")
|
||||||
|
if "--" in v:
|
||||||
|
raise ValueError("Slug cannot contain consecutive hyphens")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@field_validator("name")
|
||||||
|
@classmethod
|
||||||
|
def validate_name(cls, v: str | None) -> str | None:
|
||||||
|
"""Validate agent type name."""
|
||||||
|
if v is not None and (not v or v.strip() == ""):
|
||||||
|
raise ValueError("Agent type name cannot be empty")
|
||||||
|
return v.strip() if v else v
|
||||||
|
|
||||||
|
@field_validator("expertise")
|
||||||
|
@classmethod
|
||||||
|
def validate_expertise(cls, v: list[str] | None) -> list[str] | None:
|
||||||
|
"""Validate and normalize expertise list."""
|
||||||
|
if v is None:
|
||||||
|
return v
|
||||||
|
return [e.strip().lower() for e in v if e.strip()]
|
||||||
|
|
||||||
|
|
||||||
|
class AgentTypeInDB(AgentTypeBase):
|
||||||
|
"""Schema for agent type in database."""
|
||||||
|
|
||||||
|
id: UUID
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentTypeResponse(AgentTypeBase):
|
||||||
|
"""Schema for agent type API responses."""
|
||||||
|
|
||||||
|
id: UUID
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
|
instance_count: int | None = 0
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentTypeListResponse(BaseModel):
|
||||||
|
"""Schema for paginated agent type list responses."""
|
||||||
|
|
||||||
|
agent_types: list[AgentTypeResponse]
|
||||||
|
total: int
|
||||||
|
page: int
|
||||||
|
page_size: int
|
||||||
|
pages: int
|
||||||
26
backend/app/schemas/syndarix/enums.py
Normal file
26
backend/app/schemas/syndarix/enums.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
# app/schemas/syndarix/enums.py
|
||||||
|
"""
|
||||||
|
Re-export enums from models for use in schemas.
|
||||||
|
|
||||||
|
This allows schemas to import enums without depending on SQLAlchemy models directly.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from app.models.syndarix.enums import (
|
||||||
|
AgentStatus,
|
||||||
|
AutonomyLevel,
|
||||||
|
IssuePriority,
|
||||||
|
IssueStatus,
|
||||||
|
ProjectStatus,
|
||||||
|
SprintStatus,
|
||||||
|
SyncStatus,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AgentStatus",
|
||||||
|
"AutonomyLevel",
|
||||||
|
"IssuePriority",
|
||||||
|
"IssueStatus",
|
||||||
|
"ProjectStatus",
|
||||||
|
"SprintStatus",
|
||||||
|
"SyncStatus",
|
||||||
|
]
|
||||||
191
backend/app/schemas/syndarix/issue.py
Normal file
191
backend/app/schemas/syndarix/issue.py
Normal file
@@ -0,0 +1,191 @@
|
|||||||
|
# app/schemas/syndarix/issue.py
|
||||||
|
"""
|
||||||
|
Pydantic schemas for Issue entity.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Literal
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||||
|
|
||||||
|
from .enums import IssuePriority, IssueStatus, SyncStatus
|
||||||
|
|
||||||
|
|
||||||
|
class IssueBase(BaseModel):
|
||||||
|
"""Base issue schema with common fields."""
|
||||||
|
|
||||||
|
title: str = Field(..., min_length=1, max_length=500)
|
||||||
|
body: str = ""
|
||||||
|
status: IssueStatus = IssueStatus.OPEN
|
||||||
|
priority: IssuePriority = IssuePriority.MEDIUM
|
||||||
|
labels: list[str] = Field(default_factory=list)
|
||||||
|
story_points: int | None = Field(None, ge=0, le=100)
|
||||||
|
|
||||||
|
@field_validator("title")
|
||||||
|
@classmethod
|
||||||
|
def validate_title(cls, v: str) -> str:
|
||||||
|
"""Validate issue title."""
|
||||||
|
if not v or v.strip() == "":
|
||||||
|
raise ValueError("Issue title cannot be empty")
|
||||||
|
return v.strip()
|
||||||
|
|
||||||
|
@field_validator("labels")
|
||||||
|
@classmethod
|
||||||
|
def validate_labels(cls, v: list[str]) -> list[str]:
|
||||||
|
"""Validate and normalize labels."""
|
||||||
|
return [label.strip().lower() for label in v if label.strip()]
|
||||||
|
|
||||||
|
|
||||||
|
class IssueCreate(IssueBase):
|
||||||
|
"""Schema for creating a new issue."""
|
||||||
|
|
||||||
|
project_id: UUID
|
||||||
|
assigned_agent_id: UUID | None = None
|
||||||
|
human_assignee: str | None = Field(None, max_length=255)
|
||||||
|
sprint_id: UUID | None = None
|
||||||
|
|
||||||
|
# External tracker fields (optional, for importing from external systems)
|
||||||
|
external_tracker_type: Literal["gitea", "github", "gitlab"] | None = None
|
||||||
|
external_issue_id: str | None = Field(None, max_length=255)
|
||||||
|
remote_url: str | None = Field(None, max_length=1000)
|
||||||
|
external_issue_number: int | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class IssueUpdate(BaseModel):
|
||||||
|
"""Schema for updating an issue."""
|
||||||
|
|
||||||
|
title: str | None = Field(None, min_length=1, max_length=500)
|
||||||
|
body: str | None = None
|
||||||
|
status: IssueStatus | None = None
|
||||||
|
priority: IssuePriority | None = None
|
||||||
|
labels: list[str] | None = None
|
||||||
|
assigned_agent_id: UUID | None = None
|
||||||
|
human_assignee: str | None = Field(None, max_length=255)
|
||||||
|
sprint_id: UUID | None = None
|
||||||
|
story_points: int | None = Field(None, ge=0, le=100)
|
||||||
|
sync_status: SyncStatus | None = None
|
||||||
|
|
||||||
|
@field_validator("title")
|
||||||
|
@classmethod
|
||||||
|
def validate_title(cls, v: str | None) -> str | None:
|
||||||
|
"""Validate issue title."""
|
||||||
|
if v is not None and (not v or v.strip() == ""):
|
||||||
|
raise ValueError("Issue title cannot be empty")
|
||||||
|
return v.strip() if v else v
|
||||||
|
|
||||||
|
@field_validator("labels")
|
||||||
|
@classmethod
|
||||||
|
def validate_labels(cls, v: list[str] | None) -> list[str] | None:
|
||||||
|
"""Validate and normalize labels."""
|
||||||
|
if v is None:
|
||||||
|
return v
|
||||||
|
return [label.strip().lower() for label in v if label.strip()]
|
||||||
|
|
||||||
|
|
||||||
|
class IssueClose(BaseModel):
|
||||||
|
"""Schema for closing an issue."""
|
||||||
|
|
||||||
|
resolution: str | None = None # Optional resolution note
|
||||||
|
|
||||||
|
|
||||||
|
class IssueAssign(BaseModel):
|
||||||
|
"""Schema for assigning an issue."""
|
||||||
|
|
||||||
|
assigned_agent_id: UUID | None = None
|
||||||
|
human_assignee: str | None = Field(None, max_length=255)
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_assignment(self) -> "IssueAssign":
|
||||||
|
"""Ensure only one type of assignee is set."""
|
||||||
|
if self.assigned_agent_id and self.human_assignee:
|
||||||
|
raise ValueError("Cannot assign to both an agent and a human. Choose one.")
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class IssueSyncUpdate(BaseModel):
|
||||||
|
"""Schema for updating sync-related fields."""
|
||||||
|
|
||||||
|
sync_status: SyncStatus
|
||||||
|
last_synced_at: datetime | None = None
|
||||||
|
external_updated_at: datetime | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class IssueInDB(IssueBase):
|
||||||
|
"""Schema for issue in database."""
|
||||||
|
|
||||||
|
id: UUID
|
||||||
|
project_id: UUID
|
||||||
|
assigned_agent_id: UUID | None = None
|
||||||
|
human_assignee: str | None = None
|
||||||
|
sprint_id: UUID | None = None
|
||||||
|
external_tracker_type: str | None = None
|
||||||
|
external_issue_id: str | None = None
|
||||||
|
remote_url: str | None = None
|
||||||
|
external_issue_number: int | None = None
|
||||||
|
sync_status: SyncStatus = SyncStatus.SYNCED
|
||||||
|
last_synced_at: datetime | None = None
|
||||||
|
external_updated_at: datetime | None = None
|
||||||
|
closed_at: datetime | None = None
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
|
||||||
|
class IssueResponse(BaseModel):
|
||||||
|
"""Schema for issue API responses."""
|
||||||
|
|
||||||
|
id: UUID
|
||||||
|
project_id: UUID
|
||||||
|
title: str
|
||||||
|
body: str
|
||||||
|
status: IssueStatus
|
||||||
|
priority: IssuePriority
|
||||||
|
labels: list[str] = Field(default_factory=list)
|
||||||
|
assigned_agent_id: UUID | None = None
|
||||||
|
human_assignee: str | None = None
|
||||||
|
sprint_id: UUID | None = None
|
||||||
|
story_points: int | None = None
|
||||||
|
external_tracker_type: str | None = None
|
||||||
|
external_issue_id: str | None = None
|
||||||
|
remote_url: str | None = None
|
||||||
|
external_issue_number: int | None = None
|
||||||
|
sync_status: SyncStatus = SyncStatus.SYNCED
|
||||||
|
last_synced_at: datetime | None = None
|
||||||
|
external_updated_at: datetime | None = None
|
||||||
|
closed_at: datetime | None = None
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
|
|
||||||
|
# Expanded fields from relationships
|
||||||
|
project_name: str | None = None
|
||||||
|
project_slug: str | None = None
|
||||||
|
sprint_name: str | None = None
|
||||||
|
assigned_agent_type_name: str | None = None
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
|
||||||
|
class IssueListResponse(BaseModel):
|
||||||
|
"""Schema for paginated issue list responses."""
|
||||||
|
|
||||||
|
issues: list[IssueResponse]
|
||||||
|
total: int
|
||||||
|
page: int
|
||||||
|
page_size: int
|
||||||
|
pages: int
|
||||||
|
|
||||||
|
|
||||||
|
class IssueStats(BaseModel):
|
||||||
|
"""Schema for issue statistics."""
|
||||||
|
|
||||||
|
total: int
|
||||||
|
open: int
|
||||||
|
in_progress: int
|
||||||
|
in_review: int
|
||||||
|
blocked: int
|
||||||
|
closed: int
|
||||||
|
by_priority: dict[str, int]
|
||||||
|
total_story_points: int | None = None
|
||||||
|
completed_story_points: int | None = None
|
||||||
131
backend/app/schemas/syndarix/project.py
Normal file
131
backend/app/schemas/syndarix/project.py
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
# app/schemas/syndarix/project.py
|
||||||
|
"""
|
||||||
|
Pydantic schemas for Project entity.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||||
|
|
||||||
|
from .enums import AutonomyLevel, ProjectStatus
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectBase(BaseModel):
|
||||||
|
"""Base project schema with common fields."""
|
||||||
|
|
||||||
|
name: str = Field(..., min_length=1, max_length=255)
|
||||||
|
slug: str | None = Field(None, min_length=1, max_length=255)
|
||||||
|
description: str | None = None
|
||||||
|
autonomy_level: AutonomyLevel = AutonomyLevel.MILESTONE
|
||||||
|
status: ProjectStatus = ProjectStatus.ACTIVE
|
||||||
|
settings: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
@field_validator("slug")
|
||||||
|
@classmethod
|
||||||
|
def validate_slug(cls, v: str | None) -> str | None:
|
||||||
|
"""Validate slug format: lowercase, alphanumeric, hyphens only."""
|
||||||
|
if v is None:
|
||||||
|
return v
|
||||||
|
if not re.match(r"^[a-z0-9-]+$", v):
|
||||||
|
raise ValueError(
|
||||||
|
"Slug must contain only lowercase letters, numbers, and hyphens"
|
||||||
|
)
|
||||||
|
if v.startswith("-") or v.endswith("-"):
|
||||||
|
raise ValueError("Slug cannot start or end with a hyphen")
|
||||||
|
if "--" in v:
|
||||||
|
raise ValueError("Slug cannot contain consecutive hyphens")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@field_validator("name")
|
||||||
|
@classmethod
|
||||||
|
def validate_name(cls, v: str) -> str:
|
||||||
|
"""Validate project name."""
|
||||||
|
if not v or v.strip() == "":
|
||||||
|
raise ValueError("Project name cannot be empty")
|
||||||
|
return v.strip()
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectCreate(ProjectBase):
|
||||||
|
"""Schema for creating a new project."""
|
||||||
|
|
||||||
|
name: str = Field(..., min_length=1, max_length=255)
|
||||||
|
slug: str = Field(..., min_length=1, max_length=255)
|
||||||
|
owner_id: UUID | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectUpdate(BaseModel):
|
||||||
|
"""Schema for updating a project.
|
||||||
|
|
||||||
|
Note: owner_id is intentionally excluded to prevent IDOR vulnerabilities.
|
||||||
|
Project ownership transfer should be done via a dedicated endpoint with
|
||||||
|
proper authorization checks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str | None = Field(None, min_length=1, max_length=255)
|
||||||
|
slug: str | None = Field(None, min_length=1, max_length=255)
|
||||||
|
description: str | None = None
|
||||||
|
autonomy_level: AutonomyLevel | None = None
|
||||||
|
status: ProjectStatus | None = None
|
||||||
|
settings: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
@field_validator("slug")
|
||||||
|
@classmethod
|
||||||
|
def validate_slug(cls, v: str | None) -> str | None:
|
||||||
|
"""Validate slug format."""
|
||||||
|
if v is None:
|
||||||
|
return v
|
||||||
|
if not re.match(r"^[a-z0-9-]+$", v):
|
||||||
|
raise ValueError(
|
||||||
|
"Slug must contain only lowercase letters, numbers, and hyphens"
|
||||||
|
)
|
||||||
|
if v.startswith("-") or v.endswith("-"):
|
||||||
|
raise ValueError("Slug cannot start or end with a hyphen")
|
||||||
|
if "--" in v:
|
||||||
|
raise ValueError("Slug cannot contain consecutive hyphens")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@field_validator("name")
|
||||||
|
@classmethod
|
||||||
|
def validate_name(cls, v: str | None) -> str | None:
|
||||||
|
"""Validate project name."""
|
||||||
|
if v is not None and (not v or v.strip() == ""):
|
||||||
|
raise ValueError("Project name cannot be empty")
|
||||||
|
return v.strip() if v else v
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectInDB(ProjectBase):
|
||||||
|
"""Schema for project in database."""
|
||||||
|
|
||||||
|
id: UUID
|
||||||
|
owner_id: UUID | None = None
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectResponse(ProjectBase):
|
||||||
|
"""Schema for project API responses."""
|
||||||
|
|
||||||
|
id: UUID
|
||||||
|
owner_id: UUID | None = None
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
|
agent_count: int | None = 0
|
||||||
|
issue_count: int | None = 0
|
||||||
|
active_sprint_name: str | None = None
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectListResponse(BaseModel):
|
||||||
|
"""Schema for paginated project list responses."""
|
||||||
|
|
||||||
|
projects: list[ProjectResponse]
|
||||||
|
total: int
|
||||||
|
page: int
|
||||||
|
page_size: int
|
||||||
|
pages: int
|
||||||
135
backend/app/schemas/syndarix/sprint.py
Normal file
135
backend/app/schemas/syndarix/sprint.py
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
# app/schemas/syndarix/sprint.py
|
||||||
|
"""
|
||||||
|
Pydantic schemas for Sprint entity.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import date, datetime
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||||
|
|
||||||
|
from .enums import SprintStatus
|
||||||
|
|
||||||
|
|
||||||
|
class SprintBase(BaseModel):
|
||||||
|
"""Base sprint schema with common fields."""
|
||||||
|
|
||||||
|
name: str = Field(..., min_length=1, max_length=255)
|
||||||
|
number: int = Field(..., ge=1)
|
||||||
|
goal: str | None = None
|
||||||
|
start_date: date
|
||||||
|
end_date: date
|
||||||
|
status: SprintStatus = SprintStatus.PLANNED
|
||||||
|
planned_points: int | None = Field(None, ge=0)
|
||||||
|
velocity: int | None = Field(None, ge=0)
|
||||||
|
|
||||||
|
@field_validator("name")
|
||||||
|
@classmethod
|
||||||
|
def validate_name(cls, v: str) -> str:
|
||||||
|
"""Validate sprint name."""
|
||||||
|
if not v or v.strip() == "":
|
||||||
|
raise ValueError("Sprint name cannot be empty")
|
||||||
|
return v.strip()
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_dates(self) -> "SprintBase":
|
||||||
|
"""Validate that end_date is after start_date."""
|
||||||
|
if self.end_date < self.start_date:
|
||||||
|
raise ValueError("End date must be after or equal to start date")
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class SprintCreate(SprintBase):
|
||||||
|
"""Schema for creating a new sprint."""
|
||||||
|
|
||||||
|
project_id: UUID
|
||||||
|
|
||||||
|
|
||||||
|
class SprintUpdate(BaseModel):
|
||||||
|
"""Schema for updating a sprint."""
|
||||||
|
|
||||||
|
name: str | None = Field(None, min_length=1, max_length=255)
|
||||||
|
goal: str | None = None
|
||||||
|
start_date: date | None = None
|
||||||
|
end_date: date | None = None
|
||||||
|
status: SprintStatus | None = None
|
||||||
|
planned_points: int | None = Field(None, ge=0)
|
||||||
|
velocity: int | None = Field(None, ge=0)
|
||||||
|
|
||||||
|
@field_validator("name")
|
||||||
|
@classmethod
|
||||||
|
def validate_name(cls, v: str | None) -> str | None:
|
||||||
|
"""Validate sprint name."""
|
||||||
|
if v is not None and (not v or v.strip() == ""):
|
||||||
|
raise ValueError("Sprint name cannot be empty")
|
||||||
|
return v.strip() if v else v
|
||||||
|
|
||||||
|
|
||||||
|
class SprintStart(BaseModel):
|
||||||
|
"""Schema for starting a sprint."""
|
||||||
|
|
||||||
|
start_date: date | None = None # Optionally override start date
|
||||||
|
|
||||||
|
|
||||||
|
class SprintComplete(BaseModel):
|
||||||
|
"""Schema for completing a sprint."""
|
||||||
|
|
||||||
|
velocity: int | None = Field(None, ge=0)
|
||||||
|
notes: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class SprintInDB(SprintBase):
|
||||||
|
"""Schema for sprint in database."""
|
||||||
|
|
||||||
|
id: UUID
|
||||||
|
project_id: UUID
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
|
||||||
|
class SprintResponse(SprintBase):
|
||||||
|
"""Schema for sprint API responses."""
|
||||||
|
|
||||||
|
id: UUID
|
||||||
|
project_id: UUID
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
|
|
||||||
|
# Expanded fields from relationships
|
||||||
|
project_name: str | None = None
|
||||||
|
project_slug: str | None = None
|
||||||
|
issue_count: int | None = 0
|
||||||
|
open_issues: int | None = 0
|
||||||
|
completed_issues: int | None = 0
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
|
||||||
|
class SprintListResponse(BaseModel):
|
||||||
|
"""Schema for paginated sprint list responses."""
|
||||||
|
|
||||||
|
sprints: list[SprintResponse]
|
||||||
|
total: int
|
||||||
|
page: int
|
||||||
|
page_size: int
|
||||||
|
pages: int
|
||||||
|
|
||||||
|
|
||||||
|
class SprintVelocity(BaseModel):
|
||||||
|
"""Schema for sprint velocity metrics."""
|
||||||
|
|
||||||
|
sprint_number: int
|
||||||
|
sprint_name: str
|
||||||
|
planned_points: int | None
|
||||||
|
velocity: int | None # Sum of completed story points
|
||||||
|
velocity_ratio: float | None # velocity/planned ratio
|
||||||
|
|
||||||
|
|
||||||
|
class SprintBurndown(BaseModel):
|
||||||
|
"""Schema for sprint burndown data point."""
|
||||||
|
|
||||||
|
date: date
|
||||||
|
remaining_points: int
|
||||||
|
ideal_remaining: float
|
||||||
611
backend/app/services/event_bus.py
Normal file
611
backend/app/services/event_bus.py
Normal file
@@ -0,0 +1,611 @@
|
|||||||
|
"""
|
||||||
|
EventBus service for Redis Pub/Sub communication.
|
||||||
|
|
||||||
|
This module provides a centralized event bus for publishing and subscribing to
|
||||||
|
events across the Syndarix platform. It uses Redis Pub/Sub for real-time
|
||||||
|
message delivery between services, agents, and the frontend.
|
||||||
|
|
||||||
|
Architecture:
|
||||||
|
- Publishers emit events to project/agent-specific Redis channels
|
||||||
|
- SSE endpoints subscribe to channels and stream events to clients
|
||||||
|
- Events include metadata for reconnection support (Last-Event-ID)
|
||||||
|
- Events are typed with the EventType enum for consistency
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
# Publishing events
|
||||||
|
event_bus = EventBus()
|
||||||
|
await event_bus.connect()
|
||||||
|
|
||||||
|
event = event_bus.create_event(
|
||||||
|
event_type=EventType.AGENT_MESSAGE,
|
||||||
|
project_id=project_id,
|
||||||
|
actor_type="agent",
|
||||||
|
payload={"message": "Processing..."}
|
||||||
|
)
|
||||||
|
await event_bus.publish(event_bus.get_project_channel(project_id), event)
|
||||||
|
|
||||||
|
# Subscribing to events
|
||||||
|
async for event in event_bus.subscribe(["project:123", "agent:456"]):
|
||||||
|
handle_event(event)
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
await event_bus.disconnect()
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from collections.abc import AsyncGenerator, AsyncIterator
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from typing import Any
|
||||||
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
|
import redis.asyncio as redis
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
|
from app.schemas.events import ActorType, Event, EventType
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class EventBusError(Exception):
|
||||||
|
"""Base exception for EventBus errors."""
|
||||||
|
|
||||||
|
|
||||||
|
class EventBusConnectionError(EventBusError):
|
||||||
|
"""Raised when connection to Redis fails."""
|
||||||
|
|
||||||
|
|
||||||
|
class EventBusPublishError(EventBusError):
|
||||||
|
"""Raised when publishing an event fails."""
|
||||||
|
|
||||||
|
|
||||||
|
class EventBusSubscriptionError(EventBusError):
|
||||||
|
"""Raised when subscribing to channels fails."""
|
||||||
|
|
||||||
|
|
||||||
|
class EventBus:
|
||||||
|
"""
|
||||||
|
EventBus for Redis Pub/Sub communication.
|
||||||
|
|
||||||
|
Provides methods to publish events to channels and subscribe to events
|
||||||
|
from multiple channels. Handles connection management, serialization,
|
||||||
|
and error recovery.
|
||||||
|
|
||||||
|
This class provides:
|
||||||
|
- Event publishing to project/agent-specific channels
|
||||||
|
- Subscription management for SSE endpoints
|
||||||
|
- Reconnection support via event IDs (Last-Event-ID)
|
||||||
|
- Keepalive messages for connection health
|
||||||
|
- Type-safe event creation with the Event schema
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
redis_url: Redis connection URL
|
||||||
|
redis_client: Async Redis client instance
|
||||||
|
pubsub: Redis PubSub instance for subscriptions
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Channel prefixes for different entity types
|
||||||
|
PROJECT_CHANNEL_PREFIX = "project"
|
||||||
|
AGENT_CHANNEL_PREFIX = "agent"
|
||||||
|
USER_CHANNEL_PREFIX = "user"
|
||||||
|
GLOBAL_CHANNEL = "syndarix:global"
|
||||||
|
|
||||||
|
def __init__(self, redis_url: str | None = None) -> None:
|
||||||
|
"""
|
||||||
|
Initialize the EventBus.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
redis_url: Redis connection URL. Defaults to settings.REDIS_URL.
|
||||||
|
"""
|
||||||
|
self.redis_url = redis_url or settings.REDIS_URL
|
||||||
|
self._redis_client: redis.Redis | None = None
|
||||||
|
self._pubsub: redis.client.PubSub | None = None
|
||||||
|
self._connected = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def redis_client(self) -> redis.Redis:
|
||||||
|
"""Get the Redis client, raising if not connected."""
|
||||||
|
if self._redis_client is None:
|
||||||
|
raise EventBusConnectionError(
|
||||||
|
"EventBus not connected. Call connect() first."
|
||||||
|
)
|
||||||
|
return self._redis_client
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pubsub(self) -> redis.client.PubSub:
|
||||||
|
"""Get the PubSub instance, raising if not connected."""
|
||||||
|
if self._pubsub is None:
|
||||||
|
raise EventBusConnectionError(
|
||||||
|
"EventBus not connected. Call connect() first."
|
||||||
|
)
|
||||||
|
return self._pubsub
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_connected(self) -> bool:
|
||||||
|
"""Check if the EventBus is connected to Redis."""
|
||||||
|
return self._connected and self._redis_client is not None
|
||||||
|
|
||||||
|
async def connect(self) -> None:
|
||||||
|
"""
|
||||||
|
Connect to Redis and initialize the PubSub client.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
EventBusConnectionError: If connection to Redis fails.
|
||||||
|
"""
|
||||||
|
if self._connected:
|
||||||
|
logger.debug("EventBus already connected")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._redis_client = redis.from_url(
|
||||||
|
self.redis_url,
|
||||||
|
encoding="utf-8",
|
||||||
|
decode_responses=True,
|
||||||
|
)
|
||||||
|
# Test connection - ping() returns a coroutine for async Redis
|
||||||
|
ping_result = self._redis_client.ping()
|
||||||
|
if hasattr(ping_result, "__await__"):
|
||||||
|
await ping_result
|
||||||
|
self._pubsub = self._redis_client.pubsub()
|
||||||
|
self._connected = True
|
||||||
|
logger.info("EventBus connected to Redis")
|
||||||
|
except redis.ConnectionError as e:
|
||||||
|
logger.error(f"Failed to connect to Redis: {e}", exc_info=True)
|
||||||
|
raise EventBusConnectionError(f"Failed to connect to Redis: {e}") from e
|
||||||
|
except redis.RedisError as e:
|
||||||
|
logger.error(f"Redis error during connection: {e}", exc_info=True)
|
||||||
|
raise EventBusConnectionError(f"Redis error: {e}") from e
|
||||||
|
|
||||||
|
async def disconnect(self) -> None:
|
||||||
|
"""
|
||||||
|
Disconnect from Redis and cleanup resources.
|
||||||
|
"""
|
||||||
|
if self._pubsub:
|
||||||
|
try:
|
||||||
|
await self._pubsub.unsubscribe()
|
||||||
|
await self._pubsub.close()
|
||||||
|
except redis.RedisError as e:
|
||||||
|
logger.warning(f"Error closing PubSub: {e}")
|
||||||
|
finally:
|
||||||
|
self._pubsub = None
|
||||||
|
|
||||||
|
if self._redis_client:
|
||||||
|
try:
|
||||||
|
await self._redis_client.aclose()
|
||||||
|
except redis.RedisError as e:
|
||||||
|
logger.warning(f"Error closing Redis client: {e}")
|
||||||
|
finally:
|
||||||
|
self._redis_client = None
|
||||||
|
|
||||||
|
self._connected = False
|
||||||
|
logger.info("EventBus disconnected from Redis")
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def connection(self) -> AsyncIterator["EventBus"]:
|
||||||
|
"""
|
||||||
|
Context manager for automatic connection handling.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
async with event_bus.connection() as bus:
|
||||||
|
await bus.publish(channel, event)
|
||||||
|
"""
|
||||||
|
await self.connect()
|
||||||
|
try:
|
||||||
|
yield self
|
||||||
|
finally:
|
||||||
|
await self.disconnect()
|
||||||
|
|
||||||
|
def get_project_channel(self, project_id: UUID | str) -> str:
|
||||||
|
"""
|
||||||
|
Get the channel name for a project.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: The project UUID or string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Channel name string in format "project:{uuid}"
|
||||||
|
"""
|
||||||
|
return f"{self.PROJECT_CHANNEL_PREFIX}:{project_id}"
|
||||||
|
|
||||||
|
def get_agent_channel(self, agent_id: UUID | str) -> str:
|
||||||
|
"""
|
||||||
|
Get the channel name for an agent instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: The agent instance UUID or string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Channel name string in format "agent:{uuid}"
|
||||||
|
"""
|
||||||
|
return f"{self.AGENT_CHANNEL_PREFIX}:{agent_id}"
|
||||||
|
|
||||||
|
def get_user_channel(self, user_id: UUID | str) -> str:
|
||||||
|
"""
|
||||||
|
Get the channel name for a user (personal notifications).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user UUID or string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Channel name string in format "user:{uuid}"
|
||||||
|
"""
|
||||||
|
return f"{self.USER_CHANNEL_PREFIX}:{user_id}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_event(
|
||||||
|
event_type: EventType,
|
||||||
|
project_id: UUID,
|
||||||
|
actor_type: ActorType,
|
||||||
|
payload: dict | None = None,
|
||||||
|
actor_id: UUID | None = None,
|
||||||
|
event_id: str | None = None,
|
||||||
|
timestamp: datetime | None = None,
|
||||||
|
) -> Event:
|
||||||
|
"""
|
||||||
|
Factory method to create a new Event.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_type: The type of event
|
||||||
|
project_id: The project this event belongs to
|
||||||
|
actor_type: Type of actor ('agent', 'user', or 'system')
|
||||||
|
payload: Event-specific payload data
|
||||||
|
actor_id: ID of the agent or user who triggered the event
|
||||||
|
event_id: Optional custom event ID (UUID string)
|
||||||
|
timestamp: Optional custom timestamp (defaults to now UTC)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A new Event instance
|
||||||
|
"""
|
||||||
|
return Event(
|
||||||
|
id=event_id or str(uuid4()),
|
||||||
|
type=event_type,
|
||||||
|
timestamp=timestamp or datetime.now(UTC),
|
||||||
|
project_id=project_id,
|
||||||
|
actor_id=actor_id,
|
||||||
|
actor_type=actor_type,
|
||||||
|
payload=payload or {},
|
||||||
|
)
|
||||||
|
|
||||||
|
def _serialize_event(self, event: Event) -> str:
|
||||||
|
"""
|
||||||
|
Serialize an event to JSON string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event: The Event to serialize
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JSON string representation of the event
|
||||||
|
"""
|
||||||
|
return event.model_dump_json()
|
||||||
|
|
||||||
|
def _deserialize_event(self, data: str) -> Event:
|
||||||
|
"""
|
||||||
|
Deserialize a JSON string to an Event.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: JSON string to deserialize
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deserialized Event instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValidationError: If the data doesn't match the Event schema
|
||||||
|
"""
|
||||||
|
return Event.model_validate_json(data)
|
||||||
|
|
||||||
|
async def publish(self, channel: str, event: Event) -> int:
|
||||||
|
"""
|
||||||
|
Publish an event to a channel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channel: The channel name to publish to
|
||||||
|
event: The Event to publish
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of subscribers that received the message
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
EventBusConnectionError: If not connected to Redis
|
||||||
|
EventBusPublishError: If publishing fails
|
||||||
|
"""
|
||||||
|
if not self.is_connected:
|
||||||
|
raise EventBusConnectionError("EventBus not connected")
|
||||||
|
|
||||||
|
try:
|
||||||
|
message = self._serialize_event(event)
|
||||||
|
subscriber_count = await self.redis_client.publish(channel, message)
|
||||||
|
logger.debug(
|
||||||
|
f"Published event {event.type} to {channel} "
|
||||||
|
f"(received by {subscriber_count} subscribers)"
|
||||||
|
)
|
||||||
|
return subscriber_count
|
||||||
|
except redis.RedisError as e:
|
||||||
|
logger.error(f"Failed to publish event to {channel}: {e}", exc_info=True)
|
||||||
|
raise EventBusPublishError(f"Failed to publish event: {e}") from e
|
||||||
|
|
||||||
|
async def publish_to_project(self, event: Event) -> int:
|
||||||
|
"""
|
||||||
|
Publish an event to the project's channel.
|
||||||
|
|
||||||
|
Convenience method that publishes to the project channel based on
|
||||||
|
the event's project_id.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event: The Event to publish (must have project_id set)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of subscribers that received the message
|
||||||
|
"""
|
||||||
|
channel = self.get_project_channel(event.project_id)
|
||||||
|
return await self.publish(channel, event)
|
||||||
|
|
||||||
|
async def publish_multi(self, channels: list[str], event: Event) -> dict[str, int]:
|
||||||
|
"""
|
||||||
|
Publish an event to multiple channels.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channels: List of channel names to publish to
|
||||||
|
event: The Event to publish
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping channel names to subscriber counts
|
||||||
|
"""
|
||||||
|
results = {}
|
||||||
|
for channel in channels:
|
||||||
|
try:
|
||||||
|
results[channel] = await self.publish(channel, event)
|
||||||
|
except EventBusPublishError as e:
|
||||||
|
logger.warning(f"Failed to publish to {channel}: {e}")
|
||||||
|
results[channel] = 0
|
||||||
|
return results
|
||||||
|
|
||||||
|
async def subscribe(
|
||||||
|
self, channels: list[str], *, max_wait: float | None = None
|
||||||
|
) -> AsyncIterator[Event]:
|
||||||
|
"""
|
||||||
|
Subscribe to one or more channels and yield events.
|
||||||
|
|
||||||
|
This is an async generator that yields Event objects as they arrive.
|
||||||
|
Use max_wait to limit how long to wait for messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channels: List of channel names to subscribe to
|
||||||
|
max_wait: Optional maximum wait time in seconds for each message.
|
||||||
|
If None, waits indefinitely.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Event objects received from subscribed channels
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
EventBusConnectionError: If not connected to Redis
|
||||||
|
EventBusSubscriptionError: If subscription fails
|
||||||
|
|
||||||
|
Example:
|
||||||
|
async for event in event_bus.subscribe(["project:123"], max_wait=30):
|
||||||
|
print(f"Received: {event.type}")
|
||||||
|
"""
|
||||||
|
if not self.is_connected:
|
||||||
|
raise EventBusConnectionError("EventBus not connected")
|
||||||
|
|
||||||
|
# Create a new pubsub for this subscription
|
||||||
|
subscription_pubsub = self.redis_client.pubsub()
|
||||||
|
|
||||||
|
try:
|
||||||
|
await subscription_pubsub.subscribe(*channels)
|
||||||
|
logger.info(f"Subscribed to channels: {channels}")
|
||||||
|
except redis.RedisError as e:
|
||||||
|
logger.error(f"Failed to subscribe to channels: {e}", exc_info=True)
|
||||||
|
await subscription_pubsub.close()
|
||||||
|
raise EventBusSubscriptionError(f"Failed to subscribe: {e}") from e
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
if max_wait is not None:
|
||||||
|
async with asyncio.timeout(max_wait):
|
||||||
|
message = await subscription_pubsub.get_message(
|
||||||
|
ignore_subscribe_messages=True, timeout=1.0
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
message = await subscription_pubsub.get_message(
|
||||||
|
ignore_subscribe_messages=True, timeout=1.0
|
||||||
|
)
|
||||||
|
except TimeoutError:
|
||||||
|
# Timeout reached, stop iteration
|
||||||
|
return
|
||||||
|
|
||||||
|
if message is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if message["type"] == "message":
|
||||||
|
try:
|
||||||
|
event = self._deserialize_event(message["data"])
|
||||||
|
yield event
|
||||||
|
except ValidationError as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Invalid event data received: {e}",
|
||||||
|
extra={"channel": message.get("channel")},
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to decode event JSON: {e}",
|
||||||
|
extra={"channel": message.get("channel")},
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
await subscription_pubsub.unsubscribe(*channels)
|
||||||
|
await subscription_pubsub.close()
|
||||||
|
logger.debug(f"Unsubscribed from channels: {channels}")
|
||||||
|
except redis.RedisError as e:
|
||||||
|
logger.warning(f"Error unsubscribing from channels: {e}")
|
||||||
|
|
||||||
|
async def subscribe_sse(
|
||||||
|
self,
|
||||||
|
project_id: str | UUID,
|
||||||
|
last_event_id: str | None = None,
|
||||||
|
keepalive_interval: int = 30,
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
"""
|
||||||
|
Subscribe to events for a project in SSE format.
|
||||||
|
|
||||||
|
This is an async generator that yields SSE-formatted event strings.
|
||||||
|
It includes keepalive messages at the specified interval.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: The project to subscribe to
|
||||||
|
last_event_id: Optional last received event ID for reconnection
|
||||||
|
keepalive_interval: Seconds between keepalive messages (default 30)
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
SSE-formatted event strings (ready to send to client)
|
||||||
|
"""
|
||||||
|
if not self.is_connected:
|
||||||
|
raise EventBusConnectionError("EventBus not connected")
|
||||||
|
|
||||||
|
project_id_str = str(project_id)
|
||||||
|
channel = self.get_project_channel(project_id_str)
|
||||||
|
|
||||||
|
subscription_pubsub = self.redis_client.pubsub()
|
||||||
|
await subscription_pubsub.subscribe(channel)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Subscribed to SSE events for project {project_id_str} "
|
||||||
|
f"(last_event_id={last_event_id})"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
# Wait for messages with a timeout for keepalive
|
||||||
|
message = await asyncio.wait_for(
|
||||||
|
subscription_pubsub.get_message(ignore_subscribe_messages=True),
|
||||||
|
timeout=keepalive_interval,
|
||||||
|
)
|
||||||
|
|
||||||
|
if message is not None and message["type"] == "message":
|
||||||
|
event_data = message["data"]
|
||||||
|
|
||||||
|
# If reconnecting, check if we should skip this event
|
||||||
|
if last_event_id:
|
||||||
|
try:
|
||||||
|
event_dict = json.loads(event_data)
|
||||||
|
if event_dict.get("id") == last_event_id:
|
||||||
|
# Found the last event, start yielding from next
|
||||||
|
last_event_id = None
|
||||||
|
continue
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
yield event_data
|
||||||
|
|
||||||
|
except TimeoutError:
|
||||||
|
# Send keepalive comment
|
||||||
|
yield "" # Empty string signals keepalive
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.info(f"SSE subscription cancelled for project {project_id_str}")
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
await subscription_pubsub.unsubscribe(channel)
|
||||||
|
await subscription_pubsub.close()
|
||||||
|
logger.info(f"Unsubscribed SSE from project {project_id_str}")
|
||||||
|
|
||||||
|
async def subscribe_with_callback(
|
||||||
|
self,
|
||||||
|
channels: list[str],
|
||||||
|
callback: Any, # Callable[[Event], Awaitable[None]]
|
||||||
|
stop_event: asyncio.Event | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Subscribe to channels and process events with a callback.
|
||||||
|
|
||||||
|
This method runs until stop_event is set or an unrecoverable error occurs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channels: List of channel names to subscribe to
|
||||||
|
callback: Async function to call for each event
|
||||||
|
stop_event: Optional asyncio.Event to signal stop
|
||||||
|
|
||||||
|
Example:
|
||||||
|
async def handle_event(event: Event):
|
||||||
|
print(f"Handling: {event.type}")
|
||||||
|
|
||||||
|
stop = asyncio.Event()
|
||||||
|
asyncio.create_task(
|
||||||
|
event_bus.subscribe_with_callback(["project:123"], handle_event, stop)
|
||||||
|
)
|
||||||
|
# Later...
|
||||||
|
stop.set()
|
||||||
|
"""
|
||||||
|
if stop_event is None:
|
||||||
|
stop_event = asyncio.Event()
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for event in self.subscribe(channels):
|
||||||
|
if stop_event.is_set():
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
await callback(event)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in event callback: {e}", exc_info=True)
|
||||||
|
except EventBusSubscriptionError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error in subscription loop: {e}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
# Singleton instance for application-wide use
|
||||||
|
_event_bus: EventBus | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_event_bus() -> EventBus:
|
||||||
|
"""
|
||||||
|
Get the singleton EventBus instance.
|
||||||
|
|
||||||
|
Creates a new instance if one doesn't exist. Note that you still need
|
||||||
|
to call connect() before using the EventBus.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The singleton EventBus instance
|
||||||
|
"""
|
||||||
|
global _event_bus
|
||||||
|
if _event_bus is None:
|
||||||
|
_event_bus = EventBus()
|
||||||
|
return _event_bus
|
||||||
|
|
||||||
|
|
||||||
|
async def get_connected_event_bus() -> EventBus:
|
||||||
|
"""
|
||||||
|
Get a connected EventBus instance.
|
||||||
|
|
||||||
|
Ensures the EventBus is connected before returning. For use in
|
||||||
|
FastAPI dependency injection.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A connected EventBus instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
EventBusConnectionError: If connection fails
|
||||||
|
"""
|
||||||
|
event_bus = get_event_bus()
|
||||||
|
if not event_bus.is_connected:
|
||||||
|
await event_bus.connect()
|
||||||
|
return event_bus
|
||||||
|
|
||||||
|
|
||||||
|
async def close_event_bus() -> None:
|
||||||
|
"""
|
||||||
|
Close the global EventBus instance.
|
||||||
|
|
||||||
|
Should be called during application shutdown.
|
||||||
|
"""
|
||||||
|
global _event_bus
|
||||||
|
if _event_bus is not None:
|
||||||
|
await _event_bus.disconnect()
|
||||||
|
_event_bus = None
|
||||||
85
backend/app/services/mcp/__init__.py
Normal file
85
backend/app/services/mcp/__init__.py
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
"""
|
||||||
|
MCP Client Service Package
|
||||||
|
|
||||||
|
Provides infrastructure for communicating with MCP (Model Context Protocol)
|
||||||
|
servers. This is the foundation for AI agent tool integration.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
from app.services.mcp import get_mcp_client, MCPClientManager
|
||||||
|
|
||||||
|
# In FastAPI route
|
||||||
|
async def my_route(mcp: MCPClientManager = Depends(get_mcp_client)):
|
||||||
|
result = await mcp.call_tool("llm-gateway", "chat", {"prompt": "Hello"})
|
||||||
|
|
||||||
|
# Direct usage
|
||||||
|
manager = MCPClientManager()
|
||||||
|
await manager.initialize()
|
||||||
|
result = await manager.call_tool("issues", "create_issue", {...})
|
||||||
|
await manager.shutdown()
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .client_manager import (
|
||||||
|
MCPClientManager,
|
||||||
|
ServerHealth,
|
||||||
|
get_mcp_client,
|
||||||
|
reset_mcp_client,
|
||||||
|
shutdown_mcp_client,
|
||||||
|
)
|
||||||
|
from .config import (
|
||||||
|
MCPConfig,
|
||||||
|
MCPServerConfig,
|
||||||
|
TransportType,
|
||||||
|
create_default_config,
|
||||||
|
load_mcp_config,
|
||||||
|
)
|
||||||
|
from .connection import ConnectionPool, ConnectionState, MCPConnection
|
||||||
|
from .exceptions import (
|
||||||
|
MCPCircuitOpenError,
|
||||||
|
MCPConnectionError,
|
||||||
|
MCPError,
|
||||||
|
MCPServerNotFoundError,
|
||||||
|
MCPTimeoutError,
|
||||||
|
MCPToolError,
|
||||||
|
MCPToolNotFoundError,
|
||||||
|
MCPValidationError,
|
||||||
|
)
|
||||||
|
from .registry import MCPServerRegistry, ServerCapabilities, get_registry
|
||||||
|
from .routing import AsyncCircuitBreaker, CircuitState, ToolInfo, ToolResult, ToolRouter
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# Main facade
|
||||||
|
"MCPClientManager",
|
||||||
|
"get_mcp_client",
|
||||||
|
"shutdown_mcp_client",
|
||||||
|
"reset_mcp_client",
|
||||||
|
"ServerHealth",
|
||||||
|
# Configuration
|
||||||
|
"MCPConfig",
|
||||||
|
"MCPServerConfig",
|
||||||
|
"TransportType",
|
||||||
|
"load_mcp_config",
|
||||||
|
"create_default_config",
|
||||||
|
# Registry
|
||||||
|
"MCPServerRegistry",
|
||||||
|
"ServerCapabilities",
|
||||||
|
"get_registry",
|
||||||
|
# Connection
|
||||||
|
"ConnectionPool",
|
||||||
|
"ConnectionState",
|
||||||
|
"MCPConnection",
|
||||||
|
# Routing
|
||||||
|
"ToolRouter",
|
||||||
|
"ToolInfo",
|
||||||
|
"ToolResult",
|
||||||
|
"AsyncCircuitBreaker",
|
||||||
|
"CircuitState",
|
||||||
|
# Exceptions
|
||||||
|
"MCPError",
|
||||||
|
"MCPConnectionError",
|
||||||
|
"MCPTimeoutError",
|
||||||
|
"MCPToolError",
|
||||||
|
"MCPServerNotFoundError",
|
||||||
|
"MCPToolNotFoundError",
|
||||||
|
"MCPCircuitOpenError",
|
||||||
|
"MCPValidationError",
|
||||||
|
]
|
||||||
417
backend/app/services/mcp/client_manager.py
Normal file
417
backend/app/services/mcp/client_manager.py
Normal file
@@ -0,0 +1,417 @@
|
|||||||
|
"""
|
||||||
|
MCP Client Manager
|
||||||
|
|
||||||
|
Main facade for all MCP operations. Manages server connections,
|
||||||
|
tool discovery, and provides a unified interface for tool calls.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .config import MCPConfig, MCPServerConfig, load_mcp_config
|
||||||
|
from .connection import ConnectionPool, ConnectionState
|
||||||
|
from .exceptions import MCPServerNotFoundError
|
||||||
|
from .registry import MCPServerRegistry, get_registry
|
||||||
|
from .routing import ToolInfo, ToolResult, ToolRouter
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ServerHealth:
|
||||||
|
"""Health status for an MCP server."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
healthy: bool
|
||||||
|
state: str
|
||||||
|
url: str
|
||||||
|
error: str | None = None
|
||||||
|
tools_count: int = 0
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert to dictionary."""
|
||||||
|
return {
|
||||||
|
"name": self.name,
|
||||||
|
"healthy": self.healthy,
|
||||||
|
"state": self.state,
|
||||||
|
"url": self.url,
|
||||||
|
"error": self.error,
|
||||||
|
"tools_count": self.tools_count,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class MCPClientManager:
|
||||||
|
"""
|
||||||
|
Central manager for all MCP client operations.
|
||||||
|
|
||||||
|
Provides a unified interface for:
|
||||||
|
- Connecting to MCP servers
|
||||||
|
- Discovering and calling tools
|
||||||
|
- Health monitoring
|
||||||
|
- Connection lifecycle management
|
||||||
|
|
||||||
|
This is the main entry point for MCP operations in the application.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: MCPConfig | None = None,
|
||||||
|
registry: MCPServerRegistry | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize the MCP client manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Optional MCP configuration. If None, loads from default.
|
||||||
|
registry: Optional registry instance. If None, uses singleton.
|
||||||
|
"""
|
||||||
|
self._registry = registry or get_registry()
|
||||||
|
self._pool = ConnectionPool()
|
||||||
|
self._router: ToolRouter | None = None
|
||||||
|
self._initialized = False
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
# Load configuration if provided
|
||||||
|
if config is not None:
|
||||||
|
self._registry.load_config(config)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_initialized(self) -> bool:
|
||||||
|
"""Check if the manager is initialized."""
|
||||||
|
return self._initialized
|
||||||
|
|
||||||
|
async def initialize(self, config: MCPConfig | None = None) -> None:
|
||||||
|
"""
|
||||||
|
Initialize the MCP client manager.
|
||||||
|
|
||||||
|
Loads configuration, creates connections, and discovers tools.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Optional configuration to load
|
||||||
|
"""
|
||||||
|
async with self._lock:
|
||||||
|
if self._initialized:
|
||||||
|
logger.warning("MCPClientManager already initialized")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("Initializing MCP Client Manager")
|
||||||
|
|
||||||
|
# Load configuration
|
||||||
|
if config is not None:
|
||||||
|
self._registry.load_config(config)
|
||||||
|
elif len(self._registry.list_servers()) == 0:
|
||||||
|
# Try to load from default location
|
||||||
|
self._registry.load_config(load_mcp_config())
|
||||||
|
|
||||||
|
# Create router
|
||||||
|
self._router = ToolRouter(self._registry, self._pool)
|
||||||
|
|
||||||
|
# Connect to all enabled servers
|
||||||
|
await self._connect_all_servers()
|
||||||
|
|
||||||
|
# Discover tools from all servers
|
||||||
|
if self._router:
|
||||||
|
await self._router.discover_tools()
|
||||||
|
|
||||||
|
self._initialized = True
|
||||||
|
logger.info(
|
||||||
|
"MCP Client Manager initialized with %d servers",
|
||||||
|
len(self._registry.list_enabled_servers()),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _connect_all_servers(self) -> None:
|
||||||
|
"""Connect to all enabled MCP servers."""
|
||||||
|
enabled_servers = self._registry.get_enabled_configs()
|
||||||
|
|
||||||
|
for name, config in enabled_servers.items():
|
||||||
|
try:
|
||||||
|
await self._pool.get_connection(name, config)
|
||||||
|
logger.info("Connected to MCP server: %s", name)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to connect to MCP server %s: %s", name, e)
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
"""
|
||||||
|
Shutdown the MCP client manager.
|
||||||
|
|
||||||
|
Closes all connections and cleans up resources.
|
||||||
|
"""
|
||||||
|
async with self._lock:
|
||||||
|
if not self._initialized:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("Shutting down MCP Client Manager")
|
||||||
|
|
||||||
|
await self._pool.close_all()
|
||||||
|
self._initialized = False
|
||||||
|
|
||||||
|
logger.info("MCP Client Manager shutdown complete")
|
||||||
|
|
||||||
|
async def connect(self, server_name: str) -> None:
|
||||||
|
"""
|
||||||
|
Connect to a specific MCP server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server_name: Name of the server to connect to
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
MCPServerNotFoundError: If server is not registered
|
||||||
|
"""
|
||||||
|
config = self._registry.get(server_name)
|
||||||
|
await self._pool.get_connection(server_name, config)
|
||||||
|
logger.info("Connected to MCP server: %s", server_name)
|
||||||
|
|
||||||
|
async def disconnect(self, server_name: str) -> None:
|
||||||
|
"""
|
||||||
|
Disconnect from a specific MCP server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server_name: Name of the server to disconnect from
|
||||||
|
"""
|
||||||
|
await self._pool.close_connection(server_name)
|
||||||
|
logger.info("Disconnected from MCP server: %s", server_name)
|
||||||
|
|
||||||
|
async def disconnect_all(self) -> None:
|
||||||
|
"""Disconnect from all MCP servers."""
|
||||||
|
await self._pool.close_all()
|
||||||
|
|
||||||
|
async def call_tool(
|
||||||
|
self,
|
||||||
|
server: str,
|
||||||
|
tool: str,
|
||||||
|
args: dict[str, Any] | None = None,
|
||||||
|
timeout: float | None = None,
|
||||||
|
) -> ToolResult:
|
||||||
|
"""
|
||||||
|
Call a tool on a specific MCP server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server: Name of the MCP server
|
||||||
|
tool: Name of the tool to call
|
||||||
|
args: Tool arguments
|
||||||
|
timeout: Optional timeout override
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tool execution result
|
||||||
|
"""
|
||||||
|
if not self._initialized or self._router is None:
|
||||||
|
await self.initialize()
|
||||||
|
|
||||||
|
assert self._router is not None # Guaranteed after initialize()
|
||||||
|
return await self._router.call_tool(
|
||||||
|
server_name=server,
|
||||||
|
tool_name=tool,
|
||||||
|
arguments=args,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def route_tool(
|
||||||
|
self,
|
||||||
|
tool: str,
|
||||||
|
args: dict[str, Any] | None = None,
|
||||||
|
timeout: float | None = None,
|
||||||
|
) -> ToolResult:
|
||||||
|
"""
|
||||||
|
Route a tool call to the appropriate server automatically.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool: Name of the tool to call
|
||||||
|
args: Tool arguments
|
||||||
|
timeout: Optional timeout override
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tool execution result
|
||||||
|
"""
|
||||||
|
if not self._initialized or self._router is None:
|
||||||
|
await self.initialize()
|
||||||
|
|
||||||
|
assert self._router is not None # Guaranteed after initialize()
|
||||||
|
return await self._router.route_tool(
|
||||||
|
tool_name=tool,
|
||||||
|
arguments=args,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def list_tools(self, server: str) -> list[ToolInfo]:
|
||||||
|
"""
|
||||||
|
List all tools available on a specific server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server: Name of the MCP server
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of tool information
|
||||||
|
"""
|
||||||
|
capabilities = await self._registry.get_capabilities(server)
|
||||||
|
return [
|
||||||
|
ToolInfo(
|
||||||
|
name=t.get("name", ""),
|
||||||
|
description=t.get("description"),
|
||||||
|
server_name=server,
|
||||||
|
input_schema=t.get("input_schema"),
|
||||||
|
)
|
||||||
|
for t in capabilities.tools
|
||||||
|
]
|
||||||
|
|
||||||
|
async def list_all_tools(self) -> list[ToolInfo]:
|
||||||
|
"""
|
||||||
|
List all tools from all servers.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of tool information
|
||||||
|
"""
|
||||||
|
if not self._initialized or self._router is None:
|
||||||
|
await self.initialize()
|
||||||
|
|
||||||
|
assert self._router is not None # Guaranteed after initialize()
|
||||||
|
return await self._router.list_all_tools()
|
||||||
|
|
||||||
|
async def health_check(self) -> dict[str, ServerHealth]:
|
||||||
|
"""
|
||||||
|
Perform health check on all MCP servers.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict mapping server names to health status
|
||||||
|
"""
|
||||||
|
results: dict[str, ServerHealth] = {}
|
||||||
|
pool_status = self._pool.get_status()
|
||||||
|
pool_health = await self._pool.health_check_all()
|
||||||
|
|
||||||
|
for server_name in self._registry.list_servers():
|
||||||
|
try:
|
||||||
|
config = self._registry.get(server_name)
|
||||||
|
status = pool_status.get(server_name, {})
|
||||||
|
healthy = pool_health.get(server_name, False)
|
||||||
|
|
||||||
|
capabilities = self._registry.get_cached_capabilities(server_name)
|
||||||
|
|
||||||
|
results[server_name] = ServerHealth(
|
||||||
|
name=server_name,
|
||||||
|
healthy=healthy,
|
||||||
|
state=status.get("state", ConnectionState.DISCONNECTED.value),
|
||||||
|
url=config.url,
|
||||||
|
tools_count=len(capabilities.tools),
|
||||||
|
)
|
||||||
|
except MCPServerNotFoundError:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
results[server_name] = ServerHealth(
|
||||||
|
name=server_name,
|
||||||
|
healthy=False,
|
||||||
|
state=ConnectionState.ERROR.value,
|
||||||
|
url="unknown",
|
||||||
|
error=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def list_servers(self) -> list[str]:
|
||||||
|
"""Get list of all registered server names."""
|
||||||
|
return self._registry.list_servers()
|
||||||
|
|
||||||
|
def list_enabled_servers(self) -> list[str]:
|
||||||
|
"""Get list of enabled server names."""
|
||||||
|
return self._registry.list_enabled_servers()
|
||||||
|
|
||||||
|
def get_server_config(self, server_name: str) -> MCPServerConfig:
|
||||||
|
"""
|
||||||
|
Get configuration for a specific server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server_name: Name of the server
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Server configuration
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
MCPServerNotFoundError: If server is not registered
|
||||||
|
"""
|
||||||
|
return self._registry.get(server_name)
|
||||||
|
|
||||||
|
def register_server(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
config: MCPServerConfig,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Register a new MCP server at runtime.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Unique server name
|
||||||
|
config: Server configuration
|
||||||
|
"""
|
||||||
|
self._registry.register(name, config)
|
||||||
|
|
||||||
|
def unregister_server(self, name: str) -> bool:
|
||||||
|
"""
|
||||||
|
Unregister an MCP server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Server name to unregister
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if server was found and removed
|
||||||
|
"""
|
||||||
|
return self._registry.unregister(name)
|
||||||
|
|
||||||
|
def get_circuit_breaker_status(self) -> dict[str, dict[str, Any]]:
|
||||||
|
"""Get status of all circuit breakers."""
|
||||||
|
if self._router is None:
|
||||||
|
return {}
|
||||||
|
return self._router.get_circuit_breaker_status()
|
||||||
|
|
||||||
|
async def reset_circuit_breaker(self, server_name: str) -> bool:
|
||||||
|
"""
|
||||||
|
Reset a circuit breaker for a server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server_name: Name of the server
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if circuit breaker was reset
|
||||||
|
"""
|
||||||
|
if self._router is None:
|
||||||
|
return False
|
||||||
|
return await self._router.reset_circuit_breaker(server_name)
|
||||||
|
|
||||||
|
|
||||||
|
# Singleton instance
|
||||||
|
_manager_instance: MCPClientManager | None = None
|
||||||
|
_manager_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
async def get_mcp_client() -> MCPClientManager:
|
||||||
|
"""
|
||||||
|
Get the global MCP client manager instance.
|
||||||
|
|
||||||
|
This is the main dependency injection point for FastAPI.
|
||||||
|
Uses proper locking to avoid race conditions in async contexts.
|
||||||
|
"""
|
||||||
|
global _manager_instance
|
||||||
|
|
||||||
|
# Use lock for the entire check-and-create operation to avoid race conditions
|
||||||
|
async with _manager_lock:
|
||||||
|
if _manager_instance is None:
|
||||||
|
_manager_instance = MCPClientManager()
|
||||||
|
await _manager_instance.initialize()
|
||||||
|
|
||||||
|
return _manager_instance
|
||||||
|
|
||||||
|
|
||||||
|
async def shutdown_mcp_client() -> None:
|
||||||
|
"""Shutdown the global MCP client manager."""
|
||||||
|
global _manager_instance
|
||||||
|
|
||||||
|
# Use lock to prevent race with get_mcp_client()
|
||||||
|
async with _manager_lock:
|
||||||
|
if _manager_instance is not None:
|
||||||
|
await _manager_instance.shutdown()
|
||||||
|
_manager_instance = None
|
||||||
|
|
||||||
|
|
||||||
|
def reset_mcp_client() -> None:
|
||||||
|
"""Reset the global MCP client manager (for testing)."""
|
||||||
|
global _manager_instance
|
||||||
|
_manager_instance = None
|
||||||
234
backend/app/services/mcp/config.py
Normal file
234
backend/app/services/mcp/config.py
Normal file
@@ -0,0 +1,234 @@
|
|||||||
|
"""
|
||||||
|
MCP Configuration System
|
||||||
|
|
||||||
|
Pydantic models for MCP server configuration with YAML file loading
|
||||||
|
and environment variable overrides.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from enum import Enum
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
|
|
||||||
|
class TransportType(str, Enum):
|
||||||
|
"""Supported MCP transport types."""
|
||||||
|
|
||||||
|
HTTP = "http"
|
||||||
|
STDIO = "stdio"
|
||||||
|
SSE = "sse"
|
||||||
|
|
||||||
|
|
||||||
|
class MCPServerConfig(BaseModel):
|
||||||
|
"""Configuration for a single MCP server."""
|
||||||
|
|
||||||
|
url: str = Field(..., description="Server URL (supports ${ENV_VAR} syntax)")
|
||||||
|
transport: TransportType = Field(
|
||||||
|
default=TransportType.HTTP,
|
||||||
|
description="Transport protocol to use",
|
||||||
|
)
|
||||||
|
timeout: int = Field(
|
||||||
|
default=30,
|
||||||
|
ge=1,
|
||||||
|
le=600,
|
||||||
|
description="Request timeout in seconds",
|
||||||
|
)
|
||||||
|
retry_attempts: int = Field(
|
||||||
|
default=3,
|
||||||
|
ge=0,
|
||||||
|
le=10,
|
||||||
|
description="Number of retry attempts on failure",
|
||||||
|
)
|
||||||
|
retry_delay: float = Field(
|
||||||
|
default=1.0,
|
||||||
|
ge=0.1,
|
||||||
|
le=60.0,
|
||||||
|
description="Initial delay between retries in seconds",
|
||||||
|
)
|
||||||
|
retry_max_delay: float = Field(
|
||||||
|
default=30.0,
|
||||||
|
ge=1.0,
|
||||||
|
le=300.0,
|
||||||
|
description="Maximum delay between retries in seconds",
|
||||||
|
)
|
||||||
|
circuit_breaker_threshold: int = Field(
|
||||||
|
default=5,
|
||||||
|
ge=1,
|
||||||
|
le=50,
|
||||||
|
description="Number of failures before opening circuit",
|
||||||
|
)
|
||||||
|
circuit_breaker_timeout: float = Field(
|
||||||
|
default=30.0,
|
||||||
|
ge=5.0,
|
||||||
|
le=300.0,
|
||||||
|
description="Seconds to wait before attempting to close circuit",
|
||||||
|
)
|
||||||
|
enabled: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Whether this server is enabled",
|
||||||
|
)
|
||||||
|
description: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Human-readable description of the server",
|
||||||
|
)
|
||||||
|
|
||||||
|
@field_validator("url", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def expand_env_vars(cls, v: str) -> str:
|
||||||
|
"""Expand environment variables in URL using ${VAR:-default} syntax."""
|
||||||
|
if not isinstance(v, str):
|
||||||
|
return v
|
||||||
|
|
||||||
|
result = v
|
||||||
|
# Find all ${VAR} or ${VAR:-default} patterns
|
||||||
|
import re
|
||||||
|
|
||||||
|
pattern = r"\$\{([^}]+)\}"
|
||||||
|
matches = re.findall(pattern, v)
|
||||||
|
|
||||||
|
for match in matches:
|
||||||
|
if ":-" in match:
|
||||||
|
var_name, default = match.split(":-", 1)
|
||||||
|
else:
|
||||||
|
var_name, default = match, ""
|
||||||
|
|
||||||
|
env_value = os.environ.get(var_name.strip(), default)
|
||||||
|
result = result.replace(f"${{{match}}}", env_value)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class MCPConfig(BaseModel):
|
||||||
|
"""Root configuration for all MCP servers."""
|
||||||
|
|
||||||
|
mcp_servers: dict[str, MCPServerConfig] = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
description="Map of server names to their configurations",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Global defaults
|
||||||
|
default_timeout: int = Field(
|
||||||
|
default=30,
|
||||||
|
description="Default timeout for all servers",
|
||||||
|
)
|
||||||
|
default_retry_attempts: int = Field(
|
||||||
|
default=3,
|
||||||
|
description="Default retry attempts for all servers",
|
||||||
|
)
|
||||||
|
connection_pool_size: int = Field(
|
||||||
|
default=10,
|
||||||
|
ge=1,
|
||||||
|
le=100,
|
||||||
|
description="Maximum connections per server",
|
||||||
|
)
|
||||||
|
health_check_interval: int = Field(
|
||||||
|
default=30,
|
||||||
|
ge=5,
|
||||||
|
le=300,
|
||||||
|
description="Seconds between health checks",
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_yaml(cls, path: str | Path) -> "MCPConfig":
|
||||||
|
"""Load configuration from a YAML file."""
|
||||||
|
path = Path(path)
|
||||||
|
if not path.exists():
|
||||||
|
raise FileNotFoundError(f"MCP config file not found: {path}")
|
||||||
|
|
||||||
|
with path.open("r") as f:
|
||||||
|
data = yaml.safe_load(f)
|
||||||
|
|
||||||
|
if data is None:
|
||||||
|
data = {}
|
||||||
|
|
||||||
|
return cls.model_validate(data)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: dict[str, Any]) -> "MCPConfig":
|
||||||
|
"""Load configuration from a dictionary."""
|
||||||
|
return cls.model_validate(data)
|
||||||
|
|
||||||
|
def get_server(self, name: str) -> MCPServerConfig | None:
|
||||||
|
"""Get a server configuration by name."""
|
||||||
|
return self.mcp_servers.get(name)
|
||||||
|
|
||||||
|
def get_enabled_servers(self) -> dict[str, MCPServerConfig]:
|
||||||
|
"""Get all enabled server configurations."""
|
||||||
|
return {
|
||||||
|
name: config
|
||||||
|
for name, config in self.mcp_servers.items()
|
||||||
|
if config.enabled
|
||||||
|
}
|
||||||
|
|
||||||
|
def list_server_names(self) -> list[str]:
|
||||||
|
"""Get list of all configured server names."""
|
||||||
|
return list(self.mcp_servers.keys())
|
||||||
|
|
||||||
|
|
||||||
|
# Default configuration path
|
||||||
|
DEFAULT_CONFIG_PATH = Path(__file__).parent.parent.parent.parent / "mcp_servers.yaml"
|
||||||
|
|
||||||
|
|
||||||
|
def load_mcp_config(path: str | Path | None = None) -> MCPConfig:
|
||||||
|
"""
|
||||||
|
Load MCP configuration from file or environment.
|
||||||
|
|
||||||
|
Priority:
|
||||||
|
1. Explicit path parameter
|
||||||
|
2. MCP_CONFIG_PATH environment variable
|
||||||
|
3. Default path (backend/mcp_servers.yaml)
|
||||||
|
4. Empty config if no file exists
|
||||||
|
"""
|
||||||
|
if path is None:
|
||||||
|
path = os.environ.get("MCP_CONFIG_PATH", str(DEFAULT_CONFIG_PATH))
|
||||||
|
|
||||||
|
path = Path(path)
|
||||||
|
|
||||||
|
if not path.exists():
|
||||||
|
# Return empty config if no file exists (allows runtime registration)
|
||||||
|
return MCPConfig()
|
||||||
|
|
||||||
|
return MCPConfig.from_yaml(path)
|
||||||
|
|
||||||
|
|
||||||
|
def create_default_config() -> MCPConfig:
|
||||||
|
"""
|
||||||
|
Create a default MCP configuration with standard servers.
|
||||||
|
|
||||||
|
This is useful for development and as a template.
|
||||||
|
"""
|
||||||
|
return MCPConfig(
|
||||||
|
mcp_servers={
|
||||||
|
"llm-gateway": MCPServerConfig(
|
||||||
|
url="${LLM_GATEWAY_URL:-http://localhost:8001}",
|
||||||
|
transport=TransportType.HTTP,
|
||||||
|
timeout=60,
|
||||||
|
description="LLM Gateway for multi-provider AI interactions",
|
||||||
|
),
|
||||||
|
"knowledge-base": MCPServerConfig(
|
||||||
|
url="${KNOWLEDGE_BASE_URL:-http://localhost:8002}",
|
||||||
|
transport=TransportType.HTTP,
|
||||||
|
timeout=30,
|
||||||
|
description="Knowledge Base for RAG and document retrieval",
|
||||||
|
),
|
||||||
|
"git-ops": MCPServerConfig(
|
||||||
|
url="${GIT_OPS_URL:-http://localhost:8003}",
|
||||||
|
transport=TransportType.HTTP,
|
||||||
|
timeout=120,
|
||||||
|
description="Git Operations for repository management",
|
||||||
|
),
|
||||||
|
"issues": MCPServerConfig(
|
||||||
|
url="${ISSUES_URL:-http://localhost:8004}",
|
||||||
|
transport=TransportType.HTTP,
|
||||||
|
timeout=30,
|
||||||
|
description="Issue Tracker for Gitea/GitHub/GitLab",
|
||||||
|
),
|
||||||
|
},
|
||||||
|
default_timeout=30,
|
||||||
|
default_retry_attempts=3,
|
||||||
|
connection_pool_size=10,
|
||||||
|
health_check_interval=30,
|
||||||
|
)
|
||||||
435
backend/app/services/mcp/connection.py
Normal file
435
backend/app/services/mcp/connection.py
Normal file
@@ -0,0 +1,435 @@
|
|||||||
|
"""
|
||||||
|
MCP Connection Management
|
||||||
|
|
||||||
|
Handles connection lifecycle, pooling, and automatic reconnection
|
||||||
|
for MCP servers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from .config import MCPServerConfig, TransportType
|
||||||
|
from .exceptions import MCPConnectionError, MCPTimeoutError
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectionState(str, Enum):
|
||||||
|
"""Connection state enumeration."""
|
||||||
|
|
||||||
|
DISCONNECTED = "disconnected"
|
||||||
|
CONNECTING = "connecting"
|
||||||
|
CONNECTED = "connected"
|
||||||
|
RECONNECTING = "reconnecting"
|
||||||
|
ERROR = "error"
|
||||||
|
|
||||||
|
|
||||||
|
class MCPConnection:
|
||||||
|
"""
|
||||||
|
Manages a single connection to an MCP server.
|
||||||
|
|
||||||
|
Handles connection lifecycle, health checking, and automatic reconnection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
server_name: str,
|
||||||
|
config: MCPServerConfig,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize connection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server_name: Name of the MCP server
|
||||||
|
config: Server configuration
|
||||||
|
"""
|
||||||
|
self.server_name = server_name
|
||||||
|
self.config = config
|
||||||
|
self._state = ConnectionState.DISCONNECTED
|
||||||
|
self._client: httpx.AsyncClient | None = None
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
self._last_activity: float | None = None
|
||||||
|
self._connection_attempts = 0
|
||||||
|
self._last_error: Exception | None = None
|
||||||
|
|
||||||
|
# Reconnection settings
|
||||||
|
self._base_delay = config.retry_delay
|
||||||
|
self._max_delay = config.retry_max_delay
|
||||||
|
self._max_attempts = config.retry_attempts
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state(self) -> ConnectionState:
|
||||||
|
"""Get current connection state."""
|
||||||
|
return self._state
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_connected(self) -> bool:
|
||||||
|
"""Check if connection is established."""
|
||||||
|
return self._state == ConnectionState.CONNECTED
|
||||||
|
|
||||||
|
@property
|
||||||
|
def last_error(self) -> Exception | None:
|
||||||
|
"""Get the last error that occurred."""
|
||||||
|
return self._last_error
|
||||||
|
|
||||||
|
async def connect(self) -> None:
|
||||||
|
"""
|
||||||
|
Establish connection to the MCP server.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
MCPConnectionError: If connection fails after all retries
|
||||||
|
"""
|
||||||
|
async with self._lock:
|
||||||
|
if self._state == ConnectionState.CONNECTED:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._state = ConnectionState.CONNECTING
|
||||||
|
self._connection_attempts = 0
|
||||||
|
self._last_error = None
|
||||||
|
|
||||||
|
while self._connection_attempts < self._max_attempts:
|
||||||
|
try:
|
||||||
|
await self._do_connect()
|
||||||
|
self._state = ConnectionState.CONNECTED
|
||||||
|
self._last_activity = time.time()
|
||||||
|
logger.info(
|
||||||
|
"Connected to MCP server: %s at %s",
|
||||||
|
self.server_name,
|
||||||
|
self.config.url,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
self._connection_attempts += 1
|
||||||
|
self._last_error = e
|
||||||
|
logger.warning(
|
||||||
|
"Connection attempt %d/%d failed for %s: %s",
|
||||||
|
self._connection_attempts,
|
||||||
|
self._max_attempts,
|
||||||
|
self.server_name,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._connection_attempts < self._max_attempts:
|
||||||
|
delay = self._calculate_backoff_delay()
|
||||||
|
logger.debug(
|
||||||
|
"Retrying connection to %s in %.1fs",
|
||||||
|
self.server_name,
|
||||||
|
delay,
|
||||||
|
)
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
|
||||||
|
# All attempts failed
|
||||||
|
self._state = ConnectionState.ERROR
|
||||||
|
raise MCPConnectionError(
|
||||||
|
f"Failed to connect after {self._max_attempts} attempts",
|
||||||
|
server_name=self.server_name,
|
||||||
|
url=self.config.url,
|
||||||
|
cause=self._last_error,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _do_connect(self) -> None:
|
||||||
|
"""Perform the actual connection (transport-specific)."""
|
||||||
|
if self.config.transport == TransportType.HTTP:
|
||||||
|
self._client = httpx.AsyncClient(
|
||||||
|
base_url=self.config.url,
|
||||||
|
timeout=httpx.Timeout(self.config.timeout),
|
||||||
|
headers={
|
||||||
|
"User-Agent": "Syndarix-MCP-Client/1.0",
|
||||||
|
"Accept": "application/json",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# Verify connectivity with a simple request
|
||||||
|
try:
|
||||||
|
# Try to hit the MCP capabilities endpoint
|
||||||
|
response = await self._client.get("/mcp/capabilities")
|
||||||
|
if response.status_code not in (200, 404):
|
||||||
|
# 404 is acceptable - server might not have capabilities endpoint
|
||||||
|
response.raise_for_status()
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
if e.response.status_code != 404:
|
||||||
|
raise
|
||||||
|
except httpx.ConnectError as e:
|
||||||
|
raise MCPConnectionError(
|
||||||
|
"Failed to connect to server",
|
||||||
|
server_name=self.server_name,
|
||||||
|
url=self.config.url,
|
||||||
|
cause=e,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# For STDIO and SSE transports, we'll implement later
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Transport {self.config.transport} not yet implemented"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _calculate_backoff_delay(self) -> float:
|
||||||
|
"""Calculate exponential backoff delay with jitter."""
|
||||||
|
import random
|
||||||
|
|
||||||
|
delay = self._base_delay * (2 ** (self._connection_attempts - 1))
|
||||||
|
delay = min(delay, self._max_delay)
|
||||||
|
# Add jitter (±25%)
|
||||||
|
jitter = delay * 0.25 * (random.random() * 2 - 1)
|
||||||
|
return delay + jitter
|
||||||
|
|
||||||
|
async def disconnect(self) -> None:
|
||||||
|
"""Disconnect from the MCP server."""
|
||||||
|
async with self._lock:
|
||||||
|
if self._client is not None:
|
||||||
|
try:
|
||||||
|
await self._client.aclose()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
"Error closing connection to %s: %s",
|
||||||
|
self.server_name,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
self._client = None
|
||||||
|
|
||||||
|
self._state = ConnectionState.DISCONNECTED
|
||||||
|
logger.info("Disconnected from MCP server: %s", self.server_name)
|
||||||
|
|
||||||
|
async def reconnect(self) -> None:
|
||||||
|
"""Reconnect to the MCP server."""
|
||||||
|
async with self._lock:
|
||||||
|
self._state = ConnectionState.RECONNECTING
|
||||||
|
await self.disconnect()
|
||||||
|
await self.connect()
|
||||||
|
|
||||||
|
async def health_check(self) -> bool:
|
||||||
|
"""
|
||||||
|
Perform a health check on the connection.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if connection is healthy
|
||||||
|
"""
|
||||||
|
if not self.is_connected or self._client is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
if self.config.transport == TransportType.HTTP:
|
||||||
|
response = await self._client.get(
|
||||||
|
"/health",
|
||||||
|
timeout=5.0,
|
||||||
|
)
|
||||||
|
return response.status_code == 200
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
"Health check failed for %s: %s",
|
||||||
|
self.server_name,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def execute_request(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
path: str,
|
||||||
|
data: dict[str, Any] | None = None,
|
||||||
|
timeout: float | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Execute an HTTP request to the MCP server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
method: HTTP method (GET, POST, etc.)
|
||||||
|
path: Request path
|
||||||
|
data: Optional request body
|
||||||
|
timeout: Optional timeout override
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response data
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
MCPConnectionError: If not connected
|
||||||
|
MCPTimeoutError: If request times out
|
||||||
|
"""
|
||||||
|
if not self.is_connected or self._client is None:
|
||||||
|
raise MCPConnectionError(
|
||||||
|
"Not connected to server",
|
||||||
|
server_name=self.server_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
effective_timeout = timeout or self.config.timeout
|
||||||
|
|
||||||
|
try:
|
||||||
|
if method.upper() == "GET":
|
||||||
|
response = await self._client.get(
|
||||||
|
path,
|
||||||
|
timeout=effective_timeout,
|
||||||
|
)
|
||||||
|
elif method.upper() == "POST":
|
||||||
|
response = await self._client.post(
|
||||||
|
path,
|
||||||
|
json=data,
|
||||||
|
timeout=effective_timeout,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = await self._client.request(
|
||||||
|
method.upper(),
|
||||||
|
path,
|
||||||
|
json=data,
|
||||||
|
timeout=effective_timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._last_activity = time.time()
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
except httpx.TimeoutException as e:
|
||||||
|
raise MCPTimeoutError(
|
||||||
|
"Request timed out",
|
||||||
|
server_name=self.server_name,
|
||||||
|
timeout_seconds=effective_timeout,
|
||||||
|
operation=f"{method} {path}",
|
||||||
|
) from e
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
raise MCPConnectionError(
|
||||||
|
f"HTTP error: {e.response.status_code}",
|
||||||
|
server_name=self.server_name,
|
||||||
|
url=f"{self.config.url}{path}",
|
||||||
|
cause=e,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise MCPConnectionError(
|
||||||
|
f"Request failed: {e}",
|
||||||
|
server_name=self.server_name,
|
||||||
|
cause=e,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectionPool:
|
||||||
|
"""
|
||||||
|
Pool of connections to MCP servers.
|
||||||
|
|
||||||
|
Manages connection lifecycle and provides connection reuse.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, max_connections_per_server: int = 10) -> None:
|
||||||
|
"""
|
||||||
|
Initialize connection pool.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_connections_per_server: Maximum connections per server
|
||||||
|
"""
|
||||||
|
self._connections: dict[str, MCPConnection] = {}
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
self._max_per_server = max_connections_per_server
|
||||||
|
|
||||||
|
async def get_connection(
|
||||||
|
self,
|
||||||
|
server_name: str,
|
||||||
|
config: MCPServerConfig,
|
||||||
|
) -> MCPConnection:
|
||||||
|
"""
|
||||||
|
Get or create a connection to a server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server_name: Name of the server
|
||||||
|
config: Server configuration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Active connection
|
||||||
|
"""
|
||||||
|
async with self._lock:
|
||||||
|
if server_name not in self._connections:
|
||||||
|
connection = MCPConnection(server_name, config)
|
||||||
|
await connection.connect()
|
||||||
|
self._connections[server_name] = connection
|
||||||
|
|
||||||
|
connection = self._connections[server_name]
|
||||||
|
|
||||||
|
# Reconnect if not connected
|
||||||
|
if not connection.is_connected:
|
||||||
|
await connection.connect()
|
||||||
|
|
||||||
|
return connection
|
||||||
|
|
||||||
|
async def release_connection(self, server_name: str) -> None:
|
||||||
|
"""
|
||||||
|
Release a connection (currently just tracks usage).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server_name: Name of the server
|
||||||
|
"""
|
||||||
|
# For now, we keep connections alive
|
||||||
|
# Future: implement connection reaping for idle connections
|
||||||
|
|
||||||
|
async def close_connection(self, server_name: str) -> None:
|
||||||
|
"""
|
||||||
|
Close and remove a connection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server_name: Name of the server
|
||||||
|
"""
|
||||||
|
async with self._lock:
|
||||||
|
if server_name in self._connections:
|
||||||
|
await self._connections[server_name].disconnect()
|
||||||
|
del self._connections[server_name]
|
||||||
|
|
||||||
|
async def close_all(self) -> None:
|
||||||
|
"""Close all connections in the pool."""
|
||||||
|
async with self._lock:
|
||||||
|
for connection in self._connections.values():
|
||||||
|
try:
|
||||||
|
await connection.disconnect()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Error closing connection: %s", e)
|
||||||
|
|
||||||
|
self._connections.clear()
|
||||||
|
logger.info("Closed all MCP connections")
|
||||||
|
|
||||||
|
async def health_check_all(self) -> dict[str, bool]:
|
||||||
|
"""
|
||||||
|
Perform health check on all connections.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict mapping server names to health status
|
||||||
|
"""
|
||||||
|
results = {}
|
||||||
|
for name, connection in self._connections.items():
|
||||||
|
results[name] = await connection.health_check()
|
||||||
|
return results
|
||||||
|
|
||||||
|
def get_status(self) -> dict[str, dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get status of all connections.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict mapping server names to status info
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
name: {
|
||||||
|
"state": conn.state.value,
|
||||||
|
"is_connected": conn.is_connected,
|
||||||
|
"url": conn.config.url,
|
||||||
|
}
|
||||||
|
for name, conn in self._connections.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def connection(
|
||||||
|
self,
|
||||||
|
server_name: str,
|
||||||
|
config: MCPServerConfig,
|
||||||
|
) -> AsyncGenerator[MCPConnection, None]:
|
||||||
|
"""
|
||||||
|
Context manager for getting a connection.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
async with pool.connection("server", config) as conn:
|
||||||
|
result = await conn.execute_request("POST", "/tool", data)
|
||||||
|
"""
|
||||||
|
conn = await self.get_connection(server_name, config)
|
||||||
|
try:
|
||||||
|
yield conn
|
||||||
|
finally:
|
||||||
|
await self.release_connection(server_name)
|
||||||
201
backend/app/services/mcp/exceptions.py
Normal file
201
backend/app/services/mcp/exceptions.py
Normal file
@@ -0,0 +1,201 @@
|
|||||||
|
"""
|
||||||
|
MCP Exception Classes
|
||||||
|
|
||||||
|
Custom exceptions for MCP client operations with detailed error context.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
class MCPError(Exception):
|
||||||
|
"""Base exception for all MCP-related errors."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
*,
|
||||||
|
server_name: str | None = None,
|
||||||
|
details: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(message)
|
||||||
|
self.message = message
|
||||||
|
self.server_name = server_name
|
||||||
|
self.details = details or {}
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
parts = [self.message]
|
||||||
|
if self.server_name:
|
||||||
|
parts.append(f"server={self.server_name}")
|
||||||
|
if self.details:
|
||||||
|
parts.append(f"details={self.details}")
|
||||||
|
return " | ".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
class MCPConnectionError(MCPError):
|
||||||
|
"""Raised when connection to an MCP server fails."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
*,
|
||||||
|
server_name: str | None = None,
|
||||||
|
url: str | None = None,
|
||||||
|
cause: Exception | None = None,
|
||||||
|
details: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(message, server_name=server_name, details=details)
|
||||||
|
self.url = url
|
||||||
|
self.cause = cause
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
base = super().__str__()
|
||||||
|
if self.url:
|
||||||
|
base = f"{base} | url={self.url}"
|
||||||
|
if self.cause:
|
||||||
|
base = f"{base} | cause={type(self.cause).__name__}: {self.cause}"
|
||||||
|
return base
|
||||||
|
|
||||||
|
|
||||||
|
class MCPTimeoutError(MCPError):
|
||||||
|
"""Raised when an MCP operation times out."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
*,
|
||||||
|
server_name: str | None = None,
|
||||||
|
timeout_seconds: float | None = None,
|
||||||
|
operation: str | None = None,
|
||||||
|
details: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(message, server_name=server_name, details=details)
|
||||||
|
self.timeout_seconds = timeout_seconds
|
||||||
|
self.operation = operation
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
base = super().__str__()
|
||||||
|
if self.timeout_seconds is not None:
|
||||||
|
base = f"{base} | timeout={self.timeout_seconds}s"
|
||||||
|
if self.operation:
|
||||||
|
base = f"{base} | operation={self.operation}"
|
||||||
|
return base
|
||||||
|
|
||||||
|
|
||||||
|
class MCPToolError(MCPError):
|
||||||
|
"""Raised when a tool execution fails."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
*,
|
||||||
|
server_name: str | None = None,
|
||||||
|
tool_name: str | None = None,
|
||||||
|
tool_args: dict[str, Any] | None = None,
|
||||||
|
error_code: str | None = None,
|
||||||
|
details: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(message, server_name=server_name, details=details)
|
||||||
|
self.tool_name = tool_name
|
||||||
|
self.tool_args = tool_args
|
||||||
|
self.error_code = error_code
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
base = super().__str__()
|
||||||
|
if self.tool_name:
|
||||||
|
base = f"{base} | tool={self.tool_name}"
|
||||||
|
if self.error_code:
|
||||||
|
base = f"{base} | error_code={self.error_code}"
|
||||||
|
return base
|
||||||
|
|
||||||
|
|
||||||
|
class MCPServerNotFoundError(MCPError):
|
||||||
|
"""Raised when a requested MCP server is not registered."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
server_name: str,
|
||||||
|
*,
|
||||||
|
available_servers: list[str] | None = None,
|
||||||
|
details: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
message = f"MCP server not found: {server_name}"
|
||||||
|
super().__init__(message, server_name=server_name, details=details)
|
||||||
|
self.available_servers = available_servers or []
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
base = super().__str__()
|
||||||
|
if self.available_servers:
|
||||||
|
base = f"{base} | available={self.available_servers}"
|
||||||
|
return base
|
||||||
|
|
||||||
|
|
||||||
|
class MCPToolNotFoundError(MCPError):
|
||||||
|
"""Raised when a requested tool is not found on any server."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tool_name: str,
|
||||||
|
*,
|
||||||
|
server_name: str | None = None,
|
||||||
|
available_tools: list[str] | None = None,
|
||||||
|
details: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
message = f"Tool not found: {tool_name}"
|
||||||
|
super().__init__(message, server_name=server_name, details=details)
|
||||||
|
self.tool_name = tool_name
|
||||||
|
self.available_tools = available_tools or []
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
base = super().__str__()
|
||||||
|
if self.available_tools:
|
||||||
|
base = f"{base} | available_tools={self.available_tools[:5]}..."
|
||||||
|
return base
|
||||||
|
|
||||||
|
|
||||||
|
class MCPCircuitOpenError(MCPError):
|
||||||
|
"""Raised when a circuit breaker is open (server temporarily unavailable)."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
server_name: str,
|
||||||
|
*,
|
||||||
|
failure_count: int | None = None,
|
||||||
|
reset_timeout: float | None = None,
|
||||||
|
details: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
message = f"Circuit breaker open for server: {server_name}"
|
||||||
|
super().__init__(message, server_name=server_name, details=details)
|
||||||
|
self.failure_count = failure_count
|
||||||
|
self.reset_timeout = reset_timeout
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
base = super().__str__()
|
||||||
|
if self.failure_count is not None:
|
||||||
|
base = f"{base} | failures={self.failure_count}"
|
||||||
|
if self.reset_timeout is not None:
|
||||||
|
base = f"{base} | reset_in={self.reset_timeout}s"
|
||||||
|
return base
|
||||||
|
|
||||||
|
|
||||||
|
class MCPValidationError(MCPError):
|
||||||
|
"""Raised when tool arguments fail validation."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
*,
|
||||||
|
tool_name: str | None = None,
|
||||||
|
field_errors: dict[str, str] | None = None,
|
||||||
|
details: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(message, details=details)
|
||||||
|
self.tool_name = tool_name
|
||||||
|
self.field_errors = field_errors or {}
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
base = super().__str__()
|
||||||
|
if self.tool_name:
|
||||||
|
base = f"{base} | tool={self.tool_name}"
|
||||||
|
if self.field_errors:
|
||||||
|
base = f"{base} | fields={list(self.field_errors.keys())}"
|
||||||
|
return base
|
||||||
305
backend/app/services/mcp/registry.py
Normal file
305
backend/app/services/mcp/registry.py
Normal file
@@ -0,0 +1,305 @@
|
|||||||
|
"""
|
||||||
|
MCP Server Registry
|
||||||
|
|
||||||
|
Thread-safe singleton registry for managing MCP server configurations
|
||||||
|
and their capabilities.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from threading import Lock
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .config import MCPConfig, MCPServerConfig, load_mcp_config
|
||||||
|
from .exceptions import MCPServerNotFoundError
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ServerCapabilities:
|
||||||
|
"""Cached capabilities for an MCP server."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tools: list[dict[str, Any]] | None = None,
|
||||||
|
resources: list[dict[str, Any]] | None = None,
|
||||||
|
prompts: list[dict[str, Any]] | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.tools = tools or []
|
||||||
|
self.resources = resources or []
|
||||||
|
self.prompts = prompts or []
|
||||||
|
self._loaded = False
|
||||||
|
self._load_time: float | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_loaded(self) -> bool:
|
||||||
|
"""Check if capabilities have been loaded."""
|
||||||
|
return self._loaded
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tool_names(self) -> list[str]:
|
||||||
|
"""Get list of tool names."""
|
||||||
|
return [t.get("name", "") for t in self.tools if t.get("name")]
|
||||||
|
|
||||||
|
def mark_loaded(self) -> None:
|
||||||
|
"""Mark capabilities as loaded."""
|
||||||
|
import time
|
||||||
|
|
||||||
|
self._loaded = True
|
||||||
|
self._load_time = time.time()
|
||||||
|
|
||||||
|
|
||||||
|
class MCPServerRegistry:
|
||||||
|
"""
|
||||||
|
Thread-safe singleton registry for MCP servers.
|
||||||
|
|
||||||
|
Manages server configurations and caches their capabilities.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_instance: "MCPServerRegistry | None" = None
|
||||||
|
_lock = Lock()
|
||||||
|
|
||||||
|
def __new__(cls) -> "MCPServerRegistry":
|
||||||
|
"""Ensure singleton pattern."""
|
||||||
|
with cls._lock:
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super().__new__(cls)
|
||||||
|
cls._instance._initialized = False
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""Initialize registry (only runs once due to singleton)."""
|
||||||
|
if getattr(self, "_initialized", False):
|
||||||
|
return
|
||||||
|
|
||||||
|
self._config: MCPConfig = MCPConfig()
|
||||||
|
self._capabilities: dict[str, ServerCapabilities] = {}
|
||||||
|
self._capabilities_lock = asyncio.Lock()
|
||||||
|
self._initialized = True
|
||||||
|
|
||||||
|
logger.info("MCP Server Registry initialized")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_instance(cls) -> "MCPServerRegistry":
|
||||||
|
"""Get the singleton registry instance."""
|
||||||
|
return cls()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def reset_instance(cls) -> None:
|
||||||
|
"""Reset the singleton (for testing)."""
|
||||||
|
with cls._lock:
|
||||||
|
cls._instance = None
|
||||||
|
|
||||||
|
def load_config(self, config: MCPConfig | None = None) -> None:
|
||||||
|
"""
|
||||||
|
Load configuration into the registry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Optional config to load. If None, loads from default path.
|
||||||
|
"""
|
||||||
|
if config is None:
|
||||||
|
config = load_mcp_config()
|
||||||
|
|
||||||
|
self._config = config
|
||||||
|
self._capabilities.clear()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Loaded MCP configuration with %d servers",
|
||||||
|
len(config.mcp_servers),
|
||||||
|
)
|
||||||
|
for name in config.list_server_names():
|
||||||
|
logger.debug("Registered MCP server: %s", name)
|
||||||
|
|
||||||
|
def register(self, name: str, config: MCPServerConfig) -> None:
|
||||||
|
"""
|
||||||
|
Register a new MCP server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Unique server name
|
||||||
|
config: Server configuration
|
||||||
|
"""
|
||||||
|
self._config.mcp_servers[name] = config
|
||||||
|
self._capabilities.pop(name, None) # Clear any cached capabilities
|
||||||
|
|
||||||
|
logger.info("Registered MCP server: %s at %s", name, config.url)
|
||||||
|
|
||||||
|
def unregister(self, name: str) -> bool:
|
||||||
|
"""
|
||||||
|
Unregister an MCP server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Server name to unregister
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if server was found and removed
|
||||||
|
"""
|
||||||
|
if name in self._config.mcp_servers:
|
||||||
|
del self._config.mcp_servers[name]
|
||||||
|
self._capabilities.pop(name, None)
|
||||||
|
logger.info("Unregistered MCP server: %s", name)
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get(self, name: str) -> MCPServerConfig:
|
||||||
|
"""
|
||||||
|
Get a server configuration by name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Server name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Server configuration
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
MCPServerNotFoundError: If server is not registered
|
||||||
|
"""
|
||||||
|
config = self._config.get_server(name)
|
||||||
|
if config is None:
|
||||||
|
raise MCPServerNotFoundError(
|
||||||
|
server_name=name,
|
||||||
|
available_servers=self.list_servers(),
|
||||||
|
)
|
||||||
|
return config
|
||||||
|
|
||||||
|
def get_or_none(self, name: str) -> MCPServerConfig | None:
|
||||||
|
"""
|
||||||
|
Get a server configuration by name, or None if not found.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Server name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Server configuration or None
|
||||||
|
"""
|
||||||
|
return self._config.get_server(name)
|
||||||
|
|
||||||
|
def list_servers(self) -> list[str]:
|
||||||
|
"""Get list of all registered server names."""
|
||||||
|
return self._config.list_server_names()
|
||||||
|
|
||||||
|
def list_enabled_servers(self) -> list[str]:
|
||||||
|
"""Get list of enabled server names."""
|
||||||
|
return list(self._config.get_enabled_servers().keys())
|
||||||
|
|
||||||
|
def get_all_configs(self) -> dict[str, MCPServerConfig]:
|
||||||
|
"""Get all server configurations."""
|
||||||
|
return dict(self._config.mcp_servers)
|
||||||
|
|
||||||
|
def get_enabled_configs(self) -> dict[str, MCPServerConfig]:
|
||||||
|
"""Get all enabled server configurations."""
|
||||||
|
return self._config.get_enabled_servers()
|
||||||
|
|
||||||
|
async def get_capabilities(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
force_refresh: bool = False,
|
||||||
|
) -> ServerCapabilities:
|
||||||
|
"""
|
||||||
|
Get capabilities for a server (lazy-loaded and cached).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Server name
|
||||||
|
force_refresh: If True, refresh cached capabilities
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Server capabilities
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
MCPServerNotFoundError: If server is not registered
|
||||||
|
"""
|
||||||
|
# Verify server exists
|
||||||
|
self.get(name)
|
||||||
|
|
||||||
|
async with self._capabilities_lock:
|
||||||
|
if name not in self._capabilities or force_refresh:
|
||||||
|
# Will be populated by connection manager when connecting
|
||||||
|
self._capabilities[name] = ServerCapabilities()
|
||||||
|
|
||||||
|
return self._capabilities[name]
|
||||||
|
|
||||||
|
def set_capabilities(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
tools: list[dict[str, Any]] | None = None,
|
||||||
|
resources: list[dict[str, Any]] | None = None,
|
||||||
|
prompts: list[dict[str, Any]] | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Set capabilities for a server (called by connection manager).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Server name
|
||||||
|
tools: List of tool definitions
|
||||||
|
resources: List of resource definitions
|
||||||
|
prompts: List of prompt definitions
|
||||||
|
"""
|
||||||
|
capabilities = ServerCapabilities(
|
||||||
|
tools=tools,
|
||||||
|
resources=resources,
|
||||||
|
prompts=prompts,
|
||||||
|
)
|
||||||
|
capabilities.mark_loaded()
|
||||||
|
self._capabilities[name] = capabilities
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"Updated capabilities for %s: %d tools, %d resources, %d prompts",
|
||||||
|
name,
|
||||||
|
len(capabilities.tools),
|
||||||
|
len(capabilities.resources),
|
||||||
|
len(capabilities.prompts),
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_cached_capabilities(self, name: str) -> ServerCapabilities:
|
||||||
|
"""
|
||||||
|
Get cached capabilities without async loading.
|
||||||
|
|
||||||
|
Use this for synchronous access when you only need
|
||||||
|
cached values (e.g., for health check responses).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Server name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cached capabilities or empty ServerCapabilities
|
||||||
|
"""
|
||||||
|
return self._capabilities.get(name, ServerCapabilities())
|
||||||
|
|
||||||
|
def find_server_for_tool(self, tool_name: str) -> str | None:
|
||||||
|
"""
|
||||||
|
Find which server provides a specific tool.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: Name of the tool to find
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Server name or None if not found
|
||||||
|
"""
|
||||||
|
for name, caps in self._capabilities.items():
|
||||||
|
if tool_name in caps.tool_names:
|
||||||
|
return name
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_all_tools(self) -> dict[str, list[dict[str, Any]]]:
|
||||||
|
"""
|
||||||
|
Get all tools from all servers.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict mapping server name to list of tool definitions
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
name: caps.tools
|
||||||
|
for name, caps in self._capabilities.items()
|
||||||
|
if caps.is_loaded
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def global_config(self) -> MCPConfig:
|
||||||
|
"""Get the global MCP configuration."""
|
||||||
|
return self._config
|
||||||
|
|
||||||
|
|
||||||
|
# Module-level convenience function
|
||||||
|
def get_registry() -> MCPServerRegistry:
|
||||||
|
"""Get the global MCP server registry instance."""
|
||||||
|
return MCPServerRegistry.get_instance()
|
||||||
619
backend/app/services/mcp/routing.py
Normal file
619
backend/app/services/mcp/routing.py
Normal file
@@ -0,0 +1,619 @@
|
|||||||
|
"""
|
||||||
|
MCP Tool Call Routing
|
||||||
|
|
||||||
|
Routes tool calls to appropriate servers with retry logic,
|
||||||
|
circuit breakers, and request/response serialization.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .config import MCPServerConfig
|
||||||
|
from .connection import ConnectionPool, MCPConnection
|
||||||
|
from .exceptions import (
|
||||||
|
MCPCircuitOpenError,
|
||||||
|
MCPError,
|
||||||
|
MCPTimeoutError,
|
||||||
|
MCPToolError,
|
||||||
|
MCPToolNotFoundError,
|
||||||
|
)
|
||||||
|
from .registry import MCPServerRegistry
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CircuitState(Enum):
|
||||||
|
"""Circuit breaker states."""
|
||||||
|
|
||||||
|
CLOSED = "closed"
|
||||||
|
OPEN = "open"
|
||||||
|
HALF_OPEN = "half-open"
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncCircuitBreaker:
|
||||||
|
"""
|
||||||
|
Async-compatible circuit breaker implementation.
|
||||||
|
|
||||||
|
Unlike pybreaker which wraps sync functions, this implementation
|
||||||
|
provides explicit success/failure tracking for async code.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
fail_max: int = 5,
|
||||||
|
reset_timeout: float = 30.0,
|
||||||
|
name: str = "",
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize circuit breaker.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fail_max: Maximum failures before opening circuit
|
||||||
|
reset_timeout: Seconds to wait before trying again
|
||||||
|
name: Name for logging
|
||||||
|
"""
|
||||||
|
self.fail_max = fail_max
|
||||||
|
self.reset_timeout = reset_timeout
|
||||||
|
self.name = name
|
||||||
|
self._state = CircuitState.CLOSED
|
||||||
|
self._fail_counter = 0
|
||||||
|
self._last_failure_time: float | None = None
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_state(self) -> str:
|
||||||
|
"""Get current state as string."""
|
||||||
|
# Check if we should transition from OPEN to HALF_OPEN
|
||||||
|
if self._state == CircuitState.OPEN:
|
||||||
|
if self._should_try_reset():
|
||||||
|
return CircuitState.HALF_OPEN.value
|
||||||
|
return self._state.value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def fail_counter(self) -> int:
|
||||||
|
"""Get current failure count."""
|
||||||
|
return self._fail_counter
|
||||||
|
|
||||||
|
def _should_try_reset(self) -> bool:
|
||||||
|
"""Check if enough time has passed to try resetting."""
|
||||||
|
if self._last_failure_time is None:
|
||||||
|
return True
|
||||||
|
return (time.time() - self._last_failure_time) >= self.reset_timeout
|
||||||
|
|
||||||
|
async def success(self) -> None:
|
||||||
|
"""Record a successful call."""
|
||||||
|
async with self._lock:
|
||||||
|
self._fail_counter = 0
|
||||||
|
self._state = CircuitState.CLOSED
|
||||||
|
self._last_failure_time = None
|
||||||
|
|
||||||
|
async def failure(self) -> None:
|
||||||
|
"""Record a failed call."""
|
||||||
|
async with self._lock:
|
||||||
|
self._fail_counter += 1
|
||||||
|
self._last_failure_time = time.time()
|
||||||
|
|
||||||
|
if self._fail_counter >= self.fail_max:
|
||||||
|
self._state = CircuitState.OPEN
|
||||||
|
logger.warning(
|
||||||
|
"Circuit breaker %s opened after %d failures",
|
||||||
|
self.name,
|
||||||
|
self._fail_counter,
|
||||||
|
)
|
||||||
|
|
||||||
|
def is_open(self) -> bool:
|
||||||
|
"""Check if circuit is open (not allowing calls)."""
|
||||||
|
if self._state == CircuitState.OPEN:
|
||||||
|
return not self._should_try_reset()
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def reset(self) -> None:
|
||||||
|
"""Manually reset the circuit breaker."""
|
||||||
|
async with self._lock:
|
||||||
|
self._state = CircuitState.CLOSED
|
||||||
|
self._fail_counter = 0
|
||||||
|
self._last_failure_time = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolInfo:
|
||||||
|
"""Information about an available tool."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
description: str | None = None
|
||||||
|
server_name: str | None = None
|
||||||
|
input_schema: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert to dictionary."""
|
||||||
|
return {
|
||||||
|
"name": self.name,
|
||||||
|
"description": self.description,
|
||||||
|
"server_name": self.server_name,
|
||||||
|
"input_schema": self.input_schema,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolResult:
|
||||||
|
"""Result of a tool execution."""
|
||||||
|
|
||||||
|
success: bool
|
||||||
|
data: Any = None
|
||||||
|
error: str | None = None
|
||||||
|
error_code: str | None = None
|
||||||
|
tool_name: str | None = None
|
||||||
|
server_name: str | None = None
|
||||||
|
execution_time_ms: float = 0.0
|
||||||
|
request_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert to dictionary."""
|
||||||
|
return {
|
||||||
|
"success": self.success,
|
||||||
|
"data": self.data,
|
||||||
|
"error": self.error,
|
||||||
|
"error_code": self.error_code,
|
||||||
|
"tool_name": self.tool_name,
|
||||||
|
"server_name": self.server_name,
|
||||||
|
"execution_time_ms": self.execution_time_ms,
|
||||||
|
"request_id": self.request_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ToolRouter:
|
||||||
|
"""
|
||||||
|
Routes tool calls to the appropriate MCP server.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Tool name to server mapping
|
||||||
|
- Retry logic with exponential backoff
|
||||||
|
- Circuit breaker pattern for fault tolerance
|
||||||
|
- Request/response serialization
|
||||||
|
- Execution timing and metrics
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
registry: MCPServerRegistry,
|
||||||
|
connection_pool: ConnectionPool,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize the tool router.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
registry: MCP server registry
|
||||||
|
connection_pool: Connection pool for servers
|
||||||
|
"""
|
||||||
|
self._registry = registry
|
||||||
|
self._pool = connection_pool
|
||||||
|
self._circuit_breakers: dict[str, AsyncCircuitBreaker] = {}
|
||||||
|
self._tool_to_server: dict[str, str] = {}
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
def _get_circuit_breaker(
|
||||||
|
self,
|
||||||
|
server_name: str,
|
||||||
|
config: MCPServerConfig,
|
||||||
|
) -> AsyncCircuitBreaker:
|
||||||
|
"""Get or create a circuit breaker for a server."""
|
||||||
|
if server_name not in self._circuit_breakers:
|
||||||
|
self._circuit_breakers[server_name] = AsyncCircuitBreaker(
|
||||||
|
fail_max=config.circuit_breaker_threshold,
|
||||||
|
reset_timeout=config.circuit_breaker_timeout,
|
||||||
|
name=f"mcp-{server_name}",
|
||||||
|
)
|
||||||
|
return self._circuit_breakers[server_name]
|
||||||
|
|
||||||
|
async def register_tool_mapping(
|
||||||
|
self,
|
||||||
|
tool_name: str,
|
||||||
|
server_name: str,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Register a mapping from tool name to server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: Name of the tool
|
||||||
|
server_name: Name of the server providing the tool
|
||||||
|
"""
|
||||||
|
async with self._lock:
|
||||||
|
self._tool_to_server[tool_name] = server_name
|
||||||
|
logger.debug("Registered tool %s -> server %s", tool_name, server_name)
|
||||||
|
|
||||||
|
async def discover_tools(self) -> None:
|
||||||
|
"""
|
||||||
|
Discover all tools from registered servers and build mappings.
|
||||||
|
"""
|
||||||
|
for server_name in self._registry.list_enabled_servers():
|
||||||
|
try:
|
||||||
|
config = self._registry.get(server_name)
|
||||||
|
connection = await self._pool.get_connection(server_name, config)
|
||||||
|
|
||||||
|
# Fetch tools from server
|
||||||
|
tools = await self._fetch_tools_from_server(connection)
|
||||||
|
|
||||||
|
# Update registry with capabilities
|
||||||
|
self._registry.set_capabilities(
|
||||||
|
server_name,
|
||||||
|
tools=[t.to_dict() for t in tools],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update tool mappings
|
||||||
|
for tool in tools:
|
||||||
|
await self.register_tool_mapping(tool.name, server_name)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Discovered %d tools from server %s",
|
||||||
|
len(tools),
|
||||||
|
server_name,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to discover tools from %s: %s",
|
||||||
|
server_name,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _fetch_tools_from_server(
|
||||||
|
self,
|
||||||
|
connection: MCPConnection,
|
||||||
|
) -> list[ToolInfo]:
|
||||||
|
"""Fetch available tools from an MCP server."""
|
||||||
|
try:
|
||||||
|
response = await connection.execute_request(
|
||||||
|
"GET",
|
||||||
|
"/mcp/tools",
|
||||||
|
)
|
||||||
|
|
||||||
|
tools = []
|
||||||
|
for tool_data in response.get("tools", []):
|
||||||
|
tools.append(
|
||||||
|
ToolInfo(
|
||||||
|
name=tool_data.get("name", ""),
|
||||||
|
description=tool_data.get("description"),
|
||||||
|
server_name=connection.server_name,
|
||||||
|
input_schema=tool_data.get("inputSchema"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return tools
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
"Error fetching tools from %s: %s",
|
||||||
|
connection.server_name,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
|
||||||
|
def find_server_for_tool(self, tool_name: str) -> str | None:
|
||||||
|
"""
|
||||||
|
Find which server provides a specific tool.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: Name of the tool
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Server name or None if not found
|
||||||
|
"""
|
||||||
|
return self._tool_to_server.get(tool_name)
|
||||||
|
|
||||||
|
async def call_tool(
|
||||||
|
self,
|
||||||
|
server_name: str,
|
||||||
|
tool_name: str,
|
||||||
|
arguments: dict[str, Any] | None = None,
|
||||||
|
timeout: float | None = None,
|
||||||
|
) -> ToolResult:
|
||||||
|
"""
|
||||||
|
Call a tool on a specific server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server_name: Name of the MCP server
|
||||||
|
tool_name: Name of the tool to call
|
||||||
|
arguments: Tool arguments
|
||||||
|
timeout: Optional timeout override
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tool execution result
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
request_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"Tool call [%s]: %s.%s with args %s",
|
||||||
|
request_id,
|
||||||
|
server_name,
|
||||||
|
tool_name,
|
||||||
|
arguments,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
config = self._registry.get(server_name)
|
||||||
|
circuit_breaker = self._get_circuit_breaker(server_name, config)
|
||||||
|
|
||||||
|
# Check circuit breaker state
|
||||||
|
if circuit_breaker.is_open():
|
||||||
|
raise MCPCircuitOpenError(
|
||||||
|
server_name=server_name,
|
||||||
|
failure_count=circuit_breaker.fail_counter,
|
||||||
|
reset_timeout=config.circuit_breaker_timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute with retry logic
|
||||||
|
result = await self._execute_with_retry(
|
||||||
|
server_name=server_name,
|
||||||
|
config=config,
|
||||||
|
tool_name=tool_name,
|
||||||
|
arguments=arguments or {},
|
||||||
|
timeout=timeout,
|
||||||
|
circuit_breaker=circuit_breaker,
|
||||||
|
)
|
||||||
|
|
||||||
|
execution_time = (time.time() - start_time) * 1000
|
||||||
|
|
||||||
|
return ToolResult(
|
||||||
|
success=True,
|
||||||
|
data=result,
|
||||||
|
tool_name=tool_name,
|
||||||
|
server_name=server_name,
|
||||||
|
execution_time_ms=execution_time,
|
||||||
|
request_id=request_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
except MCPCircuitOpenError:
|
||||||
|
raise
|
||||||
|
except MCPError as e:
|
||||||
|
execution_time = (time.time() - start_time) * 1000
|
||||||
|
logger.error(
|
||||||
|
"Tool call failed [%s]: %s.%s - %s",
|
||||||
|
request_id,
|
||||||
|
server_name,
|
||||||
|
tool_name,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
return ToolResult(
|
||||||
|
success=False,
|
||||||
|
error=str(e),
|
||||||
|
error_code=type(e).__name__,
|
||||||
|
tool_name=tool_name,
|
||||||
|
server_name=server_name,
|
||||||
|
execution_time_ms=execution_time,
|
||||||
|
request_id=request_id,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
execution_time = (time.time() - start_time) * 1000
|
||||||
|
logger.exception(
|
||||||
|
"Unexpected error in tool call [%s]: %s.%s",
|
||||||
|
request_id,
|
||||||
|
server_name,
|
||||||
|
tool_name,
|
||||||
|
)
|
||||||
|
return ToolResult(
|
||||||
|
success=False,
|
||||||
|
error=str(e),
|
||||||
|
error_code="UnexpectedError",
|
||||||
|
tool_name=tool_name,
|
||||||
|
server_name=server_name,
|
||||||
|
execution_time_ms=execution_time,
|
||||||
|
request_id=request_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _execute_with_retry(
|
||||||
|
self,
|
||||||
|
server_name: str,
|
||||||
|
config: MCPServerConfig,
|
||||||
|
tool_name: str,
|
||||||
|
arguments: dict[str, Any],
|
||||||
|
timeout: float | None,
|
||||||
|
circuit_breaker: AsyncCircuitBreaker,
|
||||||
|
) -> Any:
|
||||||
|
"""Execute tool call with retry logic."""
|
||||||
|
last_error: Exception | None = None
|
||||||
|
attempts = 0
|
||||||
|
max_attempts = config.retry_attempts + 1 # +1 for initial attempt
|
||||||
|
|
||||||
|
while attempts < max_attempts:
|
||||||
|
attempts += 1
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Use circuit breaker to track failures
|
||||||
|
result = await self._execute_tool_call(
|
||||||
|
server_name=server_name,
|
||||||
|
config=config,
|
||||||
|
tool_name=tool_name,
|
||||||
|
arguments=arguments,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Success - record it
|
||||||
|
await circuit_breaker.success()
|
||||||
|
return result
|
||||||
|
|
||||||
|
except MCPCircuitOpenError:
|
||||||
|
raise
|
||||||
|
except MCPTimeoutError:
|
||||||
|
# Timeout - don't retry
|
||||||
|
await circuit_breaker.failure()
|
||||||
|
raise
|
||||||
|
except MCPToolError:
|
||||||
|
# Tool-level error - don't retry (user error)
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
last_error = e
|
||||||
|
await circuit_breaker.failure()
|
||||||
|
|
||||||
|
if attempts < max_attempts:
|
||||||
|
delay = self._calculate_retry_delay(attempts, config)
|
||||||
|
logger.warning(
|
||||||
|
"Tool call attempt %d/%d failed for %s.%s: %s. "
|
||||||
|
"Retrying in %.1fs",
|
||||||
|
attempts,
|
||||||
|
max_attempts,
|
||||||
|
server_name,
|
||||||
|
tool_name,
|
||||||
|
e,
|
||||||
|
delay,
|
||||||
|
)
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
|
||||||
|
# All attempts failed
|
||||||
|
raise MCPToolError(
|
||||||
|
f"Tool call failed after {max_attempts} attempts",
|
||||||
|
server_name=server_name,
|
||||||
|
tool_name=tool_name,
|
||||||
|
tool_args=arguments,
|
||||||
|
details={"last_error": str(last_error)},
|
||||||
|
)
|
||||||
|
|
||||||
|
def _calculate_retry_delay(
|
||||||
|
self,
|
||||||
|
attempt: int,
|
||||||
|
config: MCPServerConfig,
|
||||||
|
) -> float:
|
||||||
|
"""Calculate exponential backoff delay with jitter."""
|
||||||
|
import random
|
||||||
|
|
||||||
|
delay = config.retry_delay * (2 ** (attempt - 1))
|
||||||
|
delay = min(delay, config.retry_max_delay)
|
||||||
|
# Add jitter (±25%)
|
||||||
|
jitter = delay * 0.25 * (random.random() * 2 - 1)
|
||||||
|
return max(0.1, delay + jitter)
|
||||||
|
|
||||||
|
async def _execute_tool_call(
|
||||||
|
self,
|
||||||
|
server_name: str,
|
||||||
|
config: MCPServerConfig,
|
||||||
|
tool_name: str,
|
||||||
|
arguments: dict[str, Any],
|
||||||
|
timeout: float | None,
|
||||||
|
) -> Any:
|
||||||
|
"""Execute a single tool call."""
|
||||||
|
connection = await self._pool.get_connection(server_name, config)
|
||||||
|
|
||||||
|
# Build MCP tool call request
|
||||||
|
request_body = {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"method": "tools/call",
|
||||||
|
"params": {
|
||||||
|
"name": tool_name,
|
||||||
|
"arguments": arguments,
|
||||||
|
},
|
||||||
|
"id": str(uuid.uuid4()),
|
||||||
|
}
|
||||||
|
|
||||||
|
response = await connection.execute_request(
|
||||||
|
method="POST",
|
||||||
|
path="/mcp",
|
||||||
|
data=request_body,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle JSON-RPC response
|
||||||
|
if "error" in response:
|
||||||
|
error = response["error"]
|
||||||
|
raise MCPToolError(
|
||||||
|
error.get("message", "Tool execution failed"),
|
||||||
|
server_name=server_name,
|
||||||
|
tool_name=tool_name,
|
||||||
|
tool_args=arguments,
|
||||||
|
error_code=str(error.get("code", "UNKNOWN")),
|
||||||
|
)
|
||||||
|
|
||||||
|
return response.get("result")
|
||||||
|
|
||||||
|
async def route_tool(
|
||||||
|
self,
|
||||||
|
tool_name: str,
|
||||||
|
arguments: dict[str, Any] | None = None,
|
||||||
|
timeout: float | None = None,
|
||||||
|
) -> ToolResult:
|
||||||
|
"""
|
||||||
|
Route a tool call to the appropriate server.
|
||||||
|
|
||||||
|
Automatically discovers which server provides the tool.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: Name of the tool to call
|
||||||
|
arguments: Tool arguments
|
||||||
|
timeout: Optional timeout override
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tool execution result
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
MCPToolNotFoundError: If no server provides the tool
|
||||||
|
"""
|
||||||
|
server_name = self.find_server_for_tool(tool_name)
|
||||||
|
|
||||||
|
if server_name is None:
|
||||||
|
# Try to find from registry
|
||||||
|
server_name = self._registry.find_server_for_tool(tool_name)
|
||||||
|
|
||||||
|
if server_name is None:
|
||||||
|
raise MCPToolNotFoundError(
|
||||||
|
tool_name=tool_name,
|
||||||
|
available_tools=list(self._tool_to_server.keys()),
|
||||||
|
)
|
||||||
|
|
||||||
|
return await self.call_tool(
|
||||||
|
server_name=server_name,
|
||||||
|
tool_name=tool_name,
|
||||||
|
arguments=arguments,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def list_all_tools(self) -> list[ToolInfo]:
|
||||||
|
"""
|
||||||
|
Get all available tools from all servers.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of tool information
|
||||||
|
"""
|
||||||
|
tools = []
|
||||||
|
all_server_tools = self._registry.get_all_tools()
|
||||||
|
|
||||||
|
for server_name, server_tools in all_server_tools.items():
|
||||||
|
for tool_data in server_tools:
|
||||||
|
tools.append(
|
||||||
|
ToolInfo(
|
||||||
|
name=tool_data.get("name", ""),
|
||||||
|
description=tool_data.get("description"),
|
||||||
|
server_name=server_name,
|
||||||
|
input_schema=tool_data.get("input_schema"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return tools
|
||||||
|
|
||||||
|
def get_circuit_breaker_status(self) -> dict[str, dict[str, Any]]:
|
||||||
|
"""Get status of all circuit breakers."""
|
||||||
|
return {
|
||||||
|
name: {
|
||||||
|
"state": cb.current_state,
|
||||||
|
"failure_count": cb.fail_counter,
|
||||||
|
}
|
||||||
|
for name, cb in self._circuit_breakers.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
async def reset_circuit_breaker(self, server_name: str) -> bool:
|
||||||
|
"""
|
||||||
|
Manually reset a circuit breaker.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server_name: Name of the server
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if circuit breaker was reset
|
||||||
|
"""
|
||||||
|
async with self._lock:
|
||||||
|
if server_name in self._circuit_breakers:
|
||||||
|
# Reset by removing (will be recreated on next call)
|
||||||
|
del self._circuit_breakers[server_name]
|
||||||
|
logger.info("Reset circuit breaker for %s", server_name)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
@@ -343,7 +343,9 @@ class OAuthService:
|
|||||||
await oauth_account.update_tokens(
|
await oauth_account.update_tokens(
|
||||||
db,
|
db,
|
||||||
account=existing_oauth,
|
account=existing_oauth,
|
||||||
access_token_encrypted=token.get("access_token"), refresh_token_encrypted=token.get("refresh_token"), token_expires_at=datetime.now(UTC)
|
access_token_encrypted=token.get("access_token"),
|
||||||
|
refresh_token_encrypted=token.get("refresh_token"),
|
||||||
|
token_expires_at=datetime.now(UTC)
|
||||||
+ timedelta(seconds=token.get("expires_in", 3600)),
|
+ timedelta(seconds=token.get("expires_in", 3600)),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -375,7 +377,9 @@ class OAuthService:
|
|||||||
provider=provider,
|
provider=provider,
|
||||||
provider_user_id=provider_user_id,
|
provider_user_id=provider_user_id,
|
||||||
provider_email=provider_email,
|
provider_email=provider_email,
|
||||||
access_token_encrypted=token.get("access_token"), refresh_token_encrypted=token.get("refresh_token"), token_expires_at=datetime.now(UTC)
|
access_token_encrypted=token.get("access_token"),
|
||||||
|
refresh_token_encrypted=token.get("refresh_token"),
|
||||||
|
token_expires_at=datetime.now(UTC)
|
||||||
+ timedelta(seconds=token.get("expires_in", 3600))
|
+ timedelta(seconds=token.get("expires_in", 3600))
|
||||||
if token.get("expires_in")
|
if token.get("expires_in")
|
||||||
else None,
|
else None,
|
||||||
@@ -644,7 +648,9 @@ class OAuthService:
|
|||||||
provider=provider,
|
provider=provider,
|
||||||
provider_user_id=provider_user_id,
|
provider_user_id=provider_user_id,
|
||||||
provider_email=email,
|
provider_email=email,
|
||||||
access_token_encrypted=token.get("access_token"), refresh_token_encrypted=token.get("refresh_token"), token_expires_at=datetime.now(UTC)
|
access_token_encrypted=token.get("access_token"),
|
||||||
|
refresh_token_encrypted=token.get("refresh_token"),
|
||||||
|
token_expires_at=datetime.now(UTC)
|
||||||
+ timedelta(seconds=token.get("expires_in", 3600))
|
+ timedelta(seconds=token.get("expires_in", 3600))
|
||||||
if token.get("expires_in")
|
if token.get("expires_in")
|
||||||
else None,
|
else None,
|
||||||
|
|||||||
23
backend/app/tasks/__init__.py
Normal file
23
backend/app/tasks/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
# app/tasks/__init__.py
|
||||||
|
"""
|
||||||
|
Celery background tasks for Syndarix.
|
||||||
|
|
||||||
|
This package contains all Celery tasks organized by domain:
|
||||||
|
|
||||||
|
Modules:
|
||||||
|
agent: Agent execution tasks (run_agent_step, spawn_agent, terminate_agent)
|
||||||
|
git: Git operation tasks (clone, commit, branch, push, PR)
|
||||||
|
sync: Issue synchronization tasks (incremental/full sync, webhooks)
|
||||||
|
workflow: Workflow state management tasks
|
||||||
|
cost: Cost tracking and budget monitoring tasks
|
||||||
|
"""
|
||||||
|
|
||||||
|
from app.tasks import agent, cost, git, sync, workflow
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"agent",
|
||||||
|
"cost",
|
||||||
|
"git",
|
||||||
|
"sync",
|
||||||
|
"workflow",
|
||||||
|
]
|
||||||
146
backend/app/tasks/agent.py
Normal file
146
backend/app/tasks/agent.py
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
# app/tasks/agent.py
|
||||||
|
"""
|
||||||
|
Agent execution tasks for Syndarix.
|
||||||
|
|
||||||
|
These tasks handle the lifecycle of AI agent instances:
|
||||||
|
- Spawning new agent instances from agent types
|
||||||
|
- Executing agent steps (LLM calls, tool execution)
|
||||||
|
- Terminating agent instances
|
||||||
|
|
||||||
|
Tasks are routed to the 'agent' queue for dedicated processing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.celery_app import celery_app
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@celery_app.task(bind=True, name="app.tasks.agent.run_agent_step")
|
||||||
|
def run_agent_step(
|
||||||
|
self,
|
||||||
|
agent_instance_id: str,
|
||||||
|
context: dict[str, Any],
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Execute a single step of an agent's workflow.
|
||||||
|
|
||||||
|
This task performs one iteration of the agent loop:
|
||||||
|
1. Load agent instance state
|
||||||
|
2. Call LLM with context and available tools
|
||||||
|
3. Execute tool calls if any
|
||||||
|
4. Update agent state
|
||||||
|
5. Return result for next step or completion
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_instance_id: UUID of the agent instance
|
||||||
|
context: Current execution context including:
|
||||||
|
- messages: Conversation history
|
||||||
|
- tools: Available tool definitions
|
||||||
|
- state: Agent state data
|
||||||
|
- metadata: Project/task metadata
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict with status and agent_instance_id
|
||||||
|
"""
|
||||||
|
logger.info(
|
||||||
|
f"Running agent step for instance {agent_instance_id} with context keys: {list(context.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: Implement actual agent step execution
|
||||||
|
# This will involve:
|
||||||
|
# 1. Loading agent instance from database
|
||||||
|
# 2. Calling LLM provider (via litellm or anthropic SDK)
|
||||||
|
# 3. Processing tool calls through MCP servers
|
||||||
|
# 4. Updating agent state in database
|
||||||
|
# 5. Scheduling next step if needed
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "pending",
|
||||||
|
"agent_instance_id": agent_instance_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@celery_app.task(bind=True, name="app.tasks.agent.spawn_agent")
|
||||||
|
def spawn_agent(
|
||||||
|
self,
|
||||||
|
agent_type_id: str,
|
||||||
|
project_id: str,
|
||||||
|
initial_context: dict[str, Any],
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Spawn a new agent instance from an agent type.
|
||||||
|
|
||||||
|
This task creates a new agent instance:
|
||||||
|
1. Load agent type configuration (model, expertise, personality)
|
||||||
|
2. Create agent instance record in database
|
||||||
|
3. Initialize agent state with project context
|
||||||
|
4. Start first agent step
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_type_id: UUID of the agent type template
|
||||||
|
project_id: UUID of the project this agent will work on
|
||||||
|
initial_context: Starting context including:
|
||||||
|
- goal: High-level objective
|
||||||
|
- constraints: Any limitations or requirements
|
||||||
|
- assigned_issues: Issues to work on
|
||||||
|
- autonomy_level: FULL_CONTROL, MILESTONE, or AUTONOMOUS
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict with status, agent_type_id, and project_id
|
||||||
|
"""
|
||||||
|
logger.info(f"Spawning agent of type {agent_type_id} for project {project_id}")
|
||||||
|
|
||||||
|
# TODO: Implement agent spawning
|
||||||
|
# This will involve:
|
||||||
|
# 1. Loading agent type from database
|
||||||
|
# 2. Creating agent instance record
|
||||||
|
# 3. Setting up MCP tool access
|
||||||
|
# 4. Initializing agent state
|
||||||
|
# 5. Kicking off first step
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "spawned",
|
||||||
|
"agent_type_id": agent_type_id,
|
||||||
|
"project_id": project_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@celery_app.task(bind=True, name="app.tasks.agent.terminate_agent")
|
||||||
|
def terminate_agent(
|
||||||
|
self,
|
||||||
|
agent_instance_id: str,
|
||||||
|
reason: str,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Terminate an agent instance.
|
||||||
|
|
||||||
|
This task gracefully shuts down an agent:
|
||||||
|
1. Mark agent instance as terminated
|
||||||
|
2. Save final state for audit
|
||||||
|
3. Release any held resources
|
||||||
|
4. Notify relevant subscribers
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_instance_id: UUID of the agent instance
|
||||||
|
reason: Reason for termination (completion, error, manual, budget)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict with status and agent_instance_id
|
||||||
|
"""
|
||||||
|
logger.info(f"Terminating agent instance {agent_instance_id} with reason: {reason}")
|
||||||
|
|
||||||
|
# TODO: Implement agent termination
|
||||||
|
# This will involve:
|
||||||
|
# 1. Loading agent instance
|
||||||
|
# 2. Updating status to terminated
|
||||||
|
# 3. Saving termination reason
|
||||||
|
# 4. Cleaning up any pending tasks
|
||||||
|
# 5. Sending termination event
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "terminated",
|
||||||
|
"agent_instance_id": agent_instance_id,
|
||||||
|
}
|
||||||
201
backend/app/tasks/cost.py
Normal file
201
backend/app/tasks/cost.py
Normal file
@@ -0,0 +1,201 @@
|
|||||||
|
# app/tasks/cost.py
|
||||||
|
"""
|
||||||
|
Cost tracking and budget management tasks for Syndarix.
|
||||||
|
|
||||||
|
These tasks implement multi-layered cost tracking per ADR-012:
|
||||||
|
- Per-agent token usage tracking
|
||||||
|
- Project budget monitoring
|
||||||
|
- Daily cost aggregation
|
||||||
|
- Budget threshold alerts
|
||||||
|
- Cost reporting
|
||||||
|
|
||||||
|
Costs are tracked in real-time in Redis for speed,
|
||||||
|
then aggregated to PostgreSQL for durability.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.celery_app import celery_app
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@celery_app.task(bind=True, name="app.tasks.cost.aggregate_daily_costs")
|
||||||
|
def aggregate_daily_costs(self) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Aggregate daily costs from Redis to PostgreSQL.
|
||||||
|
|
||||||
|
This periodic task (runs daily):
|
||||||
|
1. Read accumulated costs from Redis
|
||||||
|
2. Aggregate by project, agent, and model
|
||||||
|
3. Store in PostgreSQL cost_records table
|
||||||
|
4. Clear Redis counters for new day
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict with status
|
||||||
|
"""
|
||||||
|
logger.info("Starting daily cost aggregation")
|
||||||
|
|
||||||
|
# TODO: Implement cost aggregation
|
||||||
|
# This will involve:
|
||||||
|
# 1. Fetching cost data from Redis
|
||||||
|
# 2. Grouping by project_id, agent_id, model
|
||||||
|
# 3. Inserting into PostgreSQL cost tables
|
||||||
|
# 4. Resetting Redis counters
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "pending",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@celery_app.task(bind=True, name="app.tasks.cost.check_budget_thresholds")
|
||||||
|
def check_budget_thresholds(
|
||||||
|
self,
|
||||||
|
project_id: str,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Check if a project has exceeded budget thresholds.
|
||||||
|
|
||||||
|
This task checks budget limits:
|
||||||
|
1. Get current spend from Redis counters
|
||||||
|
2. Compare against project budget limits
|
||||||
|
3. Send alerts if thresholds exceeded
|
||||||
|
4. Pause agents if hard limit reached
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: UUID of the project
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict with status and project_id
|
||||||
|
"""
|
||||||
|
logger.info(f"Checking budget thresholds for project {project_id}")
|
||||||
|
|
||||||
|
# TODO: Implement budget checking
|
||||||
|
# This will involve:
|
||||||
|
# 1. Loading project budget configuration
|
||||||
|
# 2. Getting current spend from Redis
|
||||||
|
# 3. Comparing against soft/hard limits
|
||||||
|
# 4. Sending alerts or pausing agents
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "pending",
|
||||||
|
"project_id": project_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@celery_app.task(bind=True, name="app.tasks.cost.record_llm_usage")
|
||||||
|
def record_llm_usage(
|
||||||
|
self,
|
||||||
|
agent_id: str,
|
||||||
|
project_id: str,
|
||||||
|
model: str,
|
||||||
|
prompt_tokens: int,
|
||||||
|
completion_tokens: int,
|
||||||
|
cost_usd: float,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Record LLM usage from an agent call.
|
||||||
|
|
||||||
|
This task tracks each LLM API call:
|
||||||
|
1. Increment Redis counters for real-time tracking
|
||||||
|
2. Store raw usage event for audit
|
||||||
|
3. Trigger budget check if threshold approaching
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: UUID of the agent instance
|
||||||
|
project_id: UUID of the project
|
||||||
|
model: Model identifier (e.g., claude-opus-4-5-20251101)
|
||||||
|
prompt_tokens: Number of input tokens
|
||||||
|
completion_tokens: Number of output tokens
|
||||||
|
cost_usd: Calculated cost in USD
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict with status, agent_id, project_id, and cost_usd
|
||||||
|
"""
|
||||||
|
logger.debug(
|
||||||
|
f"Recording LLM usage for model {model}: "
|
||||||
|
f"{prompt_tokens} prompt + {completion_tokens} completion tokens = ${cost_usd}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: Implement usage recording
|
||||||
|
# This will involve:
|
||||||
|
# 1. Incrementing Redis counters
|
||||||
|
# 2. Storing usage event
|
||||||
|
# 3. Checking if near budget threshold
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "pending",
|
||||||
|
"agent_id": agent_id,
|
||||||
|
"project_id": project_id,
|
||||||
|
"cost_usd": cost_usd,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@celery_app.task(bind=True, name="app.tasks.cost.generate_cost_report")
|
||||||
|
def generate_cost_report(
|
||||||
|
self,
|
||||||
|
project_id: str,
|
||||||
|
start_date: str,
|
||||||
|
end_date: str,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Generate a cost report for a project.
|
||||||
|
|
||||||
|
This task creates a detailed cost breakdown:
|
||||||
|
1. Query cost records for date range
|
||||||
|
2. Group by agent, model, and day
|
||||||
|
3. Calculate totals and trends
|
||||||
|
4. Format report for display
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: UUID of the project
|
||||||
|
start_date: Report start date (YYYY-MM-DD)
|
||||||
|
end_date: Report end date (YYYY-MM-DD)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict with status, project_id, and date range
|
||||||
|
"""
|
||||||
|
logger.info(
|
||||||
|
f"Generating cost report for project {project_id} from {start_date} to {end_date}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: Implement report generation
|
||||||
|
# This will involve:
|
||||||
|
# 1. Querying PostgreSQL for cost records
|
||||||
|
# 2. Aggregating by various dimensions
|
||||||
|
# 3. Calculating totals and averages
|
||||||
|
# 4. Formatting report data
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "pending",
|
||||||
|
"project_id": project_id,
|
||||||
|
"start_date": start_date,
|
||||||
|
"end_date": end_date,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@celery_app.task(bind=True, name="app.tasks.cost.reset_daily_budget_counters")
|
||||||
|
def reset_daily_budget_counters(self) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Reset daily budget counters in Redis.
|
||||||
|
|
||||||
|
This periodic task (runs daily at midnight UTC):
|
||||||
|
1. Archive current day's counters
|
||||||
|
2. Reset all daily budget counters
|
||||||
|
3. Prepare for new day's tracking
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict with status
|
||||||
|
"""
|
||||||
|
logger.info("Resetting daily budget counters")
|
||||||
|
|
||||||
|
# TODO: Implement counter reset
|
||||||
|
# This will involve:
|
||||||
|
# 1. Getting all daily counter keys from Redis
|
||||||
|
# 2. Archiving current values
|
||||||
|
# 3. Resetting counters to zero
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "pending",
|
||||||
|
}
|
||||||
221
backend/app/tasks/git.py
Normal file
221
backend/app/tasks/git.py
Normal file
@@ -0,0 +1,221 @@
|
|||||||
|
# app/tasks/git.py
|
||||||
|
"""
|
||||||
|
Git operation tasks for Syndarix.
|
||||||
|
|
||||||
|
These tasks handle Git operations for projects:
|
||||||
|
- Cloning repositories
|
||||||
|
- Creating branches
|
||||||
|
- Committing changes
|
||||||
|
- Pushing to remotes
|
||||||
|
- Creating pull requests
|
||||||
|
|
||||||
|
Tasks are routed to the 'git' queue for dedicated processing.
|
||||||
|
All operations are scoped by project_id for multi-tenancy.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.celery_app import celery_app
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@celery_app.task(bind=True, name="app.tasks.git.clone_repository")
|
||||||
|
def clone_repository(
|
||||||
|
self,
|
||||||
|
project_id: str,
|
||||||
|
repo_url: str,
|
||||||
|
branch: str = "main",
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Clone a repository for a project.
|
||||||
|
|
||||||
|
This task clones a Git repository to the project workspace:
|
||||||
|
1. Prepare workspace directory
|
||||||
|
2. Clone repository with credentials
|
||||||
|
3. Checkout specified branch
|
||||||
|
4. Update project metadata
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: UUID of the project
|
||||||
|
repo_url: Git repository URL (HTTPS or SSH)
|
||||||
|
branch: Branch to checkout (default: main)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict with status and project_id
|
||||||
|
"""
|
||||||
|
logger.info(
|
||||||
|
f"Cloning repository {repo_url} for project {project_id} on branch {branch}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: Implement repository cloning
|
||||||
|
# This will involve:
|
||||||
|
# 1. Getting project credentials from secrets store
|
||||||
|
# 2. Creating workspace directory
|
||||||
|
# 3. Running git clone with proper auth
|
||||||
|
# 4. Checking out the target branch
|
||||||
|
# 5. Updating project record with clone status
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "pending",
|
||||||
|
"project_id": project_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@celery_app.task(bind=True, name="app.tasks.git.commit_changes")
|
||||||
|
def commit_changes(
|
||||||
|
self,
|
||||||
|
project_id: str,
|
||||||
|
message: str,
|
||||||
|
files: list[str] | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Commit changes in a project repository.
|
||||||
|
|
||||||
|
This task creates a Git commit:
|
||||||
|
1. Stage specified files (or all if None)
|
||||||
|
2. Create commit with message
|
||||||
|
3. Update commit history record
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: UUID of the project
|
||||||
|
message: Commit message (follows conventional commits)
|
||||||
|
files: List of files to stage, or None for all staged
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict with status and project_id
|
||||||
|
"""
|
||||||
|
logger.info(f"Committing changes for project {project_id}: {message}")
|
||||||
|
|
||||||
|
# TODO: Implement commit operation
|
||||||
|
# This will involve:
|
||||||
|
# 1. Loading project workspace path
|
||||||
|
# 2. Running git add for specified files
|
||||||
|
# 3. Running git commit with message
|
||||||
|
# 4. Recording commit hash in database
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "pending",
|
||||||
|
"project_id": project_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@celery_app.task(bind=True, name="app.tasks.git.create_branch")
|
||||||
|
def create_branch(
|
||||||
|
self,
|
||||||
|
project_id: str,
|
||||||
|
branch_name: str,
|
||||||
|
from_ref: str = "HEAD",
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Create a new branch in a project repository.
|
||||||
|
|
||||||
|
This task creates a Git branch:
|
||||||
|
1. Checkout from reference
|
||||||
|
2. Create new branch
|
||||||
|
3. Update branch tracking
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: UUID of the project
|
||||||
|
branch_name: Name of the new branch (e.g., feature/123-description)
|
||||||
|
from_ref: Reference to branch from (default: HEAD)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict with status and project_id
|
||||||
|
"""
|
||||||
|
logger.info(
|
||||||
|
f"Creating branch {branch_name} from {from_ref} for project {project_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: Implement branch creation
|
||||||
|
# This will involve:
|
||||||
|
# 1. Loading project workspace
|
||||||
|
# 2. Running git checkout -b from_ref
|
||||||
|
# 3. Recording branch in database
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "pending",
|
||||||
|
"project_id": project_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@celery_app.task(bind=True, name="app.tasks.git.create_pull_request")
|
||||||
|
def create_pull_request(
|
||||||
|
self,
|
||||||
|
project_id: str,
|
||||||
|
title: str,
|
||||||
|
body: str,
|
||||||
|
head_branch: str,
|
||||||
|
base_branch: str = "main",
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Create a pull request for a project.
|
||||||
|
|
||||||
|
This task creates a PR on the external Git provider:
|
||||||
|
1. Push branch if needed
|
||||||
|
2. Create PR via API (Gitea, GitHub, GitLab)
|
||||||
|
3. Store PR reference
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: UUID of the project
|
||||||
|
title: PR title
|
||||||
|
body: PR description (markdown)
|
||||||
|
head_branch: Branch with changes
|
||||||
|
base_branch: Target branch (default: main)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict with status and project_id
|
||||||
|
"""
|
||||||
|
logger.info(
|
||||||
|
f"Creating PR '{title}' from {head_branch} to {base_branch} for project {project_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: Implement PR creation
|
||||||
|
# This will involve:
|
||||||
|
# 1. Loading project and Git provider config
|
||||||
|
# 2. Ensuring head_branch is pushed
|
||||||
|
# 3. Calling provider API to create PR
|
||||||
|
# 4. Storing PR URL and number
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "pending",
|
||||||
|
"project_id": project_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@celery_app.task(bind=True, name="app.tasks.git.push_changes")
|
||||||
|
def push_changes(
|
||||||
|
self,
|
||||||
|
project_id: str,
|
||||||
|
branch: str,
|
||||||
|
force: bool = False,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Push changes to remote repository.
|
||||||
|
|
||||||
|
This task pushes commits to the remote:
|
||||||
|
1. Verify authentication
|
||||||
|
2. Push branch to remote
|
||||||
|
3. Handle push failures
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: UUID of the project
|
||||||
|
branch: Branch to push
|
||||||
|
force: Whether to force push (use with caution)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict with status and project_id
|
||||||
|
"""
|
||||||
|
logger.info(f"Pushing branch {branch} for project {project_id} (force={force})")
|
||||||
|
|
||||||
|
# TODO: Implement push operation
|
||||||
|
# This will involve:
|
||||||
|
# 1. Loading project credentials
|
||||||
|
# 2. Running git push (with --force if specified)
|
||||||
|
# 3. Handling authentication and conflicts
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "pending",
|
||||||
|
"project_id": project_id,
|
||||||
|
}
|
||||||
194
backend/app/tasks/sync.py
Normal file
194
backend/app/tasks/sync.py
Normal file
@@ -0,0 +1,194 @@
|
|||||||
|
# app/tasks/sync.py
|
||||||
|
"""
|
||||||
|
Issue synchronization tasks for Syndarix.
|
||||||
|
|
||||||
|
These tasks handle bidirectional issue synchronization:
|
||||||
|
- Incremental sync (polling for recent changes)
|
||||||
|
- Full reconciliation (daily comprehensive sync)
|
||||||
|
- Webhook event processing
|
||||||
|
- Pushing local changes to external trackers
|
||||||
|
|
||||||
|
Tasks are routed to the 'sync' queue for dedicated processing.
|
||||||
|
Per ADR-011, sync follows a master/replica model with configurable direction.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.celery_app import celery_app
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@celery_app.task(bind=True, name="app.tasks.sync.sync_issues_incremental")
|
||||||
|
def sync_issues_incremental(self) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Perform incremental issue synchronization across all projects.
|
||||||
|
|
||||||
|
This periodic task (runs every 5 minutes):
|
||||||
|
1. Query each project's external tracker for recent changes
|
||||||
|
2. Compare with local issue cache
|
||||||
|
3. Apply updates to local database
|
||||||
|
4. Handle conflicts based on sync direction config
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict with status and type
|
||||||
|
"""
|
||||||
|
logger.info("Starting incremental issue sync across all projects")
|
||||||
|
|
||||||
|
# TODO: Implement incremental sync
|
||||||
|
# This will involve:
|
||||||
|
# 1. Loading all active projects with sync enabled
|
||||||
|
# 2. For each project, querying external tracker since last_sync_at
|
||||||
|
# 3. Upserting issues into local database
|
||||||
|
# 4. Updating last_sync_at timestamp
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "pending",
|
||||||
|
"type": "incremental",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@celery_app.task(bind=True, name="app.tasks.sync.sync_issues_full")
|
||||||
|
def sync_issues_full(self) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Perform full issue reconciliation across all projects.
|
||||||
|
|
||||||
|
This periodic task (runs daily):
|
||||||
|
1. Fetch all issues from external trackers
|
||||||
|
2. Compare with local database
|
||||||
|
3. Handle orphaned issues
|
||||||
|
4. Resolve any drift between systems
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict with status and type
|
||||||
|
"""
|
||||||
|
logger.info("Starting full issue reconciliation across all projects")
|
||||||
|
|
||||||
|
# TODO: Implement full sync
|
||||||
|
# This will involve:
|
||||||
|
# 1. Loading all active projects
|
||||||
|
# 2. Fetching complete issue lists from external trackers
|
||||||
|
# 3. Comparing with local database
|
||||||
|
# 4. Handling deletes and orphans
|
||||||
|
# 5. Resolving conflicts based on sync config
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "pending",
|
||||||
|
"type": "full",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@celery_app.task(bind=True, name="app.tasks.sync.process_webhook_event")
|
||||||
|
def process_webhook_event(
|
||||||
|
self,
|
||||||
|
provider: str,
|
||||||
|
event_type: str,
|
||||||
|
payload: dict[str, Any],
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Process a webhook event from an external Git provider.
|
||||||
|
|
||||||
|
This task handles real-time updates from:
|
||||||
|
- Gitea: issue.created, issue.updated, pull_request.*, etc.
|
||||||
|
- GitHub: issues, pull_request, push, etc.
|
||||||
|
- GitLab: issue events, merge request events, etc.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider: Git provider name (gitea, github, gitlab)
|
||||||
|
event_type: Event type from provider
|
||||||
|
payload: Raw webhook payload
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict with status, provider, and event_type
|
||||||
|
"""
|
||||||
|
logger.info(f"Processing webhook event from {provider}: {event_type}")
|
||||||
|
|
||||||
|
# TODO: Implement webhook processing
|
||||||
|
# This will involve:
|
||||||
|
# 1. Validating webhook signature
|
||||||
|
# 2. Parsing provider-specific payload
|
||||||
|
# 3. Mapping to internal event format
|
||||||
|
# 4. Updating local database
|
||||||
|
# 5. Triggering any dependent workflows
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "pending",
|
||||||
|
"provider": provider,
|
||||||
|
"event_type": event_type,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@celery_app.task(bind=True, name="app.tasks.sync.sync_project_issues")
|
||||||
|
def sync_project_issues(
|
||||||
|
self,
|
||||||
|
project_id: str,
|
||||||
|
full: bool = False,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Synchronize issues for a specific project.
|
||||||
|
|
||||||
|
This task can be triggered manually or by webhooks:
|
||||||
|
1. Connect to project's external tracker
|
||||||
|
2. Fetch issues (incremental or full)
|
||||||
|
3. Update local database
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: UUID of the project
|
||||||
|
full: Whether to do full sync or incremental
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict with status and project_id
|
||||||
|
"""
|
||||||
|
logger.info(f"Syncing issues for project {project_id} (full={full})")
|
||||||
|
|
||||||
|
# TODO: Implement project-specific sync
|
||||||
|
# This will involve:
|
||||||
|
# 1. Loading project configuration
|
||||||
|
# 2. Connecting to external tracker
|
||||||
|
# 3. Fetching issues based on full flag
|
||||||
|
# 4. Upserting to database
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "pending",
|
||||||
|
"project_id": project_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@celery_app.task(bind=True, name="app.tasks.sync.push_issue_to_external")
|
||||||
|
def push_issue_to_external(
|
||||||
|
self,
|
||||||
|
project_id: str,
|
||||||
|
issue_id: str,
|
||||||
|
operation: str,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Push a local issue change to the external tracker.
|
||||||
|
|
||||||
|
This task handles outbound sync when Syndarix is the master:
|
||||||
|
- create: Create new issue in external tracker
|
||||||
|
- update: Update existing issue
|
||||||
|
- close: Close issue in external tracker
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: UUID of the project
|
||||||
|
issue_id: UUID of the local issue
|
||||||
|
operation: Operation type (create, update, close)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict with status, issue_id, and operation
|
||||||
|
"""
|
||||||
|
logger.info(f"Pushing {operation} for issue {issue_id} in project {project_id}")
|
||||||
|
|
||||||
|
# TODO: Implement outbound sync
|
||||||
|
# This will involve:
|
||||||
|
# 1. Loading issue and project config
|
||||||
|
# 2. Mapping to external tracker format
|
||||||
|
# 3. Calling provider API
|
||||||
|
# 4. Updating external_id mapping
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "pending",
|
||||||
|
"issue_id": issue_id,
|
||||||
|
"operation": operation,
|
||||||
|
}
|
||||||
209
backend/app/tasks/workflow.py
Normal file
209
backend/app/tasks/workflow.py
Normal file
@@ -0,0 +1,209 @@
|
|||||||
|
# app/tasks/workflow.py
|
||||||
|
"""
|
||||||
|
Workflow state management tasks for Syndarix.
|
||||||
|
|
||||||
|
These tasks manage workflow execution and state transitions:
|
||||||
|
- Sprint workflows (planning -> implementation -> review -> done)
|
||||||
|
- Story workflows (todo -> in_progress -> review -> done)
|
||||||
|
- Approval checkpoints for autonomy levels
|
||||||
|
- Stale workflow recovery
|
||||||
|
|
||||||
|
Per ADR-007 and ADR-010, workflow state is durable in PostgreSQL
|
||||||
|
with defined state transitions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.celery_app import celery_app
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@celery_app.task(bind=True, name="app.tasks.workflow.recover_stale_workflows")
|
||||||
|
def recover_stale_workflows(self) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Recover workflows that have become stale.
|
||||||
|
|
||||||
|
This periodic task (runs every 5 minutes):
|
||||||
|
1. Find workflows stuck in intermediate states
|
||||||
|
2. Check for timed-out agent operations
|
||||||
|
3. Retry or escalate based on configuration
|
||||||
|
4. Notify relevant users if needed
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict with status and recovered count
|
||||||
|
"""
|
||||||
|
logger.info("Checking for stale workflows to recover")
|
||||||
|
|
||||||
|
# TODO: Implement stale workflow recovery
|
||||||
|
# This will involve:
|
||||||
|
# 1. Querying for workflows with last_updated > threshold
|
||||||
|
# 2. Checking if associated agents are still running
|
||||||
|
# 3. Retrying or resetting stuck workflows
|
||||||
|
# 4. Sending notifications for manual intervention
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "pending",
|
||||||
|
"recovered": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@celery_app.task(bind=True, name="app.tasks.workflow.execute_workflow_step")
|
||||||
|
def execute_workflow_step(
|
||||||
|
self,
|
||||||
|
workflow_id: str,
|
||||||
|
transition: str,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Execute a state transition for a workflow.
|
||||||
|
|
||||||
|
This task applies a transition to a workflow:
|
||||||
|
1. Validate transition is allowed from current state
|
||||||
|
2. Execute any pre-transition hooks
|
||||||
|
3. Update workflow state
|
||||||
|
4. Execute any post-transition hooks
|
||||||
|
5. Trigger follow-up tasks
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workflow_id: UUID of the workflow
|
||||||
|
transition: Transition to execute (start, approve, reject, etc.)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict with status, workflow_id, and transition
|
||||||
|
"""
|
||||||
|
logger.info(f"Executing transition '{transition}' for workflow {workflow_id}")
|
||||||
|
|
||||||
|
# TODO: Implement workflow transition
|
||||||
|
# This will involve:
|
||||||
|
# 1. Loading workflow from database
|
||||||
|
# 2. Validating transition from current state
|
||||||
|
# 3. Running pre-transition hooks
|
||||||
|
# 4. Updating state in database
|
||||||
|
# 5. Running post-transition hooks
|
||||||
|
# 6. Scheduling follow-up tasks
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "pending",
|
||||||
|
"workflow_id": workflow_id,
|
||||||
|
"transition": transition,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@celery_app.task(bind=True, name="app.tasks.workflow.handle_approval_response")
|
||||||
|
def handle_approval_response(
|
||||||
|
self,
|
||||||
|
workflow_id: str,
|
||||||
|
approved: bool,
|
||||||
|
comment: str | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Handle a user approval response for a workflow checkpoint.
|
||||||
|
|
||||||
|
This task processes approval decisions:
|
||||||
|
1. Record approval decision with timestamp
|
||||||
|
2. Update workflow state accordingly
|
||||||
|
3. Resume or halt workflow execution
|
||||||
|
4. Notify relevant parties
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workflow_id: UUID of the workflow
|
||||||
|
approved: Whether the checkpoint was approved
|
||||||
|
comment: Optional comment from approver
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict with status, workflow_id, and approved flag
|
||||||
|
"""
|
||||||
|
logger.info(
|
||||||
|
f"Handling approval response for workflow {workflow_id}: approved={approved}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: Implement approval handling
|
||||||
|
# This will involve:
|
||||||
|
# 1. Loading workflow and approval checkpoint
|
||||||
|
# 2. Recording decision with user and timestamp
|
||||||
|
# 3. Transitioning workflow state
|
||||||
|
# 4. Resuming or stopping execution
|
||||||
|
# 5. Sending notifications
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "pending",
|
||||||
|
"workflow_id": workflow_id,
|
||||||
|
"approved": approved,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@celery_app.task(bind=True, name="app.tasks.workflow.start_sprint_workflow")
|
||||||
|
def start_sprint_workflow(
|
||||||
|
self,
|
||||||
|
project_id: str,
|
||||||
|
sprint_id: str,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Start a new sprint workflow.
|
||||||
|
|
||||||
|
This task initializes sprint execution:
|
||||||
|
1. Create sprint workflow record
|
||||||
|
2. Set up sprint planning phase
|
||||||
|
3. Spawn Product Owner agent for planning
|
||||||
|
4. Begin story assignment
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: UUID of the project
|
||||||
|
sprint_id: UUID of the sprint
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict with status and sprint_id
|
||||||
|
"""
|
||||||
|
logger.info(
|
||||||
|
f"Starting sprint workflow for sprint {sprint_id} in project {project_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: Implement sprint workflow initialization
|
||||||
|
# This will involve:
|
||||||
|
# 1. Creating workflow record for sprint
|
||||||
|
# 2. Setting initial state to PLANNING
|
||||||
|
# 3. Spawning PO agent for sprint planning
|
||||||
|
# 4. Setting up monitoring and checkpoints
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "pending",
|
||||||
|
"sprint_id": sprint_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@celery_app.task(bind=True, name="app.tasks.workflow.start_story_workflow")
|
||||||
|
def start_story_workflow(
|
||||||
|
self,
|
||||||
|
project_id: str,
|
||||||
|
story_id: str,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Start a new story workflow.
|
||||||
|
|
||||||
|
This task initializes story execution:
|
||||||
|
1. Create story workflow record
|
||||||
|
2. Spawn appropriate developer agent
|
||||||
|
3. Set up implementation tracking
|
||||||
|
4. Configure approval checkpoints based on autonomy level
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: UUID of the project
|
||||||
|
story_id: UUID of the story/issue
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict with status and story_id
|
||||||
|
"""
|
||||||
|
logger.info(f"Starting story workflow for story {story_id} in project {project_id}")
|
||||||
|
|
||||||
|
# TODO: Implement story workflow initialization
|
||||||
|
# This will involve:
|
||||||
|
# 1. Creating workflow record for story
|
||||||
|
# 2. Determining appropriate agent type
|
||||||
|
# 3. Spawning developer agent
|
||||||
|
# 4. Setting up checkpoints based on autonomy level
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "pending",
|
||||||
|
"story_id": story_id,
|
||||||
|
}
|
||||||
324
backend/docs/MCP_CLIENT.md
Normal file
324
backend/docs/MCP_CLIENT.md
Normal file
@@ -0,0 +1,324 @@
|
|||||||
|
# MCP Client Infrastructure
|
||||||
|
|
||||||
|
This document describes the Model Context Protocol (MCP) client infrastructure used by Syndarix to communicate with AI agent tools.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The MCP client infrastructure provides a robust, fault-tolerant layer for communicating with MCP servers. It enables AI agents to discover and execute tools provided by various services (LLM Gateway, Knowledge Base, Git Operations, Issue Tracker, etc.).
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
```
|
||||||
|
┌────────────────────────────────────────────────────────────────────────┐
|
||||||
|
│ MCPClientManager │
|
||||||
|
│ (Main Facade Class) │
|
||||||
|
├────────────────────────────────────────────────────────────────────────┤
|
||||||
|
│ - initialize() / shutdown() │
|
||||||
|
│ - call_tool() / route_tool() │
|
||||||
|
│ - connect() / disconnect() │
|
||||||
|
│ - health_check() / list_tools() │
|
||||||
|
└─────────────┬────────────────────┬─────────────────┬───────────────────┘
|
||||||
|
│ │ │
|
||||||
|
▼ ▼ ▼
|
||||||
|
┌─────────────────────┐ ┌─────────────────┐ ┌──────────────────────────┐
|
||||||
|
│ MCPServerRegistry │ │ ConnectionPool │ │ ToolRouter │
|
||||||
|
│ (Singleton) │ │ │ │ │
|
||||||
|
├─────────────────────┤ ├─────────────────┤ ├──────────────────────────┤
|
||||||
|
│ - Server configs │ │ - Connection │ │ - Tool → Server mapping │
|
||||||
|
│ - Capabilities │ │ management │ │ - Circuit breakers │
|
||||||
|
│ - Tool discovery │ │ - Auto reconnect│ │ - Retry logic │
|
||||||
|
└─────────────────────┘ └─────────────────┘ └──────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
## Components
|
||||||
|
|
||||||
|
### MCPClientManager
|
||||||
|
|
||||||
|
The main entry point for all MCP operations. Provides a clean facade over the underlying infrastructure.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from app.services.mcp import get_mcp_client, MCPClientManager
|
||||||
|
|
||||||
|
# In FastAPI dependency injection
|
||||||
|
async def my_route(mcp: MCPClientManager = Depends(get_mcp_client)):
|
||||||
|
result = await mcp.call_tool(
|
||||||
|
server="llm-gateway",
|
||||||
|
tool="chat",
|
||||||
|
args={"prompt": "Hello"}
|
||||||
|
)
|
||||||
|
return result.data
|
||||||
|
|
||||||
|
# Direct usage
|
||||||
|
manager = MCPClientManager()
|
||||||
|
await manager.initialize()
|
||||||
|
|
||||||
|
# Execute a tool
|
||||||
|
result = await manager.call_tool(
|
||||||
|
server="issues",
|
||||||
|
tool="create_issue",
|
||||||
|
args={"title": "New Feature", "body": "Description"}
|
||||||
|
)
|
||||||
|
|
||||||
|
await manager.shutdown()
|
||||||
|
```
|
||||||
|
|
||||||
|
### Configuration
|
||||||
|
|
||||||
|
Configuration is loaded from YAML files and supports environment variable expansion:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# mcp_servers.yaml
|
||||||
|
mcp_servers:
|
||||||
|
llm-gateway:
|
||||||
|
url: ${LLM_GATEWAY_URL:-http://localhost:8001}
|
||||||
|
timeout: 60
|
||||||
|
transport: http
|
||||||
|
enabled: true
|
||||||
|
retry_attempts: 3
|
||||||
|
circuit_breaker_threshold: 5
|
||||||
|
circuit_breaker_timeout: 30.0
|
||||||
|
|
||||||
|
knowledge-base:
|
||||||
|
url: ${KNOWLEDGE_BASE_URL:-http://localhost:8002}
|
||||||
|
timeout: 30
|
||||||
|
enabled: true
|
||||||
|
|
||||||
|
default_timeout: 30
|
||||||
|
connection_pool_size: 10
|
||||||
|
health_check_interval: 30
|
||||||
|
```
|
||||||
|
|
||||||
|
**Environment Variable Syntax:**
|
||||||
|
- `${VAR_NAME}` - Uses the environment variable value
|
||||||
|
- `${VAR_NAME:-default}` - Uses default if variable is not set
|
||||||
|
|
||||||
|
### Connection Management
|
||||||
|
|
||||||
|
The `ConnectionPool` manages connections to MCP servers with:
|
||||||
|
|
||||||
|
- **Connection Reuse**: Connections are pooled and reused
|
||||||
|
- **Auto Reconnection**: Failed connections are automatically retried
|
||||||
|
- **Health Checks**: Periodic health checks detect unhealthy servers
|
||||||
|
- **Exponential Backoff**: Retry delays increase exponentially with jitter
|
||||||
|
|
||||||
|
```python
|
||||||
|
from app.services.mcp import ConnectionPool, MCPConnection
|
||||||
|
|
||||||
|
pool = ConnectionPool(max_connections_per_server=5)
|
||||||
|
|
||||||
|
# Get a connection (creates new or reuses existing)
|
||||||
|
conn = await pool.get_connection("server-1", config)
|
||||||
|
|
||||||
|
# Execute request
|
||||||
|
result = await conn.execute_request("POST", "/mcp", data={...})
|
||||||
|
|
||||||
|
# Health check all connections
|
||||||
|
health = await pool.health_check_all()
|
||||||
|
```
|
||||||
|
|
||||||
|
### Circuit Breaker Pattern
|
||||||
|
|
||||||
|
The `AsyncCircuitBreaker` prevents cascade failures:
|
||||||
|
|
||||||
|
| State | Description |
|
||||||
|
|-------|-------------|
|
||||||
|
| CLOSED | Normal operation, calls pass through |
|
||||||
|
| OPEN | Too many failures, calls are rejected immediately |
|
||||||
|
| HALF-OPEN | After timeout, allows one call to test if service recovered |
|
||||||
|
|
||||||
|
```python
|
||||||
|
from app.services.mcp import AsyncCircuitBreaker
|
||||||
|
|
||||||
|
breaker = AsyncCircuitBreaker(
|
||||||
|
fail_max=5, # Open after 5 failures
|
||||||
|
reset_timeout=30, # Try again after 30 seconds
|
||||||
|
name="my-service"
|
||||||
|
)
|
||||||
|
|
||||||
|
if breaker.is_open():
|
||||||
|
raise MCPCircuitOpenError(...)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await call_external_service()
|
||||||
|
await breaker.success()
|
||||||
|
except Exception:
|
||||||
|
await breaker.failure()
|
||||||
|
raise
|
||||||
|
```
|
||||||
|
|
||||||
|
### Tool Routing
|
||||||
|
|
||||||
|
The `ToolRouter` handles:
|
||||||
|
|
||||||
|
- **Tool Discovery**: Automatically discovers tools from connected servers
|
||||||
|
- **Routing**: Routes tool calls to the appropriate server
|
||||||
|
- **Retry Logic**: Retries failed calls with exponential backoff
|
||||||
|
|
||||||
|
```python
|
||||||
|
from app.services.mcp import ToolRouter
|
||||||
|
|
||||||
|
router = ToolRouter(registry, pool)
|
||||||
|
|
||||||
|
# Discover tools from all servers
|
||||||
|
await router.discover_tools()
|
||||||
|
|
||||||
|
# Route to the right server automatically
|
||||||
|
result = await router.route_tool(
|
||||||
|
tool_name="create_issue",
|
||||||
|
arguments={"title": "Bug fix"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Or call a specific server
|
||||||
|
result = await router.call_tool(
|
||||||
|
server_name="issues",
|
||||||
|
tool_name="create_issue",
|
||||||
|
arguments={"title": "Bug fix"}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Exception Hierarchy
|
||||||
|
|
||||||
|
```
|
||||||
|
MCPError
|
||||||
|
├── MCPConnectionError # Connection failures
|
||||||
|
├── MCPTimeoutError # Operation timeouts
|
||||||
|
├── MCPToolError # Tool execution errors
|
||||||
|
├── MCPServerNotFoundError # Unknown server
|
||||||
|
├── MCPToolNotFoundError # Unknown tool
|
||||||
|
├── MCPCircuitOpenError # Circuit breaker open
|
||||||
|
└── MCPValidationError # Invalid configuration
|
||||||
|
```
|
||||||
|
|
||||||
|
All exceptions include rich context:
|
||||||
|
|
||||||
|
```python
|
||||||
|
except MCPServerNotFoundError as e:
|
||||||
|
print(f"Server: {e.server_name}")
|
||||||
|
print(f"Available: {e.available_servers}")
|
||||||
|
print(f"Suggestion: {e.suggestion}")
|
||||||
|
```
|
||||||
|
|
||||||
|
## REST API Endpoints
|
||||||
|
|
||||||
|
| Method | Endpoint | Description | Auth |
|
||||||
|
|--------|----------|-------------|------|
|
||||||
|
| GET | `/api/v1/mcp/servers` | List all MCP servers | No |
|
||||||
|
| GET | `/api/v1/mcp/servers/{name}/tools` | List server tools | No |
|
||||||
|
| GET | `/api/v1/mcp/tools` | List all tools | No |
|
||||||
|
| GET | `/api/v1/mcp/health` | Health check | No |
|
||||||
|
| POST | `/api/v1/mcp/call` | Execute tool | Superuser |
|
||||||
|
| GET | `/api/v1/mcp/circuit-breakers` | List circuit breakers | No |
|
||||||
|
| POST | `/api/v1/mcp/circuit-breakers/{name}/reset` | Reset breaker | Superuser |
|
||||||
|
| POST | `/api/v1/mcp/servers/{name}/reconnect` | Force reconnect | Superuser |
|
||||||
|
|
||||||
|
### Example: Execute a Tool
|
||||||
|
|
||||||
|
```http
|
||||||
|
POST /api/v1/mcp/call
|
||||||
|
Authorization: Bearer <token>
|
||||||
|
Content-Type: application/json
|
||||||
|
|
||||||
|
{
|
||||||
|
"server": "issues",
|
||||||
|
"tool": "create_issue",
|
||||||
|
"arguments": {
|
||||||
|
"title": "New Feature Request",
|
||||||
|
"body": "Please add dark mode support"
|
||||||
|
},
|
||||||
|
"timeout": 30
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Response:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"success": true,
|
||||||
|
"data": {
|
||||||
|
"issue_id": "12345",
|
||||||
|
"url": "https://gitea.example.com/org/repo/issues/42"
|
||||||
|
},
|
||||||
|
"tool_name": "create_issue",
|
||||||
|
"server_name": "issues",
|
||||||
|
"execution_time_ms": 234.5,
|
||||||
|
"request_id": "550e8400-e29b-41d4-a716-446655440000"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage in Syndarix Agents
|
||||||
|
|
||||||
|
AI agents use the MCP client to execute tools:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class IssueCreatorAgent:
|
||||||
|
def __init__(self, mcp: MCPClientManager):
|
||||||
|
self.mcp = mcp
|
||||||
|
|
||||||
|
async def create_issue(self, title: str, body: str) -> dict:
|
||||||
|
result = await self.mcp.call_tool(
|
||||||
|
server="issues",
|
||||||
|
tool="create_issue",
|
||||||
|
args={"title": title, "body": body}
|
||||||
|
)
|
||||||
|
|
||||||
|
if not result.success:
|
||||||
|
raise AgentError(f"Failed to create issue: {result.error}")
|
||||||
|
|
||||||
|
return result.data
|
||||||
|
```
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
|
||||||
|
The MCP infrastructure is thoroughly tested:
|
||||||
|
|
||||||
|
- **Unit Tests**: `tests/services/mcp/` - Service layer tests
|
||||||
|
- **API Tests**: `tests/api/routes/test_mcp.py` - Endpoint tests
|
||||||
|
|
||||||
|
Run tests:
|
||||||
|
```bash
|
||||||
|
# All MCP tests
|
||||||
|
IS_TEST=True uv run pytest tests/services/mcp/ tests/api/routes/test_mcp.py -v
|
||||||
|
|
||||||
|
# With coverage
|
||||||
|
IS_TEST=True uv run pytest tests/services/mcp/ --cov=app/services/mcp
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration Reference
|
||||||
|
|
||||||
|
### MCPServerConfig
|
||||||
|
|
||||||
|
| Field | Type | Default | Description |
|
||||||
|
|-------|------|---------|-------------|
|
||||||
|
| `url` | str | Required | Server URL |
|
||||||
|
| `transport` | str | "http" | Transport type (http, stdio, sse) |
|
||||||
|
| `timeout` | int | 30 | Request timeout (1-600 seconds) |
|
||||||
|
| `retry_attempts` | int | 3 | Max retry attempts (0-10) |
|
||||||
|
| `retry_delay` | float | 1.0 | Initial retry delay (0.1-300 seconds) |
|
||||||
|
| `retry_max_delay` | float | 30.0 | Maximum retry delay |
|
||||||
|
| `circuit_breaker_threshold` | int | 5 | Failures before opening circuit |
|
||||||
|
| `circuit_breaker_timeout` | float | 30.0 | Seconds before trying again |
|
||||||
|
| `enabled` | bool | true | Whether server is enabled |
|
||||||
|
| `description` | str | None | Server description |
|
||||||
|
|
||||||
|
### MCPConfig (Global)
|
||||||
|
|
||||||
|
| Field | Type | Default | Description |
|
||||||
|
|-------|------|---------|-------------|
|
||||||
|
| `mcp_servers` | dict | {} | Server configurations |
|
||||||
|
| `default_timeout` | int | 30 | Default request timeout |
|
||||||
|
| `default_retry_attempts` | int | 3 | Default retry attempts |
|
||||||
|
| `connection_pool_size` | int | 10 | Max connections per server |
|
||||||
|
| `health_check_interval` | int | 30 | Health check interval (seconds) |
|
||||||
|
|
||||||
|
## Files
|
||||||
|
|
||||||
|
| Path | Description |
|
||||||
|
|------|-------------|
|
||||||
|
| `app/services/mcp/__init__.py` | Package exports |
|
||||||
|
| `app/services/mcp/client_manager.py` | Main facade class |
|
||||||
|
| `app/services/mcp/config.py` | Configuration models |
|
||||||
|
| `app/services/mcp/registry.py` | Server registry singleton |
|
||||||
|
| `app/services/mcp/connection.py` | Connection management |
|
||||||
|
| `app/services/mcp/routing.py` | Tool routing and circuit breakers |
|
||||||
|
| `app/services/mcp/exceptions.py` | Exception classes |
|
||||||
|
| `app/api/routes/mcp.py` | REST API endpoints |
|
||||||
|
| `mcp_servers.yaml` | Default configuration |
|
||||||
@@ -1,6 +1,5 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
set -e
|
set -e
|
||||||
echo "Starting Backend"
|
|
||||||
|
|
||||||
# Ensure the project's virtualenv binaries are on PATH so commands like
|
# Ensure the project's virtualenv binaries are on PATH so commands like
|
||||||
# 'uvicorn' work even when not prefixed by 'uv run'. This matches how uv
|
# 'uvicorn' work even when not prefixed by 'uv run'. This matches how uv
|
||||||
@@ -9,14 +8,23 @@ if [ -d "/app/.venv/bin" ]; then
|
|||||||
export PATH="/app/.venv/bin:$PATH"
|
export PATH="/app/.venv/bin:$PATH"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Apply database migrations
|
# Only the backend service should run migrations and init_db
|
||||||
# Avoid installing the project in editable mode (which tries to write egg-info)
|
# Celery workers should skip this to avoid race conditions
|
||||||
# when running inside a bind-mounted volume with restricted permissions.
|
# Check if the first argument contains 'celery' - if so, skip migrations
|
||||||
# See: https://github.com/astral-sh/uv (use --no-project to skip project build)
|
if [[ "$1" == *"celery"* ]]; then
|
||||||
uv run --no-project alembic upgrade head
|
echo "Starting Celery worker (skipping migrations)"
|
||||||
|
else
|
||||||
|
echo "Starting Backend"
|
||||||
|
|
||||||
# Initialize database (creates first superuser if needed)
|
# Apply database migrations
|
||||||
uv run --no-project python app/init_db.py
|
# Avoid installing the project in editable mode (which tries to write egg-info)
|
||||||
|
# when running inside a bind-mounted volume with restricted permissions.
|
||||||
|
# See: https://github.com/astral-sh/uv (use --no-project to skip project build)
|
||||||
|
uv run --no-project alembic upgrade head
|
||||||
|
|
||||||
|
# Initialize database (creates first superuser if needed)
|
||||||
|
uv run --no-project python app/init_db.py
|
||||||
|
fi
|
||||||
|
|
||||||
# Execute the command passed to docker run
|
# Execute the command passed to docker run
|
||||||
exec "$@"
|
exec "$@"
|
||||||
60
backend/mcp_servers.yaml
Normal file
60
backend/mcp_servers.yaml
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
# MCP Server Configuration
|
||||||
|
#
|
||||||
|
# This file defines the MCP servers that the Syndarix backend connects to.
|
||||||
|
# Environment variables can be used with ${VAR:-default} syntax.
|
||||||
|
#
|
||||||
|
# Example:
|
||||||
|
# url: ${MY_SERVER_URL:-http://localhost:8001}
|
||||||
|
#
|
||||||
|
# For development, these servers typically run as separate Docker containers.
|
||||||
|
# See docker-compose.yml for container definitions.
|
||||||
|
|
||||||
|
mcp_servers:
|
||||||
|
# LLM Gateway - Multi-provider AI interactions
|
||||||
|
llm-gateway:
|
||||||
|
url: ${LLM_GATEWAY_URL:-http://localhost:8001}
|
||||||
|
transport: http
|
||||||
|
timeout: 60
|
||||||
|
retry_attempts: 3
|
||||||
|
retry_delay: 1.0
|
||||||
|
retry_max_delay: 30.0
|
||||||
|
circuit_breaker_threshold: 5
|
||||||
|
circuit_breaker_timeout: 30.0
|
||||||
|
enabled: true
|
||||||
|
description: "LLM Gateway for Anthropic, OpenAI, Ollama, and other providers"
|
||||||
|
|
||||||
|
# Knowledge Base - RAG and document retrieval
|
||||||
|
knowledge-base:
|
||||||
|
url: ${KNOWLEDGE_BASE_URL:-http://localhost:8002}
|
||||||
|
transport: http
|
||||||
|
timeout: 30
|
||||||
|
retry_attempts: 3
|
||||||
|
circuit_breaker_threshold: 5
|
||||||
|
enabled: true
|
||||||
|
description: "Knowledge Base with pgvector for semantic search and RAG"
|
||||||
|
|
||||||
|
# Git Operations - Repository management
|
||||||
|
git-ops:
|
||||||
|
url: ${GIT_OPS_URL:-http://localhost:8003}
|
||||||
|
transport: http
|
||||||
|
timeout: 120
|
||||||
|
retry_attempts: 2
|
||||||
|
circuit_breaker_threshold: 3
|
||||||
|
enabled: true
|
||||||
|
description: "Git Operations for clone, commit, push, and repository management"
|
||||||
|
|
||||||
|
# Issues - Issue tracker integration
|
||||||
|
issues:
|
||||||
|
url: ${ISSUES_URL:-http://localhost:8004}
|
||||||
|
transport: http
|
||||||
|
timeout: 30
|
||||||
|
retry_attempts: 3
|
||||||
|
circuit_breaker_threshold: 5
|
||||||
|
enabled: true
|
||||||
|
description: "Issue Tracker integration for Gitea, GitHub, and GitLab"
|
||||||
|
|
||||||
|
# Global defaults
|
||||||
|
default_timeout: 30
|
||||||
|
default_retry_attempts: 3
|
||||||
|
connection_pool_size: 10
|
||||||
|
health_check_interval: 30
|
||||||
@@ -306,7 +306,7 @@ def show_next_rev_id():
|
|||||||
"""Show the next sequential revision ID."""
|
"""Show the next sequential revision ID."""
|
||||||
next_id = get_next_rev_id()
|
next_id = get_next_rev_id()
|
||||||
print(f"Next revision ID: {next_id}")
|
print(f"Next revision ID: {next_id}")
|
||||||
print(f"\nUsage:")
|
print("\nUsage:")
|
||||||
print(f" python migrate.py --local generate 'your_message' --rev-id {next_id}")
|
print(f" python migrate.py --local generate 'your_message' --rev-id {next_id}")
|
||||||
print(f" python migrate.py --local auto 'your_message' --rev-id {next_id}")
|
print(f" python migrate.py --local auto 'your_message' --rev-id {next_id}")
|
||||||
return next_id
|
return next_id
|
||||||
@@ -416,7 +416,7 @@ def main():
|
|||||||
if args.command == 'auto' and offline:
|
if args.command == 'auto' and offline:
|
||||||
generate_migration(args.message, rev_id=args.rev_id, offline=True)
|
generate_migration(args.message, rev_id=args.rev_id, offline=True)
|
||||||
print("\nOffline migration generated. Apply it later with:")
|
print("\nOffline migration generated. Apply it later with:")
|
||||||
print(f" python migrate.py --local apply")
|
print(" python migrate.py --local apply")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Setup database URL (must be done before importing settings elsewhere)
|
# Setup database URL (must be done before importing settings elsewhere)
|
||||||
|
|||||||
@@ -22,41 +22,43 @@ dependencies = [
|
|||||||
"pydantic-settings>=2.2.1",
|
"pydantic-settings>=2.2.1",
|
||||||
"python-multipart>=0.0.19",
|
"python-multipart>=0.0.19",
|
||||||
"fastapi-utils==0.8.0",
|
"fastapi-utils==0.8.0",
|
||||||
|
|
||||||
# Database
|
# Database
|
||||||
"sqlalchemy>=2.0.29",
|
"sqlalchemy>=2.0.29",
|
||||||
"alembic>=1.14.1",
|
"alembic>=1.14.1",
|
||||||
"psycopg2-binary>=2.9.9",
|
"psycopg2-binary>=2.9.9",
|
||||||
"asyncpg>=0.29.0",
|
"asyncpg>=0.29.0",
|
||||||
"aiosqlite==0.21.0",
|
"aiosqlite==0.21.0",
|
||||||
|
|
||||||
# Environment configuration
|
# Environment configuration
|
||||||
"python-dotenv>=1.0.1",
|
"python-dotenv>=1.0.1",
|
||||||
|
|
||||||
# API utilities
|
# API utilities
|
||||||
"email-validator>=2.1.0.post1",
|
"email-validator>=2.1.0.post1",
|
||||||
"ujson>=5.9.0",
|
"ujson>=5.9.0",
|
||||||
|
|
||||||
# CORS and security
|
# CORS and security
|
||||||
"starlette>=0.40.0",
|
"starlette>=0.40.0",
|
||||||
"starlette-csrf>=1.4.5",
|
"starlette-csrf>=1.4.5",
|
||||||
"slowapi>=0.1.9",
|
"slowapi>=0.1.9",
|
||||||
|
|
||||||
# Utilities
|
# Utilities
|
||||||
"httpx>=0.27.0",
|
"httpx>=0.27.0",
|
||||||
"tenacity>=8.2.3",
|
"tenacity>=8.2.3",
|
||||||
"pytz>=2024.1",
|
"pytz>=2024.1",
|
||||||
"pillow>=10.3.0",
|
"pillow>=10.3.0",
|
||||||
"apscheduler==3.11.0",
|
"apscheduler==3.11.0",
|
||||||
|
|
||||||
# Security and authentication (pinned for reproducibility)
|
# Security and authentication (pinned for reproducibility)
|
||||||
"python-jose==3.4.0",
|
"python-jose==3.4.0",
|
||||||
"passlib==1.7.4",
|
"passlib==1.7.4",
|
||||||
"bcrypt==4.2.1",
|
"bcrypt==4.2.1",
|
||||||
"cryptography==44.0.1",
|
"cryptography==44.0.1",
|
||||||
|
|
||||||
# OAuth authentication
|
# OAuth authentication
|
||||||
"authlib>=1.3.0",
|
"authlib>=1.3.0",
|
||||||
|
# Celery for background task processing (Syndarix agent jobs)
|
||||||
|
"celery[redis]>=5.4.0",
|
||||||
|
"sse-starlette>=3.1.1",
|
||||||
|
# MCP (Model Context Protocol) for AI agent tool integration
|
||||||
|
"mcp>=1.0.0",
|
||||||
|
# Circuit breaker pattern for resilient connections
|
||||||
|
"pybreaker>=1.0.0",
|
||||||
|
# YAML configuration support
|
||||||
|
"pyyaml>=6.0.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Development dependencies
|
# Development dependencies
|
||||||
@@ -155,6 +157,7 @@ unfixable = []
|
|||||||
"app/alembic/env.py" = ["E402", "F403", "F405"] # Alembic requires specific import order
|
"app/alembic/env.py" = ["E402", "F403", "F405"] # Alembic requires specific import order
|
||||||
"app/alembic/versions/*.py" = ["E402"] # Migration files have specific structure
|
"app/alembic/versions/*.py" = ["E402"] # Migration files have specific structure
|
||||||
"tests/**/*.py" = ["S101", "N806", "B017", "N817", "S110", "ASYNC251", "RUF043"] # pytest: asserts, CamelCase fixtures, blind exceptions, try-pass patterns, and async test helpers are intentional
|
"tests/**/*.py" = ["S101", "N806", "B017", "N817", "S110", "ASYNC251", "RUF043"] # pytest: asserts, CamelCase fixtures, blind exceptions, try-pass patterns, and async test helpers are intentional
|
||||||
|
"app/services/mcp/*.py" = ["ASYNC109", "S311", "RUF022"] # timeout is config param not asyncio.timeout; random is ok for jitter; __all__ order is intentional for readability
|
||||||
"app/models/__init__.py" = ["F401"] # __init__ files re-export modules
|
"app/models/__init__.py" = ["F401"] # __init__ files re-export modules
|
||||||
"app/models/base.py" = ["F401"] # Re-exports Base for use by other models
|
"app/models/base.py" = ["F401"] # Re-exports Base for use by other models
|
||||||
"app/utils/test_utils.py" = ["N806"] # SQLAlchemy session factories use CamelCase convention
|
"app/utils/test_utils.py" = ["N806"] # SQLAlchemy session factories use CamelCase convention
|
||||||
@@ -256,6 +259,30 @@ ignore_missing_imports = true
|
|||||||
module = "authlib.*"
|
module = "authlib.*"
|
||||||
ignore_missing_imports = true
|
ignore_missing_imports = true
|
||||||
|
|
||||||
|
[[tool.mypy.overrides]]
|
||||||
|
module = "celery.*"
|
||||||
|
ignore_missing_imports = true
|
||||||
|
|
||||||
|
[[tool.mypy.overrides]]
|
||||||
|
module = "redis.*"
|
||||||
|
ignore_missing_imports = true
|
||||||
|
|
||||||
|
[[tool.mypy.overrides]]
|
||||||
|
module = "sse_starlette.*"
|
||||||
|
ignore_missing_imports = true
|
||||||
|
|
||||||
|
[[tool.mypy.overrides]]
|
||||||
|
module = "httpx.*"
|
||||||
|
ignore_missing_imports = true
|
||||||
|
|
||||||
|
[[tool.mypy.overrides]]
|
||||||
|
module = "pybreaker.*"
|
||||||
|
ignore_missing_imports = true
|
||||||
|
|
||||||
|
[[tool.mypy.overrides]]
|
||||||
|
module = "yaml.*"
|
||||||
|
ignore_missing_imports = true
|
||||||
|
|
||||||
# SQLAlchemy ORM models - Column descriptors cause type confusion
|
# SQLAlchemy ORM models - Column descriptors cause type confusion
|
||||||
[[tool.mypy.overrides]]
|
[[tool.mypy.overrides]]
|
||||||
module = "app.models.*"
|
module = "app.models.*"
|
||||||
@@ -286,11 +313,43 @@ disable_error_code = ["arg-type"]
|
|||||||
module = "app.services.auth_service"
|
module = "app.services.auth_service"
|
||||||
disable_error_code = ["assignment", "arg-type"]
|
disable_error_code = ["assignment", "arg-type"]
|
||||||
|
|
||||||
|
# OAuth services - SQLAlchemy Column issues and unused type:ignore from library evolution
|
||||||
|
[[tool.mypy.overrides]]
|
||||||
|
module = "app.services.oauth_provider_service"
|
||||||
|
disable_error_code = ["assignment", "arg-type", "attr-defined", "unused-ignore"]
|
||||||
|
|
||||||
|
[[tool.mypy.overrides]]
|
||||||
|
module = "app.services.oauth_service"
|
||||||
|
disable_error_code = ["assignment", "arg-type", "attr-defined"]
|
||||||
|
|
||||||
|
# MCP services - circuit breaker and httpx client handling
|
||||||
|
[[tool.mypy.overrides]]
|
||||||
|
module = "app.services.mcp.*"
|
||||||
|
disable_error_code = ["attr-defined", "arg-type"]
|
||||||
|
|
||||||
# Test utils - Testing patterns
|
# Test utils - Testing patterns
|
||||||
[[tool.mypy.overrides]]
|
[[tool.mypy.overrides]]
|
||||||
module = "app.utils.auth_test_utils"
|
module = "app.utils.auth_test_utils"
|
||||||
disable_error_code = ["assignment", "arg-type"]
|
disable_error_code = ["assignment", "arg-type"]
|
||||||
|
|
||||||
|
# Test dependencies - ignore missing stubs
|
||||||
|
[[tool.mypy.overrides]]
|
||||||
|
module = "pytest_asyncio.*"
|
||||||
|
ignore_missing_imports = true
|
||||||
|
|
||||||
|
[[tool.mypy.overrides]]
|
||||||
|
module = "schemathesis.*"
|
||||||
|
ignore_missing_imports = true
|
||||||
|
|
||||||
|
[[tool.mypy.overrides]]
|
||||||
|
module = "testcontainers.*"
|
||||||
|
ignore_missing_imports = true
|
||||||
|
|
||||||
|
# Tests directory - relax type checking for test code
|
||||||
|
[[tool.mypy.overrides]]
|
||||||
|
module = "tests.*"
|
||||||
|
disable_error_code = ["arg-type", "union-attr", "return-value", "call-arg", "unused-ignore", "assignment", "var-annotated", "operator"]
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
# Pydantic mypy plugin configuration
|
# Pydantic mypy plugin configuration
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|||||||
39
backend/tests/api/dependencies/test_event_bus.py
Normal file
39
backend/tests/api/dependencies/test_event_bus.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
# tests/api/dependencies/test_event_bus.py
|
||||||
|
"""Tests for the event_bus dependency."""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.api.dependencies.event_bus import get_event_bus
|
||||||
|
from app.services.event_bus import EventBus
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
class TestGetEventBusDependency:
|
||||||
|
"""Tests for the get_event_bus FastAPI dependency."""
|
||||||
|
|
||||||
|
async def test_get_event_bus_returns_event_bus(self):
|
||||||
|
"""Test that get_event_bus returns an EventBus instance."""
|
||||||
|
mock_event_bus = AsyncMock(spec=EventBus)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"app.api.dependencies.event_bus._get_connected_event_bus",
|
||||||
|
return_value=mock_event_bus,
|
||||||
|
):
|
||||||
|
result = await get_event_bus()
|
||||||
|
|
||||||
|
assert result is mock_event_bus
|
||||||
|
|
||||||
|
async def test_get_event_bus_calls_get_connected_event_bus(self):
|
||||||
|
"""Test that get_event_bus calls the underlying function."""
|
||||||
|
mock_event_bus = AsyncMock(spec=EventBus)
|
||||||
|
mock_get_connected = AsyncMock(return_value=mock_event_bus)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"app.api.dependencies.event_bus._get_connected_event_bus",
|
||||||
|
mock_get_connected,
|
||||||
|
):
|
||||||
|
await get_event_bus()
|
||||||
|
|
||||||
|
mock_get_connected.assert_called_once()
|
||||||
2
backend/tests/api/routes/syndarix/__init__.py
Normal file
2
backend/tests/api/routes/syndarix/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
# tests/api/routes/syndarix/__init__.py
|
||||||
|
"""Syndarix API route tests."""
|
||||||
747
backend/tests/api/routes/syndarix/test_agent_types.py
Normal file
747
backend/tests/api/routes/syndarix/test_agent_types.py
Normal file
@@ -0,0 +1,747 @@
|
|||||||
|
# tests/api/routes/syndarix/test_agent_types.py
|
||||||
|
"""
|
||||||
|
Comprehensive tests for the AgentTypes API endpoints.
|
||||||
|
|
||||||
|
Tests cover:
|
||||||
|
- CRUD operations (create, read, update, deactivate)
|
||||||
|
- Authorization (superuser vs regular user)
|
||||||
|
- Pagination and filtering
|
||||||
|
- Error handling (not found, validation, duplicates)
|
||||||
|
- Slug lookup functionality
|
||||||
|
"""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from fastapi import status
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def test_agent_type(client, superuser_token):
|
||||||
|
"""Create a test agent type for tests."""
|
||||||
|
unique_slug = f"test-type-{uuid.uuid4().hex[:8]}"
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/agent-types",
|
||||||
|
json={
|
||||||
|
"name": "Test Agent Type",
|
||||||
|
"slug": unique_slug,
|
||||||
|
"description": "A test agent type for testing",
|
||||||
|
"expertise": ["python", "testing"],
|
||||||
|
"personality_prompt": "You are a helpful test agent.",
|
||||||
|
"primary_model": "claude-3-opus",
|
||||||
|
"fallback_models": ["claude-3-sonnet"],
|
||||||
|
"model_params": {"temperature": 0.7},
|
||||||
|
"mcp_servers": [],
|
||||||
|
"tool_permissions": {"read": True, "write": False},
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_201_CREATED
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def multiple_agent_types(client, superuser_token):
|
||||||
|
"""Create multiple agent types for pagination tests."""
|
||||||
|
types = []
|
||||||
|
for i in range(5):
|
||||||
|
unique_slug = f"multi-type-{i}-{uuid.uuid4().hex[:8]}"
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/agent-types",
|
||||||
|
json={
|
||||||
|
"name": f"Agent Type {i}",
|
||||||
|
"slug": unique_slug,
|
||||||
|
"description": f"Description for type {i}",
|
||||||
|
"expertise": ["python"],
|
||||||
|
"personality_prompt": f"Personality prompt {i}",
|
||||||
|
"primary_model": "claude-3-opus",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_201_CREATED
|
||||||
|
types.append(response.json())
|
||||||
|
return types
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
class TestCreateAgentType:
|
||||||
|
"""Tests for POST /api/v1/agent-types endpoint."""
|
||||||
|
|
||||||
|
async def test_create_agent_type_success(self, client, superuser_token):
|
||||||
|
"""Test successful agent type creation by superuser."""
|
||||||
|
unique_slug = f"created-type-{uuid.uuid4().hex[:8]}"
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/agent-types",
|
||||||
|
json={
|
||||||
|
"name": "New Agent Type",
|
||||||
|
"slug": unique_slug,
|
||||||
|
"description": "A newly created agent type",
|
||||||
|
"expertise": ["python", "fastapi"],
|
||||||
|
"personality_prompt": "You are a backend developer.",
|
||||||
|
"primary_model": "claude-3-opus",
|
||||||
|
"fallback_models": ["claude-3-sonnet"],
|
||||||
|
"model_params": {"temperature": 0.5},
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_201_CREATED
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
assert data["name"] == "New Agent Type"
|
||||||
|
assert data["slug"] == unique_slug
|
||||||
|
assert data["description"] == "A newly created agent type"
|
||||||
|
assert data["expertise"] == ["python", "fastapi"]
|
||||||
|
assert data["personality_prompt"] == "You are a backend developer."
|
||||||
|
assert data["primary_model"] == "claude-3-opus"
|
||||||
|
assert data["fallback_models"] == ["claude-3-sonnet"]
|
||||||
|
assert data["model_params"]["temperature"] == 0.5
|
||||||
|
assert data["is_active"] is True
|
||||||
|
assert data["instance_count"] == 0
|
||||||
|
assert "id" in data
|
||||||
|
assert "created_at" in data
|
||||||
|
assert "updated_at" in data
|
||||||
|
|
||||||
|
async def test_create_agent_type_minimal_fields(self, client, superuser_token):
|
||||||
|
"""Test creating agent type with only required fields."""
|
||||||
|
unique_slug = f"minimal-type-{uuid.uuid4().hex[:8]}"
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/agent-types",
|
||||||
|
json={
|
||||||
|
"name": "Minimal Agent Type",
|
||||||
|
"slug": unique_slug,
|
||||||
|
"expertise": ["general"],
|
||||||
|
"personality_prompt": "You are a general assistant.",
|
||||||
|
"primary_model": "claude-3-sonnet",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_201_CREATED
|
||||||
|
data = response.json()
|
||||||
|
assert data["name"] == "Minimal Agent Type"
|
||||||
|
assert data["slug"] == unique_slug
|
||||||
|
assert data["is_active"] is True
|
||||||
|
|
||||||
|
async def test_create_agent_type_duplicate_slug(
|
||||||
|
self, client, superuser_token, test_agent_type
|
||||||
|
):
|
||||||
|
"""Test that duplicate slugs are rejected."""
|
||||||
|
existing_slug = test_agent_type["slug"]
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/agent-types",
|
||||||
|
json={
|
||||||
|
"name": "Another Type",
|
||||||
|
"slug": existing_slug, # Duplicate slug
|
||||||
|
"expertise": ["python"],
|
||||||
|
"personality_prompt": "Prompt",
|
||||||
|
"primary_model": "claude-3-opus",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_409_CONFLICT
|
||||||
|
data = response.json()
|
||||||
|
assert data["errors"][0]["code"] == "SYS_005" # ALREADY_EXISTS
|
||||||
|
assert data["errors"][0]["field"] == "slug"
|
||||||
|
|
||||||
|
async def test_create_agent_type_regular_user_forbidden(self, client, user_token):
|
||||||
|
"""Test that regular users cannot create agent types."""
|
||||||
|
unique_slug = f"forbidden-type-{uuid.uuid4().hex[:8]}"
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/agent-types",
|
||||||
|
json={
|
||||||
|
"name": "Forbidden Type",
|
||||||
|
"slug": unique_slug,
|
||||||
|
"expertise": ["python"],
|
||||||
|
"personality_prompt": "Prompt",
|
||||||
|
"primary_model": "claude-3-opus",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||||
|
|
||||||
|
async def test_create_agent_type_unauthenticated(self, client):
|
||||||
|
"""Test that unauthenticated users cannot create agent types."""
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/agent-types",
|
||||||
|
json={
|
||||||
|
"name": "Unauth Type",
|
||||||
|
"slug": "unauth-type",
|
||||||
|
"expertise": ["python"],
|
||||||
|
"personality_prompt": "Prompt",
|
||||||
|
"primary_model": "claude-3-opus",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
|
async def test_create_agent_type_validation_missing_name(
|
||||||
|
self, client, superuser_token
|
||||||
|
):
|
||||||
|
"""Test validation error when name is missing."""
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/agent-types",
|
||||||
|
json={
|
||||||
|
"slug": "no-name-type",
|
||||||
|
"expertise": ["python"],
|
||||||
|
"personality_prompt": "Prompt",
|
||||||
|
"primary_model": "claude-3-opus",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||||
|
|
||||||
|
async def test_create_agent_type_validation_missing_primary_model(
|
||||||
|
self, client, superuser_token
|
||||||
|
):
|
||||||
|
"""Test validation error when primary_model is missing."""
|
||||||
|
unique_slug = f"no-model-type-{uuid.uuid4().hex[:8]}"
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/agent-types",
|
||||||
|
json={
|
||||||
|
"name": "No Model Type",
|
||||||
|
"slug": unique_slug,
|
||||||
|
"expertise": ["python"],
|
||||||
|
"personality_prompt": "Prompt",
|
||||||
|
# Missing primary_model
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
class TestListAgentTypes:
|
||||||
|
"""Tests for GET /api/v1/agent-types endpoint."""
|
||||||
|
|
||||||
|
async def test_list_agent_types_success(
|
||||||
|
self, client, user_token, multiple_agent_types
|
||||||
|
):
|
||||||
|
"""Test successful listing of agent types."""
|
||||||
|
response = await client.get(
|
||||||
|
"/api/v1/agent-types",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
assert "data" in data
|
||||||
|
assert "pagination" in data
|
||||||
|
assert len(data["data"]) >= 5
|
||||||
|
assert data["pagination"]["total"] >= 5
|
||||||
|
assert data["pagination"]["page"] == 1
|
||||||
|
|
||||||
|
async def test_list_agent_types_pagination(
|
||||||
|
self, client, user_token, multiple_agent_types
|
||||||
|
):
|
||||||
|
"""Test pagination of agent types."""
|
||||||
|
response = await client.get(
|
||||||
|
"/api/v1/agent-types",
|
||||||
|
params={"page": 1, "limit": 2},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
assert len(data["data"]) <= 2
|
||||||
|
assert data["pagination"]["page_size"] <= 2
|
||||||
|
assert data["pagination"]["page"] == 1
|
||||||
|
|
||||||
|
async def test_list_agent_types_filter_active(
|
||||||
|
self, client, user_token, test_agent_type
|
||||||
|
):
|
||||||
|
"""Test filtering by active status."""
|
||||||
|
# Default: only active types
|
||||||
|
response = await client.get(
|
||||||
|
"/api/v1/agent-types",
|
||||||
|
params={"is_active": True},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
# All returned types should be active
|
||||||
|
for agent_type in data["data"]:
|
||||||
|
assert agent_type["is_active"] is True
|
||||||
|
|
||||||
|
async def test_list_agent_types_search(
|
||||||
|
self, client, user_token, multiple_agent_types
|
||||||
|
):
|
||||||
|
"""Test search functionality."""
|
||||||
|
# Search for a specific type
|
||||||
|
search_term = multiple_agent_types[0]["name"]
|
||||||
|
response = await client.get(
|
||||||
|
"/api/v1/agent-types",
|
||||||
|
params={"search": search_term},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert len(data["data"]) >= 1
|
||||||
|
|
||||||
|
async def test_list_agent_types_unauthenticated(self, client):
|
||||||
|
"""Test that unauthenticated users cannot list agent types."""
|
||||||
|
response = await client.get("/api/v1/agent-types")
|
||||||
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
class TestGetAgentType:
|
||||||
|
"""Tests for GET /api/v1/agent-types/{agent_type_id} endpoint."""
|
||||||
|
|
||||||
|
async def test_get_agent_type_success(self, client, user_token, test_agent_type):
|
||||||
|
"""Test successful retrieval of agent type by ID."""
|
||||||
|
agent_type_id = test_agent_type["id"]
|
||||||
|
|
||||||
|
response = await client.get(
|
||||||
|
f"/api/v1/agent-types/{agent_type_id}",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
assert data["id"] == agent_type_id
|
||||||
|
assert data["name"] == test_agent_type["name"]
|
||||||
|
assert data["slug"] == test_agent_type["slug"]
|
||||||
|
assert "instance_count" in data
|
||||||
|
|
||||||
|
async def test_get_agent_type_not_found(self, client, user_token):
|
||||||
|
"""Test retrieval of non-existent agent type."""
|
||||||
|
fake_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
response = await client.get(
|
||||||
|
f"/api/v1/agent-types/{fake_id}",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
data = response.json()
|
||||||
|
assert data["errors"][0]["code"] == "SYS_002" # NOT_FOUND
|
||||||
|
|
||||||
|
async def test_get_agent_type_invalid_uuid(self, client, user_token):
|
||||||
|
"""Test retrieval with invalid UUID format."""
|
||||||
|
response = await client.get(
|
||||||
|
"/api/v1/agent-types/not-a-uuid",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||||
|
|
||||||
|
async def test_get_agent_type_unauthenticated(self, client, test_agent_type):
|
||||||
|
"""Test that unauthenticated users cannot get agent types."""
|
||||||
|
agent_type_id = test_agent_type["id"]
|
||||||
|
|
||||||
|
response = await client.get(f"/api/v1/agent-types/{agent_type_id}")
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
class TestGetAgentTypeBySlug:
|
||||||
|
"""Tests for GET /api/v1/agent-types/slug/{slug} endpoint."""
|
||||||
|
|
||||||
|
async def test_get_agent_type_by_slug_success(
|
||||||
|
self, client, user_token, test_agent_type
|
||||||
|
):
|
||||||
|
"""Test successful retrieval of agent type by slug."""
|
||||||
|
slug = test_agent_type["slug"]
|
||||||
|
|
||||||
|
response = await client.get(
|
||||||
|
f"/api/v1/agent-types/slug/{slug}",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
assert data["slug"] == slug
|
||||||
|
assert data["id"] == test_agent_type["id"]
|
||||||
|
assert data["name"] == test_agent_type["name"]
|
||||||
|
|
||||||
|
async def test_get_agent_type_by_slug_not_found(self, client, user_token):
|
||||||
|
"""Test retrieval of non-existent slug."""
|
||||||
|
response = await client.get(
|
||||||
|
"/api/v1/agent-types/slug/non-existent-slug",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
data = response.json()
|
||||||
|
assert data["errors"][0]["code"] == "SYS_002" # NOT_FOUND
|
||||||
|
assert "non-existent-slug" in data["errors"][0]["message"]
|
||||||
|
|
||||||
|
async def test_get_agent_type_by_slug_unauthenticated(
|
||||||
|
self, client, test_agent_type
|
||||||
|
):
|
||||||
|
"""Test that unauthenticated users cannot get agent types by slug."""
|
||||||
|
slug = test_agent_type["slug"]
|
||||||
|
|
||||||
|
response = await client.get(f"/api/v1/agent-types/slug/{slug}")
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
class TestUpdateAgentType:
|
||||||
|
"""Tests for PATCH /api/v1/agent-types/{agent_type_id} endpoint."""
|
||||||
|
|
||||||
|
async def test_update_agent_type_success(
|
||||||
|
self, client, superuser_token, test_agent_type
|
||||||
|
):
|
||||||
|
"""Test successful update of agent type."""
|
||||||
|
agent_type_id = test_agent_type["id"]
|
||||||
|
|
||||||
|
response = await client.patch(
|
||||||
|
f"/api/v1/agent-types/{agent_type_id}",
|
||||||
|
json={
|
||||||
|
"name": "Updated Agent Type",
|
||||||
|
"description": "Updated description",
|
||||||
|
"expertise": ["python", "fastapi", "testing"],
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
assert data["id"] == agent_type_id
|
||||||
|
assert data["name"] == "Updated Agent Type"
|
||||||
|
assert data["description"] == "Updated description"
|
||||||
|
assert data["expertise"] == ["python", "fastapi", "testing"]
|
||||||
|
# Slug should remain unchanged
|
||||||
|
assert data["slug"] == test_agent_type["slug"]
|
||||||
|
|
||||||
|
async def test_update_agent_type_partial(
|
||||||
|
self, client, superuser_token, test_agent_type
|
||||||
|
):
|
||||||
|
"""Test partial update of agent type."""
|
||||||
|
agent_type_id = test_agent_type["id"]
|
||||||
|
|
||||||
|
response = await client.patch(
|
||||||
|
f"/api/v1/agent-types/{agent_type_id}",
|
||||||
|
json={"description": "Only description updated"},
|
||||||
|
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
assert data["description"] == "Only description updated"
|
||||||
|
# Other fields remain unchanged
|
||||||
|
assert data["name"] == test_agent_type["name"]
|
||||||
|
|
||||||
|
async def test_update_agent_type_slug(
|
||||||
|
self, client, superuser_token, test_agent_type
|
||||||
|
):
|
||||||
|
"""Test updating agent type slug."""
|
||||||
|
agent_type_id = test_agent_type["id"]
|
||||||
|
new_slug = f"updated-slug-{uuid.uuid4().hex[:8]}"
|
||||||
|
|
||||||
|
response = await client.patch(
|
||||||
|
f"/api/v1/agent-types/{agent_type_id}",
|
||||||
|
json={"slug": new_slug},
|
||||||
|
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert data["slug"] == new_slug
|
||||||
|
|
||||||
|
async def test_update_agent_type_duplicate_slug(
|
||||||
|
self, client, superuser_token, multiple_agent_types
|
||||||
|
):
|
||||||
|
"""Test that updating to an existing slug fails."""
|
||||||
|
# Try to update first type's slug to second type's slug
|
||||||
|
first_type_id = multiple_agent_types[0]["id"]
|
||||||
|
second_type_slug = multiple_agent_types[1]["slug"]
|
||||||
|
|
||||||
|
response = await client.patch(
|
||||||
|
f"/api/v1/agent-types/{first_type_id}",
|
||||||
|
json={"slug": second_type_slug},
|
||||||
|
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_409_CONFLICT
|
||||||
|
data = response.json()
|
||||||
|
assert data["errors"][0]["code"] == "SYS_005" # ALREADY_EXISTS
|
||||||
|
|
||||||
|
async def test_update_agent_type_not_found(self, client, superuser_token):
|
||||||
|
"""Test updating non-existent agent type."""
|
||||||
|
fake_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
response = await client.patch(
|
||||||
|
f"/api/v1/agent-types/{fake_id}",
|
||||||
|
json={"name": "Updated Name"},
|
||||||
|
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
data = response.json()
|
||||||
|
assert data["errors"][0]["code"] == "SYS_002" # NOT_FOUND
|
||||||
|
|
||||||
|
async def test_update_agent_type_regular_user_forbidden(
|
||||||
|
self, client, user_token, test_agent_type
|
||||||
|
):
|
||||||
|
"""Test that regular users cannot update agent types."""
|
||||||
|
agent_type_id = test_agent_type["id"]
|
||||||
|
|
||||||
|
response = await client.patch(
|
||||||
|
f"/api/v1/agent-types/{agent_type_id}",
|
||||||
|
json={"name": "Forbidden Update"},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||||
|
|
||||||
|
async def test_update_agent_type_unauthenticated(self, client, test_agent_type):
|
||||||
|
"""Test that unauthenticated users cannot update agent types."""
|
||||||
|
agent_type_id = test_agent_type["id"]
|
||||||
|
|
||||||
|
response = await client.patch(
|
||||||
|
f"/api/v1/agent-types/{agent_type_id}",
|
||||||
|
json={"name": "Unauth Update"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
class TestDeactivateAgentType:
|
||||||
|
"""Tests for DELETE /api/v1/agent-types/{agent_type_id} endpoint."""
|
||||||
|
|
||||||
|
async def test_deactivate_agent_type_success(self, client, superuser_token):
|
||||||
|
"""Test successful deactivation of agent type."""
|
||||||
|
# Create a type to deactivate
|
||||||
|
unique_slug = f"deactivate-type-{uuid.uuid4().hex[:8]}"
|
||||||
|
create_response = await client.post(
|
||||||
|
"/api/v1/agent-types",
|
||||||
|
json={
|
||||||
|
"name": "Type to Deactivate",
|
||||||
|
"slug": unique_slug,
|
||||||
|
"expertise": ["python"],
|
||||||
|
"personality_prompt": "Prompt",
|
||||||
|
"primary_model": "claude-3-opus",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||||
|
)
|
||||||
|
assert create_response.status_code == status.HTTP_201_CREATED
|
||||||
|
agent_type_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Deactivate it
|
||||||
|
response = await client.delete(
|
||||||
|
f"/api/v1/agent-types/{agent_type_id}",
|
||||||
|
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert data["success"] is True
|
||||||
|
assert "deactivated" in data["message"].lower()
|
||||||
|
|
||||||
|
# Verify it's deactivated by checking is_active filter
|
||||||
|
get_response = await client.get(
|
||||||
|
f"/api/v1/agent-types/{agent_type_id}",
|
||||||
|
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||||
|
)
|
||||||
|
assert get_response.status_code == status.HTTP_200_OK
|
||||||
|
assert get_response.json()["is_active"] is False
|
||||||
|
|
||||||
|
async def test_deactivate_agent_type_not_found(self, client, superuser_token):
|
||||||
|
"""Test deactivating non-existent agent type."""
|
||||||
|
fake_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
response = await client.delete(
|
||||||
|
f"/api/v1/agent-types/{fake_id}",
|
||||||
|
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
data = response.json()
|
||||||
|
assert data["errors"][0]["code"] == "SYS_002" # NOT_FOUND
|
||||||
|
|
||||||
|
async def test_deactivate_agent_type_regular_user_forbidden(
|
||||||
|
self, client, user_token, test_agent_type
|
||||||
|
):
|
||||||
|
"""Test that regular users cannot deactivate agent types."""
|
||||||
|
agent_type_id = test_agent_type["id"]
|
||||||
|
|
||||||
|
response = await client.delete(
|
||||||
|
f"/api/v1/agent-types/{agent_type_id}",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||||
|
|
||||||
|
async def test_deactivate_agent_type_unauthenticated(self, client, test_agent_type):
|
||||||
|
"""Test that unauthenticated users cannot deactivate agent types."""
|
||||||
|
agent_type_id = test_agent_type["id"]
|
||||||
|
|
||||||
|
response = await client.delete(f"/api/v1/agent-types/{agent_type_id}")
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
|
async def test_deactivate_agent_type_idempotent(self, client, superuser_token):
|
||||||
|
"""Test that deactivating an already deactivated type returns 404."""
|
||||||
|
# Create and deactivate a type
|
||||||
|
unique_slug = f"idempotent-type-{uuid.uuid4().hex[:8]}"
|
||||||
|
create_response = await client.post(
|
||||||
|
"/api/v1/agent-types",
|
||||||
|
json={
|
||||||
|
"name": "Type to Deactivate Twice",
|
||||||
|
"slug": unique_slug,
|
||||||
|
"expertise": ["python"],
|
||||||
|
"personality_prompt": "Prompt",
|
||||||
|
"primary_model": "claude-3-opus",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||||
|
)
|
||||||
|
agent_type_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# First deactivation
|
||||||
|
await client.delete(
|
||||||
|
f"/api/v1/agent-types/{agent_type_id}",
|
||||||
|
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Second deactivation should fail (already deactivated)
|
||||||
|
response = await client.delete(
|
||||||
|
f"/api/v1/agent-types/{agent_type_id}",
|
||||||
|
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Depending on implementation, this might return 404 or 200
|
||||||
|
# Check implementation for expected behavior
|
||||||
|
assert response.status_code in [
|
||||||
|
status.HTTP_200_OK,
|
||||||
|
status.HTTP_404_NOT_FOUND,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
class TestAgentTypeModelParams:
|
||||||
|
"""Tests for model configuration fields."""
|
||||||
|
|
||||||
|
async def test_create_with_full_model_config(self, client, superuser_token):
|
||||||
|
"""Test creating agent type with complete model configuration."""
|
||||||
|
unique_slug = f"full-config-{uuid.uuid4().hex[:8]}"
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/agent-types",
|
||||||
|
json={
|
||||||
|
"name": "Full Config Type",
|
||||||
|
"slug": unique_slug,
|
||||||
|
"description": "Type with full model config",
|
||||||
|
"expertise": ["coding", "architecture"],
|
||||||
|
"personality_prompt": "You are an expert architect.",
|
||||||
|
"primary_model": "claude-3-opus",
|
||||||
|
"fallback_models": ["claude-3-sonnet", "claude-3-haiku"],
|
||||||
|
"model_params": {
|
||||||
|
"temperature": 0.3,
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"top_p": 0.9,
|
||||||
|
},
|
||||||
|
"mcp_servers": ["filesystem", "git"], # List of strings, not objects
|
||||||
|
"tool_permissions": {
|
||||||
|
"read_files": True,
|
||||||
|
"write_files": True,
|
||||||
|
"execute_code": False,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_201_CREATED
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
assert data["primary_model"] == "claude-3-opus"
|
||||||
|
assert data["fallback_models"] == ["claude-3-sonnet", "claude-3-haiku"]
|
||||||
|
assert data["model_params"]["temperature"] == 0.3
|
||||||
|
assert data["model_params"]["max_tokens"] == 4096
|
||||||
|
assert len(data["mcp_servers"]) == 2
|
||||||
|
assert data["tool_permissions"]["read_files"] is True
|
||||||
|
assert data["tool_permissions"]["execute_code"] is False
|
||||||
|
|
||||||
|
async def test_update_model_params(self, client, superuser_token, test_agent_type):
|
||||||
|
"""Test updating model parameters."""
|
||||||
|
agent_type_id = test_agent_type["id"]
|
||||||
|
|
||||||
|
response = await client.patch(
|
||||||
|
f"/api/v1/agent-types/{agent_type_id}",
|
||||||
|
json={
|
||||||
|
"model_params": {"temperature": 0.9, "max_tokens": 2048},
|
||||||
|
"fallback_models": ["claude-3-haiku"],
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
assert data["model_params"]["temperature"] == 0.9
|
||||||
|
assert data["fallback_models"] == ["claude-3-haiku"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
class TestAgentTypeInstanceCount:
|
||||||
|
"""Tests for instance count tracking."""
|
||||||
|
|
||||||
|
async def test_new_agent_type_has_zero_instances(self, client, superuser_token):
|
||||||
|
"""Test that newly created agent types have zero instances."""
|
||||||
|
unique_slug = f"zero-instances-{uuid.uuid4().hex[:8]}"
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/agent-types",
|
||||||
|
json={
|
||||||
|
"name": "Zero Instances Type",
|
||||||
|
"slug": unique_slug,
|
||||||
|
"expertise": ["python"],
|
||||||
|
"personality_prompt": "Prompt",
|
||||||
|
"primary_model": "claude-3-opus",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_201_CREATED
|
||||||
|
data = response.json()
|
||||||
|
assert data["instance_count"] == 0
|
||||||
|
|
||||||
|
async def test_get_agent_type_includes_instance_count(
|
||||||
|
self, client, user_token, test_agent_type
|
||||||
|
):
|
||||||
|
"""Test that getting an agent type includes instance count."""
|
||||||
|
agent_type_id = test_agent_type["id"]
|
||||||
|
|
||||||
|
response = await client.get(
|
||||||
|
f"/api/v1/agent-types/{agent_type_id}",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert "instance_count" in data
|
||||||
|
assert isinstance(data["instance_count"], int)
|
||||||
|
|
||||||
|
async def test_list_agent_types_includes_instance_counts(
|
||||||
|
self, client, user_token, test_agent_type
|
||||||
|
):
|
||||||
|
"""Test that listing agent types includes instance counts."""
|
||||||
|
response = await client.get(
|
||||||
|
"/api/v1/agent-types",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
for agent_type in data["data"]:
|
||||||
|
assert "instance_count" in agent_type
|
||||||
|
assert isinstance(agent_type["instance_count"], int)
|
||||||
976
backend/tests/api/routes/syndarix/test_agents.py
Normal file
976
backend/tests/api/routes/syndarix/test_agents.py
Normal file
@@ -0,0 +1,976 @@
|
|||||||
|
# tests/api/routes/syndarix/test_agents.py
|
||||||
|
"""Tests for agent instance management endpoints.
|
||||||
|
|
||||||
|
Tests cover:
|
||||||
|
- Agent instance CRUD operations
|
||||||
|
- Agent lifecycle management (pause, resume)
|
||||||
|
- Agent status filtering
|
||||||
|
- Agent metrics
|
||||||
|
- Authorization and access control
|
||||||
|
"""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from starlette import status
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def test_project(client, user_token):
|
||||||
|
"""Create a test project for agent tests."""
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/projects",
|
||||||
|
json={
|
||||||
|
"name": "Agent Test Project",
|
||||||
|
"slug": "agent-test-project",
|
||||||
|
"autonomy_level": "milestone",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_201_CREATED
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def test_agent_type(client, superuser_token):
|
||||||
|
"""Create a test agent type for spawning agents."""
|
||||||
|
import uuid as uuid_mod
|
||||||
|
|
||||||
|
unique_slug = f"test-developer-agent-{uuid_mod.uuid4().hex[:8]}"
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/agent-types",
|
||||||
|
json={
|
||||||
|
"name": "Test Developer Agent",
|
||||||
|
"slug": unique_slug,
|
||||||
|
"expertise": ["python", "testing"],
|
||||||
|
"primary_model": "claude-3-opus",
|
||||||
|
"personality_prompt": "You are a helpful developer agent for testing.",
|
||||||
|
"description": "A test developer agent",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_201_CREATED, f"Failed: {response.json()}"
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
class TestSpawnAgent:
|
||||||
|
"""Tests for POST /api/v1/projects/{project_id}/agents endpoint."""
|
||||||
|
|
||||||
|
async def test_spawn_agent_success(
|
||||||
|
self, client, user_token, test_project, test_agent_type
|
||||||
|
):
|
||||||
|
"""Test successfully spawning a new agent."""
|
||||||
|
project_id = test_project["id"]
|
||||||
|
agent_type_id = test_agent_type["id"]
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/agents",
|
||||||
|
json={
|
||||||
|
"project_id": project_id,
|
||||||
|
"agent_type_id": agent_type_id,
|
||||||
|
"name": "My Developer Agent",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_201_CREATED
|
||||||
|
data = response.json()
|
||||||
|
assert data["name"] == "My Developer Agent"
|
||||||
|
assert data["status"] == "idle"
|
||||||
|
assert data["project_id"] == project_id
|
||||||
|
|
||||||
|
async def test_spawn_agent_with_initial_memory(
|
||||||
|
self, client, user_token, test_project, test_agent_type
|
||||||
|
):
|
||||||
|
"""Test spawning agent with initial short-term memory."""
|
||||||
|
project_id = test_project["id"]
|
||||||
|
agent_type_id = test_agent_type["id"]
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/agents",
|
||||||
|
json={
|
||||||
|
"project_id": project_id,
|
||||||
|
"agent_type_id": agent_type_id,
|
||||||
|
"name": "Memory Agent",
|
||||||
|
"short_term_memory": {"context": "test setup"},
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_201_CREATED
|
||||||
|
data = response.json()
|
||||||
|
assert data["short_term_memory"]["context"] == "test setup"
|
||||||
|
|
||||||
|
async def test_spawn_agent_nonexistent_project(
|
||||||
|
self, client, user_token, test_agent_type
|
||||||
|
):
|
||||||
|
"""Test spawning agent in nonexistent project."""
|
||||||
|
fake_project_id = str(uuid.uuid4())
|
||||||
|
agent_type_id = test_agent_type["id"]
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
f"/api/v1/projects/{fake_project_id}/agents",
|
||||||
|
json={
|
||||||
|
"project_id": fake_project_id,
|
||||||
|
"agent_type_id": agent_type_id,
|
||||||
|
"name": "Orphan Agent",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
|
||||||
|
async def test_spawn_agent_nonexistent_type(self, client, user_token, test_project):
|
||||||
|
"""Test spawning agent with nonexistent agent type."""
|
||||||
|
project_id = test_project["id"]
|
||||||
|
fake_type_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/agents",
|
||||||
|
json={
|
||||||
|
"project_id": project_id,
|
||||||
|
"agent_type_id": fake_type_id,
|
||||||
|
"name": "Invalid Type Agent",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
|
||||||
|
async def test_spawn_agent_mismatched_project_id(
|
||||||
|
self, client, user_token, test_project, test_agent_type
|
||||||
|
):
|
||||||
|
"""Test spawning agent with mismatched project_id in body."""
|
||||||
|
project_id = test_project["id"]
|
||||||
|
agent_type_id = test_agent_type["id"]
|
||||||
|
different_project_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/agents",
|
||||||
|
json={
|
||||||
|
"project_id": different_project_id,
|
||||||
|
"agent_type_id": agent_type_id,
|
||||||
|
"name": "Mismatched Agent",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
class TestListAgents:
|
||||||
|
"""Tests for GET /api/v1/projects/{project_id}/agents endpoint."""
|
||||||
|
|
||||||
|
async def test_list_agents_empty(self, client, user_token, test_project):
|
||||||
|
"""Test listing agents when none exist."""
|
||||||
|
project_id = test_project["id"]
|
||||||
|
|
||||||
|
response = await client.get(
|
||||||
|
f"/api/v1/projects/{project_id}/agents",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert data["data"] == []
|
||||||
|
assert data["pagination"]["total"] == 0
|
||||||
|
|
||||||
|
async def test_list_agents_with_data(
|
||||||
|
self, client, user_token, test_project, test_agent_type
|
||||||
|
):
|
||||||
|
"""Test listing agents with data."""
|
||||||
|
project_id = test_project["id"]
|
||||||
|
agent_type_id = test_agent_type["id"]
|
||||||
|
|
||||||
|
# Create agents
|
||||||
|
await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/agents",
|
||||||
|
json={
|
||||||
|
"project_id": project_id,
|
||||||
|
"agent_type_id": agent_type_id,
|
||||||
|
"name": "Agent One",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/agents",
|
||||||
|
json={
|
||||||
|
"project_id": project_id,
|
||||||
|
"agent_type_id": agent_type_id,
|
||||||
|
"name": "Agent Two",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await client.get(
|
||||||
|
f"/api/v1/projects/{project_id}/agents",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert len(data["data"]) == 2
|
||||||
|
assert data["pagination"]["total"] == 2
|
||||||
|
|
||||||
|
async def test_list_agents_filter_by_status(
|
||||||
|
self, client, user_token, test_project, test_agent_type
|
||||||
|
):
|
||||||
|
"""Test filtering agents by status."""
|
||||||
|
project_id = test_project["id"]
|
||||||
|
agent_type_id = test_agent_type["id"]
|
||||||
|
|
||||||
|
# Create agent
|
||||||
|
await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/agents",
|
||||||
|
json={
|
||||||
|
"project_id": project_id,
|
||||||
|
"agent_type_id": agent_type_id,
|
||||||
|
"name": "Idle Agent",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Filter by idle status
|
||||||
|
response = await client.get(
|
||||||
|
f"/api/v1/projects/{project_id}/agents?status=idle",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert all(agent["status"] == "idle" for agent in data["data"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
class TestGetAgent:
|
||||||
|
"""Tests for GET /api/v1/projects/{project_id}/agents/{agent_id} endpoint."""
|
||||||
|
|
||||||
|
async def test_get_agent_success(
|
||||||
|
self, client, user_token, test_project, test_agent_type
|
||||||
|
):
|
||||||
|
"""Test getting agent by ID."""
|
||||||
|
project_id = test_project["id"]
|
||||||
|
agent_type_id = test_agent_type["id"]
|
||||||
|
|
||||||
|
# Create agent
|
||||||
|
create_response = await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/agents",
|
||||||
|
json={
|
||||||
|
"project_id": project_id,
|
||||||
|
"agent_type_id": agent_type_id,
|
||||||
|
"name": "Get Test Agent",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
agent_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Get agent
|
||||||
|
response = await client.get(
|
||||||
|
f"/api/v1/projects/{project_id}/agents/{agent_id}",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert data["id"] == agent_id
|
||||||
|
assert data["name"] == "Get Test Agent"
|
||||||
|
|
||||||
|
async def test_get_agent_not_found(self, client, user_token, test_project):
|
||||||
|
"""Test getting a nonexistent agent."""
|
||||||
|
project_id = test_project["id"]
|
||||||
|
fake_agent_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
response = await client.get(
|
||||||
|
f"/api/v1/projects/{project_id}/agents/{fake_agent_id}",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
class TestUpdateAgent:
|
||||||
|
"""Tests for PATCH /api/v1/projects/{project_id}/agents/{agent_id} endpoint."""
|
||||||
|
|
||||||
|
async def test_update_agent_current_task(
|
||||||
|
self, client, user_token, test_project, test_agent_type
|
||||||
|
):
|
||||||
|
"""Test updating agent current_task."""
|
||||||
|
project_id = test_project["id"]
|
||||||
|
agent_type_id = test_agent_type["id"]
|
||||||
|
|
||||||
|
# Create agent
|
||||||
|
create_response = await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/agents",
|
||||||
|
json={
|
||||||
|
"project_id": project_id,
|
||||||
|
"agent_type_id": agent_type_id,
|
||||||
|
"name": "Task Update Agent",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
agent_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Update current_task
|
||||||
|
response = await client.patch(
|
||||||
|
f"/api/v1/projects/{project_id}/agents/{agent_id}",
|
||||||
|
json={"current_task": "Working on feature #123"},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert data["current_task"] == "Working on feature #123"
|
||||||
|
|
||||||
|
async def test_update_agent_memory(
|
||||||
|
self, client, user_token, test_project, test_agent_type
|
||||||
|
):
|
||||||
|
"""Test updating agent short-term memory."""
|
||||||
|
project_id = test_project["id"]
|
||||||
|
agent_type_id = test_agent_type["id"]
|
||||||
|
|
||||||
|
# Create agent
|
||||||
|
create_response = await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/agents",
|
||||||
|
json={
|
||||||
|
"project_id": project_id,
|
||||||
|
"agent_type_id": agent_type_id,
|
||||||
|
"name": "Memory Update Agent",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
agent_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Update memory
|
||||||
|
response = await client.patch(
|
||||||
|
f"/api/v1/projects/{project_id}/agents/{agent_id}",
|
||||||
|
json={"short_term_memory": {"last_context": "updated", "step": 2}},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert data["short_term_memory"]["last_context"] == "updated"
|
||||||
|
assert data["short_term_memory"]["step"] == 2
|
||||||
|
|
||||||
|
async def test_update_agent_not_found(self, client, user_token, test_project):
|
||||||
|
"""Test updating a nonexistent agent."""
|
||||||
|
project_id = test_project["id"]
|
||||||
|
fake_agent_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
response = await client.patch(
|
||||||
|
f"/api/v1/projects/{project_id}/agents/{fake_agent_id}",
|
||||||
|
json={"current_task": "Some task"},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
class TestAgentLifecycle:
|
||||||
|
"""Tests for agent lifecycle management endpoints."""
|
||||||
|
|
||||||
|
async def test_pause_agent(self, client, user_token, test_project, test_agent_type):
|
||||||
|
"""Test pausing an agent."""
|
||||||
|
project_id = test_project["id"]
|
||||||
|
agent_type_id = test_agent_type["id"]
|
||||||
|
|
||||||
|
# Create agent
|
||||||
|
create_response = await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/agents",
|
||||||
|
json={
|
||||||
|
"project_id": project_id,
|
||||||
|
"agent_type_id": agent_type_id,
|
||||||
|
"name": "Pause Test Agent",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
agent_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Pause agent
|
||||||
|
response = await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/agents/{agent_id}/pause",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert data["status"] == "paused"
|
||||||
|
|
||||||
|
async def test_resume_paused_agent(
|
||||||
|
self, client, user_token, test_project, test_agent_type
|
||||||
|
):
|
||||||
|
"""Test resuming a paused agent."""
|
||||||
|
project_id = test_project["id"]
|
||||||
|
agent_type_id = test_agent_type["id"]
|
||||||
|
|
||||||
|
# Create and pause agent
|
||||||
|
create_response = await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/agents",
|
||||||
|
json={
|
||||||
|
"project_id": project_id,
|
||||||
|
"agent_type_id": agent_type_id,
|
||||||
|
"name": "Resume Test Agent",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
agent_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Pause first
|
||||||
|
await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/agents/{agent_id}/pause",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Resume agent
|
||||||
|
response = await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/agents/{agent_id}/resume",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert data["status"] == "idle"
|
||||||
|
|
||||||
|
async def test_pause_nonexistent_agent(self, client, user_token, test_project):
|
||||||
|
"""Test pausing a nonexistent agent."""
|
||||||
|
project_id = test_project["id"]
|
||||||
|
fake_agent_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/agents/{fake_agent_id}/pause",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
class TestDeleteAgent:
|
||||||
|
"""Tests for DELETE /api/v1/projects/{project_id}/agents/{agent_id} endpoint."""
|
||||||
|
|
||||||
|
async def test_delete_agent_success(
|
||||||
|
self, client, user_token, test_project, test_agent_type
|
||||||
|
):
|
||||||
|
"""Test deleting an agent."""
|
||||||
|
project_id = test_project["id"]
|
||||||
|
agent_type_id = test_agent_type["id"]
|
||||||
|
|
||||||
|
# Create agent
|
||||||
|
create_response = await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/agents",
|
||||||
|
json={
|
||||||
|
"project_id": project_id,
|
||||||
|
"agent_type_id": agent_type_id,
|
||||||
|
"name": "Delete Test Agent",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
agent_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Delete agent
|
||||||
|
response = await client.delete(
|
||||||
|
f"/api/v1/projects/{project_id}/agents/{agent_id}",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
assert response.json()["success"] is True
|
||||||
|
|
||||||
|
async def test_delete_agent_not_found(self, client, user_token, test_project):
|
||||||
|
"""Test deleting a nonexistent agent."""
|
||||||
|
project_id = test_project["id"]
|
||||||
|
fake_agent_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
response = await client.delete(
|
||||||
|
f"/api/v1/projects/{project_id}/agents/{fake_agent_id}",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
class TestAgentMetrics:
|
||||||
|
"""Tests for agent metrics endpoints."""
|
||||||
|
|
||||||
|
async def test_get_agent_metrics(
|
||||||
|
self, client, user_token, test_project, test_agent_type
|
||||||
|
):
|
||||||
|
"""Test getting metrics for a single agent."""
|
||||||
|
project_id = test_project["id"]
|
||||||
|
agent_type_id = test_agent_type["id"]
|
||||||
|
|
||||||
|
# Create agent
|
||||||
|
create_response = await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/agents",
|
||||||
|
json={
|
||||||
|
"project_id": project_id,
|
||||||
|
"agent_type_id": agent_type_id,
|
||||||
|
"name": "Metrics Test Agent",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
agent_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Get metrics
|
||||||
|
response = await client.get(
|
||||||
|
f"/api/v1/projects/{project_id}/agents/{agent_id}/metrics",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
# AgentInstanceMetrics schema
|
||||||
|
assert "total_instances" in data
|
||||||
|
assert "total_tasks_completed" in data
|
||||||
|
assert "total_tokens_used" in data
|
||||||
|
assert "total_cost_incurred" in data
|
||||||
|
|
||||||
|
async def test_get_project_agents_metrics(
|
||||||
|
self, client, user_token, test_project, test_agent_type
|
||||||
|
):
|
||||||
|
"""Test getting metrics for all agents in a project."""
|
||||||
|
project_id = test_project["id"]
|
||||||
|
agent_type_id = test_agent_type["id"]
|
||||||
|
|
||||||
|
# Create agents
|
||||||
|
await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/agents",
|
||||||
|
json={
|
||||||
|
"project_id": project_id,
|
||||||
|
"agent_type_id": agent_type_id,
|
||||||
|
"name": "Metrics Agent 1",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/agents",
|
||||||
|
json={
|
||||||
|
"project_id": project_id,
|
||||||
|
"agent_type_id": agent_type_id,
|
||||||
|
"name": "Metrics Agent 2",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get project-wide metrics
|
||||||
|
response = await client.get(
|
||||||
|
f"/api/v1/projects/{project_id}/agents/metrics",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
class TestAgentAuthorization:
|
||||||
|
"""Tests for agent authorization."""
|
||||||
|
|
||||||
|
async def test_superuser_can_manage_any_project_agents(
|
||||||
|
self, client, user_token, superuser_token, test_project, test_agent_type
|
||||||
|
):
|
||||||
|
"""Test that superuser can manage agents in any project."""
|
||||||
|
project_id = test_project["id"]
|
||||||
|
agent_type_id = test_agent_type["id"]
|
||||||
|
|
||||||
|
# Create agent as superuser in user's project
|
||||||
|
response = await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/agents",
|
||||||
|
json={
|
||||||
|
"project_id": project_id,
|
||||||
|
"agent_type_id": agent_type_id,
|
||||||
|
"name": "Superuser Created Agent",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_201_CREATED
|
||||||
|
|
||||||
|
async def test_user_cannot_access_other_project_agents(
|
||||||
|
self, client, user_token, superuser_token, test_agent_type
|
||||||
|
):
|
||||||
|
"""Test that user cannot access agents in another user's project."""
|
||||||
|
# Create a project as superuser (not owned by regular user)
|
||||||
|
project_response = await client.post(
|
||||||
|
"/api/v1/projects",
|
||||||
|
json={
|
||||||
|
"name": "Other User Project",
|
||||||
|
"slug": f"other-user-project-{uuid.uuid4().hex[:8]}",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||||
|
)
|
||||||
|
other_project_id = project_response.json()["id"]
|
||||||
|
|
||||||
|
# Regular user tries to list agents - should fail
|
||||||
|
response = await client.get(
|
||||||
|
f"/api/v1/projects/{other_project_id}/agents",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
class TestSpawnAgentEdgeCases:
|
||||||
|
"""Tests for agent spawn edge cases."""
|
||||||
|
|
||||||
|
async def test_spawn_agent_with_inactive_agent_type(
|
||||||
|
self, client, user_token, superuser_token, test_project
|
||||||
|
):
|
||||||
|
"""Test spawning agent with an inactive agent type fails."""
|
||||||
|
project_id = test_project["id"]
|
||||||
|
|
||||||
|
# Create an inactive agent type
|
||||||
|
unique_slug = f"inactive-agent-type-{uuid.uuid4().hex[:8]}"
|
||||||
|
create_response = await client.post(
|
||||||
|
"/api/v1/agent-types",
|
||||||
|
json={
|
||||||
|
"name": "Inactive Agent Type",
|
||||||
|
"slug": unique_slug,
|
||||||
|
"expertise": ["testing"],
|
||||||
|
"primary_model": "claude-3-opus",
|
||||||
|
"personality_prompt": "Test inactive agent.",
|
||||||
|
"description": "An inactive agent type for testing",
|
||||||
|
"is_active": False,
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||||
|
)
|
||||||
|
assert create_response.status_code == status.HTTP_201_CREATED
|
||||||
|
inactive_type_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Try to spawn agent with inactive type
|
||||||
|
response = await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/agents",
|
||||||
|
json={
|
||||||
|
"project_id": project_id,
|
||||||
|
"agent_type_id": inactive_type_id,
|
||||||
|
"name": "Agent With Inactive Type",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||||
|
# Error response uses standardized format with "errors" list
|
||||||
|
data = response.json()
|
||||||
|
assert "errors" in data
|
||||||
|
assert any("inactive" in err["message"].lower() for err in data["errors"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
class TestAgentWrongProject:
|
||||||
|
"""Tests for agent operations when agent belongs to different project."""
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def two_projects_with_agent(
|
||||||
|
self, client, user_token, superuser_token, test_agent_type
|
||||||
|
):
|
||||||
|
"""Create two projects and an agent in project1."""
|
||||||
|
# Create project1
|
||||||
|
resp1 = await client.post(
|
||||||
|
"/api/v1/projects",
|
||||||
|
json={
|
||||||
|
"name": "Project One",
|
||||||
|
"slug": f"project-one-{uuid.uuid4().hex[:8]}",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
project1 = resp1.json()
|
||||||
|
|
||||||
|
# Create project2
|
||||||
|
resp2 = await client.post(
|
||||||
|
"/api/v1/projects",
|
||||||
|
json={
|
||||||
|
"name": "Project Two",
|
||||||
|
"slug": f"project-two-{uuid.uuid4().hex[:8]}",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
project2 = resp2.json()
|
||||||
|
|
||||||
|
# Create agent in project1
|
||||||
|
agent_resp = await client.post(
|
||||||
|
f"/api/v1/projects/{project1['id']}/agents",
|
||||||
|
json={
|
||||||
|
"project_id": project1["id"],
|
||||||
|
"agent_type_id": test_agent_type["id"],
|
||||||
|
"name": "Project1 Agent",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
agent = agent_resp.json()
|
||||||
|
|
||||||
|
return {"project1": project1, "project2": project2, "agent": agent}
|
||||||
|
|
||||||
|
async def test_get_agent_wrong_project(
|
||||||
|
self, client, user_token, two_projects_with_agent
|
||||||
|
):
|
||||||
|
"""Test getting an agent via wrong project returns 404."""
|
||||||
|
data = two_projects_with_agent
|
||||||
|
agent_id = data["agent"]["id"]
|
||||||
|
wrong_project_id = data["project2"]["id"]
|
||||||
|
|
||||||
|
response = await client.get(
|
||||||
|
f"/api/v1/projects/{wrong_project_id}/agents/{agent_id}",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
|
||||||
|
async def test_update_agent_wrong_project(
|
||||||
|
self, client, user_token, two_projects_with_agent
|
||||||
|
):
|
||||||
|
"""Test updating an agent via wrong project returns 404."""
|
||||||
|
data = two_projects_with_agent
|
||||||
|
agent_id = data["agent"]["id"]
|
||||||
|
wrong_project_id = data["project2"]["id"]
|
||||||
|
|
||||||
|
response = await client.patch(
|
||||||
|
f"/api/v1/projects/{wrong_project_id}/agents/{agent_id}",
|
||||||
|
json={"current_task": "Test task"},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
|
||||||
|
async def test_pause_agent_wrong_project(
|
||||||
|
self, client, user_token, two_projects_with_agent
|
||||||
|
):
|
||||||
|
"""Test pausing an agent via wrong project returns 404."""
|
||||||
|
data = two_projects_with_agent
|
||||||
|
agent_id = data["agent"]["id"]
|
||||||
|
wrong_project_id = data["project2"]["id"]
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
f"/api/v1/projects/{wrong_project_id}/agents/{agent_id}/pause",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
|
||||||
|
async def test_resume_agent_wrong_project(
|
||||||
|
self, client, user_token, two_projects_with_agent
|
||||||
|
):
|
||||||
|
"""Test resuming an agent via wrong project returns 404."""
|
||||||
|
data = two_projects_with_agent
|
||||||
|
project1_id = data["project1"]["id"]
|
||||||
|
agent_id = data["agent"]["id"]
|
||||||
|
wrong_project_id = data["project2"]["id"]
|
||||||
|
|
||||||
|
# First pause the agent using correct project
|
||||||
|
await client.post(
|
||||||
|
f"/api/v1/projects/{project1_id}/agents/{agent_id}/pause",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try to resume via wrong project
|
||||||
|
response = await client.post(
|
||||||
|
f"/api/v1/projects/{wrong_project_id}/agents/{agent_id}/resume",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
|
||||||
|
async def test_terminate_agent_wrong_project(
|
||||||
|
self, client, user_token, two_projects_with_agent
|
||||||
|
):
|
||||||
|
"""Test terminating an agent via wrong project returns 404."""
|
||||||
|
data = two_projects_with_agent
|
||||||
|
agent_id = data["agent"]["id"]
|
||||||
|
wrong_project_id = data["project2"]["id"]
|
||||||
|
|
||||||
|
response = await client.delete(
|
||||||
|
f"/api/v1/projects/{wrong_project_id}/agents/{agent_id}",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
|
||||||
|
async def test_get_agent_metrics_wrong_project(
|
||||||
|
self, client, user_token, two_projects_with_agent
|
||||||
|
):
|
||||||
|
"""Test getting agent metrics via wrong project returns 404."""
|
||||||
|
data = two_projects_with_agent
|
||||||
|
agent_id = data["agent"]["id"]
|
||||||
|
wrong_project_id = data["project2"]["id"]
|
||||||
|
|
||||||
|
response = await client.get(
|
||||||
|
f"/api/v1/projects/{wrong_project_id}/agents/{agent_id}/metrics",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
class TestAgentStatusTransitions:
|
||||||
|
"""Tests for invalid agent status transitions."""
|
||||||
|
|
||||||
|
async def test_terminate_already_terminated_agent(
|
||||||
|
self, client, user_token, test_project, test_agent_type
|
||||||
|
):
|
||||||
|
"""Test terminating an already terminated agent fails."""
|
||||||
|
project_id = test_project["id"]
|
||||||
|
agent_type_id = test_agent_type["id"]
|
||||||
|
|
||||||
|
# Create agent
|
||||||
|
create_response = await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/agents",
|
||||||
|
json={
|
||||||
|
"project_id": project_id,
|
||||||
|
"agent_type_id": agent_type_id,
|
||||||
|
"name": "Double Terminate Agent",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
agent_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Terminate once
|
||||||
|
first_terminate = await client.delete(
|
||||||
|
f"/api/v1/projects/{project_id}/agents/{agent_id}",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
assert first_terminate.status_code == status.HTTP_200_OK
|
||||||
|
|
||||||
|
# Try to terminate again
|
||||||
|
response = await client.delete(
|
||||||
|
f"/api/v1/projects/{project_id}/agents/{agent_id}",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||||
|
data = response.json()
|
||||||
|
assert "errors" in data
|
||||||
|
assert any("terminated" in err["message"].lower() for err in data["errors"])
|
||||||
|
|
||||||
|
async def test_resume_idle_agent(
|
||||||
|
self, client, user_token, test_project, test_agent_type
|
||||||
|
):
|
||||||
|
"""Test resuming an agent that is not paused fails."""
|
||||||
|
project_id = test_project["id"]
|
||||||
|
agent_type_id = test_agent_type["id"]
|
||||||
|
|
||||||
|
# Create agent (starts in idle state)
|
||||||
|
create_response = await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/agents",
|
||||||
|
json={
|
||||||
|
"project_id": project_id,
|
||||||
|
"agent_type_id": agent_type_id,
|
||||||
|
"name": "Resume Idle Agent",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
agent_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Try to resume without pausing first
|
||||||
|
response = await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/agents/{agent_id}/resume",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should fail since agent is not paused
|
||||||
|
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||||
|
|
||||||
|
async def test_pause_already_paused_agent(
|
||||||
|
self, client, user_token, test_project, test_agent_type
|
||||||
|
):
|
||||||
|
"""Test pausing an already paused agent fails."""
|
||||||
|
project_id = test_project["id"]
|
||||||
|
agent_type_id = test_agent_type["id"]
|
||||||
|
|
||||||
|
# Create agent
|
||||||
|
create_response = await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/agents",
|
||||||
|
json={
|
||||||
|
"project_id": project_id,
|
||||||
|
"agent_type_id": agent_type_id,
|
||||||
|
"name": "Double Pause Agent",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
agent_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Pause once
|
||||||
|
first_pause = await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/agents/{agent_id}/pause",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
assert first_pause.status_code == status.HTTP_200_OK
|
||||||
|
|
||||||
|
# Try to pause again
|
||||||
|
response = await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/agents/{agent_id}/pause",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||||
|
|
||||||
|
async def test_pause_terminated_agent(
|
||||||
|
self, client, user_token, test_project, test_agent_type
|
||||||
|
):
|
||||||
|
"""Test pausing a terminated agent fails."""
|
||||||
|
project_id = test_project["id"]
|
||||||
|
agent_type_id = test_agent_type["id"]
|
||||||
|
|
||||||
|
# Create agent
|
||||||
|
create_response = await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/agents",
|
||||||
|
json={
|
||||||
|
"project_id": project_id,
|
||||||
|
"agent_type_id": agent_type_id,
|
||||||
|
"name": "Pause Terminated Agent",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
agent_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Terminate agent
|
||||||
|
await client.delete(
|
||||||
|
f"/api/v1/projects/{project_id}/agents/{agent_id}",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try to pause terminated agent
|
||||||
|
response = await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/agents/{agent_id}/pause",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||||
|
|
||||||
|
async def test_resume_terminated_agent(
|
||||||
|
self, client, user_token, test_project, test_agent_type
|
||||||
|
):
|
||||||
|
"""Test resuming a terminated agent fails."""
|
||||||
|
project_id = test_project["id"]
|
||||||
|
agent_type_id = test_agent_type["id"]
|
||||||
|
|
||||||
|
# Create agent
|
||||||
|
create_response = await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/agents",
|
||||||
|
json={
|
||||||
|
"project_id": project_id,
|
||||||
|
"agent_type_id": agent_type_id,
|
||||||
|
"name": "Resume Terminated Agent",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
agent_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Terminate agent
|
||||||
|
await client.delete(
|
||||||
|
f"/api/v1/projects/{project_id}/agents/{agent_id}",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try to resume terminated agent
|
||||||
|
response = await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/agents/{agent_id}/resume",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||||
1095
backend/tests/api/routes/syndarix/test_edge_cases.py
Normal file
1095
backend/tests/api/routes/syndarix/test_edge_cases.py
Normal file
File diff suppressed because it is too large
Load Diff
995
backend/tests/api/routes/syndarix/test_issues.py
Normal file
995
backend/tests/api/routes/syndarix/test_issues.py
Normal file
@@ -0,0 +1,995 @@
|
|||||||
|
# tests/api/routes/syndarix/test_issues.py
|
||||||
|
"""
|
||||||
|
Comprehensive tests for the Issues API endpoints.
|
||||||
|
|
||||||
|
Tests cover:
|
||||||
|
- CRUD operations (create, read, update, delete)
|
||||||
|
- Issue filtering and search
|
||||||
|
- Issue assignment
|
||||||
|
- Issue statistics
|
||||||
|
- Authorization checks
|
||||||
|
"""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from fastapi import status
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def test_project(client, user_token):
|
||||||
|
"""Create a test project for issue tests."""
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/projects",
|
||||||
|
json={"name": "Issue Test Project", "slug": "issue-test-project"},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def superuser_project(client, superuser_token):
|
||||||
|
"""Create a project owned by superuser."""
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/projects",
|
||||||
|
json={"name": "Superuser Project", "slug": "superuser-project"},
|
||||||
|
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||||
|
)
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
class TestCreateIssue:
|
||||||
|
"""Tests for POST /api/v1/projects/{project_id}/issues endpoint."""
|
||||||
|
|
||||||
|
async def test_create_issue_success(self, client, user_token, test_project):
|
||||||
|
"""Test successful issue creation."""
|
||||||
|
project_id = test_project["id"]
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/issues",
|
||||||
|
json={
|
||||||
|
"project_id": project_id,
|
||||||
|
"title": "Test Issue",
|
||||||
|
"body": "This is a test issue description",
|
||||||
|
"priority": "medium",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_201_CREATED
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
assert data["title"] == "Test Issue"
|
||||||
|
assert data["body"] == "This is a test issue description"
|
||||||
|
assert data["priority"] == "medium"
|
||||||
|
assert data["status"] == "open"
|
||||||
|
assert data["project_id"] == project_id
|
||||||
|
assert "id" in data
|
||||||
|
assert "created_at" in data
|
||||||
|
|
||||||
|
async def test_create_issue_minimal_fields(self, client, user_token, test_project):
|
||||||
|
"""Test creating issue with only required fields."""
|
||||||
|
project_id = test_project["id"]
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/issues",
|
||||||
|
json={
|
||||||
|
"project_id": project_id,
|
||||||
|
"title": "Minimal Issue",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_201_CREATED
|
||||||
|
data = response.json()
|
||||||
|
assert data["title"] == "Minimal Issue"
|
||||||
|
assert data["body"] == "" # Body defaults to empty string
|
||||||
|
assert data["status"] == "open"
|
||||||
|
|
||||||
|
async def test_create_issue_with_labels(self, client, user_token, test_project):
|
||||||
|
"""Test creating issue with labels."""
|
||||||
|
project_id = test_project["id"]
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/issues",
|
||||||
|
json={
|
||||||
|
"project_id": project_id,
|
||||||
|
"title": "Labeled Issue",
|
||||||
|
"labels": ["bug", "urgent", "frontend"],
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_201_CREATED
|
||||||
|
data = response.json()
|
||||||
|
assert "bug" in data["labels"]
|
||||||
|
assert "urgent" in data["labels"]
|
||||||
|
assert "frontend" in data["labels"]
|
||||||
|
|
||||||
|
async def test_create_issue_with_story_points(
|
||||||
|
self, client, user_token, test_project
|
||||||
|
):
|
||||||
|
"""Test creating issue with story points."""
|
||||||
|
project_id = test_project["id"]
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/issues",
|
||||||
|
json={
|
||||||
|
"project_id": project_id,
|
||||||
|
"title": "Story Points Issue",
|
||||||
|
"story_points": 5,
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_201_CREATED
|
||||||
|
data = response.json()
|
||||||
|
assert data["story_points"] == 5
|
||||||
|
|
||||||
|
async def test_create_issue_unauthorized_project(
|
||||||
|
self, client, user_token, superuser_project
|
||||||
|
):
|
||||||
|
"""Test that users cannot create issues in others' projects."""
|
||||||
|
project_id = superuser_project["id"]
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/issues",
|
||||||
|
json={
|
||||||
|
"project_id": project_id,
|
||||||
|
"title": "Unauthorized Issue",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||||
|
|
||||||
|
async def test_create_issue_nonexistent_project(self, client, user_token):
|
||||||
|
"""Test creating issue in nonexistent project."""
|
||||||
|
fake_project_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
f"/api/v1/projects/{fake_project_id}/issues",
|
||||||
|
json={
|
||||||
|
"project_id": fake_project_id,
|
||||||
|
"title": "Orphan Issue",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
class TestListIssues:
|
||||||
|
"""Tests for GET /api/v1/projects/{project_id}/issues endpoint."""
|
||||||
|
|
||||||
|
async def test_list_issues_empty(self, client, user_token, test_project):
|
||||||
|
"""Test listing issues when none exist."""
|
||||||
|
project_id = test_project["id"]
|
||||||
|
|
||||||
|
response = await client.get(
|
||||||
|
f"/api/v1/projects/{project_id}/issues",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert data["data"] == []
|
||||||
|
assert data["pagination"]["total"] == 0
|
||||||
|
|
||||||
|
async def test_list_issues_with_data(self, client, user_token, test_project):
|
||||||
|
"""Test listing issues returns created issues."""
|
||||||
|
project_id = test_project["id"]
|
||||||
|
|
||||||
|
# Create multiple issues
|
||||||
|
for i in range(3):
|
||||||
|
await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/issues",
|
||||||
|
json={
|
||||||
|
"project_id": project_id,
|
||||||
|
"title": f"Issue {i + 1}",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await client.get(
|
||||||
|
f"/api/v1/projects/{project_id}/issues",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert len(data["data"]) == 3
|
||||||
|
assert data["pagination"]["total"] == 3
|
||||||
|
|
||||||
|
async def test_list_issues_filter_by_status(self, client, user_token, test_project):
|
||||||
|
"""Test filtering issues by status."""
|
||||||
|
project_id = test_project["id"]
|
||||||
|
|
||||||
|
# Create issues with different statuses
|
||||||
|
await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/issues",
|
||||||
|
json={
|
||||||
|
"project_id": project_id,
|
||||||
|
"title": "Open Issue",
|
||||||
|
"status": "open",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/issues",
|
||||||
|
json={
|
||||||
|
"project_id": project_id,
|
||||||
|
"title": "Closed Issue",
|
||||||
|
"status": "closed",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Filter by open
|
||||||
|
response = await client.get(
|
||||||
|
f"/api/v1/projects/{project_id}/issues?status=open",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert len(data["data"]) == 1
|
||||||
|
assert data["data"][0]["status"] == "open"
|
||||||
|
|
||||||
|
async def test_list_issues_filter_by_priority(
|
||||||
|
self, client, user_token, test_project
|
||||||
|
):
|
||||||
|
"""Test filtering issues by priority."""
|
||||||
|
project_id = test_project["id"]
|
||||||
|
|
||||||
|
# Create issues with different priorities
|
||||||
|
await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/issues",
|
||||||
|
json={
|
||||||
|
"project_id": project_id,
|
||||||
|
"title": "High Priority Issue",
|
||||||
|
"priority": "high",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/issues",
|
||||||
|
json={
|
||||||
|
"project_id": project_id,
|
||||||
|
"title": "Low Priority Issue",
|
||||||
|
"priority": "low",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Filter by high priority
|
||||||
|
response = await client.get(
|
||||||
|
f"/api/v1/projects/{project_id}/issues?priority=high",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert len(data["data"]) == 1
|
||||||
|
assert data["data"][0]["priority"] == "high"
|
||||||
|
|
||||||
|
async def test_list_issues_search(self, client, user_token, test_project):
|
||||||
|
"""Test searching issues by title/body."""
|
||||||
|
project_id = test_project["id"]
|
||||||
|
|
||||||
|
await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/issues",
|
||||||
|
json={
|
||||||
|
"project_id": project_id,
|
||||||
|
"title": "Authentication Bug",
|
||||||
|
"body": "Users cannot login",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/issues",
|
||||||
|
json={
|
||||||
|
"project_id": project_id,
|
||||||
|
"title": "UI Enhancement",
|
||||||
|
"body": "Improve dashboard layout",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Search for authentication
|
||||||
|
response = await client.get(
|
||||||
|
f"/api/v1/projects/{project_id}/issues?search=authentication",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert len(data["data"]) == 1
|
||||||
|
assert "Authentication" in data["data"][0]["title"]
|
||||||
|
|
||||||
|
async def test_list_issues_pagination(self, client, user_token, test_project):
|
||||||
|
"""Test pagination works correctly."""
|
||||||
|
project_id = test_project["id"]
|
||||||
|
|
||||||
|
# Create 5 issues
|
||||||
|
for i in range(5):
|
||||||
|
await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/issues",
|
||||||
|
json={
|
||||||
|
"project_id": project_id,
|
||||||
|
"title": f"Issue {i + 1}",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get first page (2 items)
|
||||||
|
response = await client.get(
|
||||||
|
f"/api/v1/projects/{project_id}/issues?page=1&limit=2",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert len(data["data"]) == 2
|
||||||
|
assert data["pagination"]["total"] == 5
|
||||||
|
assert data["pagination"]["page"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
class TestGetIssue:
|
||||||
|
"""Tests for GET /api/v1/projects/{project_id}/issues/{issue_id} endpoint."""
|
||||||
|
|
||||||
|
async def test_get_issue_success(self, client, user_token, test_project):
|
||||||
|
"""Test getting an issue by ID."""
|
||||||
|
project_id = test_project["id"]
|
||||||
|
|
||||||
|
# Create issue
|
||||||
|
create_response = await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/issues",
|
||||||
|
json={
|
||||||
|
"project_id": project_id,
|
||||||
|
"title": "Get Test Issue",
|
||||||
|
"body": "Test description",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
issue_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Get issue
|
||||||
|
response = await client.get(
|
||||||
|
f"/api/v1/projects/{project_id}/issues/{issue_id}",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert data["id"] == issue_id
|
||||||
|
assert data["title"] == "Get Test Issue"
|
||||||
|
|
||||||
|
async def test_get_issue_not_found(self, client, user_token, test_project):
|
||||||
|
"""Test getting a nonexistent issue."""
|
||||||
|
project_id = test_project["id"]
|
||||||
|
fake_issue_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
response = await client.get(
|
||||||
|
f"/api/v1/projects/{project_id}/issues/{fake_issue_id}",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
class TestUpdateIssue:
|
||||||
|
"""Tests for PATCH /api/v1/projects/{project_id}/issues/{issue_id} endpoint."""
|
||||||
|
|
||||||
|
async def test_update_issue_success(self, client, user_token, test_project):
|
||||||
|
"""Test updating an issue."""
|
||||||
|
project_id = test_project["id"]
|
||||||
|
|
||||||
|
# Create issue
|
||||||
|
create_response = await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/issues",
|
||||||
|
json={
|
||||||
|
"project_id": project_id,
|
||||||
|
"title": "Original Title",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
issue_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Update issue
|
||||||
|
response = await client.patch(
|
||||||
|
f"/api/v1/projects/{project_id}/issues/{issue_id}",
|
||||||
|
json={"title": "Updated Title", "body": "New description"},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert data["title"] == "Updated Title"
|
||||||
|
assert data["body"] == "New description"
|
||||||
|
|
||||||
|
async def test_update_issue_status(self, client, user_token, test_project):
|
||||||
|
"""Test updating issue status."""
|
||||||
|
project_id = test_project["id"]
|
||||||
|
|
||||||
|
# Create issue
|
||||||
|
create_response = await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/issues",
|
||||||
|
json={
|
||||||
|
"project_id": project_id,
|
||||||
|
"title": "Status Test Issue",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
issue_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Update status to in_progress
|
||||||
|
response = await client.patch(
|
||||||
|
f"/api/v1/projects/{project_id}/issues/{issue_id}",
|
||||||
|
json={"status": "in_progress"},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
assert response.json()["status"] == "in_progress"
|
||||||
|
|
||||||
|
async def test_update_issue_priority(self, client, user_token, test_project):
|
||||||
|
"""Test updating issue priority."""
|
||||||
|
project_id = test_project["id"]
|
||||||
|
|
||||||
|
# Create issue with low priority
|
||||||
|
create_response = await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/issues",
|
||||||
|
json={
|
||||||
|
"project_id": project_id,
|
||||||
|
"title": "Priority Test Issue",
|
||||||
|
"priority": "low",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
issue_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Update to critical
|
||||||
|
response = await client.patch(
|
||||||
|
f"/api/v1/projects/{project_id}/issues/{issue_id}",
|
||||||
|
json={"priority": "critical"},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
assert response.json()["priority"] == "critical"
|
||||||
|
|
||||||
|
async def test_update_issue_not_found(self, client, user_token, test_project):
|
||||||
|
"""Test updating a nonexistent issue."""
|
||||||
|
project_id = test_project["id"]
|
||||||
|
fake_issue_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
response = await client.patch(
|
||||||
|
f"/api/v1/projects/{project_id}/issues/{fake_issue_id}",
|
||||||
|
json={"title": "Updated Title"},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
class TestDeleteIssue:
|
||||||
|
"""Tests for DELETE /api/v1/projects/{project_id}/issues/{issue_id} endpoint."""
|
||||||
|
|
||||||
|
async def test_delete_issue_success(self, client, user_token, test_project):
|
||||||
|
"""Test deleting an issue."""
|
||||||
|
project_id = test_project["id"]
|
||||||
|
|
||||||
|
# Create issue
|
||||||
|
create_response = await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/issues",
|
||||||
|
json={
|
||||||
|
"project_id": project_id,
|
||||||
|
"title": "Delete Test Issue",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
issue_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Delete issue
|
||||||
|
response = await client.delete(
|
||||||
|
f"/api/v1/projects/{project_id}/issues/{issue_id}",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
assert response.json()["success"] is True
|
||||||
|
|
||||||
|
async def test_delete_issue_not_found(self, client, user_token, test_project):
|
||||||
|
"""Test deleting a nonexistent issue."""
|
||||||
|
project_id = test_project["id"]
|
||||||
|
fake_issue_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
response = await client.delete(
|
||||||
|
f"/api/v1/projects/{project_id}/issues/{fake_issue_id}",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
class TestIssueStats:
|
||||||
|
"""Tests for GET /api/v1/projects/{project_id}/issues/stats endpoint."""
|
||||||
|
|
||||||
|
async def test_get_issue_stats_empty(self, client, user_token, test_project):
|
||||||
|
"""Test getting stats when no issues exist."""
|
||||||
|
project_id = test_project["id"]
|
||||||
|
|
||||||
|
response = await client.get(
|
||||||
|
f"/api/v1/projects/{project_id}/issues/stats",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert data["total"] == 0
|
||||||
|
assert data["open"] == 0
|
||||||
|
assert data["in_progress"] == 0
|
||||||
|
assert data["in_review"] == 0
|
||||||
|
assert data["blocked"] == 0
|
||||||
|
assert data["closed"] == 0
|
||||||
|
|
||||||
|
async def test_get_issue_stats_with_data(self, client, user_token, test_project):
|
||||||
|
"""Test getting stats with issues."""
|
||||||
|
project_id = test_project["id"]
|
||||||
|
|
||||||
|
# Create issues with different statuses and priorities
|
||||||
|
await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/issues",
|
||||||
|
json={
|
||||||
|
"project_id": project_id,
|
||||||
|
"title": "Open High Issue",
|
||||||
|
"status": "open",
|
||||||
|
"priority": "high",
|
||||||
|
"story_points": 5,
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/issues",
|
||||||
|
json={
|
||||||
|
"project_id": project_id,
|
||||||
|
"title": "Closed Low Issue",
|
||||||
|
"status": "closed",
|
||||||
|
"priority": "low",
|
||||||
|
"story_points": 3,
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await client.get(
|
||||||
|
f"/api/v1/projects/{project_id}/issues/stats",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert data["total"] == 2
|
||||||
|
assert data["open"] == 1
|
||||||
|
assert data["closed"] == 1
|
||||||
|
assert data["by_priority"]["high"] == 1
|
||||||
|
assert data["by_priority"]["low"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
class TestIssueAuthorization:
|
||||||
|
"""Tests for issue authorization."""
|
||||||
|
|
||||||
|
async def test_superuser_can_manage_any_project_issues(
|
||||||
|
self, client, user_token, superuser_token, test_project
|
||||||
|
):
|
||||||
|
"""Test that superuser can manage issues in any project."""
|
||||||
|
project_id = test_project["id"]
|
||||||
|
|
||||||
|
# Create issue as superuser in user's project
|
||||||
|
response = await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/issues",
|
||||||
|
json={
|
||||||
|
"project_id": project_id,
|
||||||
|
"title": "Superuser Issue",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_201_CREATED
|
||||||
|
|
||||||
|
async def test_user_cannot_access_other_project_issues(
|
||||||
|
self, client, user_token, superuser_project
|
||||||
|
):
|
||||||
|
"""Test that users cannot access issues in others' projects."""
|
||||||
|
project_id = superuser_project["id"]
|
||||||
|
|
||||||
|
response = await client.get(
|
||||||
|
f"/api/v1/projects/{project_id}/issues",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
class TestIssueAssignment:
|
||||||
|
"""Tests for issue assignment endpoints."""
|
||||||
|
|
||||||
|
async def test_assign_issue_to_human(self, client, user_token, test_project):
|
||||||
|
"""Test assigning an issue to a human."""
|
||||||
|
project_id = test_project["id"]
|
||||||
|
|
||||||
|
# Create issue
|
||||||
|
create_response = await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/issues",
|
||||||
|
json={
|
||||||
|
"project_id": project_id,
|
||||||
|
"title": "Issue to Assign",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
issue_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Assign to human
|
||||||
|
response = await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/issues/{issue_id}/assign",
|
||||||
|
json={"human_assignee": "john.doe@example.com"},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert data["human_assignee"] == "john.doe@example.com"
|
||||||
|
|
||||||
|
async def test_unassign_issue(self, client, user_token, test_project):
|
||||||
|
"""Test unassigning an issue."""
|
||||||
|
project_id = test_project["id"]
|
||||||
|
|
||||||
|
# Create issue and assign
|
||||||
|
create_response = await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/issues",
|
||||||
|
json={
|
||||||
|
"project_id": project_id,
|
||||||
|
"title": "Issue to Unassign",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
issue_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/issues/{issue_id}/assign",
|
||||||
|
json={"human_assignee": "john.doe@example.com"},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Unassign
|
||||||
|
response = await client.delete(
|
||||||
|
f"/api/v1/projects/{project_id}/issues/{issue_id}/assignment",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
# After unassign, assigned_agent_id should be None
|
||||||
|
# Note: human_assignee may or may not be cleared depending on implementation
|
||||||
|
assert data["assigned_agent_id"] is None
|
||||||
|
|
||||||
|
async def test_assign_issue_not_found(self, client, user_token, test_project):
|
||||||
|
"""Test assigning a nonexistent issue."""
|
||||||
|
project_id = test_project["id"]
|
||||||
|
fake_issue_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/issues/{fake_issue_id}/assign",
|
||||||
|
json={"human_assignee": "john.doe@example.com"},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
|
||||||
|
async def test_unassign_issue_not_found(self, client, user_token, test_project):
|
||||||
|
"""Test unassigning a nonexistent issue."""
|
||||||
|
project_id = test_project["id"]
|
||||||
|
fake_issue_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
response = await client.delete(
|
||||||
|
f"/api/v1/projects/{project_id}/issues/{fake_issue_id}/assignment",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
|
||||||
|
async def test_assign_issue_clears_assignment(
|
||||||
|
self, client, user_token, test_project
|
||||||
|
):
|
||||||
|
"""Test that assigning to null clears both assignments."""
|
||||||
|
project_id = test_project["id"]
|
||||||
|
|
||||||
|
# Create issue and assign
|
||||||
|
create_response = await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/issues",
|
||||||
|
json={
|
||||||
|
"project_id": project_id,
|
||||||
|
"title": "Issue to Clear",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
issue_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/issues/{issue_id}/assign",
|
||||||
|
json={"human_assignee": "john.doe@example.com"},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Clear assignment by sending empty object
|
||||||
|
response = await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/issues/{issue_id}/assign",
|
||||||
|
json={},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
class TestIssueSync:
|
||||||
|
"""Tests for issue sync endpoint."""
|
||||||
|
|
||||||
|
async def test_sync_issue_no_tracker(self, client, user_token, test_project):
|
||||||
|
"""Test syncing an issue without external tracker."""
|
||||||
|
project_id = test_project["id"]
|
||||||
|
|
||||||
|
# Create issue without external tracker
|
||||||
|
create_response = await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/issues",
|
||||||
|
json={
|
||||||
|
"project_id": project_id,
|
||||||
|
"title": "Issue without Tracker",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
issue_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Try to sync
|
||||||
|
response = await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/issues/{issue_id}/sync",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should fail because no external tracker configured
|
||||||
|
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||||
|
|
||||||
|
async def test_sync_issue_not_found(self, client, user_token, test_project):
|
||||||
|
"""Test syncing a nonexistent issue."""
|
||||||
|
project_id = test_project["id"]
|
||||||
|
fake_issue_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/issues/{fake_issue_id}/sync",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
class TestIssueCrossProjectValidation:
|
||||||
|
"""Tests for cross-project validation (IDOR prevention)."""
|
||||||
|
|
||||||
|
async def test_issue_not_in_project(self, client, user_token):
|
||||||
|
"""Test accessing issue that exists but not in the specified project."""
|
||||||
|
# Create two projects
|
||||||
|
project1 = await client.post(
|
||||||
|
"/api/v1/projects",
|
||||||
|
json={"name": "Project 1", "slug": "project-1-idor"},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
project2 = await client.post(
|
||||||
|
"/api/v1/projects",
|
||||||
|
json={"name": "Project 2", "slug": "project-2-idor"},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
project1_id = project1.json()["id"]
|
||||||
|
project2_id = project2.json()["id"]
|
||||||
|
|
||||||
|
# Create issue in project1
|
||||||
|
issue_response = await client.post(
|
||||||
|
f"/api/v1/projects/{project1_id}/issues",
|
||||||
|
json={
|
||||||
|
"project_id": project1_id,
|
||||||
|
"title": "Project 1 Issue",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
issue_id = issue_response.json()["id"]
|
||||||
|
|
||||||
|
# Try to access issue via project2 (IDOR attempt)
|
||||||
|
response = await client.get(
|
||||||
|
f"/api/v1/projects/{project2_id}/issues/{issue_id}",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
|
||||||
|
async def test_update_issue_wrong_project(self, client, user_token):
|
||||||
|
"""Test updating issue through wrong project."""
|
||||||
|
# Create two projects
|
||||||
|
project1 = await client.post(
|
||||||
|
"/api/v1/projects",
|
||||||
|
json={"name": "Project A", "slug": "project-a-idor"},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
project2 = await client.post(
|
||||||
|
"/api/v1/projects",
|
||||||
|
json={"name": "Project B", "slug": "project-b-idor"},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
project1_id = project1.json()["id"]
|
||||||
|
project2_id = project2.json()["id"]
|
||||||
|
|
||||||
|
# Create issue in project1
|
||||||
|
issue_response = await client.post(
|
||||||
|
f"/api/v1/projects/{project1_id}/issues",
|
||||||
|
json={
|
||||||
|
"project_id": project1_id,
|
||||||
|
"title": "Project A Issue",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
issue_id = issue_response.json()["id"]
|
||||||
|
|
||||||
|
# Try to update issue via project2 (IDOR attempt)
|
||||||
|
response = await client.patch(
|
||||||
|
f"/api/v1/projects/{project2_id}/issues/{issue_id}",
|
||||||
|
json={"title": "Hacked Title"},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
|
||||||
|
async def test_delete_issue_wrong_project(self, client, user_token):
|
||||||
|
"""Test deleting issue through wrong project."""
|
||||||
|
# Create two projects
|
||||||
|
project1 = await client.post(
|
||||||
|
"/api/v1/projects",
|
||||||
|
json={"name": "Project X", "slug": "project-x-idor"},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
project2 = await client.post(
|
||||||
|
"/api/v1/projects",
|
||||||
|
json={"name": "Project Y", "slug": "project-y-idor"},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
project1_id = project1.json()["id"]
|
||||||
|
project2_id = project2.json()["id"]
|
||||||
|
|
||||||
|
# Create issue in project1
|
||||||
|
issue_response = await client.post(
|
||||||
|
f"/api/v1/projects/{project1_id}/issues",
|
||||||
|
json={
|
||||||
|
"project_id": project1_id,
|
||||||
|
"title": "Project X Issue",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
issue_id = issue_response.json()["id"]
|
||||||
|
|
||||||
|
# Try to delete issue via project2 (IDOR attempt)
|
||||||
|
response = await client.delete(
|
||||||
|
f"/api/v1/projects/{project2_id}/issues/{issue_id}",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
class TestIssueValidation:
|
||||||
|
"""Tests for issue validation during create/update."""
|
||||||
|
|
||||||
|
async def test_create_issue_invalid_priority(
|
||||||
|
self, client, user_token, test_project
|
||||||
|
):
|
||||||
|
"""Test creating issue with invalid priority."""
|
||||||
|
project_id = test_project["id"]
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/issues",
|
||||||
|
json={
|
||||||
|
"project_id": project_id,
|
||||||
|
"title": "Issue with Invalid Priority",
|
||||||
|
"priority": "invalid_priority",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||||
|
|
||||||
|
async def test_create_issue_invalid_status(self, client, user_token, test_project):
|
||||||
|
"""Test creating issue with invalid status."""
|
||||||
|
project_id = test_project["id"]
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/issues",
|
||||||
|
json={
|
||||||
|
"project_id": project_id,
|
||||||
|
"title": "Issue with Invalid Status",
|
||||||
|
"status": "invalid_status",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||||
|
|
||||||
|
async def test_update_issue_invalid_priority(
|
||||||
|
self, client, user_token, test_project
|
||||||
|
):
|
||||||
|
"""Test updating issue with invalid priority."""
|
||||||
|
project_id = test_project["id"]
|
||||||
|
|
||||||
|
# Create issue
|
||||||
|
create_response = await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/issues",
|
||||||
|
json={
|
||||||
|
"project_id": project_id,
|
||||||
|
"title": "Issue to Update",
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
issue_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Update with invalid priority
|
||||||
|
response = await client.patch(
|
||||||
|
f"/api/v1/projects/{project_id}/issues/{issue_id}",
|
||||||
|
json={"priority": "invalid_priority"},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||||
|
|
||||||
|
async def test_create_issue_with_nonexistent_sprint(
|
||||||
|
self, client, user_token, test_project
|
||||||
|
):
|
||||||
|
"""Test creating issue with nonexistent sprint ID."""
|
||||||
|
project_id = test_project["id"]
|
||||||
|
fake_sprint_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/issues",
|
||||||
|
json={
|
||||||
|
"project_id": project_id,
|
||||||
|
"title": "Issue with Fake Sprint",
|
||||||
|
"sprint_id": fake_sprint_id,
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
|
||||||
|
async def test_create_issue_with_nonexistent_agent(
|
||||||
|
self, client, user_token, test_project
|
||||||
|
):
|
||||||
|
"""Test creating issue with nonexistent agent ID."""
|
||||||
|
project_id = test_project["id"]
|
||||||
|
fake_agent_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
f"/api/v1/projects/{project_id}/issues",
|
||||||
|
json={
|
||||||
|
"project_id": project_id,
|
||||||
|
"title": "Issue with Fake Agent",
|
||||||
|
"assigned_agent_id": fake_agent_id,
|
||||||
|
},
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
1048
backend/tests/api/routes/syndarix/test_projects.py
Normal file
1048
backend/tests/api/routes/syndarix/test_projects.py
Normal file
File diff suppressed because it is too large
Load Diff
1541
backend/tests/api/routes/syndarix/test_sprints.py
Normal file
1541
backend/tests/api/routes/syndarix/test_sprints.py
Normal file
File diff suppressed because it is too large
Load Diff
583
backend/tests/api/routes/test_events.py
Normal file
583
backend/tests/api/routes/test_events.py
Normal file
@@ -0,0 +1,583 @@
|
|||||||
|
"""
|
||||||
|
Tests for the SSE events endpoint.
|
||||||
|
|
||||||
|
This module tests the Server-Sent Events endpoint for project event streaming,
|
||||||
|
including:
|
||||||
|
- Authentication and authorization
|
||||||
|
- SSE stream connection and format
|
||||||
|
- Keepalive mechanism
|
||||||
|
- Reconnection support (Last-Event-ID)
|
||||||
|
- Connection cleanup
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from fastapi import status
|
||||||
|
from httpx import ASGITransport, AsyncClient
|
||||||
|
|
||||||
|
from app.api.dependencies.event_bus import get_event_bus
|
||||||
|
from app.core.database import get_db
|
||||||
|
from app.crud.syndarix.project import project as project_crud
|
||||||
|
from app.main import app
|
||||||
|
from app.schemas.events import Event, EventType
|
||||||
|
from app.schemas.syndarix.project import ProjectCreate
|
||||||
|
from app.services.event_bus import EventBus
|
||||||
|
|
||||||
|
|
||||||
|
class MockEventBus:
|
||||||
|
"""Mock EventBus for testing without Redis."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.published_events: list[Event] = []
|
||||||
|
self._should_yield_events = True
|
||||||
|
self._events_to_yield: list[str] = []
|
||||||
|
self._connected = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_connected(self) -> bool:
|
||||||
|
return self._connected
|
||||||
|
|
||||||
|
async def connect(self) -> None:
|
||||||
|
self._connected = True
|
||||||
|
|
||||||
|
async def disconnect(self) -> None:
|
||||||
|
self._connected = False
|
||||||
|
|
||||||
|
def get_project_channel(self, project_id: uuid.UUID | str) -> str:
|
||||||
|
"""Get the channel name for a project."""
|
||||||
|
return f"project:{project_id}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_event(
|
||||||
|
event_type: EventType,
|
||||||
|
project_id: uuid.UUID,
|
||||||
|
actor_type: str,
|
||||||
|
payload: dict | None = None,
|
||||||
|
actor_id: uuid.UUID | None = None,
|
||||||
|
event_id: str | None = None,
|
||||||
|
timestamp: datetime | None = None,
|
||||||
|
) -> Event:
|
||||||
|
"""Create a new Event."""
|
||||||
|
return Event(
|
||||||
|
id=event_id or str(uuid.uuid4()),
|
||||||
|
type=event_type,
|
||||||
|
timestamp=timestamp or datetime.now(UTC),
|
||||||
|
project_id=project_id,
|
||||||
|
actor_id=actor_id,
|
||||||
|
actor_type=actor_type,
|
||||||
|
payload=payload or {},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def publish(self, channel: str, event: Event) -> int:
|
||||||
|
"""Publish an event to a channel."""
|
||||||
|
self.published_events.append(event)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
def add_event_to_yield(self, event_json: str) -> None:
|
||||||
|
"""Add an event JSON string to be yielded by subscribe_sse."""
|
||||||
|
self._events_to_yield.append(event_json)
|
||||||
|
|
||||||
|
async def subscribe_sse(
|
||||||
|
self,
|
||||||
|
project_id: str | uuid.UUID,
|
||||||
|
last_event_id: str | None = None,
|
||||||
|
keepalive_interval: int = 30,
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
"""Mock subscribe_sse that yields pre-configured events then keepalive."""
|
||||||
|
# First yield any pre-configured events
|
||||||
|
for event_data in self._events_to_yield:
|
||||||
|
yield event_data
|
||||||
|
|
||||||
|
# Then yield keepalive
|
||||||
|
yield ""
|
||||||
|
|
||||||
|
# Then stop to allow test to complete
|
||||||
|
self._should_yield_events = False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def mock_event_bus():
|
||||||
|
"""Create a mock event bus for testing."""
|
||||||
|
return MockEventBus()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def client_with_mock_bus(async_test_db, mock_event_bus):
|
||||||
|
"""
|
||||||
|
Create a FastAPI test client with mocked database and event bus.
|
||||||
|
"""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async def override_get_db():
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
finally:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def override_get_event_bus():
|
||||||
|
return mock_event_bus
|
||||||
|
|
||||||
|
app.dependency_overrides[get_db] = override_get_db
|
||||||
|
app.dependency_overrides[get_event_bus] = override_get_event_bus
|
||||||
|
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url="http://test") as test_client:
|
||||||
|
yield test_client
|
||||||
|
|
||||||
|
app.dependency_overrides.clear()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def user_token_with_mock_bus(client_with_mock_bus, async_test_user):
|
||||||
|
"""Create an access token for the test user."""
|
||||||
|
response = await client_with_mock_bus.post(
|
||||||
|
"/api/v1/auth/login",
|
||||||
|
json={
|
||||||
|
"email": async_test_user.email,
|
||||||
|
"password": "TestPassword123!",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200, f"Login failed: {response.text}"
|
||||||
|
tokens = response.json()
|
||||||
|
return tokens["access_token"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def test_project_for_events(async_test_db, async_test_user):
|
||||||
|
"""Create a test project owned by the test user for events testing."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
project_in = ProjectCreate(
|
||||||
|
name="Test Events Project",
|
||||||
|
slug="test-events-project",
|
||||||
|
owner_id=async_test_user.id,
|
||||||
|
)
|
||||||
|
project = await project_crud.create(session, obj_in=project_in)
|
||||||
|
return project
|
||||||
|
|
||||||
|
|
||||||
|
class TestSSEEndpointAuthentication:
|
||||||
|
"""Tests for SSE endpoint authentication."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stream_events_requires_authentication(self, client_with_mock_bus):
|
||||||
|
"""Test that SSE endpoint requires authentication."""
|
||||||
|
project_id = uuid.uuid4()
|
||||||
|
|
||||||
|
response = await client_with_mock_bus.get(
|
||||||
|
f"/api/v1/projects/{project_id}/events/stream",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stream_events_with_invalid_token(self, client_with_mock_bus):
|
||||||
|
"""Test that SSE endpoint rejects invalid tokens."""
|
||||||
|
project_id = uuid.uuid4()
|
||||||
|
|
||||||
|
response = await client_with_mock_bus.get(
|
||||||
|
f"/api/v1/projects/{project_id}/events/stream",
|
||||||
|
headers={"Authorization": "Bearer invalid_token"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
|
|
||||||
|
class TestSSEEndpointAuthorization:
|
||||||
|
"""Tests for SSE endpoint authorization."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stream_events_nonexistent_project_returns_403(
|
||||||
|
self, client_with_mock_bus, user_token_with_mock_bus
|
||||||
|
):
|
||||||
|
"""Test that accessing a non-existent project returns 403."""
|
||||||
|
nonexistent_project_id = uuid.uuid4()
|
||||||
|
|
||||||
|
response = await client_with_mock_bus.get(
|
||||||
|
f"/api/v1/projects/{nonexistent_project_id}/events/stream",
|
||||||
|
headers={"Authorization": f"Bearer {user_token_with_mock_bus}"},
|
||||||
|
timeout=5.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should return 403 because project doesn't exist (auth check fails)
|
||||||
|
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stream_events_other_users_project_returns_403(
|
||||||
|
self, client_with_mock_bus, user_token_with_mock_bus, async_test_db
|
||||||
|
):
|
||||||
|
"""Test that accessing another user's project returns 403."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
# Create a project owned by a different user
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
other_user_id = uuid.uuid4() # Simulated other user
|
||||||
|
project_in = ProjectCreate(
|
||||||
|
name="Other User's Project",
|
||||||
|
slug="other-users-project",
|
||||||
|
owner_id=other_user_id,
|
||||||
|
)
|
||||||
|
other_project = await project_crud.create(session, obj_in=project_in)
|
||||||
|
|
||||||
|
response = await client_with_mock_bus.get(
|
||||||
|
f"/api/v1/projects/{other_project.id}/events/stream",
|
||||||
|
headers={"Authorization": f"Bearer {user_token_with_mock_bus}"},
|
||||||
|
timeout=5.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should return 403 because user doesn't own the project
|
||||||
|
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_test_event_nonexistent_project_returns_403(
|
||||||
|
self, client_with_mock_bus, user_token_with_mock_bus
|
||||||
|
):
|
||||||
|
"""Test that sending event to non-existent project returns 403."""
|
||||||
|
nonexistent_project_id = uuid.uuid4()
|
||||||
|
|
||||||
|
response = await client_with_mock_bus.post(
|
||||||
|
f"/api/v1/projects/{nonexistent_project_id}/events/test",
|
||||||
|
headers={"Authorization": f"Bearer {user_token_with_mock_bus}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||||
|
|
||||||
|
|
||||||
|
class TestSSEEndpointStream:
|
||||||
|
"""Tests for SSE stream functionality."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stream_events_returns_sse_response(
|
||||||
|
self, client_with_mock_bus, user_token_with_mock_bus, test_project_for_events
|
||||||
|
):
|
||||||
|
"""Test that SSE endpoint returns proper SSE response."""
|
||||||
|
project_id = test_project_for_events.id
|
||||||
|
|
||||||
|
# Make request with a timeout to avoid hanging
|
||||||
|
response = await client_with_mock_bus.get(
|
||||||
|
f"/api/v1/projects/{project_id}/events/stream",
|
||||||
|
headers={"Authorization": f"Bearer {user_token_with_mock_bus}"},
|
||||||
|
timeout=5.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# The response should start streaming
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
assert "text/event-stream" in response.headers.get("content-type", "")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stream_events_with_events(
|
||||||
|
self,
|
||||||
|
client_with_mock_bus,
|
||||||
|
user_token_with_mock_bus,
|
||||||
|
mock_event_bus,
|
||||||
|
test_project_for_events,
|
||||||
|
):
|
||||||
|
"""Test that SSE endpoint yields events."""
|
||||||
|
project_id = test_project_for_events.id
|
||||||
|
|
||||||
|
# Create a test event and add it to the mock bus
|
||||||
|
test_event = Event(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
type=EventType.AGENT_MESSAGE,
|
||||||
|
timestamp=datetime.now(UTC),
|
||||||
|
project_id=project_id,
|
||||||
|
actor_type="agent",
|
||||||
|
payload={"message": "test"},
|
||||||
|
)
|
||||||
|
mock_event_bus.add_event_to_yield(test_event.model_dump_json())
|
||||||
|
|
||||||
|
# Request the stream
|
||||||
|
response = await client_with_mock_bus.get(
|
||||||
|
f"/api/v1/projects/{project_id}/events/stream",
|
||||||
|
headers={"Authorization": f"Bearer {user_token_with_mock_bus}"},
|
||||||
|
timeout=5.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
|
||||||
|
# Check response contains event data
|
||||||
|
content = response.text
|
||||||
|
assert "agent.message" in content or "data:" in content
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stream_events_with_last_event_id(
|
||||||
|
self, client_with_mock_bus, user_token_with_mock_bus, test_project_for_events
|
||||||
|
):
|
||||||
|
"""Test that Last-Event-ID header is accepted."""
|
||||||
|
project_id = test_project_for_events.id
|
||||||
|
last_event_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
response = await client_with_mock_bus.get(
|
||||||
|
f"/api/v1/projects/{project_id}/events/stream",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {user_token_with_mock_bus}",
|
||||||
|
"Last-Event-ID": last_event_id,
|
||||||
|
},
|
||||||
|
timeout=5.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should accept the header and return OK
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
|
||||||
|
|
||||||
|
class TestSSEEndpointHeaders:
|
||||||
|
"""Tests for SSE response headers."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stream_events_cache_control_header(
|
||||||
|
self, client_with_mock_bus, user_token_with_mock_bus, test_project_for_events
|
||||||
|
):
|
||||||
|
"""Test that SSE response has no-cache header."""
|
||||||
|
project_id = test_project_for_events.id
|
||||||
|
|
||||||
|
response = await client_with_mock_bus.get(
|
||||||
|
f"/api/v1/projects/{project_id}/events/stream",
|
||||||
|
headers={"Authorization": f"Bearer {user_token_with_mock_bus}"},
|
||||||
|
timeout=5.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
cache_control = response.headers.get("cache-control", "")
|
||||||
|
assert "no-cache" in cache_control.lower()
|
||||||
|
|
||||||
|
|
||||||
|
class TestTestEventEndpoint:
|
||||||
|
"""Tests for the test event endpoint."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_test_event_requires_auth(self, client_with_mock_bus):
|
||||||
|
"""Test that test event endpoint requires authentication."""
|
||||||
|
project_id = uuid.uuid4()
|
||||||
|
|
||||||
|
response = await client_with_mock_bus.post(
|
||||||
|
f"/api/v1/projects/{project_id}/events/test",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_test_event_success(
|
||||||
|
self,
|
||||||
|
client_with_mock_bus,
|
||||||
|
user_token_with_mock_bus,
|
||||||
|
mock_event_bus,
|
||||||
|
test_project_for_events,
|
||||||
|
):
|
||||||
|
"""Test sending a test event."""
|
||||||
|
project_id = test_project_for_events.id
|
||||||
|
|
||||||
|
response = await client_with_mock_bus.post(
|
||||||
|
f"/api/v1/projects/{project_id}/events/test",
|
||||||
|
headers={"Authorization": f"Bearer {user_token_with_mock_bus}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert data["success"] is True
|
||||||
|
assert "event_id" in data
|
||||||
|
assert data["event_type"] == "agent.message"
|
||||||
|
|
||||||
|
# Verify event was published
|
||||||
|
assert len(mock_event_bus.published_events) == 1
|
||||||
|
published = mock_event_bus.published_events[0]
|
||||||
|
assert published.type == EventType.AGENT_MESSAGE
|
||||||
|
assert published.project_id == project_id
|
||||||
|
|
||||||
|
|
||||||
|
class TestEventSchema:
|
||||||
|
"""Tests for the Event schema."""
|
||||||
|
|
||||||
|
def test_event_creation(self):
|
||||||
|
"""Test Event creation with required fields."""
|
||||||
|
project_id = uuid.uuid4()
|
||||||
|
event = Event(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
type=EventType.AGENT_MESSAGE,
|
||||||
|
timestamp=datetime.now(UTC),
|
||||||
|
project_id=project_id,
|
||||||
|
actor_type="agent",
|
||||||
|
payload={"message": "test"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert event.id is not None
|
||||||
|
assert event.type == EventType.AGENT_MESSAGE
|
||||||
|
assert event.project_id == project_id
|
||||||
|
assert event.actor_type == "agent"
|
||||||
|
assert event.payload == {"message": "test"}
|
||||||
|
|
||||||
|
def test_event_json_serialization(self):
|
||||||
|
"""Test Event JSON serialization."""
|
||||||
|
project_id = uuid.uuid4()
|
||||||
|
event = Event(
|
||||||
|
id="test-id",
|
||||||
|
type=EventType.AGENT_STATUS_CHANGED,
|
||||||
|
timestamp=datetime.now(UTC),
|
||||||
|
project_id=project_id,
|
||||||
|
actor_type="system",
|
||||||
|
payload={"status": "running"},
|
||||||
|
)
|
||||||
|
|
||||||
|
json_str = event.model_dump_json()
|
||||||
|
parsed = json.loads(json_str)
|
||||||
|
|
||||||
|
assert parsed["id"] == "test-id"
|
||||||
|
assert parsed["type"] == "agent.status_changed"
|
||||||
|
assert str(parsed["project_id"]) == str(project_id)
|
||||||
|
assert parsed["payload"]["status"] == "running"
|
||||||
|
|
||||||
|
|
||||||
|
class TestEventBusUnit:
|
||||||
|
"""Unit tests for EventBus class."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_event_bus_not_connected_raises(self):
|
||||||
|
"""Test that accessing redis_client before connect raises."""
|
||||||
|
from app.services.event_bus import EventBusConnectionError
|
||||||
|
|
||||||
|
bus = EventBus()
|
||||||
|
|
||||||
|
with pytest.raises(EventBusConnectionError, match="not connected"):
|
||||||
|
_ = bus.redis_client
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_event_bus_channel_names(self):
|
||||||
|
"""Test channel name generation."""
|
||||||
|
bus = EventBus()
|
||||||
|
project_id = uuid.uuid4()
|
||||||
|
agent_id = uuid.uuid4()
|
||||||
|
user_id = uuid.uuid4()
|
||||||
|
|
||||||
|
assert bus.get_project_channel(project_id) == f"project:{project_id}"
|
||||||
|
assert bus.get_agent_channel(agent_id) == f"agent:{agent_id}"
|
||||||
|
assert bus.get_user_channel(user_id) == f"user:{user_id}"
|
||||||
|
|
||||||
|
def test_event_bus_create_event(self):
|
||||||
|
"""Test EventBus.create_event factory method."""
|
||||||
|
project_id = uuid.uuid4()
|
||||||
|
actor_id = uuid.uuid4()
|
||||||
|
|
||||||
|
event = EventBus.create_event(
|
||||||
|
event_type=EventType.ISSUE_CREATED,
|
||||||
|
project_id=project_id,
|
||||||
|
actor_type="user",
|
||||||
|
actor_id=actor_id,
|
||||||
|
payload={"title": "Test Issue"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert event.type == EventType.ISSUE_CREATED
|
||||||
|
assert event.project_id == project_id
|
||||||
|
assert event.actor_id == actor_id
|
||||||
|
assert event.actor_type == "user"
|
||||||
|
assert event.payload == {"title": "Test Issue"}
|
||||||
|
|
||||||
|
|
||||||
|
class TestEventBusIntegration:
|
||||||
|
"""Integration tests for EventBus with mocked Redis."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_event_bus_connect_disconnect(self):
|
||||||
|
"""Test EventBus connect and disconnect."""
|
||||||
|
with patch("app.services.event_bus.redis.from_url") as mock_redis:
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_redis.return_value = mock_client
|
||||||
|
mock_client.ping = AsyncMock()
|
||||||
|
mock_client.pubsub = lambda: AsyncMock()
|
||||||
|
|
||||||
|
bus = EventBus(redis_url="redis://localhost:6379/0")
|
||||||
|
|
||||||
|
# Connect
|
||||||
|
await bus.connect()
|
||||||
|
mock_client.ping.assert_called_once()
|
||||||
|
assert bus._redis_client is not None
|
||||||
|
assert bus.is_connected
|
||||||
|
|
||||||
|
# Disconnect
|
||||||
|
await bus.disconnect()
|
||||||
|
mock_client.aclose.assert_called_once()
|
||||||
|
assert bus._redis_client is None
|
||||||
|
assert not bus.is_connected
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_event_bus_publish(self):
|
||||||
|
"""Test EventBus event publishing."""
|
||||||
|
with patch("app.services.event_bus.redis.from_url") as mock_redis:
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_redis.return_value = mock_client
|
||||||
|
mock_client.ping = AsyncMock()
|
||||||
|
mock_client.publish = AsyncMock(return_value=1)
|
||||||
|
mock_client.pubsub = lambda: AsyncMock()
|
||||||
|
|
||||||
|
bus = EventBus()
|
||||||
|
await bus.connect()
|
||||||
|
|
||||||
|
project_id = uuid.uuid4()
|
||||||
|
event = EventBus.create_event(
|
||||||
|
event_type=EventType.AGENT_SPAWNED,
|
||||||
|
project_id=project_id,
|
||||||
|
actor_type="system",
|
||||||
|
payload={"agent_name": "test-agent"},
|
||||||
|
)
|
||||||
|
|
||||||
|
channel = bus.get_project_channel(project_id)
|
||||||
|
result = await bus.publish(channel, event)
|
||||||
|
|
||||||
|
# Verify publish was called
|
||||||
|
mock_client.publish.assert_called_once()
|
||||||
|
call_args = mock_client.publish.call_args
|
||||||
|
|
||||||
|
# Check channel name
|
||||||
|
assert call_args[0][0] == f"project:{project_id}"
|
||||||
|
|
||||||
|
# Check result
|
||||||
|
assert result == 1
|
||||||
|
|
||||||
|
await bus.disconnect()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_event_bus_connect_failure(self):
|
||||||
|
"""Test EventBus handles connection failure."""
|
||||||
|
from app.services.event_bus import EventBusConnectionError
|
||||||
|
|
||||||
|
with patch("app.services.event_bus.redis.from_url") as mock_redis:
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_redis.return_value = mock_client
|
||||||
|
|
||||||
|
import redis.asyncio as redis_async
|
||||||
|
|
||||||
|
mock_client.ping = AsyncMock(
|
||||||
|
side_effect=redis_async.ConnectionError("Connection refused")
|
||||||
|
)
|
||||||
|
|
||||||
|
bus = EventBus()
|
||||||
|
|
||||||
|
with pytest.raises(EventBusConnectionError, match="Failed to connect"):
|
||||||
|
await bus.connect()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_event_bus_already_connected(self):
|
||||||
|
"""Test EventBus connect when already connected is a no-op."""
|
||||||
|
with patch("app.services.event_bus.redis.from_url") as mock_redis:
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_redis.return_value = mock_client
|
||||||
|
mock_client.ping = AsyncMock()
|
||||||
|
mock_client.pubsub = lambda: AsyncMock()
|
||||||
|
|
||||||
|
bus = EventBus()
|
||||||
|
|
||||||
|
# First connect
|
||||||
|
await bus.connect()
|
||||||
|
assert mock_client.ping.call_count == 1
|
||||||
|
|
||||||
|
# Second connect should be a no-op
|
||||||
|
await bus.connect()
|
||||||
|
assert mock_client.ping.call_count == 1
|
||||||
|
|
||||||
|
await bus.disconnect()
|
||||||
491
backend/tests/api/routes/test_mcp.py
Normal file
491
backend/tests/api/routes/test_mcp.py
Normal file
@@ -0,0 +1,491 @@
|
|||||||
|
"""
|
||||||
|
Tests for MCP API Routes
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import status
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from app.main import app
|
||||||
|
from app.models.user import User
|
||||||
|
from app.services.mcp import (
|
||||||
|
MCPCircuitOpenError,
|
||||||
|
MCPClientManager,
|
||||||
|
MCPConnectionError,
|
||||||
|
MCPServerNotFoundError,
|
||||||
|
MCPTimeoutError,
|
||||||
|
MCPToolNotFoundError,
|
||||||
|
ServerHealth,
|
||||||
|
)
|
||||||
|
from app.services.mcp.config import MCPServerConfig, TransportType
|
||||||
|
from app.services.mcp.routing import ToolInfo, ToolResult
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_mcp_client():
|
||||||
|
"""Create a mock MCP client manager."""
|
||||||
|
client = MagicMock(spec=MCPClientManager)
|
||||||
|
client.is_initialized = True
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_superuser():
|
||||||
|
"""Create a mock superuser."""
|
||||||
|
user = MagicMock(spec=User)
|
||||||
|
user.id = "00000000-0000-0000-0000-000000000001"
|
||||||
|
user.is_superuser = True
|
||||||
|
user.email = "admin@example.com"
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client(mock_mcp_client, mock_superuser):
|
||||||
|
"""Create a FastAPI test client with mocked dependencies."""
|
||||||
|
from app.api.routes.mcp import get_mcp_client
|
||||||
|
from app.api.dependencies.permissions import require_superuser
|
||||||
|
|
||||||
|
# Override dependencies
|
||||||
|
async def override_get_mcp_client():
|
||||||
|
return mock_mcp_client
|
||||||
|
|
||||||
|
async def override_require_superuser():
|
||||||
|
return mock_superuser
|
||||||
|
|
||||||
|
app.dependency_overrides[get_mcp_client] = override_get_mcp_client
|
||||||
|
app.dependency_overrides[require_superuser] = override_require_superuser
|
||||||
|
|
||||||
|
with patch("app.main.check_database_health", return_value=True):
|
||||||
|
yield TestClient(app)
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
app.dependency_overrides.clear()
|
||||||
|
|
||||||
|
|
||||||
|
class TestListServers:
|
||||||
|
"""Tests for GET /mcp/servers endpoint."""
|
||||||
|
|
||||||
|
def test_list_servers_success(self, client, mock_mcp_client):
|
||||||
|
"""Test listing MCP servers returns correct data."""
|
||||||
|
# Setup mock
|
||||||
|
mock_mcp_client.list_servers.return_value = ["server-1", "server-2"]
|
||||||
|
mock_mcp_client.get_server_config.side_effect = [
|
||||||
|
MCPServerConfig(
|
||||||
|
url="http://server1:8000",
|
||||||
|
timeout=30,
|
||||||
|
enabled=True,
|
||||||
|
transport=TransportType.HTTP,
|
||||||
|
description="Server 1",
|
||||||
|
),
|
||||||
|
MCPServerConfig(
|
||||||
|
url="http://server2:8000",
|
||||||
|
timeout=60,
|
||||||
|
enabled=True,
|
||||||
|
transport=TransportType.SSE,
|
||||||
|
description="Server 2",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
response = client.get("/api/v1/mcp/servers")
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert data["total"] == 2
|
||||||
|
assert len(data["servers"]) == 2
|
||||||
|
assert data["servers"][0]["name"] == "server-1"
|
||||||
|
assert data["servers"][0]["url"] == "http://server1:8000"
|
||||||
|
assert data["servers"][1]["name"] == "server-2"
|
||||||
|
assert data["servers"][1]["transport"] == "sse"
|
||||||
|
|
||||||
|
def test_list_servers_empty(self, client, mock_mcp_client):
|
||||||
|
"""Test listing servers when none are registered."""
|
||||||
|
mock_mcp_client.list_servers.return_value = []
|
||||||
|
|
||||||
|
response = client.get("/api/v1/mcp/servers")
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert data["total"] == 0
|
||||||
|
assert data["servers"] == []
|
||||||
|
|
||||||
|
def test_list_servers_handles_not_found(self, client, mock_mcp_client):
|
||||||
|
"""Test that missing server configs are skipped gracefully."""
|
||||||
|
mock_mcp_client.list_servers.return_value = ["server-1", "missing"]
|
||||||
|
mock_mcp_client.get_server_config.side_effect = [
|
||||||
|
MCPServerConfig(url="http://server1:8000"),
|
||||||
|
MCPServerNotFoundError(server_name="missing"),
|
||||||
|
]
|
||||||
|
|
||||||
|
response = client.get("/api/v1/mcp/servers")
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
# Should only include the successfully retrieved server
|
||||||
|
assert data["total"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestListServerTools:
|
||||||
|
"""Tests for GET /mcp/servers/{server_name}/tools endpoint."""
|
||||||
|
|
||||||
|
def test_list_server_tools_success(self, client, mock_mcp_client):
|
||||||
|
"""Test listing tools for a specific server."""
|
||||||
|
mock_mcp_client.list_tools = AsyncMock(
|
||||||
|
return_value=[
|
||||||
|
ToolInfo(name="tool1", description="Tool 1", server_name="server-1"),
|
||||||
|
ToolInfo(name="tool2", description="Tool 2", server_name="server-1"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.get("/api/v1/mcp/servers/server-1/tools")
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert data["total"] == 2
|
||||||
|
assert data["tools"][0]["name"] == "tool1"
|
||||||
|
assert data["tools"][1]["name"] == "tool2"
|
||||||
|
|
||||||
|
def test_list_server_tools_not_found(self, client, mock_mcp_client):
|
||||||
|
"""Test listing tools for non-existent server."""
|
||||||
|
mock_mcp_client.list_tools = AsyncMock(
|
||||||
|
side_effect=MCPServerNotFoundError(server_name="unknown")
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.get("/api/v1/mcp/servers/unknown/tools")
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
|
||||||
|
|
||||||
|
class TestListAllTools:
|
||||||
|
"""Tests for GET /mcp/tools endpoint."""
|
||||||
|
|
||||||
|
def test_list_all_tools_success(self, client, mock_mcp_client):
|
||||||
|
"""Test listing all tools from all servers."""
|
||||||
|
mock_mcp_client.list_all_tools = AsyncMock(
|
||||||
|
return_value=[
|
||||||
|
ToolInfo(name="tool1", server_name="server-1"),
|
||||||
|
ToolInfo(name="tool2", server_name="server-1"),
|
||||||
|
ToolInfo(name="tool3", server_name="server-2"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.get("/api/v1/mcp/tools")
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert data["total"] == 3
|
||||||
|
|
||||||
|
def test_list_all_tools_empty(self, client, mock_mcp_client):
|
||||||
|
"""Test listing tools when none are available."""
|
||||||
|
mock_mcp_client.list_all_tools = AsyncMock(return_value=[])
|
||||||
|
|
||||||
|
response = client.get("/api/v1/mcp/tools")
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert data["total"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestHealthCheck:
|
||||||
|
"""Tests for GET /mcp/health endpoint."""
|
||||||
|
|
||||||
|
def test_health_check_success(self, client, mock_mcp_client):
|
||||||
|
"""Test health check returns correct data."""
|
||||||
|
mock_mcp_client.health_check = AsyncMock(
|
||||||
|
return_value={
|
||||||
|
"server-1": ServerHealth(
|
||||||
|
name="server-1",
|
||||||
|
healthy=True,
|
||||||
|
state="connected",
|
||||||
|
url="http://server1:8000",
|
||||||
|
tools_count=5,
|
||||||
|
),
|
||||||
|
"server-2": ServerHealth(
|
||||||
|
name="server-2",
|
||||||
|
healthy=False,
|
||||||
|
state="error",
|
||||||
|
url="http://server2:8000",
|
||||||
|
error="Connection refused",
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.get("/api/v1/mcp/health")
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert data["total"] == 2
|
||||||
|
assert data["healthy_count"] == 1
|
||||||
|
assert data["unhealthy_count"] == 1
|
||||||
|
assert data["servers"]["server-1"]["healthy"] is True
|
||||||
|
assert data["servers"]["server-2"]["healthy"] is False
|
||||||
|
|
||||||
|
def test_health_check_all_healthy(self, client, mock_mcp_client):
|
||||||
|
"""Test health check when all servers are healthy."""
|
||||||
|
mock_mcp_client.health_check = AsyncMock(
|
||||||
|
return_value={
|
||||||
|
"server-1": ServerHealth(
|
||||||
|
name="server-1",
|
||||||
|
healthy=True,
|
||||||
|
state="connected",
|
||||||
|
url="http://server1:8000",
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.get("/api/v1/mcp/health")
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert data["healthy_count"] == 1
|
||||||
|
assert data["unhealthy_count"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestCallTool:
|
||||||
|
"""Tests for POST /mcp/call endpoint."""
|
||||||
|
|
||||||
|
def test_call_tool_success(self, client, mock_mcp_client):
|
||||||
|
"""Test successful tool execution."""
|
||||||
|
mock_mcp_client.call_tool = AsyncMock(
|
||||||
|
return_value=ToolResult(
|
||||||
|
success=True,
|
||||||
|
data={"result": "ok"},
|
||||||
|
tool_name="test-tool",
|
||||||
|
server_name="server-1",
|
||||||
|
execution_time_ms=123.45,
|
||||||
|
request_id="test-request-id",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/mcp/call",
|
||||||
|
json={
|
||||||
|
"server": "server-1",
|
||||||
|
"tool": "test-tool",
|
||||||
|
"arguments": {"key": "value"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert data["success"] is True
|
||||||
|
assert data["data"] == {"result": "ok"}
|
||||||
|
assert data["tool_name"] == "test-tool"
|
||||||
|
assert data["server_name"] == "server-1"
|
||||||
|
|
||||||
|
def test_call_tool_with_timeout(self, client, mock_mcp_client):
|
||||||
|
"""Test tool execution with custom timeout."""
|
||||||
|
mock_mcp_client.call_tool = AsyncMock(
|
||||||
|
return_value=ToolResult(success=True, data={})
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/mcp/call",
|
||||||
|
json={
|
||||||
|
"server": "server-1",
|
||||||
|
"tool": "test-tool",
|
||||||
|
"timeout": 60.0,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
mock_mcp_client.call_tool.assert_called_once()
|
||||||
|
call_args = mock_mcp_client.call_tool.call_args
|
||||||
|
assert call_args.kwargs["timeout"] == 60.0
|
||||||
|
|
||||||
|
def test_call_tool_server_not_found(self, client, mock_mcp_client):
|
||||||
|
"""Test tool execution with non-existent server."""
|
||||||
|
mock_mcp_client.call_tool = AsyncMock(
|
||||||
|
side_effect=MCPServerNotFoundError(server_name="unknown")
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/mcp/call",
|
||||||
|
json={"server": "unknown", "tool": "test-tool"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
|
||||||
|
def test_call_tool_not_found(self, client, mock_mcp_client):
|
||||||
|
"""Test tool execution with non-existent tool."""
|
||||||
|
mock_mcp_client.call_tool = AsyncMock(
|
||||||
|
side_effect=MCPToolNotFoundError(tool_name="unknown")
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/mcp/call",
|
||||||
|
json={"server": "server-1", "tool": "unknown"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
|
||||||
|
def test_call_tool_timeout(self, client, mock_mcp_client):
|
||||||
|
"""Test tool execution timeout."""
|
||||||
|
mock_mcp_client.call_tool = AsyncMock(
|
||||||
|
side_effect=MCPTimeoutError(
|
||||||
|
"Request timed out",
|
||||||
|
server_name="server-1",
|
||||||
|
timeout_seconds=30.0,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/mcp/call",
|
||||||
|
json={"server": "server-1", "tool": "slow-tool"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_504_GATEWAY_TIMEOUT
|
||||||
|
|
||||||
|
def test_call_tool_connection_error(self, client, mock_mcp_client):
|
||||||
|
"""Test tool execution with connection failure."""
|
||||||
|
mock_mcp_client.call_tool = AsyncMock(
|
||||||
|
side_effect=MCPConnectionError(
|
||||||
|
"Connection refused",
|
||||||
|
server_name="server-1",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/mcp/call",
|
||||||
|
json={"server": "server-1", "tool": "test-tool"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_502_BAD_GATEWAY
|
||||||
|
|
||||||
|
def test_call_tool_circuit_open(self, client, mock_mcp_client):
|
||||||
|
"""Test tool execution with open circuit breaker."""
|
||||||
|
mock_mcp_client.call_tool = AsyncMock(
|
||||||
|
side_effect=MCPCircuitOpenError(
|
||||||
|
server_name="server-1",
|
||||||
|
failure_count=5,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/mcp/call",
|
||||||
|
json={"server": "server-1", "tool": "test-tool"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE
|
||||||
|
|
||||||
|
|
||||||
|
class TestCircuitBreakers:
|
||||||
|
"""Tests for circuit breaker endpoints."""
|
||||||
|
|
||||||
|
def test_list_circuit_breakers(self, client, mock_mcp_client):
|
||||||
|
"""Test listing circuit breaker statuses."""
|
||||||
|
mock_mcp_client.get_circuit_breaker_status.return_value = {
|
||||||
|
"server-1": {"state": "closed", "failure_count": 0},
|
||||||
|
"server-2": {"state": "open", "failure_count": 5},
|
||||||
|
}
|
||||||
|
|
||||||
|
response = client.get("/api/v1/mcp/circuit-breakers")
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert len(data["circuit_breakers"]) == 2
|
||||||
|
|
||||||
|
def test_list_circuit_breakers_empty(self, client, mock_mcp_client):
|
||||||
|
"""Test listing when no circuit breakers exist."""
|
||||||
|
mock_mcp_client.get_circuit_breaker_status.return_value = {}
|
||||||
|
|
||||||
|
response = client.get("/api/v1/mcp/circuit-breakers")
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
assert data["circuit_breakers"] == []
|
||||||
|
|
||||||
|
def test_reset_circuit_breaker_success(self, client, mock_mcp_client):
|
||||||
|
"""Test successfully resetting a circuit breaker."""
|
||||||
|
mock_mcp_client.reset_circuit_breaker = AsyncMock(return_value=True)
|
||||||
|
|
||||||
|
response = client.post("/api/v1/mcp/circuit-breakers/server-1/reset")
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||||
|
|
||||||
|
def test_reset_circuit_breaker_not_found(self, client, mock_mcp_client):
|
||||||
|
"""Test resetting non-existent circuit breaker."""
|
||||||
|
mock_mcp_client.reset_circuit_breaker = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
response = client.post("/api/v1/mcp/circuit-breakers/unknown/reset")
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
|
||||||
|
|
||||||
|
class TestReconnectServer:
|
||||||
|
"""Tests for POST /mcp/servers/{server_name}/reconnect endpoint."""
|
||||||
|
|
||||||
|
def test_reconnect_success(self, client, mock_mcp_client):
|
||||||
|
"""Test successful server reconnection."""
|
||||||
|
mock_mcp_client.disconnect = AsyncMock()
|
||||||
|
mock_mcp_client.connect = AsyncMock()
|
||||||
|
|
||||||
|
response = client.post("/api/v1/mcp/servers/server-1/reconnect")
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||||
|
mock_mcp_client.disconnect.assert_called_once_with("server-1")
|
||||||
|
mock_mcp_client.connect.assert_called_once_with("server-1")
|
||||||
|
|
||||||
|
def test_reconnect_server_not_found(self, client, mock_mcp_client):
|
||||||
|
"""Test reconnecting to non-existent server."""
|
||||||
|
mock_mcp_client.disconnect = AsyncMock(
|
||||||
|
side_effect=MCPServerNotFoundError(server_name="unknown")
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post("/api/v1/mcp/servers/unknown/reconnect")
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
|
||||||
|
def test_reconnect_connection_failure(self, client, mock_mcp_client):
|
||||||
|
"""Test reconnection failure."""
|
||||||
|
mock_mcp_client.disconnect = AsyncMock()
|
||||||
|
mock_mcp_client.connect = AsyncMock(
|
||||||
|
side_effect=MCPConnectionError(
|
||||||
|
"Connection refused",
|
||||||
|
server_name="server-1",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post("/api/v1/mcp/servers/server-1/reconnect")
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_502_BAD_GATEWAY
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPEndpointsEdgeCases:
|
||||||
|
"""Edge case tests for MCP endpoints."""
|
||||||
|
|
||||||
|
def test_servers_content_type(self, client, mock_mcp_client):
|
||||||
|
"""Test that endpoints return JSON content type."""
|
||||||
|
mock_mcp_client.list_servers.return_value = []
|
||||||
|
|
||||||
|
response = client.get("/api/v1/mcp/servers")
|
||||||
|
|
||||||
|
assert "application/json" in response.headers["content-type"]
|
||||||
|
|
||||||
|
def test_call_tool_validation_error(self, client):
|
||||||
|
"""Test that invalid request body returns validation error."""
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/mcp/call",
|
||||||
|
json={}, # Missing required fields
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||||
|
|
||||||
|
def test_call_tool_missing_server(self, client):
|
||||||
|
"""Test that missing server field returns validation error."""
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/mcp/call",
|
||||||
|
json={"tool": "test-tool"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||||
|
|
||||||
|
def test_call_tool_missing_tool(self, client):
|
||||||
|
"""Test that missing tool field returns validation error."""
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/mcp/call",
|
||||||
|
json={"server": "server-1"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||||
@@ -437,3 +437,197 @@ class TestOAuthProviderEndpoints:
|
|||||||
)
|
)
|
||||||
# Missing client_id returns 401 (invalid_client)
|
# Missing client_id returns 401 (invalid_client)
|
||||||
assert response.status_code == 401
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
class TestOAuthProviderAdminEndpoints:
|
||||||
|
"""Tests for OAuth provider admin endpoints."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_clients_admin_only(self, client, user_token):
|
||||||
|
"""Test that listing clients requires superuser."""
|
||||||
|
with patch("app.api.routes.oauth_provider.settings") as mock_settings:
|
||||||
|
mock_settings.OAUTH_PROVIDER_ENABLED = True
|
||||||
|
|
||||||
|
response = await client.get(
|
||||||
|
"/api/v1/oauth/provider/clients",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
# Regular user should be forbidden
|
||||||
|
assert response.status_code == 403
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_clients_success(self, client, superuser_token):
|
||||||
|
"""Test listing OAuth clients as superuser."""
|
||||||
|
with patch("app.api.routes.oauth_provider.settings") as mock_settings:
|
||||||
|
mock_settings.OAUTH_PROVIDER_ENABLED = True
|
||||||
|
|
||||||
|
response = await client.get(
|
||||||
|
"/api/v1/oauth/provider/clients",
|
||||||
|
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert isinstance(response.json(), list)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_client_not_found(self, client, superuser_token):
|
||||||
|
"""Test deleting non-existent OAuth client."""
|
||||||
|
with patch("app.api.routes.oauth_provider.settings") as mock_settings:
|
||||||
|
mock_settings.OAUTH_PROVIDER_ENABLED = True
|
||||||
|
|
||||||
|
response = await client.delete(
|
||||||
|
"/api/v1/oauth/provider/clients/non_existent_client_id",
|
||||||
|
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_client_success(self, client, superuser_token, async_test_db):
|
||||||
|
"""Test successfully deleting an OAuth client."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
from app.crud.oauth import oauth_client
|
||||||
|
from app.schemas.oauth import OAuthClientCreate
|
||||||
|
|
||||||
|
# Create a test client to delete
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
client_data = OAuthClientCreate(
|
||||||
|
client_name="Delete Test Client",
|
||||||
|
redirect_uris=["http://localhost:3000/callback"],
|
||||||
|
allowed_scopes=["read:users"],
|
||||||
|
)
|
||||||
|
test_client, _ = await oauth_client.create_client(
|
||||||
|
session, obj_in=client_data
|
||||||
|
)
|
||||||
|
test_client_id = test_client.client_id
|
||||||
|
|
||||||
|
with patch("app.api.routes.oauth_provider.settings") as mock_settings:
|
||||||
|
mock_settings.OAUTH_PROVIDER_ENABLED = True
|
||||||
|
|
||||||
|
response = await client.delete(
|
||||||
|
f"/api/v1/oauth/provider/clients/{test_client_id}",
|
||||||
|
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 204
|
||||||
|
|
||||||
|
|
||||||
|
class TestOAuthProviderConsentEndpoints:
|
||||||
|
"""Tests for OAuth provider consent management endpoints."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_consents_unauthenticated(self, client):
|
||||||
|
"""Test listing consents without authentication."""
|
||||||
|
with patch("app.api.routes.oauth_provider.settings") as mock_settings:
|
||||||
|
mock_settings.OAUTH_PROVIDER_ENABLED = True
|
||||||
|
|
||||||
|
response = await client.get("/api/v1/oauth/provider/consents")
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_consents_empty(self, client, user_token):
|
||||||
|
"""Test listing consents when user has none."""
|
||||||
|
with patch("app.api.routes.oauth_provider.settings") as mock_settings:
|
||||||
|
mock_settings.OAUTH_PROVIDER_ENABLED = True
|
||||||
|
|
||||||
|
response = await client.get(
|
||||||
|
"/api/v1/oauth/provider/consents",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == []
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_consents_with_data(
|
||||||
|
self, client, user_token, async_test_user, async_test_db
|
||||||
|
):
|
||||||
|
"""Test listing consents when user has granted some."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
from app.crud.oauth import oauth_client
|
||||||
|
from app.models.oauth_provider_token import OAuthConsent
|
||||||
|
from app.schemas.oauth import OAuthClientCreate
|
||||||
|
|
||||||
|
# Create a test client and grant consent
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
client_data = OAuthClientCreate(
|
||||||
|
client_name="Consented App",
|
||||||
|
redirect_uris=["http://localhost:3000/callback"],
|
||||||
|
allowed_scopes=["read:users", "write:users"],
|
||||||
|
)
|
||||||
|
test_client, _ = await oauth_client.create_client(
|
||||||
|
session, obj_in=client_data
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create consent record
|
||||||
|
consent = OAuthConsent(
|
||||||
|
user_id=async_test_user.id,
|
||||||
|
client_id=test_client.client_id,
|
||||||
|
granted_scopes="read:users write:users",
|
||||||
|
)
|
||||||
|
session.add(consent)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
with patch("app.api.routes.oauth_provider.settings") as mock_settings:
|
||||||
|
mock_settings.OAUTH_PROVIDER_ENABLED = True
|
||||||
|
|
||||||
|
response = await client.get(
|
||||||
|
"/api/v1/oauth/provider/consents",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert len(data) == 1
|
||||||
|
assert data[0]["client_name"] == "Consented App"
|
||||||
|
assert "read:users" in data[0]["granted_scopes"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_revoke_consent_not_found(self, client, user_token):
|
||||||
|
"""Test revoking consent that doesn't exist."""
|
||||||
|
with patch("app.api.routes.oauth_provider.settings") as mock_settings:
|
||||||
|
mock_settings.OAUTH_PROVIDER_ENABLED = True
|
||||||
|
|
||||||
|
response = await client.delete(
|
||||||
|
"/api/v1/oauth/provider/consents/non_existent_client",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_revoke_consent_success(
|
||||||
|
self, client, user_token, async_test_user, async_test_db
|
||||||
|
):
|
||||||
|
"""Test successfully revoking consent."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
from app.crud.oauth import oauth_client
|
||||||
|
from app.models.oauth_provider_token import OAuthConsent
|
||||||
|
from app.schemas.oauth import OAuthClientCreate
|
||||||
|
|
||||||
|
# Create a test client and grant consent
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
client_data = OAuthClientCreate(
|
||||||
|
client_name="Revoke Test App",
|
||||||
|
redirect_uris=["http://localhost:3000/callback"],
|
||||||
|
allowed_scopes=["read:users"],
|
||||||
|
)
|
||||||
|
test_client, _ = await oauth_client.create_client(
|
||||||
|
session, obj_in=client_data
|
||||||
|
)
|
||||||
|
test_client_id = test_client.client_id
|
||||||
|
|
||||||
|
# Create consent record
|
||||||
|
consent = OAuthConsent(
|
||||||
|
user_id=async_test_user.id,
|
||||||
|
client_id=test_client.client_id,
|
||||||
|
granted_scopes="read:users",
|
||||||
|
)
|
||||||
|
session.add(consent)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
with patch("app.api.routes.oauth_provider.settings") as mock_settings:
|
||||||
|
mock_settings.OAUTH_PROVIDER_ENABLED = True
|
||||||
|
|
||||||
|
response = await client.delete(
|
||||||
|
f"/api/v1/oauth/provider/consents/{test_client_id}",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 204
|
||||||
|
|||||||
784
backend/tests/core/test_redis.py
Normal file
784
backend/tests/core/test_redis.py
Normal file
@@ -0,0 +1,784 @@
|
|||||||
|
"""
|
||||||
|
Tests for Redis client utility functions (app/core/redis.py).
|
||||||
|
|
||||||
|
Covers:
|
||||||
|
- Cache operations (get, set, delete, expire)
|
||||||
|
- JSON serialization helpers
|
||||||
|
- Pub/sub operations
|
||||||
|
- Health check
|
||||||
|
- Connection pooling
|
||||||
|
- Error handling
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from redis.exceptions import ConnectionError, RedisError, TimeoutError
|
||||||
|
|
||||||
|
from app.core.redis import (
|
||||||
|
DEFAULT_CACHE_TTL,
|
||||||
|
POOL_MAX_CONNECTIONS,
|
||||||
|
RedisClient,
|
||||||
|
check_redis_health,
|
||||||
|
close_redis,
|
||||||
|
get_redis,
|
||||||
|
redis_client,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRedisClientInit:
|
||||||
|
"""Test RedisClient initialization."""
|
||||||
|
|
||||||
|
def test_default_url_from_settings(self):
|
||||||
|
"""Test that default URL comes from settings."""
|
||||||
|
with patch("app.core.redis.settings") as mock_settings:
|
||||||
|
mock_settings.REDIS_URL = "redis://test:6379/0"
|
||||||
|
client = RedisClient()
|
||||||
|
assert client._url == "redis://test:6379/0"
|
||||||
|
|
||||||
|
def test_custom_url_override(self):
|
||||||
|
"""Test that custom URL overrides settings."""
|
||||||
|
client = RedisClient(url="redis://custom:6379/1")
|
||||||
|
assert client._url == "redis://custom:6379/1"
|
||||||
|
|
||||||
|
def test_initial_state(self):
|
||||||
|
"""Test initial client state."""
|
||||||
|
client = RedisClient(url="redis://localhost:6379/0")
|
||||||
|
assert client._pool is None
|
||||||
|
assert client._client is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestCacheOperations:
|
||||||
|
"""Test cache get/set/delete operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_set_success(self):
|
||||||
|
"""Test setting a cache value."""
|
||||||
|
client = RedisClient(url="redis://localhost:6379/0")
|
||||||
|
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.set = AsyncMock(return_value=True)
|
||||||
|
|
||||||
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||||
|
result = await client.cache_set("test-key", "test-value", ttl=60)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
mock_redis.set.assert_called_once_with("test-key", "test-value", ex=60)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_set_default_ttl(self):
|
||||||
|
"""Test setting a cache value with default TTL."""
|
||||||
|
client = RedisClient(url="redis://localhost:6379/0")
|
||||||
|
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.set = AsyncMock(return_value=True)
|
||||||
|
|
||||||
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||||
|
result = await client.cache_set("test-key", "test-value")
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
mock_redis.set.assert_called_once_with(
|
||||||
|
"test-key", "test-value", ex=DEFAULT_CACHE_TTL
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_set_connection_error(self):
|
||||||
|
"""Test cache_set handles connection errors."""
|
||||||
|
client = RedisClient(url="redis://localhost:6379/0")
|
||||||
|
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.set = AsyncMock(side_effect=ConnectionError("Connection refused"))
|
||||||
|
|
||||||
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||||
|
result = await client.cache_set("test-key", "test-value")
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_set_timeout_error(self):
|
||||||
|
"""Test cache_set handles timeout errors."""
|
||||||
|
client = RedisClient(url="redis://localhost:6379/0")
|
||||||
|
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.set = AsyncMock(side_effect=TimeoutError("Timeout"))
|
||||||
|
|
||||||
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||||
|
result = await client.cache_set("test-key", "test-value")
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_set_redis_error(self):
|
||||||
|
"""Test cache_set handles generic Redis errors."""
|
||||||
|
client = RedisClient(url="redis://localhost:6379/0")
|
||||||
|
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.set = AsyncMock(side_effect=RedisError("Unknown error"))
|
||||||
|
|
||||||
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||||
|
result = await client.cache_set("test-key", "test-value")
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_get_success(self):
|
||||||
|
"""Test getting a cached value."""
|
||||||
|
client = RedisClient(url="redis://localhost:6379/0")
|
||||||
|
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.get = AsyncMock(return_value="cached-value")
|
||||||
|
|
||||||
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||||
|
result = await client.cache_get("test-key")
|
||||||
|
|
||||||
|
assert result == "cached-value"
|
||||||
|
mock_redis.get.assert_called_once_with("test-key")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_get_miss(self):
|
||||||
|
"""Test cache miss returns None."""
|
||||||
|
client = RedisClient(url="redis://localhost:6379/0")
|
||||||
|
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.get = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||||
|
result = await client.cache_get("nonexistent-key")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_get_connection_error(self):
|
||||||
|
"""Test cache_get handles connection errors."""
|
||||||
|
client = RedisClient(url="redis://localhost:6379/0")
|
||||||
|
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.get = AsyncMock(side_effect=ConnectionError("Connection refused"))
|
||||||
|
|
||||||
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||||
|
result = await client.cache_get("test-key")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_delete_success(self):
|
||||||
|
"""Test deleting a cache key."""
|
||||||
|
client = RedisClient(url="redis://localhost:6379/0")
|
||||||
|
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.delete = AsyncMock(return_value=1)
|
||||||
|
|
||||||
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||||
|
result = await client.cache_delete("test-key")
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
mock_redis.delete.assert_called_once_with("test-key")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_delete_nonexistent_key(self):
|
||||||
|
"""Test deleting a nonexistent key returns False."""
|
||||||
|
client = RedisClient(url="redis://localhost:6379/0")
|
||||||
|
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.delete = AsyncMock(return_value=0)
|
||||||
|
|
||||||
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||||
|
result = await client.cache_delete("nonexistent-key")
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_delete_connection_error(self):
|
||||||
|
"""Test cache_delete handles connection errors."""
|
||||||
|
client = RedisClient(url="redis://localhost:6379/0")
|
||||||
|
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.delete = AsyncMock(side_effect=ConnectionError("Connection refused"))
|
||||||
|
|
||||||
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||||
|
result = await client.cache_delete("test-key")
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestCacheDeletePattern:
|
||||||
|
"""Test cache_delete_pattern operation."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_delete_pattern_success(self):
|
||||||
|
"""Test deleting keys by pattern."""
|
||||||
|
client = RedisClient(url="redis://localhost:6379/0")
|
||||||
|
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.delete = AsyncMock(return_value=1)
|
||||||
|
|
||||||
|
# Create async iterator for scan_iter
|
||||||
|
async def mock_scan_iter(pattern):
|
||||||
|
for key in ["user:1", "user:2", "user:3"]:
|
||||||
|
yield key
|
||||||
|
|
||||||
|
mock_redis.scan_iter = mock_scan_iter
|
||||||
|
|
||||||
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||||
|
result = await client.cache_delete_pattern("user:*")
|
||||||
|
|
||||||
|
assert result == 3
|
||||||
|
assert mock_redis.delete.call_count == 3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_delete_pattern_no_matches(self):
|
||||||
|
"""Test deleting pattern with no matches."""
|
||||||
|
client = RedisClient(url="redis://localhost:6379/0")
|
||||||
|
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
|
||||||
|
async def mock_scan_iter(pattern):
|
||||||
|
if False: # Empty iterator
|
||||||
|
yield
|
||||||
|
|
||||||
|
mock_redis.scan_iter = mock_scan_iter
|
||||||
|
|
||||||
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||||
|
result = await client.cache_delete_pattern("nonexistent:*")
|
||||||
|
|
||||||
|
assert result == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_delete_pattern_error(self):
|
||||||
|
"""Test cache_delete_pattern handles errors."""
|
||||||
|
client = RedisClient(url="redis://localhost:6379/0")
|
||||||
|
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
|
||||||
|
async def mock_scan_iter(pattern):
|
||||||
|
raise ConnectionError("Connection lost")
|
||||||
|
if False: # Make it a generator
|
||||||
|
yield
|
||||||
|
|
||||||
|
mock_redis.scan_iter = mock_scan_iter
|
||||||
|
|
||||||
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||||
|
result = await client.cache_delete_pattern("user:*")
|
||||||
|
|
||||||
|
assert result == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestCacheExpire:
|
||||||
|
"""Test cache_expire operation."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_expire_success(self):
|
||||||
|
"""Test setting TTL on existing key."""
|
||||||
|
client = RedisClient(url="redis://localhost:6379/0")
|
||||||
|
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.expire = AsyncMock(return_value=True)
|
||||||
|
|
||||||
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||||
|
result = await client.cache_expire("test-key", 120)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
mock_redis.expire.assert_called_once_with("test-key", 120)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_expire_nonexistent_key(self):
|
||||||
|
"""Test setting TTL on nonexistent key."""
|
||||||
|
client = RedisClient(url="redis://localhost:6379/0")
|
||||||
|
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.expire = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||||
|
result = await client.cache_expire("nonexistent-key", 120)
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_expire_error(self):
|
||||||
|
"""Test cache_expire handles errors."""
|
||||||
|
client = RedisClient(url="redis://localhost:6379/0")
|
||||||
|
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.expire = AsyncMock(side_effect=ConnectionError("Error"))
|
||||||
|
|
||||||
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||||
|
result = await client.cache_expire("test-key", 120)
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestCacheHelpers:
|
||||||
|
"""Test cache helper methods (exists, ttl)."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_exists_true(self):
|
||||||
|
"""Test cache_exists returns True for existing key."""
|
||||||
|
client = RedisClient(url="redis://localhost:6379/0")
|
||||||
|
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.exists = AsyncMock(return_value=1)
|
||||||
|
|
||||||
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||||
|
result = await client.cache_exists("test-key")
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_exists_false(self):
|
||||||
|
"""Test cache_exists returns False for nonexistent key."""
|
||||||
|
client = RedisClient(url="redis://localhost:6379/0")
|
||||||
|
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.exists = AsyncMock(return_value=0)
|
||||||
|
|
||||||
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||||
|
result = await client.cache_exists("nonexistent-key")
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_exists_error(self):
|
||||||
|
"""Test cache_exists handles errors."""
|
||||||
|
client = RedisClient(url="redis://localhost:6379/0")
|
||||||
|
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.exists = AsyncMock(side_effect=ConnectionError("Error"))
|
||||||
|
|
||||||
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||||
|
result = await client.cache_exists("test-key")
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_ttl_with_ttl(self):
|
||||||
|
"""Test cache_ttl returns remaining TTL."""
|
||||||
|
client = RedisClient(url="redis://localhost:6379/0")
|
||||||
|
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.ttl = AsyncMock(return_value=300)
|
||||||
|
|
||||||
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||||
|
result = await client.cache_ttl("test-key")
|
||||||
|
|
||||||
|
assert result == 300
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_ttl_no_ttl(self):
|
||||||
|
"""Test cache_ttl returns -1 for key without TTL."""
|
||||||
|
client = RedisClient(url="redis://localhost:6379/0")
|
||||||
|
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.ttl = AsyncMock(return_value=-1)
|
||||||
|
|
||||||
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||||
|
result = await client.cache_ttl("test-key")
|
||||||
|
|
||||||
|
assert result == -1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_ttl_nonexistent_key(self):
|
||||||
|
"""Test cache_ttl returns -2 for nonexistent key."""
|
||||||
|
client = RedisClient(url="redis://localhost:6379/0")
|
||||||
|
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.ttl = AsyncMock(return_value=-2)
|
||||||
|
|
||||||
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||||
|
result = await client.cache_ttl("nonexistent-key")
|
||||||
|
|
||||||
|
assert result == -2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_ttl_error(self):
|
||||||
|
"""Test cache_ttl handles errors."""
|
||||||
|
client = RedisClient(url="redis://localhost:6379/0")
|
||||||
|
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.ttl = AsyncMock(side_effect=ConnectionError("Error"))
|
||||||
|
|
||||||
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||||
|
result = await client.cache_ttl("test-key")
|
||||||
|
|
||||||
|
assert result == -2
|
||||||
|
|
||||||
|
|
||||||
|
class TestJsonOperations:
|
||||||
|
"""Test JSON serialization cache operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_set_json_success(self):
|
||||||
|
"""Test setting a JSON value in cache."""
|
||||||
|
client = RedisClient(url="redis://localhost:6379/0")
|
||||||
|
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.set = AsyncMock(return_value=True)
|
||||||
|
|
||||||
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||||
|
data = {"user": "test", "count": 42}
|
||||||
|
result = await client.cache_set_json("test-key", data, ttl=60)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
mock_redis.set.assert_called_once()
|
||||||
|
# Verify JSON was serialized
|
||||||
|
call_args = mock_redis.set.call_args
|
||||||
|
assert call_args[0][1] == json.dumps(data)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_set_json_serialization_error(self):
|
||||||
|
"""Test cache_set_json handles serialization errors."""
|
||||||
|
client = RedisClient(url="redis://localhost:6379/0")
|
||||||
|
|
||||||
|
# Object that can't be serialized
|
||||||
|
class NonSerializable:
|
||||||
|
pass
|
||||||
|
|
||||||
|
result = await client.cache_set_json("test-key", NonSerializable())
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_get_json_success(self):
|
||||||
|
"""Test getting a JSON value from cache."""
|
||||||
|
client = RedisClient(url="redis://localhost:6379/0")
|
||||||
|
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
data = {"user": "test", "count": 42}
|
||||||
|
mock_redis.get = AsyncMock(return_value=json.dumps(data))
|
||||||
|
|
||||||
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||||
|
result = await client.cache_get_json("test-key")
|
||||||
|
|
||||||
|
assert result == data
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_get_json_miss(self):
|
||||||
|
"""Test cache_get_json returns None on cache miss."""
|
||||||
|
client = RedisClient(url="redis://localhost:6379/0")
|
||||||
|
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.get = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||||
|
result = await client.cache_get_json("nonexistent-key")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_get_json_invalid_json(self):
|
||||||
|
"""Test cache_get_json handles invalid JSON."""
|
||||||
|
client = RedisClient(url="redis://localhost:6379/0")
|
||||||
|
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.get = AsyncMock(return_value="not valid json {{{")
|
||||||
|
|
||||||
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||||
|
result = await client.cache_get_json("test-key")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestPubSubOperations:
|
||||||
|
"""Test pub/sub operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_publish_string_message(self):
|
||||||
|
"""Test publishing a string message."""
|
||||||
|
client = RedisClient(url="redis://localhost:6379/0")
|
||||||
|
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.publish = AsyncMock(return_value=2)
|
||||||
|
|
||||||
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||||
|
result = await client.publish("test-channel", "hello world")
|
||||||
|
|
||||||
|
assert result == 2
|
||||||
|
mock_redis.publish.assert_called_once_with("test-channel", "hello world")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_publish_dict_message(self):
|
||||||
|
"""Test publishing a dict message (JSON serialized)."""
|
||||||
|
client = RedisClient(url="redis://localhost:6379/0")
|
||||||
|
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.publish = AsyncMock(return_value=1)
|
||||||
|
|
||||||
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||||
|
data = {"event": "user_created", "user_id": 123}
|
||||||
|
result = await client.publish("events", data)
|
||||||
|
|
||||||
|
assert result == 1
|
||||||
|
mock_redis.publish.assert_called_once_with("events", json.dumps(data))
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_publish_connection_error(self):
|
||||||
|
"""Test publish handles connection errors."""
|
||||||
|
client = RedisClient(url="redis://localhost:6379/0")
|
||||||
|
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.publish = AsyncMock(side_effect=ConnectionError("Connection lost"))
|
||||||
|
|
||||||
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||||
|
result = await client.publish("test-channel", "hello")
|
||||||
|
|
||||||
|
assert result == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_subscribe_context_manager(self):
|
||||||
|
"""Test subscribe context manager."""
|
||||||
|
client = RedisClient(url="redis://localhost:6379/0")
|
||||||
|
|
||||||
|
mock_pubsub = AsyncMock()
|
||||||
|
mock_pubsub.subscribe = AsyncMock()
|
||||||
|
mock_pubsub.unsubscribe = AsyncMock()
|
||||||
|
mock_pubsub.close = AsyncMock()
|
||||||
|
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.pubsub = MagicMock(return_value=mock_pubsub)
|
||||||
|
|
||||||
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||||
|
async with client.subscribe("channel1", "channel2") as pubsub:
|
||||||
|
assert pubsub is mock_pubsub
|
||||||
|
mock_pubsub.subscribe.assert_called_once_with("channel1", "channel2")
|
||||||
|
|
||||||
|
# After exiting context, should unsubscribe and close
|
||||||
|
mock_pubsub.unsubscribe.assert_called_once_with("channel1", "channel2")
|
||||||
|
mock_pubsub.close.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_psubscribe_context_manager(self):
|
||||||
|
"""Test pattern subscribe context manager."""
|
||||||
|
client = RedisClient(url="redis://localhost:6379/0")
|
||||||
|
|
||||||
|
mock_pubsub = AsyncMock()
|
||||||
|
mock_pubsub.psubscribe = AsyncMock()
|
||||||
|
mock_pubsub.punsubscribe = AsyncMock()
|
||||||
|
mock_pubsub.close = AsyncMock()
|
||||||
|
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.pubsub = MagicMock(return_value=mock_pubsub)
|
||||||
|
|
||||||
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||||
|
async with client.psubscribe("user:*", "event:*") as pubsub:
|
||||||
|
assert pubsub is mock_pubsub
|
||||||
|
mock_pubsub.psubscribe.assert_called_once_with("user:*", "event:*")
|
||||||
|
|
||||||
|
mock_pubsub.punsubscribe.assert_called_once_with("user:*", "event:*")
|
||||||
|
mock_pubsub.close.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
class TestHealthCheck:
|
||||||
|
"""Test health check functionality."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_health_check_success(self):
|
||||||
|
"""Test health check returns True when Redis is healthy."""
|
||||||
|
client = RedisClient(url="redis://localhost:6379/0")
|
||||||
|
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.ping = AsyncMock(return_value=True)
|
||||||
|
|
||||||
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||||
|
result = await client.health_check()
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
mock_redis.ping.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_health_check_connection_error(self):
|
||||||
|
"""Test health check returns False on connection error."""
|
||||||
|
client = RedisClient(url="redis://localhost:6379/0")
|
||||||
|
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.ping = AsyncMock(side_effect=ConnectionError("Connection refused"))
|
||||||
|
|
||||||
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||||
|
result = await client.health_check()
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_health_check_timeout_error(self):
|
||||||
|
"""Test health check returns False on timeout."""
|
||||||
|
client = RedisClient(url="redis://localhost:6379/0")
|
||||||
|
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.ping = AsyncMock(side_effect=TimeoutError("Timeout"))
|
||||||
|
|
||||||
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||||
|
result = await client.health_check()
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_health_check_redis_error(self):
|
||||||
|
"""Test health check returns False on Redis error."""
|
||||||
|
client = RedisClient(url="redis://localhost:6379/0")
|
||||||
|
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.ping = AsyncMock(side_effect=RedisError("Unknown error"))
|
||||||
|
|
||||||
|
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||||
|
result = await client.health_check()
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestConnectionPooling:
|
||||||
|
"""Test connection pooling functionality."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pool_initialization(self):
|
||||||
|
"""Test that pool is lazily initialized."""
|
||||||
|
client = RedisClient(url="redis://localhost:6379/0")
|
||||||
|
|
||||||
|
assert client._pool is None
|
||||||
|
|
||||||
|
with patch("app.core.redis.ConnectionPool") as MockPool:
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
MockPool.from_url = MagicMock(return_value=mock_pool)
|
||||||
|
|
||||||
|
pool = await client._ensure_pool()
|
||||||
|
|
||||||
|
assert pool is mock_pool
|
||||||
|
MockPool.from_url.assert_called_once_with(
|
||||||
|
"redis://localhost:6379/0",
|
||||||
|
max_connections=POOL_MAX_CONNECTIONS,
|
||||||
|
socket_timeout=10,
|
||||||
|
socket_connect_timeout=10,
|
||||||
|
decode_responses=True,
|
||||||
|
health_check_interval=30,
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pool_reuses_existing(self):
|
||||||
|
"""Test that pool is reused after initialization."""
|
||||||
|
client = RedisClient(url="redis://localhost:6379/0")
|
||||||
|
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
client._pool = mock_pool
|
||||||
|
|
||||||
|
pool = await client._ensure_pool()
|
||||||
|
|
||||||
|
assert pool is mock_pool
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_close_disposes_resources(self):
|
||||||
|
"""Test that close() disposes pool and client."""
|
||||||
|
client = RedisClient(url="redis://localhost:6379/0")
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_pool = AsyncMock()
|
||||||
|
mock_pool.disconnect = AsyncMock()
|
||||||
|
|
||||||
|
client._client = mock_client
|
||||||
|
client._pool = mock_pool
|
||||||
|
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
mock_client.close.assert_called_once()
|
||||||
|
mock_pool.disconnect.assert_called_once()
|
||||||
|
assert client._client is None
|
||||||
|
assert client._pool is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_close_handles_none(self):
|
||||||
|
"""Test that close() handles None client and pool gracefully."""
|
||||||
|
client = RedisClient(url="redis://localhost:6379/0")
|
||||||
|
|
||||||
|
# Should not raise
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
assert client._client is None
|
||||||
|
assert client._pool is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_pool_info_not_initialized(self):
|
||||||
|
"""Test pool info when not initialized."""
|
||||||
|
client = RedisClient(url="redis://localhost:6379/0")
|
||||||
|
|
||||||
|
info = await client.get_pool_info()
|
||||||
|
|
||||||
|
assert info == {"status": "not_initialized"}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_pool_info_active(self):
|
||||||
|
"""Test pool info when active."""
|
||||||
|
client = RedisClient(url="redis://user:pass@localhost:6379/0")
|
||||||
|
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
client._pool = mock_pool
|
||||||
|
|
||||||
|
info = await client.get_pool_info()
|
||||||
|
|
||||||
|
assert info["status"] == "active"
|
||||||
|
assert info["max_connections"] == POOL_MAX_CONNECTIONS
|
||||||
|
# Password should be hidden
|
||||||
|
assert "pass" not in info["url"]
|
||||||
|
assert "localhost:6379/0" in info["url"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestModuleLevelFunctions:
|
||||||
|
"""Test module-level convenience functions."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_redis_dependency(self):
|
||||||
|
"""Test get_redis FastAPI dependency."""
|
||||||
|
redis_gen = get_redis()
|
||||||
|
|
||||||
|
client = await redis_gen.__anext__()
|
||||||
|
assert client is redis_client
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
with pytest.raises(StopAsyncIteration):
|
||||||
|
await redis_gen.__anext__()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_redis_health(self):
|
||||||
|
"""Test module-level check_redis_health function."""
|
||||||
|
with patch.object(redis_client, "health_check", return_value=True) as mock:
|
||||||
|
result = await check_redis_health()
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
mock.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_close_redis(self):
|
||||||
|
"""Test module-level close_redis function."""
|
||||||
|
with patch.object(redis_client, "close") as mock:
|
||||||
|
await close_redis()
|
||||||
|
|
||||||
|
mock.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
class TestThreadSafety:
|
||||||
|
"""Test thread-safety of pool initialization."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_concurrent_pool_initialization(self):
|
||||||
|
"""Test that concurrent _ensure_pool calls create only one pool."""
|
||||||
|
client = RedisClient(url="redis://localhost:6379/0")
|
||||||
|
|
||||||
|
call_count = 0
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
|
||||||
|
def counting_from_url(*args, **kwargs):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
return mock_pool
|
||||||
|
|
||||||
|
with patch("app.core.redis.ConnectionPool") as MockPool:
|
||||||
|
MockPool.from_url = MagicMock(side_effect=counting_from_url)
|
||||||
|
|
||||||
|
# Start multiple concurrent _ensure_pool calls
|
||||||
|
results = await asyncio.gather(
|
||||||
|
client._ensure_pool(),
|
||||||
|
client._ensure_pool(),
|
||||||
|
client._ensure_pool(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# All results should be the same pool instance
|
||||||
|
assert results[0] is results[1] is results[2]
|
||||||
|
assert results[0] is mock_pool
|
||||||
|
# Pool should only be created once despite concurrent calls
|
||||||
|
assert call_count == 1
|
||||||
2
backend/tests/crud/syndarix/__init__.py
Normal file
2
backend/tests/crud/syndarix/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
# tests/crud/syndarix/__init__.py
|
||||||
|
"""Syndarix CRUD operation tests."""
|
||||||
217
backend/tests/crud/syndarix/conftest.py
Normal file
217
backend/tests/crud/syndarix/conftest.py
Normal file
@@ -0,0 +1,217 @@
|
|||||||
|
# tests/crud/syndarix/conftest.py
|
||||||
|
"""
|
||||||
|
Shared fixtures for Syndarix CRUD tests.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from datetime import date, timedelta
|
||||||
|
from decimal import Decimal
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from app.models.syndarix import (
|
||||||
|
AgentInstance,
|
||||||
|
AgentStatus,
|
||||||
|
AgentType,
|
||||||
|
AutonomyLevel,
|
||||||
|
Issue,
|
||||||
|
IssuePriority,
|
||||||
|
IssueStatus,
|
||||||
|
Project,
|
||||||
|
ProjectStatus,
|
||||||
|
Sprint,
|
||||||
|
SprintStatus,
|
||||||
|
)
|
||||||
|
from app.models.user import User
|
||||||
|
from app.schemas.syndarix import (
|
||||||
|
AgentTypeCreate,
|
||||||
|
ProjectCreate,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def project_create_data():
|
||||||
|
"""Return data for creating a project via schema."""
|
||||||
|
return ProjectCreate(
|
||||||
|
name="Test Project",
|
||||||
|
slug="test-project-crud",
|
||||||
|
description="A test project for CRUD testing",
|
||||||
|
autonomy_level=AutonomyLevel.MILESTONE,
|
||||||
|
status=ProjectStatus.ACTIVE,
|
||||||
|
settings={"mcp_servers": ["gitea"]},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def agent_type_create_data():
|
||||||
|
"""Return data for creating an agent type via schema."""
|
||||||
|
return AgentTypeCreate(
|
||||||
|
name="Backend Engineer",
|
||||||
|
slug="backend-engineer-crud",
|
||||||
|
description="Specialized in backend development",
|
||||||
|
expertise=["python", "fastapi", "postgresql"],
|
||||||
|
personality_prompt="You are an expert backend engineer with deep knowledge of Python and FastAPI.",
|
||||||
|
primary_model="claude-opus-4-5-20251101",
|
||||||
|
fallback_models=["claude-sonnet-4-20250514"],
|
||||||
|
model_params={"temperature": 0.7, "max_tokens": 4096},
|
||||||
|
mcp_servers=["gitea", "file-system"],
|
||||||
|
tool_permissions={"allowed": ["*"], "denied": []},
|
||||||
|
is_active=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sprint_create_data():
|
||||||
|
"""Return data for creating a sprint via schema."""
|
||||||
|
today = date.today()
|
||||||
|
return {
|
||||||
|
"name": "Sprint 1",
|
||||||
|
"number": 1,
|
||||||
|
"goal": "Complete initial setup and core features",
|
||||||
|
"start_date": today,
|
||||||
|
"end_date": today + timedelta(days=14),
|
||||||
|
"status": SprintStatus.PLANNED,
|
||||||
|
"planned_points": 21,
|
||||||
|
"velocity": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def issue_create_data():
|
||||||
|
"""Return data for creating an issue via schema."""
|
||||||
|
return {
|
||||||
|
"title": "Implement user authentication",
|
||||||
|
"body": "As a user, I want to log in securely so that I can access my account.",
|
||||||
|
"status": IssueStatus.OPEN,
|
||||||
|
"priority": IssuePriority.HIGH,
|
||||||
|
"labels": ["backend", "security"],
|
||||||
|
"story_points": 5,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def test_owner_crud(async_test_db):
|
||||||
|
"""Create a test user to be used as project owner in CRUD tests."""
|
||||||
|
from app.core.auth import get_password_hash
|
||||||
|
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
user = User(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
email="crud-owner@example.com",
|
||||||
|
password_hash=get_password_hash("TestPassword123!"),
|
||||||
|
first_name="CRUD",
|
||||||
|
last_name="Owner",
|
||||||
|
is_active=True,
|
||||||
|
is_superuser=False,
|
||||||
|
)
|
||||||
|
session.add(user)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(user)
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def test_project_crud(async_test_db, test_owner_crud, project_create_data):
|
||||||
|
"""Create a test project in the database for CRUD tests."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
project = Project(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name=project_create_data.name,
|
||||||
|
slug=project_create_data.slug,
|
||||||
|
description=project_create_data.description,
|
||||||
|
autonomy_level=project_create_data.autonomy_level,
|
||||||
|
status=project_create_data.status,
|
||||||
|
settings=project_create_data.settings,
|
||||||
|
owner_id=test_owner_crud.id,
|
||||||
|
)
|
||||||
|
session.add(project)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(project)
|
||||||
|
return project
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def test_agent_type_crud(async_test_db, agent_type_create_data):
|
||||||
|
"""Create a test agent type in the database for CRUD tests."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
agent_type = AgentType(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name=agent_type_create_data.name,
|
||||||
|
slug=agent_type_create_data.slug,
|
||||||
|
description=agent_type_create_data.description,
|
||||||
|
expertise=agent_type_create_data.expertise,
|
||||||
|
personality_prompt=agent_type_create_data.personality_prompt,
|
||||||
|
primary_model=agent_type_create_data.primary_model,
|
||||||
|
fallback_models=agent_type_create_data.fallback_models,
|
||||||
|
model_params=agent_type_create_data.model_params,
|
||||||
|
mcp_servers=agent_type_create_data.mcp_servers,
|
||||||
|
tool_permissions=agent_type_create_data.tool_permissions,
|
||||||
|
is_active=agent_type_create_data.is_active,
|
||||||
|
)
|
||||||
|
session.add(agent_type)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(agent_type)
|
||||||
|
return agent_type
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def test_agent_instance_crud(
|
||||||
|
async_test_db, test_project_crud, test_agent_type_crud
|
||||||
|
):
|
||||||
|
"""Create a test agent instance in the database for CRUD tests."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
agent_instance = AgentInstance(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
agent_type_id=test_agent_type_crud.id,
|
||||||
|
project_id=test_project_crud.id,
|
||||||
|
name="TestAgent",
|
||||||
|
status=AgentStatus.IDLE,
|
||||||
|
current_task=None,
|
||||||
|
short_term_memory={},
|
||||||
|
long_term_memory_ref=None,
|
||||||
|
session_id=None,
|
||||||
|
tasks_completed=0,
|
||||||
|
tokens_used=0,
|
||||||
|
cost_incurred=Decimal("0.0000"),
|
||||||
|
)
|
||||||
|
session.add(agent_instance)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(agent_instance)
|
||||||
|
return agent_instance
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def test_sprint_crud(async_test_db, test_project_crud, sprint_create_data):
|
||||||
|
"""Create a test sprint in the database for CRUD tests."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
sprint = Sprint(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
project_id=test_project_crud.id,
|
||||||
|
**sprint_create_data,
|
||||||
|
)
|
||||||
|
session.add(sprint)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(sprint)
|
||||||
|
return sprint
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def test_issue_crud(async_test_db, test_project_crud, issue_create_data):
|
||||||
|
"""Create a test issue in the database for CRUD tests."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
issue = Issue(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
project_id=test_project_crud.id,
|
||||||
|
**issue_create_data,
|
||||||
|
)
|
||||||
|
session.add(issue)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(issue)
|
||||||
|
return issue
|
||||||
473
backend/tests/crud/syndarix/test_agent_instance.py
Normal file
473
backend/tests/crud/syndarix/test_agent_instance.py
Normal file
@@ -0,0 +1,473 @@
|
|||||||
|
# tests/crud/syndarix/test_agent_instance.py
|
||||||
|
"""Tests for AgentInstance CRUD operations."""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from decimal import Decimal
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from sqlalchemy.exc import IntegrityError, OperationalError
|
||||||
|
|
||||||
|
from app.crud.syndarix.agent_instance import agent_instance
|
||||||
|
from app.models.syndarix import AgentInstance, AgentType, Project
|
||||||
|
from app.models.syndarix.enums import (
|
||||||
|
AgentStatus,
|
||||||
|
ProjectStatus,
|
||||||
|
)
|
||||||
|
from app.schemas.syndarix import AgentInstanceCreate
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def db_session(async_test_db):
|
||||||
|
"""Create a database session for tests."""
|
||||||
|
_, AsyncTestingSessionLocal = async_test_db
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def test_project(db_session):
|
||||||
|
"""Create a test project."""
|
||||||
|
project = Project(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Test Project",
|
||||||
|
slug=f"test-project-{uuid.uuid4().hex[:8]}",
|
||||||
|
status=ProjectStatus.ACTIVE,
|
||||||
|
)
|
||||||
|
db_session.add(project)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(project)
|
||||||
|
return project
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def test_agent_type(db_session):
|
||||||
|
"""Create a test agent type."""
|
||||||
|
agent_type = AgentType(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Test Agent Type",
|
||||||
|
slug=f"test-agent-type-{uuid.uuid4().hex[:8]}",
|
||||||
|
primary_model="claude-3-opus",
|
||||||
|
personality_prompt="You are a helpful test agent.",
|
||||||
|
)
|
||||||
|
db_session.add(agent_type)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(agent_type)
|
||||||
|
return agent_type
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def test_agent_instance(db_session, test_project, test_agent_type):
|
||||||
|
"""Create a test agent instance."""
|
||||||
|
instance = AgentInstance(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
agent_type_id=test_agent_type.id,
|
||||||
|
project_id=test_project.id,
|
||||||
|
name="Test Agent",
|
||||||
|
status=AgentStatus.IDLE,
|
||||||
|
)
|
||||||
|
db_session.add(instance)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(instance)
|
||||||
|
return instance
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentInstanceCreate:
|
||||||
|
"""Tests for agent instance creation."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_instance_success(
|
||||||
|
self, db_session, test_project, test_agent_type
|
||||||
|
):
|
||||||
|
"""Test successful agent instance creation."""
|
||||||
|
instance_data = AgentInstanceCreate(
|
||||||
|
agent_type_id=test_agent_type.id,
|
||||||
|
project_id=test_project.id,
|
||||||
|
name="New Agent",
|
||||||
|
)
|
||||||
|
created = await agent_instance.create(db_session, obj_in=instance_data)
|
||||||
|
assert created.name == "New Agent"
|
||||||
|
assert created.status == AgentStatus.IDLE
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_instance_with_all_fields(
|
||||||
|
self, db_session, test_project, test_agent_type
|
||||||
|
):
|
||||||
|
"""Test agent instance creation with all optional fields."""
|
||||||
|
instance_data = AgentInstanceCreate(
|
||||||
|
agent_type_id=test_agent_type.id,
|
||||||
|
project_id=test_project.id,
|
||||||
|
name="Full Agent",
|
||||||
|
status=AgentStatus.WORKING,
|
||||||
|
current_task="Processing request",
|
||||||
|
short_term_memory={"context": "test context", "history": []},
|
||||||
|
long_term_memory_ref="ref-123",
|
||||||
|
session_id="session-456",
|
||||||
|
)
|
||||||
|
created = await agent_instance.create(db_session, obj_in=instance_data)
|
||||||
|
assert created.current_task == "Processing request"
|
||||||
|
assert created.status == AgentStatus.WORKING
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_instance_integrity_error(
|
||||||
|
self, db_session, test_project, test_agent_type
|
||||||
|
):
|
||||||
|
"""Test agent instance creation with integrity error."""
|
||||||
|
instance_data = AgentInstanceCreate(
|
||||||
|
agent_type_id=test_agent_type.id,
|
||||||
|
project_id=test_project.id,
|
||||||
|
name="Test Agent",
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
db_session,
|
||||||
|
"commit",
|
||||||
|
side_effect=IntegrityError("", {}, Exception()),
|
||||||
|
):
|
||||||
|
with pytest.raises(ValueError, match="Database integrity error"):
|
||||||
|
await agent_instance.create(db_session, obj_in=instance_data)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_instance_unexpected_error(
|
||||||
|
self, db_session, test_project, test_agent_type
|
||||||
|
):
|
||||||
|
"""Test agent instance creation with unexpected error."""
|
||||||
|
instance_data = AgentInstanceCreate(
|
||||||
|
agent_type_id=test_agent_type.id,
|
||||||
|
project_id=test_project.id,
|
||||||
|
name="Test Agent",
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
db_session,
|
||||||
|
"commit",
|
||||||
|
side_effect=RuntimeError("Unexpected error"),
|
||||||
|
):
|
||||||
|
with pytest.raises(RuntimeError, match="Unexpected error"):
|
||||||
|
await agent_instance.create(db_session, obj_in=instance_data)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentInstanceGetWithDetails:
|
||||||
|
"""Tests for getting agent instance with details."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_with_details_not_found(self, db_session):
|
||||||
|
"""Test getting non-existent agent instance with details."""
|
||||||
|
result = await agent_instance.get_with_details(
|
||||||
|
db_session, instance_id=uuid.uuid4()
|
||||||
|
)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_with_details_success(self, db_session, test_agent_instance):
|
||||||
|
"""Test getting agent instance with details."""
|
||||||
|
result = await agent_instance.get_with_details(
|
||||||
|
db_session, instance_id=test_agent_instance.id
|
||||||
|
)
|
||||||
|
assert result is not None
|
||||||
|
assert result["instance"].id == test_agent_instance.id
|
||||||
|
assert "agent_type_name" in result
|
||||||
|
assert "assigned_issues_count" in result
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_with_details_db_error(self, db_session, test_agent_instance):
|
||||||
|
"""Test getting agent instance with details when DB error occurs."""
|
||||||
|
with patch.object(
|
||||||
|
db_session,
|
||||||
|
"execute",
|
||||||
|
side_effect=OperationalError("Connection lost", {}, Exception()),
|
||||||
|
):
|
||||||
|
with pytest.raises(OperationalError):
|
||||||
|
await agent_instance.get_with_details(
|
||||||
|
db_session, instance_id=test_agent_instance.id
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentInstanceGetByProject:
|
||||||
|
"""Tests for getting agent instances by project."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_by_project_success(
|
||||||
|
self, db_session, test_project, test_agent_instance
|
||||||
|
):
|
||||||
|
"""Test getting agent instances by project."""
|
||||||
|
instances, total = await agent_instance.get_by_project(
|
||||||
|
db_session, project_id=test_project.id
|
||||||
|
)
|
||||||
|
assert len(instances) == 1
|
||||||
|
assert total == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_by_project_with_status_filter(
|
||||||
|
self, db_session, test_project, test_agent_instance
|
||||||
|
):
|
||||||
|
"""Test getting agent instances with status filter."""
|
||||||
|
instances, _total = await agent_instance.get_by_project(
|
||||||
|
db_session,
|
||||||
|
project_id=test_project.id,
|
||||||
|
status=AgentStatus.IDLE,
|
||||||
|
)
|
||||||
|
assert len(instances) == 1
|
||||||
|
assert instances[0].status == AgentStatus.IDLE
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_by_project_db_error(self, db_session, test_project):
|
||||||
|
"""Test getting agent instances when DB error occurs."""
|
||||||
|
with patch.object(
|
||||||
|
db_session,
|
||||||
|
"execute",
|
||||||
|
side_effect=OperationalError("Connection lost", {}, Exception()),
|
||||||
|
):
|
||||||
|
with pytest.raises(OperationalError):
|
||||||
|
await agent_instance.get_by_project(
|
||||||
|
db_session, project_id=test_project.id
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentInstanceGetByAgentType:
|
||||||
|
"""Tests for getting agent instances by agent type."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_by_agent_type_success(
|
||||||
|
self, db_session, test_agent_type, test_agent_instance
|
||||||
|
):
|
||||||
|
"""Test getting agent instances by agent type."""
|
||||||
|
instances = await agent_instance.get_by_agent_type(
|
||||||
|
db_session, agent_type_id=test_agent_type.id
|
||||||
|
)
|
||||||
|
assert len(instances) == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_by_agent_type_with_status_filter(
|
||||||
|
self, db_session, test_agent_type, test_agent_instance
|
||||||
|
):
|
||||||
|
"""Test getting agent instances by agent type with status filter."""
|
||||||
|
instances = await agent_instance.get_by_agent_type(
|
||||||
|
db_session,
|
||||||
|
agent_type_id=test_agent_type.id,
|
||||||
|
status=AgentStatus.IDLE,
|
||||||
|
)
|
||||||
|
assert len(instances) == 1
|
||||||
|
assert instances[0].status == AgentStatus.IDLE
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_by_agent_type_db_error(self, db_session, test_agent_type):
|
||||||
|
"""Test getting agent instances by agent type when DB error occurs."""
|
||||||
|
with patch.object(
|
||||||
|
db_session,
|
||||||
|
"execute",
|
||||||
|
side_effect=OperationalError("Connection lost", {}, Exception()),
|
||||||
|
):
|
||||||
|
with pytest.raises(OperationalError):
|
||||||
|
await agent_instance.get_by_agent_type(
|
||||||
|
db_session, agent_type_id=test_agent_type.id
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentInstanceStatusOperations:
|
||||||
|
"""Tests for agent instance status operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_status_not_found(self, db_session):
|
||||||
|
"""Test updating status for non-existent agent instance."""
|
||||||
|
result = await agent_instance.update_status(
|
||||||
|
db_session,
|
||||||
|
instance_id=uuid.uuid4(),
|
||||||
|
status=AgentStatus.WORKING,
|
||||||
|
)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_status_success(self, db_session, test_agent_instance):
|
||||||
|
"""Test successfully updating agent instance status."""
|
||||||
|
result = await agent_instance.update_status(
|
||||||
|
db_session,
|
||||||
|
instance_id=test_agent_instance.id,
|
||||||
|
status=AgentStatus.WORKING,
|
||||||
|
current_task="Processing task",
|
||||||
|
)
|
||||||
|
assert result is not None
|
||||||
|
assert result.status == AgentStatus.WORKING
|
||||||
|
assert result.current_task == "Processing task"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_status_db_error(self, db_session, test_agent_instance):
|
||||||
|
"""Test updating status when DB error occurs."""
|
||||||
|
with patch.object(
|
||||||
|
db_session,
|
||||||
|
"commit",
|
||||||
|
side_effect=OperationalError("Connection lost", {}, Exception()),
|
||||||
|
):
|
||||||
|
with pytest.raises(OperationalError):
|
||||||
|
await agent_instance.update_status(
|
||||||
|
db_session,
|
||||||
|
instance_id=test_agent_instance.id,
|
||||||
|
status=AgentStatus.WORKING,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentInstanceTerminate:
|
||||||
|
"""Tests for agent instance termination."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_terminate_not_found(self, db_session):
|
||||||
|
"""Test terminating non-existent agent instance."""
|
||||||
|
result = await agent_instance.terminate(db_session, instance_id=uuid.uuid4())
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_terminate_success(self, db_session, test_agent_instance):
|
||||||
|
"""Test successfully terminating agent instance."""
|
||||||
|
result = await agent_instance.terminate(
|
||||||
|
db_session, instance_id=test_agent_instance.id
|
||||||
|
)
|
||||||
|
assert result is not None
|
||||||
|
assert result.status == AgentStatus.TERMINATED
|
||||||
|
assert result.terminated_at is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_terminate_db_error(self, db_session, test_agent_instance):
|
||||||
|
"""Test terminating agent instance when DB error occurs."""
|
||||||
|
with patch.object(
|
||||||
|
db_session,
|
||||||
|
"commit",
|
||||||
|
side_effect=OperationalError("Connection lost", {}, Exception()),
|
||||||
|
):
|
||||||
|
with pytest.raises(OperationalError):
|
||||||
|
await agent_instance.terminate(
|
||||||
|
db_session, instance_id=test_agent_instance.id
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentInstanceTaskCompletion:
|
||||||
|
"""Tests for recording task completion."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_record_task_completion_not_found(self, db_session):
|
||||||
|
"""Test recording task completion for non-existent agent instance."""
|
||||||
|
result = await agent_instance.record_task_completion(
|
||||||
|
db_session,
|
||||||
|
instance_id=uuid.uuid4(),
|
||||||
|
tokens_used=100,
|
||||||
|
cost_incurred=Decimal("0.01"),
|
||||||
|
)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_record_task_completion_success(
|
||||||
|
self, db_session, test_agent_instance
|
||||||
|
):
|
||||||
|
"""Test successfully recording task completion."""
|
||||||
|
result = await agent_instance.record_task_completion(
|
||||||
|
db_session,
|
||||||
|
instance_id=test_agent_instance.id,
|
||||||
|
tokens_used=1000,
|
||||||
|
cost_incurred=Decimal("0.05"),
|
||||||
|
)
|
||||||
|
assert result is not None
|
||||||
|
assert result.tasks_completed == 1
|
||||||
|
assert result.tokens_used == 1000
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_record_task_completion_db_error(
|
||||||
|
self, db_session, test_agent_instance
|
||||||
|
):
|
||||||
|
"""Test recording task completion when DB error occurs."""
|
||||||
|
with patch.object(
|
||||||
|
db_session,
|
||||||
|
"commit",
|
||||||
|
side_effect=OperationalError("Connection lost", {}, Exception()),
|
||||||
|
):
|
||||||
|
with pytest.raises(OperationalError):
|
||||||
|
await agent_instance.record_task_completion(
|
||||||
|
db_session,
|
||||||
|
instance_id=test_agent_instance.id,
|
||||||
|
tokens_used=100,
|
||||||
|
cost_incurred=Decimal("0.01"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentInstanceMetrics:
|
||||||
|
"""Tests for agent instance metrics."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_project_metrics_empty(self, db_session, test_project):
|
||||||
|
"""Test getting project metrics with no agent instances."""
|
||||||
|
result = await agent_instance.get_project_metrics(
|
||||||
|
db_session, project_id=test_project.id
|
||||||
|
)
|
||||||
|
assert result["total_instances"] == 0
|
||||||
|
assert result["active_instances"] == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_project_metrics_with_data(
|
||||||
|
self, db_session, test_project, test_agent_instance
|
||||||
|
):
|
||||||
|
"""Test getting project metrics with agent instances."""
|
||||||
|
result = await agent_instance.get_project_metrics(
|
||||||
|
db_session, project_id=test_project.id
|
||||||
|
)
|
||||||
|
assert result["total_instances"] == 1
|
||||||
|
assert result["idle_instances"] == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_project_metrics_db_error(self, db_session, test_project):
|
||||||
|
"""Test getting project metrics when DB error occurs."""
|
||||||
|
with patch.object(
|
||||||
|
db_session,
|
||||||
|
"execute",
|
||||||
|
side_effect=OperationalError("Connection lost", {}, Exception()),
|
||||||
|
):
|
||||||
|
with pytest.raises(OperationalError):
|
||||||
|
await agent_instance.get_project_metrics(
|
||||||
|
db_session, project_id=test_project.id
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentInstanceBulkTerminate:
|
||||||
|
"""Tests for bulk termination."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bulk_terminate_by_project_empty(self, db_session, test_project):
|
||||||
|
"""Test bulk terminating with no agent instances."""
|
||||||
|
count = await agent_instance.bulk_terminate_by_project(
|
||||||
|
db_session, project_id=test_project.id
|
||||||
|
)
|
||||||
|
assert count == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bulk_terminate_by_project_success(
|
||||||
|
self, db_session, test_project, test_agent_instance, test_agent_type
|
||||||
|
):
|
||||||
|
"""Test successfully bulk terminating agent instances."""
|
||||||
|
# Create another active instance
|
||||||
|
instance2 = AgentInstance(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
agent_type_id=test_agent_type.id,
|
||||||
|
project_id=test_project.id,
|
||||||
|
name="Test Agent 2",
|
||||||
|
status=AgentStatus.WORKING,
|
||||||
|
)
|
||||||
|
db_session.add(instance2)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
count = await agent_instance.bulk_terminate_by_project(
|
||||||
|
db_session, project_id=test_project.id
|
||||||
|
)
|
||||||
|
assert count == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bulk_terminate_by_project_db_error(
|
||||||
|
self, db_session, test_project, test_agent_instance
|
||||||
|
):
|
||||||
|
"""Test bulk terminating when DB error occurs."""
|
||||||
|
with patch.object(
|
||||||
|
db_session,
|
||||||
|
"execute",
|
||||||
|
side_effect=OperationalError("Connection lost", {}, Exception()),
|
||||||
|
):
|
||||||
|
with pytest.raises(OperationalError):
|
||||||
|
await agent_instance.bulk_terminate_by_project(
|
||||||
|
db_session, project_id=test_project.id
|
||||||
|
)
|
||||||
436
backend/tests/crud/syndarix/test_agent_instance_crud.py
Normal file
436
backend/tests/crud/syndarix/test_agent_instance_crud.py
Normal file
@@ -0,0 +1,436 @@
|
|||||||
|
# tests/crud/syndarix/test_agent_instance_crud.py
|
||||||
|
"""
|
||||||
|
Tests for AgentInstance CRUD operations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from decimal import Decimal
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.crud.syndarix import agent_instance as agent_instance_crud
|
||||||
|
from app.models.syndarix import AgentStatus
|
||||||
|
from app.schemas.syndarix import AgentInstanceCreate, AgentInstanceUpdate
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentInstanceCreate:
|
||||||
|
"""Tests for agent instance creation."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_agent_instance_success(
|
||||||
|
self, async_test_db, test_project_crud, test_agent_type_crud
|
||||||
|
):
|
||||||
|
"""Test successfully creating an agent instance."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
instance_data = AgentInstanceCreate(
|
||||||
|
agent_type_id=test_agent_type_crud.id,
|
||||||
|
project_id=test_project_crud.id,
|
||||||
|
name="TestBot",
|
||||||
|
status=AgentStatus.IDLE,
|
||||||
|
current_task=None,
|
||||||
|
short_term_memory={"context": "initial"},
|
||||||
|
long_term_memory_ref="project-123/agent-456",
|
||||||
|
session_id="session-abc",
|
||||||
|
)
|
||||||
|
result = await agent_instance_crud.create(session, obj_in=instance_data)
|
||||||
|
|
||||||
|
assert result.id is not None
|
||||||
|
assert result.agent_type_id == test_agent_type_crud.id
|
||||||
|
assert result.project_id == test_project_crud.id
|
||||||
|
assert result.status == AgentStatus.IDLE
|
||||||
|
assert result.short_term_memory == {"context": "initial"}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_agent_instance_minimal(
|
||||||
|
self, async_test_db, test_project_crud, test_agent_type_crud
|
||||||
|
):
|
||||||
|
"""Test creating agent instance with minimal fields."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
instance_data = AgentInstanceCreate(
|
||||||
|
agent_type_id=test_agent_type_crud.id,
|
||||||
|
project_id=test_project_crud.id,
|
||||||
|
name="MinimalBot",
|
||||||
|
)
|
||||||
|
result = await agent_instance_crud.create(session, obj_in=instance_data)
|
||||||
|
|
||||||
|
assert result.status == AgentStatus.IDLE # Default
|
||||||
|
assert result.tasks_completed == 0
|
||||||
|
assert result.tokens_used == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentInstanceRead:
|
||||||
|
"""Tests for agent instance read operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_agent_instance_by_id(
|
||||||
|
self, async_test_db, test_agent_instance_crud
|
||||||
|
):
|
||||||
|
"""Test getting agent instance by ID."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await agent_instance_crud.get(
|
||||||
|
session, id=str(test_agent_instance_crud.id)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.id == test_agent_instance_crud.id
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_agent_instance_by_id_not_found(self, async_test_db):
|
||||||
|
"""Test getting non-existent agent instance returns None."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await agent_instance_crud.get(session, id=str(uuid.uuid4()))
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_with_details(self, async_test_db, test_agent_instance_crud):
|
||||||
|
"""Test getting agent instance with related details."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await agent_instance_crud.get_with_details(
|
||||||
|
session,
|
||||||
|
instance_id=test_agent_instance_crud.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result["instance"].id == test_agent_instance_crud.id
|
||||||
|
assert result["agent_type_name"] is not None
|
||||||
|
assert result["project_name"] is not None
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentInstanceUpdate:
|
||||||
|
"""Tests for agent instance update operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_agent_instance_status(
|
||||||
|
self, async_test_db, test_agent_instance_crud
|
||||||
|
):
|
||||||
|
"""Test updating agent instance status."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
instance = await agent_instance_crud.get(
|
||||||
|
session, id=str(test_agent_instance_crud.id)
|
||||||
|
)
|
||||||
|
|
||||||
|
update_data = AgentInstanceUpdate(
|
||||||
|
status=AgentStatus.WORKING,
|
||||||
|
current_task="Processing feature request",
|
||||||
|
)
|
||||||
|
result = await agent_instance_crud.update(
|
||||||
|
session, db_obj=instance, obj_in=update_data
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.status == AgentStatus.WORKING
|
||||||
|
assert result.current_task == "Processing feature request"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_agent_instance_memory(
|
||||||
|
self, async_test_db, test_agent_instance_crud
|
||||||
|
):
|
||||||
|
"""Test updating agent instance short-term memory."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
instance = await agent_instance_crud.get(
|
||||||
|
session, id=str(test_agent_instance_crud.id)
|
||||||
|
)
|
||||||
|
|
||||||
|
new_memory = {
|
||||||
|
"conversation": ["msg1", "msg2"],
|
||||||
|
"decisions": {"key": "value"},
|
||||||
|
}
|
||||||
|
update_data = AgentInstanceUpdate(short_term_memory=new_memory)
|
||||||
|
result = await agent_instance_crud.update(
|
||||||
|
session, db_obj=instance, obj_in=update_data
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.short_term_memory == new_memory
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentInstanceStatusUpdate:
|
||||||
|
"""Tests for agent instance status update method."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_status(self, async_test_db, test_agent_instance_crud):
|
||||||
|
"""Test updating agent instance status via dedicated method."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await agent_instance_crud.update_status(
|
||||||
|
session,
|
||||||
|
instance_id=test_agent_instance_crud.id,
|
||||||
|
status=AgentStatus.WORKING,
|
||||||
|
current_task="Working on feature X",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.status == AgentStatus.WORKING
|
||||||
|
assert result.current_task == "Working on feature X"
|
||||||
|
assert result.last_activity_at is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_status_nonexistent(self, async_test_db):
|
||||||
|
"""Test updating status of non-existent instance returns None."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await agent_instance_crud.update_status(
|
||||||
|
session,
|
||||||
|
instance_id=uuid.uuid4(),
|
||||||
|
status=AgentStatus.WORKING,
|
||||||
|
)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentInstanceTerminate:
|
||||||
|
"""Tests for agent instance termination."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_terminate_agent_instance(
|
||||||
|
self, async_test_db, test_project_crud, test_agent_type_crud
|
||||||
|
):
|
||||||
|
"""Test terminating an agent instance."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
# Create an instance to terminate
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
instance_data = AgentInstanceCreate(
|
||||||
|
agent_type_id=test_agent_type_crud.id,
|
||||||
|
project_id=test_project_crud.id,
|
||||||
|
name="TerminateBot",
|
||||||
|
status=AgentStatus.WORKING,
|
||||||
|
)
|
||||||
|
created = await agent_instance_crud.create(session, obj_in=instance_data)
|
||||||
|
instance_id = created.id
|
||||||
|
|
||||||
|
# Terminate
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await agent_instance_crud.terminate(
|
||||||
|
session, instance_id=instance_id
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.status == AgentStatus.TERMINATED
|
||||||
|
assert result.terminated_at is not None
|
||||||
|
assert result.current_task is None
|
||||||
|
assert result.session_id is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_terminate_nonexistent_instance(self, async_test_db):
|
||||||
|
"""Test terminating non-existent instance returns None."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await agent_instance_crud.terminate(
|
||||||
|
session, instance_id=uuid.uuid4()
|
||||||
|
)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentInstanceMetrics:
|
||||||
|
"""Tests for agent instance metrics operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_record_task_completion(
|
||||||
|
self, async_test_db, test_agent_instance_crud
|
||||||
|
):
|
||||||
|
"""Test recording task completion with metrics."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await agent_instance_crud.record_task_completion(
|
||||||
|
session,
|
||||||
|
instance_id=test_agent_instance_crud.id,
|
||||||
|
tokens_used=1500,
|
||||||
|
cost_incurred=Decimal("0.0150"),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.tasks_completed == 1
|
||||||
|
assert result.tokens_used == 1500
|
||||||
|
assert result.cost_incurred == Decimal("0.0150")
|
||||||
|
assert result.last_activity_at is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_record_multiple_task_completions(
|
||||||
|
self, async_test_db, test_project_crud, test_agent_type_crud
|
||||||
|
):
|
||||||
|
"""Test recording multiple task completions accumulates metrics."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
# Create fresh instance
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
instance_data = AgentInstanceCreate(
|
||||||
|
agent_type_id=test_agent_type_crud.id,
|
||||||
|
project_id=test_project_crud.id,
|
||||||
|
name="MetricsBot",
|
||||||
|
)
|
||||||
|
created = await agent_instance_crud.create(session, obj_in=instance_data)
|
||||||
|
instance_id = created.id
|
||||||
|
|
||||||
|
# Record first task
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
await agent_instance_crud.record_task_completion(
|
||||||
|
session,
|
||||||
|
instance_id=instance_id,
|
||||||
|
tokens_used=1000,
|
||||||
|
cost_incurred=Decimal("0.0100"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Record second task
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await agent_instance_crud.record_task_completion(
|
||||||
|
session,
|
||||||
|
instance_id=instance_id,
|
||||||
|
tokens_used=2000,
|
||||||
|
cost_incurred=Decimal("0.0200"),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.tasks_completed == 2
|
||||||
|
assert result.tokens_used == 3000
|
||||||
|
assert result.cost_incurred == Decimal("0.0300")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_project_metrics(
|
||||||
|
self, async_test_db, test_project_crud, test_agent_instance_crud
|
||||||
|
):
|
||||||
|
"""Test getting aggregated metrics for a project."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await agent_instance_crud.get_project_metrics(
|
||||||
|
session,
|
||||||
|
project_id=test_project_crud.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert "total_instances" in result
|
||||||
|
assert "active_instances" in result
|
||||||
|
assert "idle_instances" in result
|
||||||
|
assert "total_tasks_completed" in result
|
||||||
|
assert "total_tokens_used" in result
|
||||||
|
assert "total_cost_incurred" in result
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentInstanceByProject:
|
||||||
|
"""Tests for getting instances by project."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_by_project(
|
||||||
|
self, async_test_db, test_project_crud, test_agent_instance_crud
|
||||||
|
):
|
||||||
|
"""Test getting instances by project."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
instances, total = await agent_instance_crud.get_by_project(
|
||||||
|
session,
|
||||||
|
project_id=test_project_crud.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert total >= 1
|
||||||
|
assert all(i.project_id == test_project_crud.id for i in instances)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_by_project_with_status(
|
||||||
|
self, async_test_db, test_project_crud, test_agent_type_crud
|
||||||
|
):
|
||||||
|
"""Test getting instances by project filtered by status."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
# Create instances with different statuses
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
idle_instance = AgentInstanceCreate(
|
||||||
|
agent_type_id=test_agent_type_crud.id,
|
||||||
|
project_id=test_project_crud.id,
|
||||||
|
name="IdleBot",
|
||||||
|
status=AgentStatus.IDLE,
|
||||||
|
)
|
||||||
|
await agent_instance_crud.create(session, obj_in=idle_instance)
|
||||||
|
|
||||||
|
working_instance = AgentInstanceCreate(
|
||||||
|
agent_type_id=test_agent_type_crud.id,
|
||||||
|
project_id=test_project_crud.id,
|
||||||
|
name="WorkerBot",
|
||||||
|
status=AgentStatus.WORKING,
|
||||||
|
)
|
||||||
|
await agent_instance_crud.create(session, obj_in=working_instance)
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
instances, _total = await agent_instance_crud.get_by_project(
|
||||||
|
session,
|
||||||
|
project_id=test_project_crud.id,
|
||||||
|
status=AgentStatus.WORKING,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert all(i.status == AgentStatus.WORKING for i in instances)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentInstanceByAgentType:
|
||||||
|
"""Tests for getting instances by agent type."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_by_agent_type(
|
||||||
|
self, async_test_db, test_agent_type_crud, test_agent_instance_crud
|
||||||
|
):
|
||||||
|
"""Test getting instances by agent type."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
instances = await agent_instance_crud.get_by_agent_type(
|
||||||
|
session,
|
||||||
|
agent_type_id=test_agent_type_crud.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(instances) >= 1
|
||||||
|
assert all(i.agent_type_id == test_agent_type_crud.id for i in instances)
|
||||||
|
|
||||||
|
|
||||||
|
class TestBulkTerminate:
|
||||||
|
"""Tests for bulk termination of instances."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bulk_terminate_by_project(
|
||||||
|
self, async_test_db, test_project_crud, test_agent_type_crud
|
||||||
|
):
|
||||||
|
"""Test bulk terminating all instances in a project."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
# Create multiple instances
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
for i in range(3):
|
||||||
|
instance_data = AgentInstanceCreate(
|
||||||
|
agent_type_id=test_agent_type_crud.id,
|
||||||
|
project_id=test_project_crud.id,
|
||||||
|
name=f"BulkBot-{i}",
|
||||||
|
status=AgentStatus.WORKING if i < 2 else AgentStatus.IDLE,
|
||||||
|
)
|
||||||
|
await agent_instance_crud.create(session, obj_in=instance_data)
|
||||||
|
|
||||||
|
# Bulk terminate
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
count = await agent_instance_crud.bulk_terminate_by_project(
|
||||||
|
session,
|
||||||
|
project_id=test_project_crud.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert count >= 3
|
||||||
|
|
||||||
|
# Verify all are terminated
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
instances, _ = await agent_instance_crud.get_by_project(
|
||||||
|
session,
|
||||||
|
project_id=test_project_crud.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
for instance in instances:
|
||||||
|
assert instance.status == AgentStatus.TERMINATED
|
||||||
312
backend/tests/crud/syndarix/test_agent_type.py
Normal file
312
backend/tests/crud/syndarix/test_agent_type.py
Normal file
@@ -0,0 +1,312 @@
|
|||||||
|
# tests/crud/syndarix/test_agent_type.py
|
||||||
|
"""Tests for AgentType CRUD operations."""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from sqlalchemy.exc import IntegrityError, OperationalError
|
||||||
|
|
||||||
|
from app.crud.syndarix.agent_type import agent_type
|
||||||
|
from app.models.syndarix import AgentType
|
||||||
|
from app.schemas.syndarix import AgentTypeCreate
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def db_session(async_test_db):
|
||||||
|
"""Create a database session for tests."""
|
||||||
|
_, AsyncTestingSessionLocal = async_test_db
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def test_agent_type(db_session):
|
||||||
|
"""Create a test agent type."""
|
||||||
|
at = AgentType(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Test Agent Type",
|
||||||
|
slug=f"test-agent-type-{uuid.uuid4().hex[:8]}",
|
||||||
|
primary_model="claude-3-opus",
|
||||||
|
personality_prompt="You are a helpful test agent.",
|
||||||
|
expertise=["python", "testing"],
|
||||||
|
is_active=True,
|
||||||
|
)
|
||||||
|
db_session.add(at)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(at)
|
||||||
|
return at
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentTypeGetBySlug:
|
||||||
|
"""Tests for getting agent type by slug."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_by_slug_not_found(self, db_session):
|
||||||
|
"""Test getting non-existent agent type by slug."""
|
||||||
|
result = await agent_type.get_by_slug(db_session, slug="nonexistent")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_by_slug_success(self, db_session, test_agent_type):
|
||||||
|
"""Test successfully getting agent type by slug."""
|
||||||
|
result = await agent_type.get_by_slug(db_session, slug=test_agent_type.slug)
|
||||||
|
assert result is not None
|
||||||
|
assert result.id == test_agent_type.id
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_by_slug_db_error(self, db_session):
|
||||||
|
"""Test getting agent type by slug when DB error occurs."""
|
||||||
|
with patch.object(
|
||||||
|
db_session,
|
||||||
|
"execute",
|
||||||
|
side_effect=OperationalError("Connection lost", {}, Exception()),
|
||||||
|
):
|
||||||
|
with pytest.raises(OperationalError):
|
||||||
|
await agent_type.get_by_slug(db_session, slug="test")
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentTypeCreate:
|
||||||
|
"""Tests for agent type creation."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_agent_type_success(self, db_session):
|
||||||
|
"""Test successful agent type creation."""
|
||||||
|
agent_type_data = AgentTypeCreate(
|
||||||
|
name="New Agent Type",
|
||||||
|
slug=f"new-agent-type-{uuid.uuid4().hex[:8]}",
|
||||||
|
primary_model="claude-3-opus",
|
||||||
|
personality_prompt="You are a new agent.",
|
||||||
|
)
|
||||||
|
created = await agent_type.create(db_session, obj_in=agent_type_data)
|
||||||
|
assert created.name == "New Agent Type"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_agent_type_duplicate_slug(self, db_session, test_agent_type):
|
||||||
|
"""Test agent type creation with duplicate slug."""
|
||||||
|
agent_type_data = AgentTypeCreate(
|
||||||
|
name="Another Agent Type",
|
||||||
|
slug=test_agent_type.slug, # Use existing slug
|
||||||
|
primary_model="claude-3-opus",
|
||||||
|
personality_prompt="You are another agent.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock IntegrityError with slug in the message
|
||||||
|
mock_orig = MagicMock()
|
||||||
|
mock_orig.__str__ = (
|
||||||
|
lambda self: "duplicate key value violates unique constraint on slug"
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
db_session,
|
||||||
|
"commit",
|
||||||
|
side_effect=IntegrityError("", {}, mock_orig),
|
||||||
|
):
|
||||||
|
with pytest.raises(ValueError, match="already exists"):
|
||||||
|
await agent_type.create(db_session, obj_in=agent_type_data)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_agent_type_integrity_error(self, db_session):
|
||||||
|
"""Test agent type creation with general integrity error."""
|
||||||
|
agent_type_data = AgentTypeCreate(
|
||||||
|
name="Test Agent Type",
|
||||||
|
slug=f"test-{uuid.uuid4().hex[:8]}",
|
||||||
|
primary_model="claude-3-opus",
|
||||||
|
personality_prompt="You are a test agent.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock IntegrityError without slug in the message
|
||||||
|
mock_orig = MagicMock()
|
||||||
|
mock_orig.__str__ = lambda self: "foreign key constraint violation"
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
db_session,
|
||||||
|
"commit",
|
||||||
|
side_effect=IntegrityError("", {}, mock_orig),
|
||||||
|
):
|
||||||
|
with pytest.raises(ValueError, match="Database integrity error"):
|
||||||
|
await agent_type.create(db_session, obj_in=agent_type_data)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_agent_type_unexpected_error(self, db_session):
|
||||||
|
"""Test agent type creation with unexpected error."""
|
||||||
|
agent_type_data = AgentTypeCreate(
|
||||||
|
name="Test Agent Type",
|
||||||
|
slug=f"test-{uuid.uuid4().hex[:8]}",
|
||||||
|
primary_model="claude-3-opus",
|
||||||
|
personality_prompt="You are a test agent.",
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
db_session,
|
||||||
|
"commit",
|
||||||
|
side_effect=RuntimeError("Unexpected error"),
|
||||||
|
):
|
||||||
|
with pytest.raises(RuntimeError, match="Unexpected error"):
|
||||||
|
await agent_type.create(db_session, obj_in=agent_type_data)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentTypeGetMultiWithFilters:
|
||||||
|
"""Tests for getting agent types with filters."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_multi_with_filters_success(self, db_session, test_agent_type):
|
||||||
|
"""Test successfully getting agent types with filters."""
|
||||||
|
_results, total = await agent_type.get_multi_with_filters(db_session)
|
||||||
|
assert total >= 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_multi_with_filters_sort_asc(self, db_session, test_agent_type):
|
||||||
|
"""Test getting agent types with ascending sort order."""
|
||||||
|
_results, total = await agent_type.get_multi_with_filters(
|
||||||
|
db_session,
|
||||||
|
sort_by="created_at",
|
||||||
|
sort_order="asc",
|
||||||
|
)
|
||||||
|
assert total >= 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_multi_with_filters_db_error(self, db_session):
|
||||||
|
"""Test getting agent types when DB error occurs."""
|
||||||
|
with patch.object(
|
||||||
|
db_session,
|
||||||
|
"execute",
|
||||||
|
side_effect=OperationalError("Connection lost", {}, Exception()),
|
||||||
|
):
|
||||||
|
with pytest.raises(OperationalError):
|
||||||
|
await agent_type.get_multi_with_filters(db_session)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentTypeGetWithInstanceCount:
|
||||||
|
"""Tests for getting agent type with instance count."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_with_instance_count_not_found(self, db_session):
|
||||||
|
"""Test getting non-existent agent type with instance count."""
|
||||||
|
result = await agent_type.get_with_instance_count(
|
||||||
|
db_session, agent_type_id=uuid.uuid4()
|
||||||
|
)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_with_instance_count_success(self, db_session, test_agent_type):
|
||||||
|
"""Test successfully getting agent type with instance count."""
|
||||||
|
result = await agent_type.get_with_instance_count(
|
||||||
|
db_session, agent_type_id=test_agent_type.id
|
||||||
|
)
|
||||||
|
assert result is not None
|
||||||
|
assert result["agent_type"].id == test_agent_type.id
|
||||||
|
assert result["instance_count"] == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_with_instance_count_db_error(self, db_session, test_agent_type):
|
||||||
|
"""Test getting agent type with instance count when DB error occurs."""
|
||||||
|
with patch.object(
|
||||||
|
db_session,
|
||||||
|
"execute",
|
||||||
|
side_effect=OperationalError("Connection lost", {}, Exception()),
|
||||||
|
):
|
||||||
|
with pytest.raises(OperationalError):
|
||||||
|
await agent_type.get_with_instance_count(
|
||||||
|
db_session, agent_type_id=test_agent_type.id
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentTypeGetMultiWithInstanceCounts:
|
||||||
|
"""Tests for getting agent types with instance counts."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_multi_with_instance_counts_empty(self, db_session):
|
||||||
|
"""Test getting agent types with instance counts when none exist."""
|
||||||
|
# Create a separate project to ensure isolation
|
||||||
|
results, total = await agent_type.get_multi_with_instance_counts(
|
||||||
|
db_session,
|
||||||
|
is_active=None,
|
||||||
|
search="nonexistent-xyz-query",
|
||||||
|
)
|
||||||
|
assert results == []
|
||||||
|
assert total == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_multi_with_instance_counts_success(
|
||||||
|
self, db_session, test_agent_type
|
||||||
|
):
|
||||||
|
"""Test successfully getting agent types with instance counts."""
|
||||||
|
results, total = await agent_type.get_multi_with_instance_counts(db_session)
|
||||||
|
assert total >= 1
|
||||||
|
assert len(results) >= 1
|
||||||
|
assert "agent_type" in results[0]
|
||||||
|
assert "instance_count" in results[0]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_multi_with_instance_counts_db_error(
|
||||||
|
self, db_session, test_agent_type
|
||||||
|
):
|
||||||
|
"""Test getting agent types with instance counts when DB error occurs."""
|
||||||
|
with patch.object(
|
||||||
|
db_session,
|
||||||
|
"execute",
|
||||||
|
side_effect=OperationalError("Connection lost", {}, Exception()),
|
||||||
|
):
|
||||||
|
with pytest.raises(OperationalError):
|
||||||
|
await agent_type.get_multi_with_instance_counts(db_session)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentTypeGetByExpertise:
|
||||||
|
"""Tests for getting agent types by expertise."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.skip(
|
||||||
|
reason="Uses PostgreSQL JSONB contains operator, not available in SQLite"
|
||||||
|
)
|
||||||
|
async def test_get_by_expertise_success(self, db_session, test_agent_type):
|
||||||
|
"""Test successfully getting agent types by expertise."""
|
||||||
|
results = await agent_type.get_by_expertise(db_session, expertise="python")
|
||||||
|
assert len(results) >= 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.skip(
|
||||||
|
reason="Uses PostgreSQL JSONB contains operator, not available in SQLite"
|
||||||
|
)
|
||||||
|
async def test_get_by_expertise_db_error(self, db_session):
|
||||||
|
"""Test getting agent types by expertise when DB error occurs."""
|
||||||
|
with patch.object(
|
||||||
|
db_session,
|
||||||
|
"execute",
|
||||||
|
side_effect=OperationalError("Connection lost", {}, Exception()),
|
||||||
|
):
|
||||||
|
with pytest.raises(OperationalError):
|
||||||
|
await agent_type.get_by_expertise(db_session, expertise="python")
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentTypeDeactivate:
|
||||||
|
"""Tests for deactivating agent types."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_deactivate_not_found(self, db_session):
|
||||||
|
"""Test deactivating non-existent agent type."""
|
||||||
|
result = await agent_type.deactivate(db_session, agent_type_id=uuid.uuid4())
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_deactivate_success(self, db_session, test_agent_type):
|
||||||
|
"""Test successfully deactivating agent type."""
|
||||||
|
result = await agent_type.deactivate(
|
||||||
|
db_session, agent_type_id=test_agent_type.id
|
||||||
|
)
|
||||||
|
assert result is not None
|
||||||
|
assert result.is_active is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_deactivate_db_error(self, db_session, test_agent_type):
|
||||||
|
"""Test deactivating agent type when DB error occurs."""
|
||||||
|
with patch.object(
|
||||||
|
db_session,
|
||||||
|
"commit",
|
||||||
|
side_effect=OperationalError("Connection lost", {}, Exception()),
|
||||||
|
):
|
||||||
|
with pytest.raises(OperationalError):
|
||||||
|
await agent_type.deactivate(
|
||||||
|
db_session, agent_type_id=test_agent_type.id
|
||||||
|
)
|
||||||
383
backend/tests/crud/syndarix/test_agent_type_crud.py
Normal file
383
backend/tests/crud/syndarix/test_agent_type_crud.py
Normal file
@@ -0,0 +1,383 @@
|
|||||||
|
# tests/crud/syndarix/test_agent_type_crud.py
|
||||||
|
"""
|
||||||
|
Tests for AgentType CRUD operations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.crud.syndarix import agent_type as agent_type_crud
|
||||||
|
from app.schemas.syndarix import AgentTypeCreate, AgentTypeUpdate
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentTypeCreate:
|
||||||
|
"""Tests for agent type creation."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_agent_type_success(self, async_test_db):
|
||||||
|
"""Test successfully creating an agent type."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
agent_type_data = AgentTypeCreate(
|
||||||
|
name="QA Engineer",
|
||||||
|
slug="qa-engineer",
|
||||||
|
description="Specialized in testing and quality assurance",
|
||||||
|
expertise=["testing", "pytest", "playwright"],
|
||||||
|
personality_prompt="You are an expert QA engineer...",
|
||||||
|
primary_model="claude-opus-4-5-20251101",
|
||||||
|
fallback_models=["claude-sonnet-4-20250514"],
|
||||||
|
model_params={"temperature": 0.5},
|
||||||
|
mcp_servers=["gitea"],
|
||||||
|
tool_permissions={"allowed": ["*"]},
|
||||||
|
is_active=True,
|
||||||
|
)
|
||||||
|
result = await agent_type_crud.create(session, obj_in=agent_type_data)
|
||||||
|
|
||||||
|
assert result.id is not None
|
||||||
|
assert result.name == "QA Engineer"
|
||||||
|
assert result.slug == "qa-engineer"
|
||||||
|
assert result.expertise == ["testing", "pytest", "playwright"]
|
||||||
|
assert result.is_active is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_agent_type_duplicate_slug_fails(
|
||||||
|
self, async_test_db, test_agent_type_crud
|
||||||
|
):
|
||||||
|
"""Test creating agent type with duplicate slug raises ValueError."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
agent_type_data = AgentTypeCreate(
|
||||||
|
name="Duplicate Agent",
|
||||||
|
slug=test_agent_type_crud.slug, # Duplicate slug
|
||||||
|
personality_prompt="Duplicate",
|
||||||
|
primary_model="claude-opus-4-5-20251101",
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
await agent_type_crud.create(session, obj_in=agent_type_data)
|
||||||
|
|
||||||
|
assert "already exists" in str(exc_info.value).lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_agent_type_minimal_fields(self, async_test_db):
|
||||||
|
"""Test creating agent type with minimal required fields."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
agent_type_data = AgentTypeCreate(
|
||||||
|
name="Minimal Agent",
|
||||||
|
slug="minimal-agent",
|
||||||
|
personality_prompt="You are an assistant.",
|
||||||
|
primary_model="claude-opus-4-5-20251101",
|
||||||
|
)
|
||||||
|
result = await agent_type_crud.create(session, obj_in=agent_type_data)
|
||||||
|
|
||||||
|
assert result.name == "Minimal Agent"
|
||||||
|
assert result.expertise == [] # Default
|
||||||
|
assert result.fallback_models == [] # Default
|
||||||
|
assert result.is_active is True # Default
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentTypeRead:
|
||||||
|
"""Tests for agent type read operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_agent_type_by_id(self, async_test_db, test_agent_type_crud):
|
||||||
|
"""Test getting agent type by ID."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await agent_type_crud.get(session, id=str(test_agent_type_crud.id))
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.id == test_agent_type_crud.id
|
||||||
|
assert result.name == test_agent_type_crud.name
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_agent_type_by_id_not_found(self, async_test_db):
|
||||||
|
"""Test getting non-existent agent type returns None."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await agent_type_crud.get(session, id=str(uuid.uuid4()))
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_agent_type_by_slug(self, async_test_db, test_agent_type_crud):
|
||||||
|
"""Test getting agent type by slug."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await agent_type_crud.get_by_slug(
|
||||||
|
session, slug=test_agent_type_crud.slug
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.slug == test_agent_type_crud.slug
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_agent_type_by_slug_not_found(self, async_test_db):
|
||||||
|
"""Test getting non-existent slug returns None."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await agent_type_crud.get_by_slug(
|
||||||
|
session, slug="non-existent-agent"
|
||||||
|
)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentTypeUpdate:
|
||||||
|
"""Tests for agent type update operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_agent_type_basic_fields(
|
||||||
|
self, async_test_db, test_agent_type_crud
|
||||||
|
):
|
||||||
|
"""Test updating basic agent type fields."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
agent_type = await agent_type_crud.get(
|
||||||
|
session, id=str(test_agent_type_crud.id)
|
||||||
|
)
|
||||||
|
|
||||||
|
update_data = AgentTypeUpdate(
|
||||||
|
name="Updated Agent Name",
|
||||||
|
description="Updated description",
|
||||||
|
)
|
||||||
|
result = await agent_type_crud.update(
|
||||||
|
session, db_obj=agent_type, obj_in=update_data
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.name == "Updated Agent Name"
|
||||||
|
assert result.description == "Updated description"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_agent_type_expertise(
|
||||||
|
self, async_test_db, test_agent_type_crud
|
||||||
|
):
|
||||||
|
"""Test updating agent type expertise."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
agent_type = await agent_type_crud.get(
|
||||||
|
session, id=str(test_agent_type_crud.id)
|
||||||
|
)
|
||||||
|
|
||||||
|
update_data = AgentTypeUpdate(
|
||||||
|
expertise=["new-skill", "another-skill"],
|
||||||
|
)
|
||||||
|
result = await agent_type_crud.update(
|
||||||
|
session, db_obj=agent_type, obj_in=update_data
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "new-skill" in result.expertise
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_agent_type_model_params(
|
||||||
|
self, async_test_db, test_agent_type_crud
|
||||||
|
):
|
||||||
|
"""Test updating agent type model parameters."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
agent_type = await agent_type_crud.get(
|
||||||
|
session, id=str(test_agent_type_crud.id)
|
||||||
|
)
|
||||||
|
|
||||||
|
new_params = {"temperature": 0.9, "max_tokens": 8192}
|
||||||
|
update_data = AgentTypeUpdate(model_params=new_params)
|
||||||
|
result = await agent_type_crud.update(
|
||||||
|
session, db_obj=agent_type, obj_in=update_data
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.model_params == new_params
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentTypeDelete:
|
||||||
|
"""Tests for agent type delete operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_agent_type(self, async_test_db):
|
||||||
|
"""Test deleting an agent type."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
# Create an agent type to delete
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
agent_type_data = AgentTypeCreate(
|
||||||
|
name="Delete Me Agent",
|
||||||
|
slug="delete-me-agent",
|
||||||
|
personality_prompt="Delete test",
|
||||||
|
primary_model="claude-opus-4-5-20251101",
|
||||||
|
)
|
||||||
|
created = await agent_type_crud.create(session, obj_in=agent_type_data)
|
||||||
|
agent_type_id = created.id
|
||||||
|
|
||||||
|
# Delete the agent type
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await agent_type_crud.remove(session, id=str(agent_type_id))
|
||||||
|
assert result is not None
|
||||||
|
|
||||||
|
# Verify deletion
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
deleted = await agent_type_crud.get(session, id=str(agent_type_id))
|
||||||
|
assert deleted is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentTypeFilters:
|
||||||
|
"""Tests for agent type filtering and search."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_multi_with_filters_active(self, async_test_db):
|
||||||
|
"""Test filtering agent types by is_active."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
# Create active and inactive agent types
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
active_type = AgentTypeCreate(
|
||||||
|
name="Active Agent Type",
|
||||||
|
slug="active-agent-type-filter",
|
||||||
|
personality_prompt="Active",
|
||||||
|
primary_model="claude-opus-4-5-20251101",
|
||||||
|
is_active=True,
|
||||||
|
)
|
||||||
|
await agent_type_crud.create(session, obj_in=active_type)
|
||||||
|
|
||||||
|
inactive_type = AgentTypeCreate(
|
||||||
|
name="Inactive Agent Type",
|
||||||
|
slug="inactive-agent-type-filter",
|
||||||
|
personality_prompt="Inactive",
|
||||||
|
primary_model="claude-opus-4-5-20251101",
|
||||||
|
is_active=False,
|
||||||
|
)
|
||||||
|
await agent_type_crud.create(session, obj_in=inactive_type)
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
active_types, _ = await agent_type_crud.get_multi_with_filters(
|
||||||
|
session,
|
||||||
|
is_active=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert all(at.is_active for at in active_types)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_multi_with_filters_search(self, async_test_db):
|
||||||
|
"""Test searching agent types by name."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
agent_type_data = AgentTypeCreate(
|
||||||
|
name="Searchable Agent Type",
|
||||||
|
slug="searchable-agent-type",
|
||||||
|
description="This is searchable",
|
||||||
|
personality_prompt="Searchable",
|
||||||
|
primary_model="claude-opus-4-5-20251101",
|
||||||
|
)
|
||||||
|
await agent_type_crud.create(session, obj_in=agent_type_data)
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
agent_types, total = await agent_type_crud.get_multi_with_filters(
|
||||||
|
session,
|
||||||
|
search="Searchable",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert total >= 1
|
||||||
|
assert any(at.name == "Searchable Agent Type" for at in agent_types)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_multi_with_filters_pagination(self, async_test_db):
|
||||||
|
"""Test pagination of agent type results."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
for i in range(5):
|
||||||
|
agent_type_data = AgentTypeCreate(
|
||||||
|
name=f"Page Agent Type {i}",
|
||||||
|
slug=f"page-agent-type-{i}",
|
||||||
|
personality_prompt=f"Page {i}",
|
||||||
|
primary_model="claude-opus-4-5-20251101",
|
||||||
|
)
|
||||||
|
await agent_type_crud.create(session, obj_in=agent_type_data)
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
page1, _total = await agent_type_crud.get_multi_with_filters(
|
||||||
|
session,
|
||||||
|
skip=0,
|
||||||
|
limit=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(page1) <= 2
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentTypeSpecialMethods:
|
||||||
|
"""Tests for special agent type CRUD methods."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_deactivate_agent_type(self, async_test_db):
|
||||||
|
"""Test deactivating an agent type."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
# Create an active agent type
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
agent_type_data = AgentTypeCreate(
|
||||||
|
name="Deactivate Me",
|
||||||
|
slug="deactivate-me-agent",
|
||||||
|
personality_prompt="Deactivate",
|
||||||
|
primary_model="claude-opus-4-5-20251101",
|
||||||
|
is_active=True,
|
||||||
|
)
|
||||||
|
created = await agent_type_crud.create(session, obj_in=agent_type_data)
|
||||||
|
agent_type_id = created.id
|
||||||
|
|
||||||
|
# Deactivate
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await agent_type_crud.deactivate(
|
||||||
|
session, agent_type_id=agent_type_id
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.is_active is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_deactivate_nonexistent_agent_type(self, async_test_db):
|
||||||
|
"""Test deactivating non-existent agent type returns None."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await agent_type_crud.deactivate(
|
||||||
|
session, agent_type_id=uuid.uuid4()
|
||||||
|
)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_with_instance_count(
|
||||||
|
self, async_test_db, test_agent_type_crud, test_agent_instance_crud
|
||||||
|
):
|
||||||
|
"""Test getting agent type with instance count."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await agent_type_crud.get_with_instance_count(
|
||||||
|
session,
|
||||||
|
agent_type_id=test_agent_type_crud.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result["agent_type"].id == test_agent_type_crud.id
|
||||||
|
assert result["instance_count"] >= 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_with_instance_count_not_found(self, async_test_db):
|
||||||
|
"""Test getting non-existent agent type with count returns None."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await agent_type_crud.get_with_instance_count(
|
||||||
|
session,
|
||||||
|
agent_type_id=uuid.uuid4(),
|
||||||
|
)
|
||||||
|
assert result is None
|
||||||
682
backend/tests/crud/syndarix/test_issue.py
Normal file
682
backend/tests/crud/syndarix/test_issue.py
Normal file
@@ -0,0 +1,682 @@
|
|||||||
|
# tests/crud/syndarix/test_issue.py
|
||||||
|
"""Tests for Issue CRUD operations."""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from sqlalchemy.exc import IntegrityError, OperationalError
|
||||||
|
|
||||||
|
from app.crud.syndarix.issue import issue
|
||||||
|
from app.models.syndarix import Issue, Project, Sprint
|
||||||
|
from app.models.syndarix.enums import (
|
||||||
|
IssuePriority,
|
||||||
|
IssueStatus,
|
||||||
|
ProjectStatus,
|
||||||
|
SprintStatus,
|
||||||
|
SyncStatus,
|
||||||
|
)
|
||||||
|
from app.schemas.syndarix import IssueCreate
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def db_session(async_test_db):
|
||||||
|
"""Create a database session for tests."""
|
||||||
|
_, AsyncTestingSessionLocal = async_test_db
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def test_project(db_session):
|
||||||
|
"""Create a test project for issues."""
|
||||||
|
project = Project(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Test Project",
|
||||||
|
slug=f"test-project-{uuid.uuid4().hex[:8]}",
|
||||||
|
status=ProjectStatus.ACTIVE,
|
||||||
|
)
|
||||||
|
db_session.add(project)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(project)
|
||||||
|
return project
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def test_sprint(db_session, test_project):
|
||||||
|
"""Create a test sprint."""
|
||||||
|
from datetime import date
|
||||||
|
|
||||||
|
sprint = Sprint(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
project_id=test_project.id,
|
||||||
|
name="Test Sprint",
|
||||||
|
number=1,
|
||||||
|
status=SprintStatus.PLANNED,
|
||||||
|
start_date=date.today(),
|
||||||
|
end_date=date.today(),
|
||||||
|
)
|
||||||
|
db_session.add(sprint)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(sprint)
|
||||||
|
return sprint
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def test_issue(db_session, test_project):
|
||||||
|
"""Create a test issue."""
|
||||||
|
issue_obj = Issue(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
project_id=test_project.id,
|
||||||
|
title="Test Issue",
|
||||||
|
body="Test issue body",
|
||||||
|
status=IssueStatus.OPEN,
|
||||||
|
priority=IssuePriority.MEDIUM,
|
||||||
|
labels=["bug", "backend"],
|
||||||
|
)
|
||||||
|
db_session.add(issue_obj)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(issue_obj)
|
||||||
|
return issue_obj
|
||||||
|
|
||||||
|
|
||||||
|
class TestIssueCreate:
|
||||||
|
"""Tests for issue creation."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_issue_success(self, db_session, test_project):
|
||||||
|
"""Test successful issue creation."""
|
||||||
|
issue_data = IssueCreate(
|
||||||
|
project_id=test_project.id,
|
||||||
|
title="New Issue",
|
||||||
|
body="Issue description",
|
||||||
|
status=IssueStatus.OPEN,
|
||||||
|
priority=IssuePriority.HIGH,
|
||||||
|
labels=["feature"],
|
||||||
|
)
|
||||||
|
created = await issue.create(db_session, obj_in=issue_data)
|
||||||
|
assert created.title == "New Issue"
|
||||||
|
assert created.priority == IssuePriority.HIGH
|
||||||
|
assert created.sync_status == SyncStatus.SYNCED
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_issue_with_external_tracker(self, db_session, test_project):
|
||||||
|
"""Test issue creation with external tracker info."""
|
||||||
|
issue_data = IssueCreate(
|
||||||
|
project_id=test_project.id,
|
||||||
|
title="External Issue",
|
||||||
|
external_tracker_type="gitea",
|
||||||
|
external_issue_id="ext-123",
|
||||||
|
remote_url="https://gitea.example.com/issues/123",
|
||||||
|
external_issue_number=123,
|
||||||
|
)
|
||||||
|
created = await issue.create(db_session, obj_in=issue_data)
|
||||||
|
assert created.external_tracker_type == "gitea"
|
||||||
|
assert created.external_issue_id == "ext-123"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_issue_integrity_error(self, db_session, test_project):
|
||||||
|
"""Test issue creation with integrity error."""
|
||||||
|
issue_data = IssueCreate(
|
||||||
|
project_id=test_project.id,
|
||||||
|
title="Test Issue",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock commit to raise IntegrityError
|
||||||
|
mock_orig = MagicMock()
|
||||||
|
mock_orig.__str__ = lambda self: "UNIQUE constraint failed"
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
db_session,
|
||||||
|
"commit",
|
||||||
|
side_effect=IntegrityError("", {}, mock_orig),
|
||||||
|
):
|
||||||
|
with pytest.raises(ValueError, match="Database integrity error"):
|
||||||
|
await issue.create(db_session, obj_in=issue_data)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_issue_unexpected_error(self, db_session, test_project):
|
||||||
|
"""Test issue creation with unexpected error."""
|
||||||
|
issue_data = IssueCreate(
|
||||||
|
project_id=test_project.id,
|
||||||
|
title="Test Issue",
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
db_session,
|
||||||
|
"commit",
|
||||||
|
side_effect=RuntimeError("Unexpected error"),
|
||||||
|
):
|
||||||
|
with pytest.raises(RuntimeError, match="Unexpected error"):
|
||||||
|
await issue.create(db_session, obj_in=issue_data)
|
||||||
|
|
||||||
|
|
||||||
|
class TestIssueGetWithDetails:
|
||||||
|
"""Tests for getting issue with details."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_with_details_not_found(self, db_session):
|
||||||
|
"""Test getting non-existent issue with details."""
|
||||||
|
result = await issue.get_with_details(db_session, issue_id=uuid.uuid4())
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_with_details_success(self, db_session, test_issue):
|
||||||
|
"""Test getting issue with details."""
|
||||||
|
result = await issue.get_with_details(db_session, issue_id=test_issue.id)
|
||||||
|
assert result is not None
|
||||||
|
assert result["issue"].id == test_issue.id
|
||||||
|
assert "project_name" in result
|
||||||
|
assert "project_slug" in result
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_with_details_db_error(self, db_session, test_issue):
|
||||||
|
"""Test getting issue with details when DB error occurs."""
|
||||||
|
with patch.object(
|
||||||
|
db_session,
|
||||||
|
"execute",
|
||||||
|
side_effect=OperationalError("Connection lost", {}, Exception()),
|
||||||
|
):
|
||||||
|
with pytest.raises(OperationalError):
|
||||||
|
await issue.get_with_details(db_session, issue_id=test_issue.id)
|
||||||
|
|
||||||
|
|
||||||
|
class TestIssueGetByProject:
|
||||||
|
"""Tests for getting issues by project."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_by_project_with_filters(
|
||||||
|
self, db_session, test_project, test_issue
|
||||||
|
):
|
||||||
|
"""Test getting issues with various filters."""
|
||||||
|
# Create issue with specific labels
|
||||||
|
issue2 = Issue(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
project_id=test_project.id,
|
||||||
|
title="Filtered Issue",
|
||||||
|
status=IssueStatus.IN_PROGRESS,
|
||||||
|
priority=IssuePriority.HIGH,
|
||||||
|
labels=["frontend"],
|
||||||
|
)
|
||||||
|
db_session.add(issue2)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
# Test status filter
|
||||||
|
issues, _total = await issue.get_by_project(
|
||||||
|
db_session,
|
||||||
|
project_id=test_project.id,
|
||||||
|
status=IssueStatus.IN_PROGRESS,
|
||||||
|
)
|
||||||
|
assert len(issues) == 1
|
||||||
|
assert issues[0].status == IssueStatus.IN_PROGRESS
|
||||||
|
|
||||||
|
# Test priority filter
|
||||||
|
issues, _total = await issue.get_by_project(
|
||||||
|
db_session,
|
||||||
|
project_id=test_project.id,
|
||||||
|
priority=IssuePriority.HIGH,
|
||||||
|
)
|
||||||
|
assert len(issues) == 1
|
||||||
|
assert issues[0].priority == IssuePriority.HIGH
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.skip(
|
||||||
|
reason="Labels filter uses PostgreSQL @> operator, not available in SQLite"
|
||||||
|
)
|
||||||
|
async def test_get_by_project_with_labels_filter(
|
||||||
|
self, db_session, test_project, test_issue
|
||||||
|
):
|
||||||
|
"""Test getting issues filtered by labels."""
|
||||||
|
issues, _total = await issue.get_by_project(
|
||||||
|
db_session,
|
||||||
|
project_id=test_project.id,
|
||||||
|
labels=["bug"],
|
||||||
|
)
|
||||||
|
assert len(issues) == 1
|
||||||
|
assert "bug" in issues[0].labels
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_by_project_sort_order_asc(
|
||||||
|
self, db_session, test_project, test_issue
|
||||||
|
):
|
||||||
|
"""Test getting issues with ascending sort order."""
|
||||||
|
# Create another issue
|
||||||
|
issue2 = Issue(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
project_id=test_project.id,
|
||||||
|
title="Second Issue",
|
||||||
|
status=IssueStatus.OPEN,
|
||||||
|
)
|
||||||
|
db_session.add(issue2)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
issues, _total = await issue.get_by_project(
|
||||||
|
db_session,
|
||||||
|
project_id=test_project.id,
|
||||||
|
sort_by="created_at",
|
||||||
|
sort_order="asc",
|
||||||
|
)
|
||||||
|
assert len(issues) == 2
|
||||||
|
# Compare without timezone info since DB may strip it
|
||||||
|
first_time = (
|
||||||
|
issues[0].created_at.replace(tzinfo=None)
|
||||||
|
if issues[0].created_at.tzinfo
|
||||||
|
else issues[0].created_at
|
||||||
|
)
|
||||||
|
second_time = (
|
||||||
|
issues[1].created_at.replace(tzinfo=None)
|
||||||
|
if issues[1].created_at.tzinfo
|
||||||
|
else issues[1].created_at
|
||||||
|
)
|
||||||
|
assert first_time <= second_time
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_by_project_db_error(self, db_session, test_project):
|
||||||
|
"""Test getting issues when DB error occurs."""
|
||||||
|
with patch.object(
|
||||||
|
db_session,
|
||||||
|
"execute",
|
||||||
|
side_effect=OperationalError("Connection lost", {}, Exception()),
|
||||||
|
):
|
||||||
|
with pytest.raises(OperationalError):
|
||||||
|
await issue.get_by_project(db_session, project_id=test_project.id)
|
||||||
|
|
||||||
|
|
||||||
|
class TestIssueGetBySprint:
|
||||||
|
"""Tests for getting issues by sprint."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_by_sprint_with_status(
|
||||||
|
self, db_session, test_project, test_sprint
|
||||||
|
):
|
||||||
|
"""Test getting issues by sprint with status filter."""
|
||||||
|
# Create issues in sprint
|
||||||
|
issue1 = Issue(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
project_id=test_project.id,
|
||||||
|
sprint_id=test_sprint.id,
|
||||||
|
title="Sprint Issue 1",
|
||||||
|
status=IssueStatus.OPEN,
|
||||||
|
)
|
||||||
|
issue2 = Issue(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
project_id=test_project.id,
|
||||||
|
sprint_id=test_sprint.id,
|
||||||
|
title="Sprint Issue 2",
|
||||||
|
status=IssueStatus.CLOSED,
|
||||||
|
)
|
||||||
|
db_session.add_all([issue1, issue2])
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
# Test status filter
|
||||||
|
issues = await issue.get_by_sprint(
|
||||||
|
db_session,
|
||||||
|
sprint_id=test_sprint.id,
|
||||||
|
status=IssueStatus.OPEN,
|
||||||
|
)
|
||||||
|
assert len(issues) == 1
|
||||||
|
assert issues[0].status == IssueStatus.OPEN
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_by_sprint_db_error(self, db_session, test_sprint):
|
||||||
|
"""Test getting issues by sprint when DB error occurs."""
|
||||||
|
with patch.object(
|
||||||
|
db_session,
|
||||||
|
"execute",
|
||||||
|
side_effect=OperationalError("Connection lost", {}, Exception()),
|
||||||
|
):
|
||||||
|
with pytest.raises(OperationalError):
|
||||||
|
await issue.get_by_sprint(db_session, sprint_id=test_sprint.id)
|
||||||
|
|
||||||
|
|
||||||
|
class TestIssueAssignment:
|
||||||
|
"""Tests for issue assignment operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_assign_to_agent_not_found(self, db_session):
|
||||||
|
"""Test assigning non-existent issue to agent."""
|
||||||
|
result = await issue.assign_to_agent(
|
||||||
|
db_session, issue_id=uuid.uuid4(), agent_id=uuid.uuid4()
|
||||||
|
)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_assign_to_agent_db_error(self, db_session, test_issue):
|
||||||
|
"""Test assigning issue to agent when DB error occurs."""
|
||||||
|
with patch.object(
|
||||||
|
db_session,
|
||||||
|
"commit",
|
||||||
|
side_effect=OperationalError("Connection lost", {}, Exception()),
|
||||||
|
):
|
||||||
|
with pytest.raises(OperationalError):
|
||||||
|
await issue.assign_to_agent(
|
||||||
|
db_session, issue_id=test_issue.id, agent_id=uuid.uuid4()
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_assign_to_human_not_found(self, db_session):
|
||||||
|
"""Test assigning non-existent issue to human."""
|
||||||
|
result = await issue.assign_to_human(
|
||||||
|
db_session, issue_id=uuid.uuid4(), human_assignee="john@example.com"
|
||||||
|
)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_assign_to_human_db_error(self, db_session, test_issue):
|
||||||
|
"""Test assigning issue to human when DB error occurs."""
|
||||||
|
with patch.object(
|
||||||
|
db_session,
|
||||||
|
"commit",
|
||||||
|
side_effect=OperationalError("Connection lost", {}, Exception()),
|
||||||
|
):
|
||||||
|
with pytest.raises(OperationalError):
|
||||||
|
await issue.assign_to_human(
|
||||||
|
db_session,
|
||||||
|
issue_id=test_issue.id,
|
||||||
|
human_assignee="john@example.com",
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unassign_not_found(self, db_session):
|
||||||
|
"""Test unassigning non-existent issue."""
|
||||||
|
result = await issue.unassign(db_session, issue_id=uuid.uuid4())
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unassign_db_error(self, db_session, test_issue):
|
||||||
|
"""Test unassigning issue when DB error occurs."""
|
||||||
|
with patch.object(
|
||||||
|
db_session,
|
||||||
|
"commit",
|
||||||
|
side_effect=OperationalError("Connection lost", {}, Exception()),
|
||||||
|
):
|
||||||
|
with pytest.raises(OperationalError):
|
||||||
|
await issue.unassign(db_session, issue_id=test_issue.id)
|
||||||
|
|
||||||
|
|
||||||
|
class TestIssueStatusChanges:
|
||||||
|
"""Tests for issue status change operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_close_issue_not_found(self, db_session):
|
||||||
|
"""Test closing non-existent issue."""
|
||||||
|
result = await issue.close_issue(db_session, issue_id=uuid.uuid4())
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_close_issue_db_error(self, db_session, test_issue):
|
||||||
|
"""Test closing issue when DB error occurs."""
|
||||||
|
with patch.object(
|
||||||
|
db_session,
|
||||||
|
"commit",
|
||||||
|
side_effect=OperationalError("Connection lost", {}, Exception()),
|
||||||
|
):
|
||||||
|
with pytest.raises(OperationalError):
|
||||||
|
await issue.close_issue(db_session, issue_id=test_issue.id)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reopen_issue_not_found(self, db_session):
|
||||||
|
"""Test reopening non-existent issue."""
|
||||||
|
result = await issue.reopen_issue(db_session, issue_id=uuid.uuid4())
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reopen_issue_db_error(self, db_session, test_issue):
|
||||||
|
"""Test reopening issue when DB error occurs."""
|
||||||
|
with patch.object(
|
||||||
|
db_session,
|
||||||
|
"commit",
|
||||||
|
side_effect=OperationalError("Connection lost", {}, Exception()),
|
||||||
|
):
|
||||||
|
with pytest.raises(OperationalError):
|
||||||
|
await issue.reopen_issue(db_session, issue_id=test_issue.id)
|
||||||
|
|
||||||
|
|
||||||
|
class TestIssueSyncStatus:
|
||||||
|
"""Tests for issue sync status operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_sync_status_not_found(self, db_session):
|
||||||
|
"""Test updating sync status for non-existent issue."""
|
||||||
|
result = await issue.update_sync_status(
|
||||||
|
db_session,
|
||||||
|
issue_id=uuid.uuid4(),
|
||||||
|
sync_status=SyncStatus.SYNCED,
|
||||||
|
)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_sync_status_with_timestamps(self, db_session, test_issue):
|
||||||
|
"""Test updating sync status with timestamps."""
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
result = await issue.update_sync_status(
|
||||||
|
db_session,
|
||||||
|
issue_id=test_issue.id,
|
||||||
|
sync_status=SyncStatus.SYNCED,
|
||||||
|
last_synced_at=now,
|
||||||
|
external_updated_at=now,
|
||||||
|
)
|
||||||
|
assert result is not None
|
||||||
|
assert result.sync_status == SyncStatus.SYNCED
|
||||||
|
# Compare without timezone info since DB may strip it
|
||||||
|
assert result.last_synced_at.replace(tzinfo=None) == now.replace(tzinfo=None)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_sync_status_db_error(self, db_session, test_issue):
|
||||||
|
"""Test updating sync status when DB error occurs."""
|
||||||
|
with patch.object(
|
||||||
|
db_session,
|
||||||
|
"commit",
|
||||||
|
side_effect=OperationalError("Connection lost", {}, Exception()),
|
||||||
|
):
|
||||||
|
with pytest.raises(OperationalError):
|
||||||
|
await issue.update_sync_status(
|
||||||
|
db_session,
|
||||||
|
issue_id=test_issue.id,
|
||||||
|
sync_status=SyncStatus.ERROR,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestIssueStats:
|
||||||
|
"""Tests for issue statistics."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_project_stats(self, db_session, test_project, test_issue):
|
||||||
|
"""Test getting project issue statistics."""
|
||||||
|
stats = await issue.get_project_stats(db_session, project_id=test_project.id)
|
||||||
|
assert stats["total"] >= 1
|
||||||
|
assert "open" in stats
|
||||||
|
assert "by_priority" in stats
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_project_stats_db_error(self, db_session, test_project):
|
||||||
|
"""Test getting project stats when DB error occurs."""
|
||||||
|
with patch.object(
|
||||||
|
db_session,
|
||||||
|
"execute",
|
||||||
|
side_effect=OperationalError("Connection lost", {}, Exception()),
|
||||||
|
):
|
||||||
|
with pytest.raises(OperationalError):
|
||||||
|
await issue.get_project_stats(db_session, project_id=test_project.id)
|
||||||
|
|
||||||
|
|
||||||
|
class TestIssueExternalTracker:
|
||||||
|
"""Tests for external tracker operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_by_external_id_not_found(self, db_session):
|
||||||
|
"""Test getting issue by non-existent external ID."""
|
||||||
|
result = await issue.get_by_external_id(
|
||||||
|
db_session,
|
||||||
|
external_tracker_type="gitea",
|
||||||
|
external_issue_id="nonexistent",
|
||||||
|
)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_by_external_id_success(self, db_session, test_project):
|
||||||
|
"""Test getting issue by external ID."""
|
||||||
|
# Create issue with external tracker
|
||||||
|
issue_obj = Issue(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
project_id=test_project.id,
|
||||||
|
title="External Issue",
|
||||||
|
external_tracker_type="gitea",
|
||||||
|
external_issue_id="ext-456",
|
||||||
|
)
|
||||||
|
db_session.add(issue_obj)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
result = await issue.get_by_external_id(
|
||||||
|
db_session,
|
||||||
|
external_tracker_type="gitea",
|
||||||
|
external_issue_id="ext-456",
|
||||||
|
)
|
||||||
|
assert result is not None
|
||||||
|
assert result.external_issue_id == "ext-456"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_by_external_id_db_error(self, db_session):
|
||||||
|
"""Test getting issue by external ID when DB error occurs."""
|
||||||
|
with patch.object(
|
||||||
|
db_session,
|
||||||
|
"execute",
|
||||||
|
side_effect=OperationalError("Connection lost", {}, Exception()),
|
||||||
|
):
|
||||||
|
with pytest.raises(OperationalError):
|
||||||
|
await issue.get_by_external_id(
|
||||||
|
db_session,
|
||||||
|
external_tracker_type="gitea",
|
||||||
|
external_issue_id="test",
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_pending_sync(self, db_session, test_project):
|
||||||
|
"""Test getting issues pending sync."""
|
||||||
|
# Create issue with pending sync
|
||||||
|
issue_obj = Issue(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
project_id=test_project.id,
|
||||||
|
title="Pending Sync Issue",
|
||||||
|
external_tracker_type="gitea",
|
||||||
|
external_issue_id="ext-789",
|
||||||
|
sync_status=SyncStatus.PENDING,
|
||||||
|
)
|
||||||
|
db_session.add(issue_obj)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
# Test without project filter
|
||||||
|
issues = await issue.get_pending_sync(db_session)
|
||||||
|
assert len(issues) >= 1
|
||||||
|
|
||||||
|
# Test with project filter
|
||||||
|
issues = await issue.get_pending_sync(db_session, project_id=test_project.id)
|
||||||
|
assert len(issues) >= 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_pending_sync_db_error(self, db_session):
|
||||||
|
"""Test getting pending sync issues when DB error occurs."""
|
||||||
|
with patch.object(
|
||||||
|
db_session,
|
||||||
|
"execute",
|
||||||
|
side_effect=OperationalError("Connection lost", {}, Exception()),
|
||||||
|
):
|
||||||
|
with pytest.raises(OperationalError):
|
||||||
|
await issue.get_pending_sync(db_session)
|
||||||
|
|
||||||
|
|
||||||
|
class TestIssueSprintOperations:
|
||||||
|
"""Tests for sprint-related issue operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_remove_sprint_from_issues(
|
||||||
|
self, db_session, test_project, test_sprint
|
||||||
|
):
|
||||||
|
"""Test removing sprint from all issues."""
|
||||||
|
# Create issues in sprint
|
||||||
|
issue1 = Issue(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
project_id=test_project.id,
|
||||||
|
sprint_id=test_sprint.id,
|
||||||
|
title="Sprint Issue 1",
|
||||||
|
)
|
||||||
|
issue2 = Issue(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
project_id=test_project.id,
|
||||||
|
sprint_id=test_sprint.id,
|
||||||
|
title="Sprint Issue 2",
|
||||||
|
)
|
||||||
|
db_session.add_all([issue1, issue2])
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
count = await issue.remove_sprint_from_issues(
|
||||||
|
db_session, sprint_id=test_sprint.id
|
||||||
|
)
|
||||||
|
assert count == 2
|
||||||
|
|
||||||
|
# Verify issues no longer in sprint
|
||||||
|
await db_session.refresh(issue1)
|
||||||
|
await db_session.refresh(issue2)
|
||||||
|
assert issue1.sprint_id is None
|
||||||
|
assert issue2.sprint_id is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_remove_sprint_from_issues_db_error(self, db_session, test_sprint):
|
||||||
|
"""Test removing sprint from issues when DB error occurs."""
|
||||||
|
with patch.object(
|
||||||
|
db_session,
|
||||||
|
"execute",
|
||||||
|
side_effect=OperationalError("Connection lost", {}, Exception()),
|
||||||
|
):
|
||||||
|
with pytest.raises(OperationalError):
|
||||||
|
await issue.remove_sprint_from_issues(
|
||||||
|
db_session, sprint_id=test_sprint.id
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_remove_from_sprint_not_found(self, db_session):
|
||||||
|
"""Test removing non-existent issue from sprint."""
|
||||||
|
result = await issue.remove_from_sprint(db_session, issue_id=uuid.uuid4())
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_remove_from_sprint_success(
|
||||||
|
self, db_session, test_project, test_sprint
|
||||||
|
):
|
||||||
|
"""Test removing issue from sprint."""
|
||||||
|
issue_obj = Issue(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
project_id=test_project.id,
|
||||||
|
sprint_id=test_sprint.id,
|
||||||
|
title="Issue in Sprint",
|
||||||
|
)
|
||||||
|
db_session.add(issue_obj)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
result = await issue.remove_from_sprint(db_session, issue_id=issue_obj.id)
|
||||||
|
assert result is not None
|
||||||
|
assert result.sprint_id is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_remove_from_sprint_db_error(
|
||||||
|
self, db_session, test_project, test_sprint
|
||||||
|
):
|
||||||
|
"""Test removing issue from sprint when DB error occurs."""
|
||||||
|
issue_obj = Issue(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
project_id=test_project.id,
|
||||||
|
sprint_id=test_sprint.id,
|
||||||
|
title="Issue in Sprint",
|
||||||
|
)
|
||||||
|
db_session.add(issue_obj)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
db_session,
|
||||||
|
"commit",
|
||||||
|
side_effect=OperationalError("Connection lost", {}, Exception()),
|
||||||
|
):
|
||||||
|
with pytest.raises(OperationalError):
|
||||||
|
await issue.remove_from_sprint(db_session, issue_id=issue_obj.id)
|
||||||
572
backend/tests/crud/syndarix/test_issue_crud.py
Normal file
572
backend/tests/crud/syndarix/test_issue_crud.py
Normal file
@@ -0,0 +1,572 @@
|
|||||||
|
# tests/crud/syndarix/test_issue_crud.py
|
||||||
|
"""
|
||||||
|
Tests for Issue CRUD operations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.crud.syndarix import issue as issue_crud
|
||||||
|
from app.models.syndarix import IssuePriority, IssueStatus, SyncStatus
|
||||||
|
from app.schemas.syndarix import IssueCreate, IssueUpdate
|
||||||
|
|
||||||
|
|
||||||
|
class TestIssueCreate:
|
||||||
|
"""Tests for issue creation."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_issue_success(self, async_test_db, test_project_crud):
|
||||||
|
"""Test successfully creating an issue."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
issue_data = IssueCreate(
|
||||||
|
project_id=test_project_crud.id,
|
||||||
|
title="Test Issue",
|
||||||
|
body="This is a test issue body",
|
||||||
|
status=IssueStatus.OPEN,
|
||||||
|
priority=IssuePriority.HIGH,
|
||||||
|
labels=["bug", "security"],
|
||||||
|
story_points=5,
|
||||||
|
)
|
||||||
|
result = await issue_crud.create(session, obj_in=issue_data)
|
||||||
|
|
||||||
|
assert result.id is not None
|
||||||
|
assert result.title == "Test Issue"
|
||||||
|
assert result.body == "This is a test issue body"
|
||||||
|
assert result.status == IssueStatus.OPEN
|
||||||
|
assert result.priority == IssuePriority.HIGH
|
||||||
|
assert result.labels == ["bug", "security"]
|
||||||
|
assert result.story_points == 5
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_issue_with_external_tracker(
|
||||||
|
self, async_test_db, test_project_crud
|
||||||
|
):
|
||||||
|
"""Test creating issue with external tracker info."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
issue_data = IssueCreate(
|
||||||
|
project_id=test_project_crud.id,
|
||||||
|
title="External Issue",
|
||||||
|
external_tracker_type="gitea",
|
||||||
|
external_issue_id="gitea-123",
|
||||||
|
remote_url="https://gitea.example.com/issues/123",
|
||||||
|
external_issue_number=123,
|
||||||
|
)
|
||||||
|
result = await issue_crud.create(session, obj_in=issue_data)
|
||||||
|
|
||||||
|
assert result.external_tracker_type == "gitea"
|
||||||
|
assert result.external_issue_id == "gitea-123"
|
||||||
|
assert result.external_issue_number == 123
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_issue_minimal(self, async_test_db, test_project_crud):
|
||||||
|
"""Test creating issue with minimal fields."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
issue_data = IssueCreate(
|
||||||
|
project_id=test_project_crud.id,
|
||||||
|
title="Minimal Issue",
|
||||||
|
)
|
||||||
|
result = await issue_crud.create(session, obj_in=issue_data)
|
||||||
|
|
||||||
|
assert result.title == "Minimal Issue"
|
||||||
|
assert result.body == "" # Default
|
||||||
|
assert result.status == IssueStatus.OPEN # Default
|
||||||
|
assert result.priority == IssuePriority.MEDIUM # Default
|
||||||
|
|
||||||
|
|
||||||
|
class TestIssueRead:
|
||||||
|
"""Tests for issue read operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_issue_by_id(self, async_test_db, test_issue_crud):
|
||||||
|
"""Test getting issue by ID."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await issue_crud.get(session, id=str(test_issue_crud.id))
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.id == test_issue_crud.id
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_issue_by_id_not_found(self, async_test_db):
|
||||||
|
"""Test getting non-existent issue returns None."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await issue_crud.get(session, id=str(uuid.uuid4()))
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_with_details(self, async_test_db, test_issue_crud):
|
||||||
|
"""Test getting issue with related details."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await issue_crud.get_with_details(
|
||||||
|
session,
|
||||||
|
issue_id=test_issue_crud.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result["issue"].id == test_issue_crud.id
|
||||||
|
assert result["project_name"] is not None
|
||||||
|
|
||||||
|
|
||||||
|
class TestIssueUpdate:
|
||||||
|
"""Tests for issue update operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_issue_basic_fields(self, async_test_db, test_issue_crud):
|
||||||
|
"""Test updating basic issue fields."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
issue = await issue_crud.get(session, id=str(test_issue_crud.id))
|
||||||
|
|
||||||
|
update_data = IssueUpdate(
|
||||||
|
title="Updated Title",
|
||||||
|
body="Updated body content",
|
||||||
|
)
|
||||||
|
result = await issue_crud.update(session, db_obj=issue, obj_in=update_data)
|
||||||
|
|
||||||
|
assert result.title == "Updated Title"
|
||||||
|
assert result.body == "Updated body content"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_issue_status(self, async_test_db, test_issue_crud):
|
||||||
|
"""Test updating issue status."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
issue = await issue_crud.get(session, id=str(test_issue_crud.id))
|
||||||
|
|
||||||
|
update_data = IssueUpdate(status=IssueStatus.IN_PROGRESS)
|
||||||
|
result = await issue_crud.update(session, db_obj=issue, obj_in=update_data)
|
||||||
|
|
||||||
|
assert result.status == IssueStatus.IN_PROGRESS
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_issue_priority(self, async_test_db, test_issue_crud):
|
||||||
|
"""Test updating issue priority."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
issue = await issue_crud.get(session, id=str(test_issue_crud.id))
|
||||||
|
|
||||||
|
update_data = IssueUpdate(priority=IssuePriority.CRITICAL)
|
||||||
|
result = await issue_crud.update(session, db_obj=issue, obj_in=update_data)
|
||||||
|
|
||||||
|
assert result.priority == IssuePriority.CRITICAL
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_issue_labels(self, async_test_db, test_issue_crud):
|
||||||
|
"""Test updating issue labels."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
issue = await issue_crud.get(session, id=str(test_issue_crud.id))
|
||||||
|
|
||||||
|
update_data = IssueUpdate(labels=["new-label", "updated"])
|
||||||
|
result = await issue_crud.update(session, db_obj=issue, obj_in=update_data)
|
||||||
|
|
||||||
|
assert "new-label" in result.labels
|
||||||
|
|
||||||
|
|
||||||
|
class TestIssueAssignment:
|
||||||
|
"""Tests for issue assignment operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_assign_to_agent(
|
||||||
|
self, async_test_db, test_issue_crud, test_agent_instance_crud
|
||||||
|
):
|
||||||
|
"""Test assigning issue to an agent."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await issue_crud.assign_to_agent(
|
||||||
|
session,
|
||||||
|
issue_id=test_issue_crud.id,
|
||||||
|
agent_id=test_agent_instance_crud.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.assigned_agent_id == test_agent_instance_crud.id
|
||||||
|
assert result.human_assignee is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unassign_agent(
|
||||||
|
self, async_test_db, test_issue_crud, test_agent_instance_crud
|
||||||
|
):
|
||||||
|
"""Test unassigning agent from issue."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
# First assign
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
await issue_crud.assign_to_agent(
|
||||||
|
session,
|
||||||
|
issue_id=test_issue_crud.id,
|
||||||
|
agent_id=test_agent_instance_crud.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Then unassign
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await issue_crud.assign_to_agent(
|
||||||
|
session,
|
||||||
|
issue_id=test_issue_crud.id,
|
||||||
|
agent_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.assigned_agent_id is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_assign_to_human(self, async_test_db, test_issue_crud):
|
||||||
|
"""Test assigning issue to a human."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await issue_crud.assign_to_human(
|
||||||
|
session,
|
||||||
|
issue_id=test_issue_crud.id,
|
||||||
|
human_assignee="developer@example.com",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.human_assignee == "developer@example.com"
|
||||||
|
assert result.assigned_agent_id is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_assign_to_human_clears_agent(
|
||||||
|
self, async_test_db, test_issue_crud, test_agent_instance_crud
|
||||||
|
):
|
||||||
|
"""Test assigning to human clears agent assignment."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
# First assign to agent
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
await issue_crud.assign_to_agent(
|
||||||
|
session,
|
||||||
|
issue_id=test_issue_crud.id,
|
||||||
|
agent_id=test_agent_instance_crud.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Then assign to human
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await issue_crud.assign_to_human(
|
||||||
|
session,
|
||||||
|
issue_id=test_issue_crud.id,
|
||||||
|
human_assignee="developer@example.com",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.human_assignee == "developer@example.com"
|
||||||
|
assert result.assigned_agent_id is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestIssueLifecycle:
|
||||||
|
"""Tests for issue lifecycle operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_close_issue(self, async_test_db, test_issue_crud):
|
||||||
|
"""Test closing an issue."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await issue_crud.close_issue(session, issue_id=test_issue_crud.id)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.status == IssueStatus.CLOSED
|
||||||
|
assert result.closed_at is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reopen_issue(self, async_test_db, test_project_crud):
|
||||||
|
"""Test reopening a closed issue."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
# Create and close an issue
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
issue_data = IssueCreate(
|
||||||
|
project_id=test_project_crud.id,
|
||||||
|
title="Issue to Reopen",
|
||||||
|
)
|
||||||
|
created = await issue_crud.create(session, obj_in=issue_data)
|
||||||
|
await issue_crud.close_issue(session, issue_id=created.id)
|
||||||
|
issue_id = created.id
|
||||||
|
|
||||||
|
# Reopen
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await issue_crud.reopen_issue(session, issue_id=issue_id)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.status == IssueStatus.OPEN
|
||||||
|
assert result.closed_at is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestIssueByProject:
|
||||||
|
"""Tests for getting issues by project."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_by_project(
|
||||||
|
self, async_test_db, test_project_crud, test_issue_crud
|
||||||
|
):
|
||||||
|
"""Test getting issues by project."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
issues, total = await issue_crud.get_by_project(
|
||||||
|
session,
|
||||||
|
project_id=test_project_crud.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert total >= 1
|
||||||
|
assert all(i.project_id == test_project_crud.id for i in issues)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_by_project_with_status(self, async_test_db, test_project_crud):
|
||||||
|
"""Test filtering issues by status."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
# Create issues with different statuses
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
open_issue = IssueCreate(
|
||||||
|
project_id=test_project_crud.id,
|
||||||
|
title="Open Issue Filter",
|
||||||
|
status=IssueStatus.OPEN,
|
||||||
|
)
|
||||||
|
await issue_crud.create(session, obj_in=open_issue)
|
||||||
|
|
||||||
|
closed_issue = IssueCreate(
|
||||||
|
project_id=test_project_crud.id,
|
||||||
|
title="Closed Issue Filter",
|
||||||
|
status=IssueStatus.CLOSED,
|
||||||
|
)
|
||||||
|
await issue_crud.create(session, obj_in=closed_issue)
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
issues, _ = await issue_crud.get_by_project(
|
||||||
|
session,
|
||||||
|
project_id=test_project_crud.id,
|
||||||
|
status=IssueStatus.OPEN,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert all(i.status == IssueStatus.OPEN for i in issues)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_by_project_with_priority(self, async_test_db, test_project_crud):
|
||||||
|
"""Test filtering issues by priority."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
high_issue = IssueCreate(
|
||||||
|
project_id=test_project_crud.id,
|
||||||
|
title="High Priority Issue",
|
||||||
|
priority=IssuePriority.HIGH,
|
||||||
|
)
|
||||||
|
await issue_crud.create(session, obj_in=high_issue)
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
issues, _ = await issue_crud.get_by_project(
|
||||||
|
session,
|
||||||
|
project_id=test_project_crud.id,
|
||||||
|
priority=IssuePriority.HIGH,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert all(i.priority == IssuePriority.HIGH for i in issues)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_by_project_with_search(self, async_test_db, test_project_crud):
|
||||||
|
"""Test searching issues by title/body."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
searchable_issue = IssueCreate(
|
||||||
|
project_id=test_project_crud.id,
|
||||||
|
title="Searchable Unique Title",
|
||||||
|
body="This body contains searchable content",
|
||||||
|
)
|
||||||
|
await issue_crud.create(session, obj_in=searchable_issue)
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
issues, total = await issue_crud.get_by_project(
|
||||||
|
session,
|
||||||
|
project_id=test_project_crud.id,
|
||||||
|
search="Searchable Unique",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert total >= 1
|
||||||
|
assert any(i.title == "Searchable Unique Title" for i in issues)
|
||||||
|
|
||||||
|
|
||||||
|
class TestIssueBySprint:
|
||||||
|
"""Tests for getting issues by sprint."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_by_sprint(
|
||||||
|
self, async_test_db, test_project_crud, test_sprint_crud
|
||||||
|
):
|
||||||
|
"""Test getting issues by sprint."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
# Create issue in sprint
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
issue_data = IssueCreate(
|
||||||
|
project_id=test_project_crud.id,
|
||||||
|
title="Sprint Issue",
|
||||||
|
sprint_id=test_sprint_crud.id,
|
||||||
|
)
|
||||||
|
await issue_crud.create(session, obj_in=issue_data)
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
issues = await issue_crud.get_by_sprint(
|
||||||
|
session,
|
||||||
|
sprint_id=test_sprint_crud.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(issues) >= 1
|
||||||
|
assert all(i.sprint_id == test_sprint_crud.id for i in issues)
|
||||||
|
|
||||||
|
|
||||||
|
class TestIssueSyncStatus:
|
||||||
|
"""Tests for issue sync status operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_sync_status(self, async_test_db, test_project_crud):
|
||||||
|
"""Test updating issue sync status."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
# Create issue with external tracker
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
issue_data = IssueCreate(
|
||||||
|
project_id=test_project_crud.id,
|
||||||
|
title="Sync Status Issue",
|
||||||
|
external_tracker_type="gitea",
|
||||||
|
external_issue_id="gitea-456",
|
||||||
|
)
|
||||||
|
created = await issue_crud.create(session, obj_in=issue_data)
|
||||||
|
issue_id = created.id
|
||||||
|
|
||||||
|
# Update sync status
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await issue_crud.update_sync_status(
|
||||||
|
session,
|
||||||
|
issue_id=issue_id,
|
||||||
|
sync_status=SyncStatus.PENDING,
|
||||||
|
last_synced_at=now,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.sync_status == SyncStatus.PENDING
|
||||||
|
assert result.last_synced_at is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_pending_sync(self, async_test_db, test_project_crud):
|
||||||
|
"""Test getting issues pending sync."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
# Create issue with pending sync
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
issue_data = IssueCreate(
|
||||||
|
project_id=test_project_crud.id,
|
||||||
|
title="Pending Sync Issue",
|
||||||
|
external_tracker_type="gitea",
|
||||||
|
external_issue_id="gitea-789",
|
||||||
|
)
|
||||||
|
created = await issue_crud.create(session, obj_in=issue_data)
|
||||||
|
|
||||||
|
# Set to pending
|
||||||
|
await issue_crud.update_sync_status(
|
||||||
|
session,
|
||||||
|
issue_id=created.id,
|
||||||
|
sync_status=SyncStatus.PENDING,
|
||||||
|
)
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
issues = await issue_crud.get_pending_sync(session)
|
||||||
|
|
||||||
|
assert any(i.sync_status == SyncStatus.PENDING for i in issues)
|
||||||
|
|
||||||
|
|
||||||
|
class TestIssueExternalTracker:
|
||||||
|
"""Tests for external tracker operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_by_external_id(self, async_test_db, test_project_crud):
|
||||||
|
"""Test getting issue by external tracker ID."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
# Create issue with external ID
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
issue_data = IssueCreate(
|
||||||
|
project_id=test_project_crud.id,
|
||||||
|
title="External ID Issue",
|
||||||
|
external_tracker_type="github",
|
||||||
|
external_issue_id="github-unique-123",
|
||||||
|
)
|
||||||
|
await issue_crud.create(session, obj_in=issue_data)
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await issue_crud.get_by_external_id(
|
||||||
|
session,
|
||||||
|
external_tracker_type="github",
|
||||||
|
external_issue_id="github-unique-123",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.external_issue_id == "github-unique-123"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_by_external_id_not_found(self, async_test_db):
|
||||||
|
"""Test getting non-existent external ID returns None."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await issue_crud.get_by_external_id(
|
||||||
|
session,
|
||||||
|
external_tracker_type="gitea",
|
||||||
|
external_issue_id="non-existent",
|
||||||
|
)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestIssueStats:
|
||||||
|
"""Tests for issue statistics."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_project_stats(self, async_test_db, test_project_crud):
|
||||||
|
"""Test getting issue statistics for a project."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
# Create issues with various statuses and priorities
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
for status in [
|
||||||
|
IssueStatus.OPEN,
|
||||||
|
IssueStatus.IN_PROGRESS,
|
||||||
|
IssueStatus.CLOSED,
|
||||||
|
]:
|
||||||
|
issue_data = IssueCreate(
|
||||||
|
project_id=test_project_crud.id,
|
||||||
|
title=f"Stats Issue {status.value}",
|
||||||
|
status=status,
|
||||||
|
story_points=3,
|
||||||
|
)
|
||||||
|
await issue_crud.create(session, obj_in=issue_data)
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
stats = await issue_crud.get_project_stats(
|
||||||
|
session,
|
||||||
|
project_id=test_project_crud.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "total" in stats
|
||||||
|
assert "open" in stats
|
||||||
|
assert "in_progress" in stats
|
||||||
|
assert "closed" in stats
|
||||||
|
assert "by_priority" in stats
|
||||||
|
assert "total_story_points" in stats
|
||||||
272
backend/tests/crud/syndarix/test_project.py
Normal file
272
backend/tests/crud/syndarix/test_project.py
Normal file
@@ -0,0 +1,272 @@
|
|||||||
|
# tests/crud/syndarix/test_project.py
|
||||||
|
"""Tests for Project CRUD operations."""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from sqlalchemy.exc import IntegrityError, OperationalError
|
||||||
|
|
||||||
|
from app.crud.syndarix.project import project
|
||||||
|
from app.models.syndarix import Project
|
||||||
|
from app.models.syndarix.enums import ProjectStatus
|
||||||
|
from app.schemas.syndarix import ProjectCreate
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def db_session(async_test_db):
|
||||||
|
"""Create a database session for tests."""
|
||||||
|
_, AsyncTestingSessionLocal = async_test_db
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def test_project(db_session):
|
||||||
|
"""Create a test project."""
|
||||||
|
proj = Project(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Test Project",
|
||||||
|
slug=f"test-project-{uuid.uuid4().hex[:8]}",
|
||||||
|
status=ProjectStatus.ACTIVE,
|
||||||
|
)
|
||||||
|
db_session.add(proj)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(proj)
|
||||||
|
return proj
|
||||||
|
|
||||||
|
|
||||||
|
class TestProjectGetBySlug:
|
||||||
|
"""Tests for getting project by slug."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_by_slug_not_found(self, db_session):
|
||||||
|
"""Test getting non-existent project by slug."""
|
||||||
|
result = await project.get_by_slug(db_session, slug="nonexistent")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_by_slug_success(self, db_session, test_project):
|
||||||
|
"""Test successfully getting project by slug."""
|
||||||
|
result = await project.get_by_slug(db_session, slug=test_project.slug)
|
||||||
|
assert result is not None
|
||||||
|
assert result.id == test_project.id
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_by_slug_db_error(self, db_session):
|
||||||
|
"""Test getting project by slug when DB error occurs."""
|
||||||
|
with patch.object(
|
||||||
|
db_session,
|
||||||
|
"execute",
|
||||||
|
side_effect=OperationalError("Connection lost", {}, Exception()),
|
||||||
|
):
|
||||||
|
with pytest.raises(OperationalError):
|
||||||
|
await project.get_by_slug(db_session, slug="test")
|
||||||
|
|
||||||
|
|
||||||
|
class TestProjectCreate:
|
||||||
|
"""Tests for project creation."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_project_success(self, db_session):
|
||||||
|
"""Test successful project creation."""
|
||||||
|
project_data = ProjectCreate(
|
||||||
|
name="New Project",
|
||||||
|
slug=f"new-project-{uuid.uuid4().hex[:8]}",
|
||||||
|
)
|
||||||
|
created = await project.create(db_session, obj_in=project_data)
|
||||||
|
assert created.name == "New Project"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_project_duplicate_slug(self, db_session, test_project):
|
||||||
|
"""Test project creation with duplicate slug."""
|
||||||
|
project_data = ProjectCreate(
|
||||||
|
name="Another Project",
|
||||||
|
slug=test_project.slug, # Use existing slug
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock IntegrityError with slug in the message
|
||||||
|
mock_orig = MagicMock()
|
||||||
|
mock_orig.__str__ = (
|
||||||
|
lambda self: "duplicate key value violates unique constraint on slug"
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
db_session,
|
||||||
|
"commit",
|
||||||
|
side_effect=IntegrityError("", {}, mock_orig),
|
||||||
|
):
|
||||||
|
with pytest.raises(ValueError, match="already exists"):
|
||||||
|
await project.create(db_session, obj_in=project_data)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_project_integrity_error(self, db_session):
|
||||||
|
"""Test project creation with general integrity error."""
|
||||||
|
project_data = ProjectCreate(
|
||||||
|
name="Test Project",
|
||||||
|
slug=f"test-{uuid.uuid4().hex[:8]}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock IntegrityError without slug in the message
|
||||||
|
mock_orig = MagicMock()
|
||||||
|
mock_orig.__str__ = lambda self: "foreign key constraint violation"
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
db_session,
|
||||||
|
"commit",
|
||||||
|
side_effect=IntegrityError("", {}, mock_orig),
|
||||||
|
):
|
||||||
|
with pytest.raises(ValueError, match="Database integrity error"):
|
||||||
|
await project.create(db_session, obj_in=project_data)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_project_unexpected_error(self, db_session):
|
||||||
|
"""Test project creation with unexpected error."""
|
||||||
|
project_data = ProjectCreate(
|
||||||
|
name="Test Project",
|
||||||
|
slug=f"test-{uuid.uuid4().hex[:8]}",
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
db_session,
|
||||||
|
"commit",
|
||||||
|
side_effect=RuntimeError("Unexpected error"),
|
||||||
|
):
|
||||||
|
with pytest.raises(RuntimeError, match="Unexpected error"):
|
||||||
|
await project.create(db_session, obj_in=project_data)
|
||||||
|
|
||||||
|
|
||||||
|
class TestProjectGetMultiWithFilters:
|
||||||
|
"""Tests for getting projects with filters."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_multi_with_filters_success(self, db_session, test_project):
|
||||||
|
"""Test successfully getting projects with filters."""
|
||||||
|
_results, total = await project.get_multi_with_filters(db_session)
|
||||||
|
assert total >= 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_multi_with_filters_db_error(self, db_session):
|
||||||
|
"""Test getting projects when DB error occurs."""
|
||||||
|
with patch.object(
|
||||||
|
db_session,
|
||||||
|
"execute",
|
||||||
|
side_effect=OperationalError("Connection lost", {}, Exception()),
|
||||||
|
):
|
||||||
|
with pytest.raises(OperationalError):
|
||||||
|
await project.get_multi_with_filters(db_session)
|
||||||
|
|
||||||
|
|
||||||
|
class TestProjectGetWithCounts:
|
||||||
|
"""Tests for getting project with counts."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_with_counts_not_found(self, db_session):
|
||||||
|
"""Test getting non-existent project with counts."""
|
||||||
|
result = await project.get_with_counts(db_session, project_id=uuid.uuid4())
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_with_counts_success(self, db_session, test_project):
|
||||||
|
"""Test successfully getting project with counts."""
|
||||||
|
result = await project.get_with_counts(db_session, project_id=test_project.id)
|
||||||
|
assert result is not None
|
||||||
|
assert result["project"].id == test_project.id
|
||||||
|
assert result["agent_count"] == 0
|
||||||
|
assert result["issue_count"] == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_with_counts_db_error(self, db_session, test_project):
|
||||||
|
"""Test getting project with counts when DB error occurs."""
|
||||||
|
with patch.object(
|
||||||
|
db_session,
|
||||||
|
"execute",
|
||||||
|
side_effect=OperationalError("Connection lost", {}, Exception()),
|
||||||
|
):
|
||||||
|
with pytest.raises(OperationalError):
|
||||||
|
await project.get_with_counts(db_session, project_id=test_project.id)
|
||||||
|
|
||||||
|
|
||||||
|
class TestProjectGetMultiWithCounts:
|
||||||
|
"""Tests for getting projects with counts."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_multi_with_counts_empty(self, db_session):
|
||||||
|
"""Test getting projects with counts when none match."""
|
||||||
|
results, total = await project.get_multi_with_counts(
|
||||||
|
db_session,
|
||||||
|
search="nonexistent-xyz-query",
|
||||||
|
)
|
||||||
|
assert results == []
|
||||||
|
assert total == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_multi_with_counts_success(self, db_session, test_project):
|
||||||
|
"""Test successfully getting projects with counts."""
|
||||||
|
results, total = await project.get_multi_with_counts(db_session)
|
||||||
|
assert total >= 1
|
||||||
|
assert len(results) >= 1
|
||||||
|
assert "project" in results[0]
|
||||||
|
assert "agent_count" in results[0]
|
||||||
|
assert "issue_count" in results[0]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_multi_with_counts_db_error(self, db_session, test_project):
|
||||||
|
"""Test getting projects with counts when DB error occurs."""
|
||||||
|
with patch.object(
|
||||||
|
db_session,
|
||||||
|
"execute",
|
||||||
|
side_effect=OperationalError("Connection lost", {}, Exception()),
|
||||||
|
):
|
||||||
|
with pytest.raises(OperationalError):
|
||||||
|
await project.get_multi_with_counts(db_session)
|
||||||
|
|
||||||
|
|
||||||
|
class TestProjectGetByOwner:
|
||||||
|
"""Tests for getting projects by owner."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_projects_by_owner_empty(self, db_session):
|
||||||
|
"""Test getting projects by owner when none exist."""
|
||||||
|
results = await project.get_projects_by_owner(db_session, owner_id=uuid.uuid4())
|
||||||
|
assert results == []
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_projects_by_owner_db_error(self, db_session):
|
||||||
|
"""Test getting projects by owner when DB error occurs."""
|
||||||
|
with patch.object(
|
||||||
|
db_session,
|
||||||
|
"execute",
|
||||||
|
side_effect=OperationalError("Connection lost", {}, Exception()),
|
||||||
|
):
|
||||||
|
with pytest.raises(OperationalError):
|
||||||
|
await project.get_projects_by_owner(db_session, owner_id=uuid.uuid4())
|
||||||
|
|
||||||
|
|
||||||
|
class TestProjectArchive:
|
||||||
|
"""Tests for archiving projects."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_archive_project_not_found(self, db_session):
|
||||||
|
"""Test archiving non-existent project."""
|
||||||
|
result = await project.archive_project(db_session, project_id=uuid.uuid4())
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_archive_project_success(self, db_session, test_project):
|
||||||
|
"""Test successfully archiving project."""
|
||||||
|
result = await project.archive_project(db_session, project_id=test_project.id)
|
||||||
|
assert result is not None
|
||||||
|
assert result.status == ProjectStatus.ARCHIVED
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_archive_project_db_error(self, db_session, test_project):
|
||||||
|
"""Test archiving project when DB error occurs."""
|
||||||
|
with patch.object(
|
||||||
|
db_session,
|
||||||
|
"commit",
|
||||||
|
side_effect=OperationalError("Connection lost", {}, Exception()),
|
||||||
|
):
|
||||||
|
with pytest.raises(OperationalError):
|
||||||
|
await project.archive_project(db_session, project_id=test_project.id)
|
||||||
438
backend/tests/crud/syndarix/test_project_crud.py
Normal file
438
backend/tests/crud/syndarix/test_project_crud.py
Normal file
@@ -0,0 +1,438 @@
|
|||||||
|
# tests/crud/syndarix/test_project_crud.py
|
||||||
|
"""
|
||||||
|
Tests for Project CRUD operations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.crud.syndarix import project as project_crud
|
||||||
|
from app.models.syndarix import AutonomyLevel, ProjectStatus
|
||||||
|
from app.schemas.syndarix import ProjectCreate, ProjectUpdate
|
||||||
|
|
||||||
|
|
||||||
|
class TestProjectCreate:
|
||||||
|
"""Tests for project creation."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_project_success(self, async_test_db, test_owner_crud):
|
||||||
|
"""Test successfully creating a project."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
project_data = ProjectCreate(
|
||||||
|
name="New Project",
|
||||||
|
slug="new-project",
|
||||||
|
description="A brand new project",
|
||||||
|
autonomy_level=AutonomyLevel.MILESTONE,
|
||||||
|
status=ProjectStatus.ACTIVE,
|
||||||
|
settings={"key": "value"},
|
||||||
|
owner_id=test_owner_crud.id,
|
||||||
|
)
|
||||||
|
result = await project_crud.create(session, obj_in=project_data)
|
||||||
|
|
||||||
|
assert result.id is not None
|
||||||
|
assert result.name == "New Project"
|
||||||
|
assert result.slug == "new-project"
|
||||||
|
assert result.description == "A brand new project"
|
||||||
|
assert result.autonomy_level == AutonomyLevel.MILESTONE
|
||||||
|
assert result.status == ProjectStatus.ACTIVE
|
||||||
|
assert result.settings == {"key": "value"}
|
||||||
|
assert result.owner_id == test_owner_crud.id
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_project_duplicate_slug_fails(
|
||||||
|
self, async_test_db, test_project_crud
|
||||||
|
):
|
||||||
|
"""Test creating project with duplicate slug raises ValueError."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
project_data = ProjectCreate(
|
||||||
|
name="Duplicate Project",
|
||||||
|
slug=test_project_crud.slug, # Duplicate slug
|
||||||
|
description="This should fail",
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
await project_crud.create(session, obj_in=project_data)
|
||||||
|
|
||||||
|
assert "already exists" in str(exc_info.value).lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_project_minimal_fields(self, async_test_db):
|
||||||
|
"""Test creating project with minimal required fields."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
project_data = ProjectCreate(
|
||||||
|
name="Minimal Project",
|
||||||
|
slug="minimal-project",
|
||||||
|
)
|
||||||
|
result = await project_crud.create(session, obj_in=project_data)
|
||||||
|
|
||||||
|
assert result.name == "Minimal Project"
|
||||||
|
assert result.slug == "minimal-project"
|
||||||
|
assert result.autonomy_level == AutonomyLevel.MILESTONE # Default
|
||||||
|
assert result.status == ProjectStatus.ACTIVE # Default
|
||||||
|
|
||||||
|
|
||||||
|
class TestProjectRead:
|
||||||
|
"""Tests for project read operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_project_by_id(self, async_test_db, test_project_crud):
|
||||||
|
"""Test getting project by ID."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await project_crud.get(session, id=str(test_project_crud.id))
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.id == test_project_crud.id
|
||||||
|
assert result.name == test_project_crud.name
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_project_by_id_not_found(self, async_test_db):
|
||||||
|
"""Test getting non-existent project returns None."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await project_crud.get(session, id=str(uuid.uuid4()))
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_project_by_slug(self, async_test_db, test_project_crud):
|
||||||
|
"""Test getting project by slug."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await project_crud.get_by_slug(
|
||||||
|
session, slug=test_project_crud.slug
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.slug == test_project_crud.slug
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_project_by_slug_not_found(self, async_test_db):
|
||||||
|
"""Test getting non-existent slug returns None."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await project_crud.get_by_slug(session, slug="non-existent-slug")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestProjectUpdate:
|
||||||
|
"""Tests for project update operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_project_basic_fields(self, async_test_db, test_project_crud):
|
||||||
|
"""Test updating basic project fields."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
project = await project_crud.get(session, id=str(test_project_crud.id))
|
||||||
|
|
||||||
|
update_data = ProjectUpdate(
|
||||||
|
name="Updated Project Name",
|
||||||
|
description="Updated description",
|
||||||
|
)
|
||||||
|
result = await project_crud.update(
|
||||||
|
session, db_obj=project, obj_in=update_data
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.name == "Updated Project Name"
|
||||||
|
assert result.description == "Updated description"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_project_status(self, async_test_db, test_project_crud):
|
||||||
|
"""Test updating project status."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
project = await project_crud.get(session, id=str(test_project_crud.id))
|
||||||
|
|
||||||
|
update_data = ProjectUpdate(status=ProjectStatus.PAUSED)
|
||||||
|
result = await project_crud.update(
|
||||||
|
session, db_obj=project, obj_in=update_data
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.status == ProjectStatus.PAUSED
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_project_autonomy_level(
|
||||||
|
self, async_test_db, test_project_crud
|
||||||
|
):
|
||||||
|
"""Test updating project autonomy level."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
project = await project_crud.get(session, id=str(test_project_crud.id))
|
||||||
|
|
||||||
|
update_data = ProjectUpdate(autonomy_level=AutonomyLevel.AUTONOMOUS)
|
||||||
|
result = await project_crud.update(
|
||||||
|
session, db_obj=project, obj_in=update_data
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.autonomy_level == AutonomyLevel.AUTONOMOUS
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_project_settings(self, async_test_db, test_project_crud):
|
||||||
|
"""Test updating project settings."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
project = await project_crud.get(session, id=str(test_project_crud.id))
|
||||||
|
|
||||||
|
new_settings = {
|
||||||
|
"mcp_servers": ["gitea", "slack"],
|
||||||
|
"webhook_url": "https://example.com",
|
||||||
|
}
|
||||||
|
update_data = ProjectUpdate(settings=new_settings)
|
||||||
|
result = await project_crud.update(
|
||||||
|
session, db_obj=project, obj_in=update_data
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.settings == new_settings
|
||||||
|
|
||||||
|
|
||||||
|
class TestProjectDelete:
|
||||||
|
"""Tests for project delete operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_project(self, async_test_db, test_owner_crud):
|
||||||
|
"""Test deleting a project."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
# Create a project to delete
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
project_data = ProjectCreate(
|
||||||
|
name="Delete Me",
|
||||||
|
slug="delete-me-project",
|
||||||
|
owner_id=test_owner_crud.id,
|
||||||
|
)
|
||||||
|
created = await project_crud.create(session, obj_in=project_data)
|
||||||
|
project_id = created.id
|
||||||
|
|
||||||
|
# Delete the project
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await project_crud.remove(session, id=str(project_id))
|
||||||
|
assert result is not None
|
||||||
|
assert result.id == project_id
|
||||||
|
|
||||||
|
# Verify deletion
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
deleted = await project_crud.get(session, id=str(project_id))
|
||||||
|
assert deleted is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_nonexistent_project(self, async_test_db):
|
||||||
|
"""Test deleting non-existent project returns None."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await project_crud.remove(session, id=str(uuid.uuid4()))
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestProjectFilters:
|
||||||
|
"""Tests for project filtering and search."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_multi_with_filters_status(self, async_test_db, test_owner_crud):
|
||||||
|
"""Test filtering projects by status."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
# Create multiple projects with different statuses
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
for i, status in enumerate(ProjectStatus):
|
||||||
|
project_data = ProjectCreate(
|
||||||
|
name=f"Project {status.value}",
|
||||||
|
slug=f"project-filter-{status.value}-{i}",
|
||||||
|
status=status,
|
||||||
|
owner_id=test_owner_crud.id,
|
||||||
|
)
|
||||||
|
await project_crud.create(session, obj_in=project_data)
|
||||||
|
|
||||||
|
# Filter by ACTIVE status
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
projects, _total = await project_crud.get_multi_with_filters(
|
||||||
|
session,
|
||||||
|
status=ProjectStatus.ACTIVE,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert all(p.status == ProjectStatus.ACTIVE for p in projects)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_multi_with_filters_search(self, async_test_db, test_owner_crud):
|
||||||
|
"""Test searching projects by name/slug."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
project_data = ProjectCreate(
|
||||||
|
name="Searchable Project",
|
||||||
|
slug="searchable-unique-slug",
|
||||||
|
description="This project is searchable",
|
||||||
|
owner_id=test_owner_crud.id,
|
||||||
|
)
|
||||||
|
await project_crud.create(session, obj_in=project_data)
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
projects, total = await project_crud.get_multi_with_filters(
|
||||||
|
session,
|
||||||
|
search="Searchable",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert total >= 1
|
||||||
|
assert any(p.name == "Searchable Project" for p in projects)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_multi_with_filters_owner(
|
||||||
|
self, async_test_db, test_owner_crud, test_project_crud
|
||||||
|
):
|
||||||
|
"""Test filtering projects by owner."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
projects, total = await project_crud.get_multi_with_filters(
|
||||||
|
session,
|
||||||
|
owner_id=test_owner_crud.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert total >= 1
|
||||||
|
assert all(p.owner_id == test_owner_crud.id for p in projects)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_multi_with_filters_pagination(
|
||||||
|
self, async_test_db, test_owner_crud
|
||||||
|
):
|
||||||
|
"""Test pagination of project results."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
# Create multiple projects
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
for i in range(5):
|
||||||
|
project_data = ProjectCreate(
|
||||||
|
name=f"Page Project {i}",
|
||||||
|
slug=f"page-project-{i}",
|
||||||
|
owner_id=test_owner_crud.id,
|
||||||
|
)
|
||||||
|
await project_crud.create(session, obj_in=project_data)
|
||||||
|
|
||||||
|
# Get first page
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
page1, total = await project_crud.get_multi_with_filters(
|
||||||
|
session,
|
||||||
|
skip=0,
|
||||||
|
limit=2,
|
||||||
|
owner_id=test_owner_crud.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(page1) <= 2
|
||||||
|
assert total >= 5
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_multi_with_filters_sorting(self, async_test_db, test_owner_crud):
|
||||||
|
"""Test sorting project results."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
for _i, name in enumerate(["Charlie", "Alice", "Bob"]):
|
||||||
|
project_data = ProjectCreate(
|
||||||
|
name=name,
|
||||||
|
slug=f"sort-project-{name.lower()}",
|
||||||
|
owner_id=test_owner_crud.id,
|
||||||
|
)
|
||||||
|
await project_crud.create(session, obj_in=project_data)
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
projects, _ = await project_crud.get_multi_with_filters(
|
||||||
|
session,
|
||||||
|
sort_by="name",
|
||||||
|
sort_order="asc",
|
||||||
|
owner_id=test_owner_crud.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
names = [p.name for p in projects if p.name in ["Alice", "Bob", "Charlie"]]
|
||||||
|
assert names == sorted(names)
|
||||||
|
|
||||||
|
|
||||||
|
class TestProjectSpecialMethods:
|
||||||
|
"""Tests for special project CRUD methods."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_archive_project(self, async_test_db, test_project_crud):
|
||||||
|
"""Test archiving a project."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await project_crud.archive_project(
|
||||||
|
session, project_id=test_project_crud.id
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.status == ProjectStatus.ARCHIVED
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_archive_nonexistent_project(self, async_test_db):
|
||||||
|
"""Test archiving non-existent project returns None."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await project_crud.archive_project(
|
||||||
|
session, project_id=uuid.uuid4()
|
||||||
|
)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_projects_by_owner(
|
||||||
|
self, async_test_db, test_owner_crud, test_project_crud
|
||||||
|
):
|
||||||
|
"""Test getting all projects by owner."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
projects = await project_crud.get_projects_by_owner(
|
||||||
|
session,
|
||||||
|
owner_id=test_owner_crud.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(projects) >= 1
|
||||||
|
assert all(p.owner_id == test_owner_crud.id for p in projects)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_projects_by_owner_with_status(
|
||||||
|
self, async_test_db, test_owner_crud
|
||||||
|
):
|
||||||
|
"""Test getting projects by owner filtered by status."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
# Create projects with different statuses
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
active_project = ProjectCreate(
|
||||||
|
name="Active Owner Project",
|
||||||
|
slug="active-owner-project",
|
||||||
|
status=ProjectStatus.ACTIVE,
|
||||||
|
owner_id=test_owner_crud.id,
|
||||||
|
)
|
||||||
|
await project_crud.create(session, obj_in=active_project)
|
||||||
|
|
||||||
|
paused_project = ProjectCreate(
|
||||||
|
name="Paused Owner Project",
|
||||||
|
slug="paused-owner-project",
|
||||||
|
status=ProjectStatus.PAUSED,
|
||||||
|
owner_id=test_owner_crud.id,
|
||||||
|
)
|
||||||
|
await project_crud.create(session, obj_in=paused_project)
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
projects = await project_crud.get_projects_by_owner(
|
||||||
|
session,
|
||||||
|
owner_id=test_owner_crud.id,
|
||||||
|
status=ProjectStatus.ACTIVE,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert all(p.status == ProjectStatus.ACTIVE for p in projects)
|
||||||
502
backend/tests/crud/syndarix/test_sprint.py
Normal file
502
backend/tests/crud/syndarix/test_sprint.py
Normal file
@@ -0,0 +1,502 @@
|
|||||||
|
# tests/crud/syndarix/test_sprint.py
|
||||||
|
"""Tests for Sprint CRUD operations."""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from datetime import date, timedelta
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from sqlalchemy.exc import IntegrityError, OperationalError
|
||||||
|
|
||||||
|
from app.crud.syndarix.sprint import sprint
|
||||||
|
from app.models.syndarix import Issue, Project, Sprint
|
||||||
|
from app.models.syndarix.enums import (
|
||||||
|
IssueStatus,
|
||||||
|
ProjectStatus,
|
||||||
|
SprintStatus,
|
||||||
|
)
|
||||||
|
from app.schemas.syndarix import SprintCreate
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def db_session(async_test_db):
|
||||||
|
"""Create a database session for tests."""
|
||||||
|
_, AsyncTestingSessionLocal = async_test_db
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def test_project(db_session):
|
||||||
|
"""Create a test project for sprints."""
|
||||||
|
project = Project(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Test Project",
|
||||||
|
slug=f"test-project-{uuid.uuid4().hex[:8]}",
|
||||||
|
status=ProjectStatus.ACTIVE,
|
||||||
|
)
|
||||||
|
db_session.add(project)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(project)
|
||||||
|
return project
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def test_sprint(db_session, test_project):
|
||||||
|
"""Create a test sprint."""
|
||||||
|
sprint_obj = Sprint(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
project_id=test_project.id,
|
||||||
|
name="Test Sprint",
|
||||||
|
number=1,
|
||||||
|
status=SprintStatus.PLANNED,
|
||||||
|
start_date=date.today(),
|
||||||
|
end_date=date.today() + timedelta(days=14),
|
||||||
|
)
|
||||||
|
db_session.add(sprint_obj)
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(sprint_obj)
|
||||||
|
return sprint_obj
|
||||||
|
|
||||||
|
|
||||||
|
class TestSprintCreate:
|
||||||
|
"""Tests for sprint creation."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_sprint_success(self, db_session, test_project):
|
||||||
|
"""Test successful sprint creation."""
|
||||||
|
sprint_data = SprintCreate(
|
||||||
|
project_id=test_project.id,
|
||||||
|
name="New Sprint",
|
||||||
|
number=1,
|
||||||
|
start_date=date.today(),
|
||||||
|
end_date=date.today() + timedelta(days=14),
|
||||||
|
)
|
||||||
|
created = await sprint.create(db_session, obj_in=sprint_data)
|
||||||
|
assert created.name == "New Sprint"
|
||||||
|
assert created.number == 1
|
||||||
|
assert created.status == SprintStatus.PLANNED
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_sprint_with_all_fields(self, db_session, test_project):
|
||||||
|
"""Test sprint creation with all optional fields."""
|
||||||
|
sprint_data = SprintCreate(
|
||||||
|
project_id=test_project.id,
|
||||||
|
name="Full Sprint",
|
||||||
|
number=2,
|
||||||
|
goal="Deliver user authentication",
|
||||||
|
start_date=date.today(),
|
||||||
|
end_date=date.today() + timedelta(days=14),
|
||||||
|
status=SprintStatus.PLANNED,
|
||||||
|
planned_points=20,
|
||||||
|
velocity=15,
|
||||||
|
)
|
||||||
|
created = await sprint.create(db_session, obj_in=sprint_data)
|
||||||
|
assert created.goal == "Deliver user authentication"
|
||||||
|
assert created.planned_points == 20
|
||||||
|
assert created.velocity == 15
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_sprint_integrity_error(self, db_session, test_project):
|
||||||
|
"""Test sprint creation with integrity error."""
|
||||||
|
sprint_data = SprintCreate(
|
||||||
|
project_id=test_project.id,
|
||||||
|
name="Test Sprint",
|
||||||
|
number=1,
|
||||||
|
start_date=date.today(),
|
||||||
|
end_date=date.today() + timedelta(days=14),
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
db_session,
|
||||||
|
"commit",
|
||||||
|
side_effect=IntegrityError("", {}, Exception()),
|
||||||
|
):
|
||||||
|
with pytest.raises(ValueError, match="Database integrity error"):
|
||||||
|
await sprint.create(db_session, obj_in=sprint_data)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_sprint_unexpected_error(self, db_session, test_project):
|
||||||
|
"""Test sprint creation with unexpected error."""
|
||||||
|
sprint_data = SprintCreate(
|
||||||
|
project_id=test_project.id,
|
||||||
|
name="Test Sprint",
|
||||||
|
number=1,
|
||||||
|
start_date=date.today(),
|
||||||
|
end_date=date.today() + timedelta(days=14),
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
db_session,
|
||||||
|
"commit",
|
||||||
|
side_effect=RuntimeError("Unexpected error"),
|
||||||
|
):
|
||||||
|
with pytest.raises(RuntimeError, match="Unexpected error"):
|
||||||
|
await sprint.create(db_session, obj_in=sprint_data)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSprintGetWithDetails:
|
||||||
|
"""Tests for getting sprint with details."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_with_details_not_found(self, db_session):
|
||||||
|
"""Test getting non-existent sprint with details."""
|
||||||
|
result = await sprint.get_with_details(db_session, sprint_id=uuid.uuid4())
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_with_details_success(self, db_session, test_sprint):
|
||||||
|
"""Test getting sprint with details."""
|
||||||
|
result = await sprint.get_with_details(db_session, sprint_id=test_sprint.id)
|
||||||
|
assert result is not None
|
||||||
|
assert result["sprint"].id == test_sprint.id
|
||||||
|
assert "project_name" in result
|
||||||
|
assert "issue_count" in result
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_with_details_db_error(self, db_session, test_sprint):
|
||||||
|
"""Test getting sprint with details when DB error occurs."""
|
||||||
|
with patch.object(
|
||||||
|
db_session,
|
||||||
|
"execute",
|
||||||
|
side_effect=OperationalError("Connection lost", {}, Exception()),
|
||||||
|
):
|
||||||
|
with pytest.raises(OperationalError):
|
||||||
|
await sprint.get_with_details(db_session, sprint_id=test_sprint.id)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSprintGetByProject:
|
||||||
|
"""Tests for getting sprints by project."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_by_project_with_status_filter(
|
||||||
|
self, db_session, test_project, test_sprint
|
||||||
|
):
|
||||||
|
"""Test getting sprints with status filter."""
|
||||||
|
sprints, _total = await sprint.get_by_project(
|
||||||
|
db_session,
|
||||||
|
project_id=test_project.id,
|
||||||
|
status=SprintStatus.PLANNED,
|
||||||
|
)
|
||||||
|
assert len(sprints) == 1
|
||||||
|
assert sprints[0].status == SprintStatus.PLANNED
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_by_project_db_error(self, db_session, test_project):
|
||||||
|
"""Test getting sprints when DB error occurs."""
|
||||||
|
with patch.object(
|
||||||
|
db_session,
|
||||||
|
"execute",
|
||||||
|
side_effect=OperationalError("Connection lost", {}, Exception()),
|
||||||
|
):
|
||||||
|
with pytest.raises(OperationalError):
|
||||||
|
await sprint.get_by_project(db_session, project_id=test_project.id)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSprintActiveOperations:
|
||||||
|
"""Tests for active sprint operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_active_sprint_none(self, db_session, test_project, test_sprint):
|
||||||
|
"""Test getting active sprint when none exists."""
|
||||||
|
result = await sprint.get_active_sprint(db_session, project_id=test_project.id)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_active_sprint_db_error(self, db_session, test_project):
|
||||||
|
"""Test getting active sprint when DB error occurs."""
|
||||||
|
with patch.object(
|
||||||
|
db_session,
|
||||||
|
"execute",
|
||||||
|
side_effect=OperationalError("Connection lost", {}, Exception()),
|
||||||
|
):
|
||||||
|
with pytest.raises(OperationalError):
|
||||||
|
await sprint.get_active_sprint(db_session, project_id=test_project.id)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSprintNumberOperations:
|
||||||
|
"""Tests for sprint number operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_next_sprint_number_empty(self, db_session, test_project):
|
||||||
|
"""Test getting next sprint number for project with no sprints."""
|
||||||
|
result = await sprint.get_next_sprint_number(
|
||||||
|
db_session, project_id=test_project.id
|
||||||
|
)
|
||||||
|
assert result == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_next_sprint_number_with_existing(
|
||||||
|
self, db_session, test_project, test_sprint
|
||||||
|
):
|
||||||
|
"""Test getting next sprint number with existing sprints."""
|
||||||
|
result = await sprint.get_next_sprint_number(
|
||||||
|
db_session, project_id=test_project.id
|
||||||
|
)
|
||||||
|
assert result == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_next_sprint_number_db_error(self, db_session, test_project):
|
||||||
|
"""Test getting next sprint number when DB error occurs."""
|
||||||
|
with patch.object(
|
||||||
|
db_session,
|
||||||
|
"execute",
|
||||||
|
side_effect=OperationalError("Connection lost", {}, Exception()),
|
||||||
|
):
|
||||||
|
with pytest.raises(OperationalError):
|
||||||
|
await sprint.get_next_sprint_number(
|
||||||
|
db_session, project_id=test_project.id
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSprintLifecycle:
|
||||||
|
"""Tests for sprint lifecycle operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_start_sprint_not_found(self, db_session):
|
||||||
|
"""Test starting non-existent sprint."""
|
||||||
|
result = await sprint.start_sprint(db_session, sprint_id=uuid.uuid4())
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_start_sprint_invalid_status(self, db_session, test_project):
|
||||||
|
"""Test starting sprint with invalid status."""
|
||||||
|
# Create an active sprint
|
||||||
|
active_sprint = Sprint(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
project_id=test_project.id,
|
||||||
|
name="Active Sprint",
|
||||||
|
number=1,
|
||||||
|
status=SprintStatus.ACTIVE,
|
||||||
|
start_date=date.today(),
|
||||||
|
end_date=date.today() + timedelta(days=14),
|
||||||
|
)
|
||||||
|
db_session.add(active_sprint)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Cannot start sprint with status"):
|
||||||
|
await sprint.start_sprint(db_session, sprint_id=active_sprint.id)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_start_sprint_with_existing_active(self, db_session, test_project):
|
||||||
|
"""Test starting sprint when another is already active."""
|
||||||
|
# Create active sprint
|
||||||
|
active_sprint = Sprint(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
project_id=test_project.id,
|
||||||
|
name="Active Sprint",
|
||||||
|
number=1,
|
||||||
|
status=SprintStatus.ACTIVE,
|
||||||
|
start_date=date.today(),
|
||||||
|
end_date=date.today() + timedelta(days=14),
|
||||||
|
)
|
||||||
|
db_session.add(active_sprint)
|
||||||
|
|
||||||
|
# Create planned sprint
|
||||||
|
planned_sprint = Sprint(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
project_id=test_project.id,
|
||||||
|
name="Planned Sprint",
|
||||||
|
number=2,
|
||||||
|
status=SprintStatus.PLANNED,
|
||||||
|
start_date=date.today() + timedelta(days=15),
|
||||||
|
end_date=date.today() + timedelta(days=29),
|
||||||
|
)
|
||||||
|
db_session.add(planned_sprint)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Project already has an active sprint"):
|
||||||
|
await sprint.start_sprint(db_session, sprint_id=planned_sprint.id)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_start_sprint_db_error(self, db_session, test_sprint):
|
||||||
|
"""Test starting sprint when DB error occurs."""
|
||||||
|
with patch.object(
|
||||||
|
db_session,
|
||||||
|
"commit",
|
||||||
|
side_effect=OperationalError("Connection lost", {}, Exception()),
|
||||||
|
):
|
||||||
|
with pytest.raises(OperationalError):
|
||||||
|
await sprint.start_sprint(db_session, sprint_id=test_sprint.id)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_complete_sprint_not_found(self, db_session):
|
||||||
|
"""Test completing non-existent sprint."""
|
||||||
|
result = await sprint.complete_sprint(db_session, sprint_id=uuid.uuid4())
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_complete_sprint_invalid_status(self, db_session, test_sprint):
|
||||||
|
"""Test completing sprint with invalid status (PLANNED)."""
|
||||||
|
with pytest.raises(ValueError, match="Cannot complete sprint with status"):
|
||||||
|
await sprint.complete_sprint(db_session, sprint_id=test_sprint.id)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_complete_sprint_db_error(self, db_session, test_project):
|
||||||
|
"""Test completing sprint when DB error occurs."""
|
||||||
|
# Create active sprint
|
||||||
|
active_sprint = Sprint(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
project_id=test_project.id,
|
||||||
|
name="Active Sprint",
|
||||||
|
number=1,
|
||||||
|
status=SprintStatus.ACTIVE,
|
||||||
|
start_date=date.today(),
|
||||||
|
end_date=date.today() + timedelta(days=14),
|
||||||
|
)
|
||||||
|
db_session.add(active_sprint)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
db_session,
|
||||||
|
"commit",
|
||||||
|
side_effect=OperationalError("Connection lost", {}, Exception()),
|
||||||
|
):
|
||||||
|
with pytest.raises(OperationalError):
|
||||||
|
await sprint.complete_sprint(db_session, sprint_id=active_sprint.id)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cancel_sprint_not_found(self, db_session):
|
||||||
|
"""Test cancelling non-existent sprint."""
|
||||||
|
result = await sprint.cancel_sprint(db_session, sprint_id=uuid.uuid4())
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cancel_sprint_invalid_status(self, db_session, test_project):
|
||||||
|
"""Test cancelling sprint with invalid status (COMPLETED)."""
|
||||||
|
completed_sprint = Sprint(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
project_id=test_project.id,
|
||||||
|
name="Completed Sprint",
|
||||||
|
number=1,
|
||||||
|
status=SprintStatus.COMPLETED,
|
||||||
|
start_date=date.today() - timedelta(days=14),
|
||||||
|
end_date=date.today(),
|
||||||
|
)
|
||||||
|
db_session.add(completed_sprint)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Cannot cancel sprint with status"):
|
||||||
|
await sprint.cancel_sprint(db_session, sprint_id=completed_sprint.id)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cancel_sprint_success(self, db_session, test_sprint):
|
||||||
|
"""Test successfully cancelling a planned sprint."""
|
||||||
|
result = await sprint.cancel_sprint(db_session, sprint_id=test_sprint.id)
|
||||||
|
assert result is not None
|
||||||
|
assert result.status == SprintStatus.CANCELLED
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cancel_sprint_db_error(self, db_session, test_sprint):
|
||||||
|
"""Test cancelling sprint when DB error occurs."""
|
||||||
|
with patch.object(
|
||||||
|
db_session,
|
||||||
|
"commit",
|
||||||
|
side_effect=OperationalError("Connection lost", {}, Exception()),
|
||||||
|
):
|
||||||
|
with pytest.raises(OperationalError):
|
||||||
|
await sprint.cancel_sprint(db_session, sprint_id=test_sprint.id)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSprintVelocity:
|
||||||
|
"""Tests for velocity operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_velocity_empty(self, db_session, test_project):
|
||||||
|
"""Test getting velocity with no completed sprints."""
|
||||||
|
result = await sprint.get_velocity(db_session, project_id=test_project.id)
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_velocity_with_data(self, db_session, test_project):
|
||||||
|
"""Test getting velocity with completed sprints."""
|
||||||
|
completed_sprint = Sprint(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
project_id=test_project.id,
|
||||||
|
name="Completed Sprint",
|
||||||
|
number=1,
|
||||||
|
status=SprintStatus.COMPLETED,
|
||||||
|
start_date=date.today() - timedelta(days=14),
|
||||||
|
end_date=date.today(),
|
||||||
|
planned_points=20,
|
||||||
|
velocity=18,
|
||||||
|
)
|
||||||
|
db_session.add(completed_sprint)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
result = await sprint.get_velocity(db_session, project_id=test_project.id)
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["sprint_number"] == 1
|
||||||
|
assert result[0]["velocity"] == 18
|
||||||
|
assert result[0]["velocity_ratio"] == 0.9
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_velocity_db_error(self, db_session, test_project):
|
||||||
|
"""Test getting velocity when DB error occurs."""
|
||||||
|
with patch.object(
|
||||||
|
db_session,
|
||||||
|
"execute",
|
||||||
|
side_effect=OperationalError("Connection lost", {}, Exception()),
|
||||||
|
):
|
||||||
|
with pytest.raises(OperationalError):
|
||||||
|
await sprint.get_velocity(db_session, project_id=test_project.id)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSprintWithIssueCounts:
|
||||||
|
"""Tests for sprints with issue counts."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_sprints_with_issue_counts_empty(self, db_session, test_project):
|
||||||
|
"""Test getting sprints with issue counts when no sprints exist."""
|
||||||
|
results, total = await sprint.get_sprints_with_issue_counts(
|
||||||
|
db_session, project_id=test_project.id
|
||||||
|
)
|
||||||
|
assert results == []
|
||||||
|
assert total == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_sprints_with_issue_counts_success(
|
||||||
|
self, db_session, test_project, test_sprint
|
||||||
|
):
|
||||||
|
"""Test getting sprints with issue counts."""
|
||||||
|
# Add some issues to the sprint
|
||||||
|
issue1 = Issue(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
project_id=test_project.id,
|
||||||
|
sprint_id=test_sprint.id,
|
||||||
|
title="Issue 1",
|
||||||
|
status=IssueStatus.OPEN,
|
||||||
|
)
|
||||||
|
issue2 = Issue(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
project_id=test_project.id,
|
||||||
|
sprint_id=test_sprint.id,
|
||||||
|
title="Issue 2",
|
||||||
|
status=IssueStatus.CLOSED,
|
||||||
|
)
|
||||||
|
db_session.add_all([issue1, issue2])
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
results, _total = await sprint.get_sprints_with_issue_counts(
|
||||||
|
db_session, project_id=test_project.id
|
||||||
|
)
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0]["issue_count"] == 2
|
||||||
|
assert results[0]["open_issues"] == 1
|
||||||
|
assert results[0]["completed_issues"] == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_sprints_with_issue_counts_db_error(
|
||||||
|
self, db_session, test_project, test_sprint
|
||||||
|
):
|
||||||
|
"""Test getting sprints with issue counts when DB error occurs."""
|
||||||
|
with patch.object(
|
||||||
|
db_session,
|
||||||
|
"execute",
|
||||||
|
side_effect=OperationalError("Connection lost", {}, Exception()),
|
||||||
|
):
|
||||||
|
with pytest.raises(OperationalError):
|
||||||
|
await sprint.get_sprints_with_issue_counts(
|
||||||
|
db_session, project_id=test_project.id
|
||||||
|
)
|
||||||
540
backend/tests/crud/syndarix/test_sprint_crud.py
Normal file
540
backend/tests/crud/syndarix/test_sprint_crud.py
Normal file
@@ -0,0 +1,540 @@
|
|||||||
|
# tests/crud/syndarix/test_sprint_crud.py
|
||||||
|
"""
|
||||||
|
Tests for Sprint CRUD operations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from datetime import date, timedelta
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.crud.syndarix import sprint as sprint_crud
|
||||||
|
from app.models.syndarix import SprintStatus
|
||||||
|
from app.schemas.syndarix import SprintCreate, SprintUpdate
|
||||||
|
|
||||||
|
|
||||||
|
class TestSprintCreate:
|
||||||
|
"""Tests for sprint creation."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_sprint_success(self, async_test_db, test_project_crud):
|
||||||
|
"""Test successfully creating a sprint."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
today = date.today()
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
sprint_data = SprintCreate(
|
||||||
|
project_id=test_project_crud.id,
|
||||||
|
name="Sprint 1",
|
||||||
|
number=1,
|
||||||
|
goal="Complete initial setup",
|
||||||
|
start_date=today,
|
||||||
|
end_date=today + timedelta(days=14),
|
||||||
|
status=SprintStatus.PLANNED,
|
||||||
|
planned_points=21,
|
||||||
|
)
|
||||||
|
result = await sprint_crud.create(session, obj_in=sprint_data)
|
||||||
|
|
||||||
|
assert result.id is not None
|
||||||
|
assert result.name == "Sprint 1"
|
||||||
|
assert result.number == 1
|
||||||
|
assert result.goal == "Complete initial setup"
|
||||||
|
assert result.status == SprintStatus.PLANNED
|
||||||
|
assert result.planned_points == 21
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_sprint_minimal(self, async_test_db, test_project_crud):
|
||||||
|
"""Test creating sprint with minimal fields."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
today = date.today()
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
sprint_data = SprintCreate(
|
||||||
|
project_id=test_project_crud.id,
|
||||||
|
name="Minimal Sprint",
|
||||||
|
number=1,
|
||||||
|
start_date=today,
|
||||||
|
end_date=today + timedelta(days=14),
|
||||||
|
)
|
||||||
|
result = await sprint_crud.create(session, obj_in=sprint_data)
|
||||||
|
|
||||||
|
assert result.name == "Minimal Sprint"
|
||||||
|
assert result.status == SprintStatus.PLANNED # Default
|
||||||
|
assert result.goal is None
|
||||||
|
assert result.planned_points is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestSprintRead:
|
||||||
|
"""Tests for sprint read operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_sprint_by_id(self, async_test_db, test_sprint_crud):
|
||||||
|
"""Test getting sprint by ID."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await sprint_crud.get(session, id=str(test_sprint_crud.id))
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.id == test_sprint_crud.id
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_sprint_by_id_not_found(self, async_test_db):
|
||||||
|
"""Test getting non-existent sprint returns None."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await sprint_crud.get(session, id=str(uuid.uuid4()))
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_with_details(self, async_test_db, test_sprint_crud):
|
||||||
|
"""Test getting sprint with related details."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await sprint_crud.get_with_details(
|
||||||
|
session,
|
||||||
|
sprint_id=test_sprint_crud.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result["sprint"].id == test_sprint_crud.id
|
||||||
|
assert result["project_name"] is not None
|
||||||
|
assert "issue_count" in result
|
||||||
|
assert "open_issues" in result
|
||||||
|
assert "completed_issues" in result
|
||||||
|
|
||||||
|
|
||||||
|
class TestSprintUpdate:
|
||||||
|
"""Tests for sprint update operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_sprint_basic_fields(self, async_test_db, test_sprint_crud):
|
||||||
|
"""Test updating basic sprint fields."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
sprint = await sprint_crud.get(session, id=str(test_sprint_crud.id))
|
||||||
|
|
||||||
|
update_data = SprintUpdate(
|
||||||
|
name="Updated Sprint Name",
|
||||||
|
goal="Updated goal",
|
||||||
|
)
|
||||||
|
result = await sprint_crud.update(
|
||||||
|
session, db_obj=sprint, obj_in=update_data
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.name == "Updated Sprint Name"
|
||||||
|
assert result.goal == "Updated goal"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_sprint_dates(self, async_test_db, test_sprint_crud):
|
||||||
|
"""Test updating sprint dates."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
today = date.today()
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
sprint = await sprint_crud.get(session, id=str(test_sprint_crud.id))
|
||||||
|
|
||||||
|
update_data = SprintUpdate(
|
||||||
|
start_date=today + timedelta(days=1),
|
||||||
|
end_date=today + timedelta(days=21),
|
||||||
|
)
|
||||||
|
result = await sprint_crud.update(
|
||||||
|
session, db_obj=sprint, obj_in=update_data
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.start_date == today + timedelta(days=1)
|
||||||
|
assert result.end_date == today + timedelta(days=21)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSprintLifecycle:
|
||||||
|
"""Tests for sprint lifecycle operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_start_sprint(self, async_test_db, test_sprint_crud):
|
||||||
|
"""Test starting a planned sprint."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await sprint_crud.start_sprint(
|
||||||
|
session,
|
||||||
|
sprint_id=test_sprint_crud.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.status == SprintStatus.ACTIVE
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_start_sprint_with_custom_date(
|
||||||
|
self, async_test_db, test_project_crud
|
||||||
|
):
|
||||||
|
"""Test starting sprint with custom start date."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
today = date.today()
|
||||||
|
|
||||||
|
# Create a planned sprint
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
sprint_data = SprintCreate(
|
||||||
|
project_id=test_project_crud.id,
|
||||||
|
name="Start Date Sprint",
|
||||||
|
number=10,
|
||||||
|
start_date=today,
|
||||||
|
end_date=today + timedelta(days=14),
|
||||||
|
status=SprintStatus.PLANNED,
|
||||||
|
)
|
||||||
|
created = await sprint_crud.create(session, obj_in=sprint_data)
|
||||||
|
sprint_id = created.id
|
||||||
|
|
||||||
|
# Start with custom date
|
||||||
|
new_start = today + timedelta(days=2)
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await sprint_crud.start_sprint(
|
||||||
|
session,
|
||||||
|
sprint_id=sprint_id,
|
||||||
|
start_date=new_start,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.status == SprintStatus.ACTIVE
|
||||||
|
assert result.start_date == new_start
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_start_sprint_already_active_fails(
|
||||||
|
self, async_test_db, test_project_crud
|
||||||
|
):
|
||||||
|
"""Test starting an already active sprint raises ValueError."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
today = date.today()
|
||||||
|
|
||||||
|
# Create and start a sprint
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
sprint_data = SprintCreate(
|
||||||
|
project_id=test_project_crud.id,
|
||||||
|
name="Already Active Sprint",
|
||||||
|
number=20,
|
||||||
|
start_date=today,
|
||||||
|
end_date=today + timedelta(days=14),
|
||||||
|
status=SprintStatus.ACTIVE,
|
||||||
|
)
|
||||||
|
created = await sprint_crud.create(session, obj_in=sprint_data)
|
||||||
|
sprint_id = created.id
|
||||||
|
|
||||||
|
# Try to start again
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
await sprint_crud.start_sprint(session, sprint_id=sprint_id)
|
||||||
|
|
||||||
|
assert "cannot start sprint" in str(exc_info.value).lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_complete_sprint(self, async_test_db, test_project_crud):
|
||||||
|
"""Test completing an active sprint."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
today = date.today()
|
||||||
|
|
||||||
|
# Create an active sprint
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
sprint_data = SprintCreate(
|
||||||
|
project_id=test_project_crud.id,
|
||||||
|
name="Complete Me Sprint",
|
||||||
|
number=30,
|
||||||
|
start_date=today - timedelta(days=14),
|
||||||
|
end_date=today,
|
||||||
|
status=SprintStatus.ACTIVE,
|
||||||
|
planned_points=21,
|
||||||
|
)
|
||||||
|
created = await sprint_crud.create(session, obj_in=sprint_data)
|
||||||
|
sprint_id = created.id
|
||||||
|
|
||||||
|
# Complete
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await sprint_crud.complete_sprint(session, sprint_id=sprint_id)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.status == SprintStatus.COMPLETED
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_complete_planned_sprint_fails(
|
||||||
|
self, async_test_db, test_project_crud
|
||||||
|
):
|
||||||
|
"""Test completing a planned sprint raises ValueError."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
today = date.today()
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
sprint_data = SprintCreate(
|
||||||
|
project_id=test_project_crud.id,
|
||||||
|
name="Planned Sprint",
|
||||||
|
number=40,
|
||||||
|
start_date=today,
|
||||||
|
end_date=today + timedelta(days=14),
|
||||||
|
status=SprintStatus.PLANNED,
|
||||||
|
)
|
||||||
|
created = await sprint_crud.create(session, obj_in=sprint_data)
|
||||||
|
sprint_id = created.id
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
await sprint_crud.complete_sprint(session, sprint_id=sprint_id)
|
||||||
|
|
||||||
|
assert "cannot complete sprint" in str(exc_info.value).lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cancel_sprint(self, async_test_db, test_project_crud):
|
||||||
|
"""Test cancelling a sprint."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
today = date.today()
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
sprint_data = SprintCreate(
|
||||||
|
project_id=test_project_crud.id,
|
||||||
|
name="Cancel Me Sprint",
|
||||||
|
number=50,
|
||||||
|
start_date=today,
|
||||||
|
end_date=today + timedelta(days=14),
|
||||||
|
status=SprintStatus.ACTIVE,
|
||||||
|
)
|
||||||
|
created = await sprint_crud.create(session, obj_in=sprint_data)
|
||||||
|
sprint_id = created.id
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await sprint_crud.cancel_sprint(session, sprint_id=sprint_id)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.status == SprintStatus.CANCELLED
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cancel_completed_sprint_fails(
|
||||||
|
self, async_test_db, test_project_crud
|
||||||
|
):
|
||||||
|
"""Test cancelling a completed sprint raises ValueError."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
today = date.today()
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
sprint_data = SprintCreate(
|
||||||
|
project_id=test_project_crud.id,
|
||||||
|
name="Completed Sprint",
|
||||||
|
number=60,
|
||||||
|
start_date=today - timedelta(days=14),
|
||||||
|
end_date=today,
|
||||||
|
status=SprintStatus.COMPLETED,
|
||||||
|
)
|
||||||
|
created = await sprint_crud.create(session, obj_in=sprint_data)
|
||||||
|
sprint_id = created.id
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
await sprint_crud.cancel_sprint(session, sprint_id=sprint_id)
|
||||||
|
|
||||||
|
assert "cannot cancel sprint" in str(exc_info.value).lower()
|
||||||
|
|
||||||
|
|
||||||
|
class TestSprintByProject:
|
||||||
|
"""Tests for getting sprints by project."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_by_project(
|
||||||
|
self, async_test_db, test_project_crud, test_sprint_crud
|
||||||
|
):
|
||||||
|
"""Test getting sprints by project."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
sprints, total = await sprint_crud.get_by_project(
|
||||||
|
session,
|
||||||
|
project_id=test_project_crud.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert total >= 1
|
||||||
|
assert all(s.project_id == test_project_crud.id for s in sprints)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_by_project_with_status(self, async_test_db, test_project_crud):
|
||||||
|
"""Test filtering sprints by status."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
today = date.today()
|
||||||
|
|
||||||
|
# Create sprints with different statuses
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
planned_sprint = SprintCreate(
|
||||||
|
project_id=test_project_crud.id,
|
||||||
|
name="Planned Filter Sprint",
|
||||||
|
number=70,
|
||||||
|
start_date=today,
|
||||||
|
end_date=today + timedelta(days=14),
|
||||||
|
status=SprintStatus.PLANNED,
|
||||||
|
)
|
||||||
|
await sprint_crud.create(session, obj_in=planned_sprint)
|
||||||
|
|
||||||
|
active_sprint = SprintCreate(
|
||||||
|
project_id=test_project_crud.id,
|
||||||
|
name="Active Filter Sprint",
|
||||||
|
number=71,
|
||||||
|
start_date=today,
|
||||||
|
end_date=today + timedelta(days=14),
|
||||||
|
status=SprintStatus.ACTIVE,
|
||||||
|
)
|
||||||
|
await sprint_crud.create(session, obj_in=active_sprint)
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
sprints, _ = await sprint_crud.get_by_project(
|
||||||
|
session,
|
||||||
|
project_id=test_project_crud.id,
|
||||||
|
status=SprintStatus.ACTIVE,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert all(s.status == SprintStatus.ACTIVE for s in sprints)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSprintActiveSprint:
|
||||||
|
"""Tests for active sprint operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_active_sprint(self, async_test_db, test_project_crud):
|
||||||
|
"""Test getting active sprint for a project."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
today = date.today()
|
||||||
|
|
||||||
|
# Create an active sprint
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
sprint_data = SprintCreate(
|
||||||
|
project_id=test_project_crud.id,
|
||||||
|
name="Active Sprint",
|
||||||
|
number=80,
|
||||||
|
start_date=today,
|
||||||
|
end_date=today + timedelta(days=14),
|
||||||
|
status=SprintStatus.ACTIVE,
|
||||||
|
)
|
||||||
|
await sprint_crud.create(session, obj_in=sprint_data)
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await sprint_crud.get_active_sprint(
|
||||||
|
session,
|
||||||
|
project_id=test_project_crud.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.status == SprintStatus.ACTIVE
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_active_sprint_none(self, async_test_db, test_project_crud):
|
||||||
|
"""Test getting active sprint when none exists."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
# Note: test_sprint_crud has PLANNED status by default
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
result = await sprint_crud.get_active_sprint(
|
||||||
|
session,
|
||||||
|
project_id=test_project_crud.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# May or may not be None depending on other tests
|
||||||
|
if result is not None:
|
||||||
|
assert result.status == SprintStatus.ACTIVE
|
||||||
|
|
||||||
|
|
||||||
|
class TestSprintNextNumber:
|
||||||
|
"""Tests for getting next sprint number."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_next_sprint_number(self, async_test_db, test_project_crud):
|
||||||
|
"""Test getting next sprint number."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
today = date.today()
|
||||||
|
|
||||||
|
# Create sprints with numbers
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
for i in range(1, 4):
|
||||||
|
sprint_data = SprintCreate(
|
||||||
|
project_id=test_project_crud.id,
|
||||||
|
name=f"Number Sprint {i}",
|
||||||
|
number=i,
|
||||||
|
start_date=today,
|
||||||
|
end_date=today + timedelta(days=14),
|
||||||
|
)
|
||||||
|
await sprint_crud.create(session, obj_in=sprint_data)
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
next_number = await sprint_crud.get_next_sprint_number(
|
||||||
|
session,
|
||||||
|
project_id=test_project_crud.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert next_number >= 4
|
||||||
|
|
||||||
|
|
||||||
|
class TestSprintVelocity:
|
||||||
|
"""Tests for sprint velocity operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_velocity(self, async_test_db, test_project_crud):
|
||||||
|
"""Test getting velocity data for completed sprints."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
today = date.today()
|
||||||
|
|
||||||
|
# Create completed sprints with points
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
for i in range(1, 4):
|
||||||
|
sprint_data = SprintCreate(
|
||||||
|
project_id=test_project_crud.id,
|
||||||
|
name=f"Velocity Sprint {i}",
|
||||||
|
number=100 + i,
|
||||||
|
start_date=today - timedelta(days=14 * i),
|
||||||
|
end_date=today - timedelta(days=14 * (i - 1)),
|
||||||
|
status=SprintStatus.COMPLETED,
|
||||||
|
planned_points=20,
|
||||||
|
velocity=15 + i,
|
||||||
|
)
|
||||||
|
await sprint_crud.create(session, obj_in=sprint_data)
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
velocity_data = await sprint_crud.get_velocity(
|
||||||
|
session,
|
||||||
|
project_id=test_project_crud.id,
|
||||||
|
limit=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(velocity_data) >= 1
|
||||||
|
for data in velocity_data:
|
||||||
|
assert "sprint_number" in data
|
||||||
|
assert "sprint_name" in data
|
||||||
|
assert "planned_points" in data
|
||||||
|
assert "velocity" in data
|
||||||
|
assert "velocity_ratio" in data
|
||||||
|
|
||||||
|
|
||||||
|
class TestSprintWithIssueCounts:
|
||||||
|
"""Tests for getting sprints with issue counts."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_sprints_with_issue_counts(
|
||||||
|
self, async_test_db, test_project_crud, test_sprint_crud
|
||||||
|
):
|
||||||
|
"""Test getting sprints with issue counts."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
results, total = await sprint_crud.get_sprints_with_issue_counts(
|
||||||
|
session,
|
||||||
|
project_id=test_project_crud.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert total >= 1
|
||||||
|
for result in results:
|
||||||
|
assert "sprint" in result
|
||||||
|
assert "issue_count" in result
|
||||||
|
assert "open_issues" in result
|
||||||
|
assert "completed_issues" in result
|
||||||
@@ -12,7 +12,7 @@ from sqlalchemy.exc import DataError, IntegrityError, OperationalError
|
|||||||
from sqlalchemy.orm import joinedload
|
from sqlalchemy.orm import joinedload
|
||||||
|
|
||||||
from app.crud.user import user as user_crud
|
from app.crud.user import user as user_crud
|
||||||
from app.schemas.users import UserCreate, UserUpdate
|
from app.schemas.users import UserCreate
|
||||||
|
|
||||||
|
|
||||||
class TestCRUDBaseGet:
|
class TestCRUDBaseGet:
|
||||||
@@ -266,7 +266,8 @@ class TestCRUDBaseUpdate:
|
|||||||
"statement", {}, Exception("UNIQUE constraint failed")
|
"statement", {}, Exception("UNIQUE constraint failed")
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
update_data = UserUpdate(email=async_test_user.email)
|
# Use dict since UserUpdate doesn't allow email changes
|
||||||
|
update_data = {"email": async_test_user.email}
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="already exists"):
|
with pytest.raises(ValueError, match="already exists"):
|
||||||
await user_crud.update(
|
await user_crud.update(
|
||||||
|
|||||||
2
backend/tests/models/syndarix/__init__.py
Normal file
2
backend/tests/models/syndarix/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
# tests/models/syndarix/__init__.py
|
||||||
|
"""Syndarix model unit tests."""
|
||||||
191
backend/tests/models/syndarix/conftest.py
Normal file
191
backend/tests/models/syndarix/conftest.py
Normal file
@@ -0,0 +1,191 @@
|
|||||||
|
# tests/models/syndarix/conftest.py
|
||||||
|
"""
|
||||||
|
Shared fixtures for Syndarix model tests.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from datetime import date, timedelta
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from app.models.syndarix import (
|
||||||
|
AgentInstance,
|
||||||
|
AgentStatus,
|
||||||
|
AgentType,
|
||||||
|
AutonomyLevel,
|
||||||
|
Issue,
|
||||||
|
IssuePriority,
|
||||||
|
IssueStatus,
|
||||||
|
Project,
|
||||||
|
ProjectStatus,
|
||||||
|
Sprint,
|
||||||
|
SprintStatus,
|
||||||
|
)
|
||||||
|
from app.models.user import User
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_project_data():
|
||||||
|
"""Return sample project data for testing."""
|
||||||
|
return {
|
||||||
|
"name": "Test Project",
|
||||||
|
"slug": "test-project",
|
||||||
|
"description": "A test project for unit testing",
|
||||||
|
"autonomy_level": AutonomyLevel.MILESTONE,
|
||||||
|
"status": ProjectStatus.ACTIVE,
|
||||||
|
"settings": {"mcp_servers": ["gitea", "slack"]},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_agent_type_data():
|
||||||
|
"""Return sample agent type data for testing."""
|
||||||
|
return {
|
||||||
|
"name": "Backend Engineer",
|
||||||
|
"slug": "backend-engineer",
|
||||||
|
"description": "Specialized in backend development",
|
||||||
|
"expertise": ["python", "fastapi", "postgresql"],
|
||||||
|
"personality_prompt": "You are an expert backend engineer...",
|
||||||
|
"primary_model": "claude-opus-4-5-20251101",
|
||||||
|
"fallback_models": ["claude-sonnet-4-20250514"],
|
||||||
|
"model_params": {"temperature": 0.7, "max_tokens": 4096},
|
||||||
|
"mcp_servers": ["gitea", "file-system"],
|
||||||
|
"tool_permissions": {"allowed": ["*"], "denied": []},
|
||||||
|
"is_active": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_sprint_data():
|
||||||
|
"""Return sample sprint data for testing."""
|
||||||
|
today = date.today()
|
||||||
|
return {
|
||||||
|
"name": "Sprint 1",
|
||||||
|
"number": 1,
|
||||||
|
"goal": "Complete initial setup and core features",
|
||||||
|
"start_date": today,
|
||||||
|
"end_date": today + timedelta(days=14),
|
||||||
|
"status": SprintStatus.PLANNED,
|
||||||
|
"planned_points": 21,
|
||||||
|
"completed_points": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_issue_data():
|
||||||
|
"""Return sample issue data for testing."""
|
||||||
|
return {
|
||||||
|
"title": "Implement user authentication",
|
||||||
|
"body": "As a user, I want to log in securely...",
|
||||||
|
"status": IssueStatus.OPEN,
|
||||||
|
"priority": IssuePriority.HIGH,
|
||||||
|
"labels": ["backend", "security"],
|
||||||
|
"story_points": 5,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def test_owner(async_test_db):
|
||||||
|
"""Create a test user to be used as project owner."""
|
||||||
|
from app.core.auth import get_password_hash
|
||||||
|
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
user = User(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
email="owner@example.com",
|
||||||
|
password_hash=get_password_hash("TestPassword123!"),
|
||||||
|
first_name="Test",
|
||||||
|
last_name="Owner",
|
||||||
|
is_active=True,
|
||||||
|
is_superuser=False,
|
||||||
|
)
|
||||||
|
session.add(user)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(user)
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def test_project(async_test_db, test_owner, sample_project_data):
|
||||||
|
"""Create a test project in the database."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
project = Project(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
owner_id=test_owner.id,
|
||||||
|
**sample_project_data,
|
||||||
|
)
|
||||||
|
session.add(project)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(project)
|
||||||
|
return project
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def test_agent_type(async_test_db, sample_agent_type_data):
|
||||||
|
"""Create a test agent type in the database."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
agent_type = AgentType(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
**sample_agent_type_data,
|
||||||
|
)
|
||||||
|
session.add(agent_type)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(agent_type)
|
||||||
|
return agent_type
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def test_agent_instance(async_test_db, test_project, test_agent_type):
|
||||||
|
"""Create a test agent instance in the database."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
agent_instance = AgentInstance(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
agent_type_id=test_agent_type.id,
|
||||||
|
project_id=test_project.id,
|
||||||
|
status=AgentStatus.IDLE,
|
||||||
|
current_task=None,
|
||||||
|
short_term_memory={},
|
||||||
|
long_term_memory_ref=None,
|
||||||
|
session_id=None,
|
||||||
|
)
|
||||||
|
session.add(agent_instance)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(agent_instance)
|
||||||
|
return agent_instance
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def test_sprint(async_test_db, test_project, sample_sprint_data):
|
||||||
|
"""Create a test sprint in the database."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
sprint = Sprint(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
project_id=test_project.id,
|
||||||
|
**sample_sprint_data,
|
||||||
|
)
|
||||||
|
session.add(sprint)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(sprint)
|
||||||
|
return sprint
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def test_issue(async_test_db, test_project, sample_issue_data):
|
||||||
|
"""Create a test issue in the database."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
issue = Issue(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
project_id=test_project.id,
|
||||||
|
**sample_issue_data,
|
||||||
|
)
|
||||||
|
session.add(issue)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(issue)
|
||||||
|
return issue
|
||||||
464
backend/tests/models/syndarix/test_agent_instance.py
Normal file
464
backend/tests/models/syndarix/test_agent_instance.py
Normal file
@@ -0,0 +1,464 @@
|
|||||||
|
# tests/models/syndarix/test_agent_instance.py
|
||||||
|
"""
|
||||||
|
Unit tests for the AgentInstance model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from decimal import Decimal
|
||||||
|
|
||||||
|
from app.models.syndarix import (
|
||||||
|
AgentInstance,
|
||||||
|
AgentStatus,
|
||||||
|
AgentType,
|
||||||
|
Project,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentInstanceModel:
|
||||||
|
"""Tests for AgentInstance model creation and fields."""
|
||||||
|
|
||||||
|
def test_create_agent_instance_with_required_fields(self, db_session):
|
||||||
|
"""Test creating an agent instance with only required fields."""
|
||||||
|
# First create dependencies
|
||||||
|
project = Project(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Test Project",
|
||||||
|
slug="test-project-instance",
|
||||||
|
)
|
||||||
|
db_session.add(project)
|
||||||
|
|
||||||
|
agent_type = AgentType(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Test Agent",
|
||||||
|
slug="test-agent-instance",
|
||||||
|
personality_prompt="Test",
|
||||||
|
primary_model="claude-opus-4-5-20251101",
|
||||||
|
)
|
||||||
|
db_session.add(agent_type)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Create agent instance
|
||||||
|
instance = AgentInstance(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
agent_type_id=agent_type.id,
|
||||||
|
project_id=project.id,
|
||||||
|
name="Alice",
|
||||||
|
)
|
||||||
|
db_session.add(instance)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
retrieved = (
|
||||||
|
db_session.query(AgentInstance).filter_by(project_id=project.id).first()
|
||||||
|
)
|
||||||
|
|
||||||
|
assert retrieved is not None
|
||||||
|
assert retrieved.agent_type_id == agent_type.id
|
||||||
|
assert retrieved.project_id == project.id
|
||||||
|
assert retrieved.status == AgentStatus.IDLE # Default
|
||||||
|
assert retrieved.current_task is None
|
||||||
|
assert retrieved.short_term_memory == {}
|
||||||
|
assert retrieved.long_term_memory_ref is None
|
||||||
|
assert retrieved.session_id is None
|
||||||
|
assert retrieved.tasks_completed == 0
|
||||||
|
assert retrieved.tokens_used == 0
|
||||||
|
assert retrieved.cost_incurred == Decimal("0")
|
||||||
|
|
||||||
|
def test_create_agent_instance_with_all_fields(self, db_session):
|
||||||
|
"""Test creating an agent instance with all optional fields."""
|
||||||
|
# First create dependencies
|
||||||
|
project = Project(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Full Project",
|
||||||
|
slug="full-project-instance",
|
||||||
|
)
|
||||||
|
db_session.add(project)
|
||||||
|
|
||||||
|
agent_type = AgentType(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Full Agent",
|
||||||
|
slug="full-agent-instance",
|
||||||
|
personality_prompt="Test",
|
||||||
|
primary_model="claude-opus-4-5-20251101",
|
||||||
|
)
|
||||||
|
db_session.add(agent_type)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
instance_id = uuid.uuid4()
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
|
||||||
|
instance = AgentInstance(
|
||||||
|
id=instance_id,
|
||||||
|
agent_type_id=agent_type.id,
|
||||||
|
project_id=project.id,
|
||||||
|
name="Bob",
|
||||||
|
status=AgentStatus.WORKING,
|
||||||
|
current_task="Implementing user authentication",
|
||||||
|
short_term_memory={
|
||||||
|
"context": "Working on auth",
|
||||||
|
"recent_files": ["auth.py"],
|
||||||
|
},
|
||||||
|
long_term_memory_ref="project-123/agent-456",
|
||||||
|
session_id="session-abc-123",
|
||||||
|
last_activity_at=now,
|
||||||
|
tasks_completed=5,
|
||||||
|
tokens_used=10000,
|
||||||
|
cost_incurred=Decimal("0.5000"),
|
||||||
|
)
|
||||||
|
db_session.add(instance)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
retrieved = db_session.query(AgentInstance).filter_by(id=instance_id).first()
|
||||||
|
|
||||||
|
assert retrieved.status == AgentStatus.WORKING
|
||||||
|
assert retrieved.current_task == "Implementing user authentication"
|
||||||
|
assert retrieved.short_term_memory == {
|
||||||
|
"context": "Working on auth",
|
||||||
|
"recent_files": ["auth.py"],
|
||||||
|
}
|
||||||
|
assert retrieved.long_term_memory_ref == "project-123/agent-456"
|
||||||
|
assert retrieved.session_id == "session-abc-123"
|
||||||
|
assert retrieved.tasks_completed == 5
|
||||||
|
assert retrieved.tokens_used == 10000
|
||||||
|
assert retrieved.cost_incurred == Decimal("0.5000")
|
||||||
|
|
||||||
|
def test_agent_instance_timestamps(self, db_session):
|
||||||
|
"""Test that timestamps are automatically set."""
|
||||||
|
project = Project(
|
||||||
|
id=uuid.uuid4(), name="Timestamp Project", slug="timestamp-project-ai"
|
||||||
|
)
|
||||||
|
agent_type = AgentType(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Timestamp Agent",
|
||||||
|
slug="timestamp-agent-ai",
|
||||||
|
personality_prompt="Test",
|
||||||
|
primary_model="claude-opus-4-5-20251101",
|
||||||
|
)
|
||||||
|
db_session.add(project)
|
||||||
|
db_session.add(agent_type)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
instance = AgentInstance(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
agent_type_id=agent_type.id,
|
||||||
|
project_id=project.id,
|
||||||
|
name="Charlie",
|
||||||
|
)
|
||||||
|
db_session.add(instance)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
assert isinstance(instance.created_at, datetime)
|
||||||
|
assert isinstance(instance.updated_at, datetime)
|
||||||
|
|
||||||
|
def test_agent_instance_string_representation(self, db_session):
|
||||||
|
"""Test the string representation of an agent instance."""
|
||||||
|
project = Project(id=uuid.uuid4(), name="Repr Project", slug="repr-project-ai")
|
||||||
|
agent_type = AgentType(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Repr Agent",
|
||||||
|
slug="repr-agent-ai",
|
||||||
|
personality_prompt="Test",
|
||||||
|
primary_model="claude-opus-4-5-20251101",
|
||||||
|
)
|
||||||
|
db_session.add(project)
|
||||||
|
db_session.add(agent_type)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
instance_id = uuid.uuid4()
|
||||||
|
instance = AgentInstance(
|
||||||
|
id=instance_id,
|
||||||
|
agent_type_id=agent_type.id,
|
||||||
|
project_id=project.id,
|
||||||
|
name="Dave",
|
||||||
|
status=AgentStatus.IDLE,
|
||||||
|
)
|
||||||
|
|
||||||
|
repr_str = repr(instance)
|
||||||
|
assert "Dave" in repr_str
|
||||||
|
assert str(instance_id) in repr_str
|
||||||
|
assert str(agent_type.id) in repr_str
|
||||||
|
assert str(project.id) in repr_str
|
||||||
|
assert "idle" in repr_str
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentInstanceStatus:
|
||||||
|
"""Tests for AgentInstance status transitions."""
|
||||||
|
|
||||||
|
def test_all_agent_statuses(self, db_session):
|
||||||
|
"""Test that all agent statuses can be stored."""
|
||||||
|
project = Project(
|
||||||
|
id=uuid.uuid4(), name="Status Project", slug="status-project-ai"
|
||||||
|
)
|
||||||
|
agent_type = AgentType(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Status Agent",
|
||||||
|
slug="status-agent-ai",
|
||||||
|
personality_prompt="Test",
|
||||||
|
primary_model="claude-opus-4-5-20251101",
|
||||||
|
)
|
||||||
|
db_session.add(project)
|
||||||
|
db_session.add(agent_type)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
for idx, status in enumerate(AgentStatus):
|
||||||
|
instance = AgentInstance(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
agent_type_id=agent_type.id,
|
||||||
|
project_id=project.id,
|
||||||
|
name=f"Agent-{idx}",
|
||||||
|
status=status,
|
||||||
|
)
|
||||||
|
db_session.add(instance)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
retrieved = (
|
||||||
|
db_session.query(AgentInstance).filter_by(id=instance.id).first()
|
||||||
|
)
|
||||||
|
assert retrieved.status == status
|
||||||
|
|
||||||
|
def test_status_update(self, db_session):
|
||||||
|
"""Test updating agent instance status."""
|
||||||
|
project = Project(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Update Status Project",
|
||||||
|
slug="update-status-project-ai",
|
||||||
|
)
|
||||||
|
agent_type = AgentType(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Update Status Agent",
|
||||||
|
slug="update-status-agent-ai",
|
||||||
|
personality_prompt="Test",
|
||||||
|
primary_model="claude-opus-4-5-20251101",
|
||||||
|
)
|
||||||
|
db_session.add(project)
|
||||||
|
db_session.add(agent_type)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
instance = AgentInstance(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
agent_type_id=agent_type.id,
|
||||||
|
project_id=project.id,
|
||||||
|
name="Eve",
|
||||||
|
status=AgentStatus.IDLE,
|
||||||
|
)
|
||||||
|
db_session.add(instance)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Update to WORKING
|
||||||
|
instance.status = AgentStatus.WORKING
|
||||||
|
instance.current_task = "Processing feature request"
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
retrieved = db_session.query(AgentInstance).filter_by(id=instance.id).first()
|
||||||
|
assert retrieved.status == AgentStatus.WORKING
|
||||||
|
assert retrieved.current_task == "Processing feature request"
|
||||||
|
|
||||||
|
def test_terminate_agent_instance(self, db_session):
|
||||||
|
"""Test terminating an agent instance."""
|
||||||
|
project = Project(
|
||||||
|
id=uuid.uuid4(), name="Terminate Project", slug="terminate-project-ai"
|
||||||
|
)
|
||||||
|
agent_type = AgentType(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Terminate Agent",
|
||||||
|
slug="terminate-agent-ai",
|
||||||
|
personality_prompt="Test",
|
||||||
|
primary_model="claude-opus-4-5-20251101",
|
||||||
|
)
|
||||||
|
db_session.add(project)
|
||||||
|
db_session.add(agent_type)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
instance = AgentInstance(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
agent_type_id=agent_type.id,
|
||||||
|
project_id=project.id,
|
||||||
|
name="Frank",
|
||||||
|
status=AgentStatus.WORKING,
|
||||||
|
current_task="Working on something",
|
||||||
|
session_id="active-session",
|
||||||
|
)
|
||||||
|
db_session.add(instance)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Terminate
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
instance.status = AgentStatus.TERMINATED
|
||||||
|
instance.terminated_at = now
|
||||||
|
instance.current_task = None
|
||||||
|
instance.session_id = None
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
retrieved = db_session.query(AgentInstance).filter_by(id=instance.id).first()
|
||||||
|
assert retrieved.status == AgentStatus.TERMINATED
|
||||||
|
assert retrieved.terminated_at is not None
|
||||||
|
assert retrieved.current_task is None
|
||||||
|
assert retrieved.session_id is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentInstanceMetrics:
|
||||||
|
"""Tests for AgentInstance usage metrics."""
|
||||||
|
|
||||||
|
def test_increment_metrics(self, db_session):
|
||||||
|
"""Test incrementing usage metrics."""
|
||||||
|
project = Project(
|
||||||
|
id=uuid.uuid4(), name="Metrics Project", slug="metrics-project-ai"
|
||||||
|
)
|
||||||
|
agent_type = AgentType(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Metrics Agent",
|
||||||
|
slug="metrics-agent-ai",
|
||||||
|
personality_prompt="Test",
|
||||||
|
primary_model="claude-opus-4-5-20251101",
|
||||||
|
)
|
||||||
|
db_session.add(project)
|
||||||
|
db_session.add(agent_type)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
instance = AgentInstance(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
agent_type_id=agent_type.id,
|
||||||
|
project_id=project.id,
|
||||||
|
name="Grace",
|
||||||
|
)
|
||||||
|
db_session.add(instance)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Record task completion
|
||||||
|
instance.tasks_completed += 1
|
||||||
|
instance.tokens_used += 1500
|
||||||
|
instance.cost_incurred += Decimal("0.0150")
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
retrieved = db_session.query(AgentInstance).filter_by(id=instance.id).first()
|
||||||
|
assert retrieved.tasks_completed == 1
|
||||||
|
assert retrieved.tokens_used == 1500
|
||||||
|
assert retrieved.cost_incurred == Decimal("0.0150")
|
||||||
|
|
||||||
|
# Record another task
|
||||||
|
retrieved.tasks_completed += 1
|
||||||
|
retrieved.tokens_used += 2500
|
||||||
|
retrieved.cost_incurred += Decimal("0.0250")
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
updated = db_session.query(AgentInstance).filter_by(id=instance.id).first()
|
||||||
|
assert updated.tasks_completed == 2
|
||||||
|
assert updated.tokens_used == 4000
|
||||||
|
assert updated.cost_incurred == Decimal("0.0400")
|
||||||
|
|
||||||
|
def test_large_token_count(self, db_session):
|
||||||
|
"""Test handling large token counts."""
|
||||||
|
project = Project(
|
||||||
|
id=uuid.uuid4(), name="Large Tokens Project", slug="large-tokens-project-ai"
|
||||||
|
)
|
||||||
|
agent_type = AgentType(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Large Tokens Agent",
|
||||||
|
slug="large-tokens-agent-ai",
|
||||||
|
personality_prompt="Test",
|
||||||
|
primary_model="claude-opus-4-5-20251101",
|
||||||
|
)
|
||||||
|
db_session.add(project)
|
||||||
|
db_session.add(agent_type)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
instance = AgentInstance(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
agent_type_id=agent_type.id,
|
||||||
|
project_id=project.id,
|
||||||
|
name="Henry",
|
||||||
|
tokens_used=10_000_000_000, # 10 billion tokens
|
||||||
|
cost_incurred=Decimal("100000.0000"), # $100,000
|
||||||
|
)
|
||||||
|
db_session.add(instance)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
retrieved = db_session.query(AgentInstance).filter_by(id=instance.id).first()
|
||||||
|
assert retrieved.tokens_used == 10_000_000_000
|
||||||
|
assert retrieved.cost_incurred == Decimal("100000.0000")
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentInstanceShortTermMemory:
|
||||||
|
"""Tests for AgentInstance short-term memory JSON field."""
|
||||||
|
|
||||||
|
def test_store_complex_memory(self, db_session):
|
||||||
|
"""Test storing complex short-term memory."""
|
||||||
|
project = Project(
|
||||||
|
id=uuid.uuid4(), name="Memory Project", slug="memory-project-ai"
|
||||||
|
)
|
||||||
|
agent_type = AgentType(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Memory Agent",
|
||||||
|
slug="memory-agent-ai",
|
||||||
|
personality_prompt="Test",
|
||||||
|
primary_model="claude-opus-4-5-20251101",
|
||||||
|
)
|
||||||
|
db_session.add(project)
|
||||||
|
db_session.add(agent_type)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
memory = {
|
||||||
|
"conversation_history": [
|
||||||
|
{"role": "user", "content": "Implement feature X"},
|
||||||
|
{"role": "assistant", "content": "I'll start by..."},
|
||||||
|
],
|
||||||
|
"recent_files": ["auth.py", "models.py", "test_auth.py"],
|
||||||
|
"decisions": {
|
||||||
|
"architecture": "Use repository pattern",
|
||||||
|
"testing": "TDD approach",
|
||||||
|
},
|
||||||
|
"blockers": [],
|
||||||
|
"context_tokens": 2048,
|
||||||
|
}
|
||||||
|
|
||||||
|
instance = AgentInstance(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
agent_type_id=agent_type.id,
|
||||||
|
project_id=project.id,
|
||||||
|
name="Ivy",
|
||||||
|
short_term_memory=memory,
|
||||||
|
)
|
||||||
|
db_session.add(instance)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
retrieved = db_session.query(AgentInstance).filter_by(id=instance.id).first()
|
||||||
|
assert retrieved.short_term_memory == memory
|
||||||
|
assert len(retrieved.short_term_memory["conversation_history"]) == 2
|
||||||
|
assert "auth.py" in retrieved.short_term_memory["recent_files"]
|
||||||
|
|
||||||
|
def test_update_memory(self, db_session):
|
||||||
|
"""Test updating short-term memory."""
|
||||||
|
project = Project(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Update Memory Project",
|
||||||
|
slug="update-memory-project-ai",
|
||||||
|
)
|
||||||
|
agent_type = AgentType(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Update Memory Agent",
|
||||||
|
slug="update-memory-agent-ai",
|
||||||
|
personality_prompt="Test",
|
||||||
|
primary_model="claude-opus-4-5-20251101",
|
||||||
|
)
|
||||||
|
db_session.add(project)
|
||||||
|
db_session.add(agent_type)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
instance = AgentInstance(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
agent_type_id=agent_type.id,
|
||||||
|
project_id=project.id,
|
||||||
|
name="Jack",
|
||||||
|
short_term_memory={"initial": "state"},
|
||||||
|
)
|
||||||
|
db_session.add(instance)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Update memory
|
||||||
|
instance.short_term_memory = {"updated": "state", "new_key": "new_value"}
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
retrieved = db_session.query(AgentInstance).filter_by(id=instance.id).first()
|
||||||
|
assert "initial" not in retrieved.short_term_memory
|
||||||
|
assert retrieved.short_term_memory["updated"] == "state"
|
||||||
|
assert retrieved.short_term_memory["new_key"] == "new_value"
|
||||||
324
backend/tests/models/syndarix/test_agent_type.py
Normal file
324
backend/tests/models/syndarix/test_agent_type.py
Normal file
@@ -0,0 +1,324 @@
|
|||||||
|
# tests/models/syndarix/test_agent_type.py
|
||||||
|
"""
|
||||||
|
Unit tests for the AgentType model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy.exc import IntegrityError
|
||||||
|
|
||||||
|
from app.models.syndarix import AgentType
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentTypeModel:
|
||||||
|
"""Tests for AgentType model creation and fields."""
|
||||||
|
|
||||||
|
def test_create_agent_type_with_required_fields(self, db_session):
|
||||||
|
"""Test creating an agent type with only required fields."""
|
||||||
|
agent_type = AgentType(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Test Agent",
|
||||||
|
slug="test-agent",
|
||||||
|
personality_prompt="You are a helpful assistant.",
|
||||||
|
primary_model="claude-opus-4-5-20251101",
|
||||||
|
)
|
||||||
|
db_session.add(agent_type)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
retrieved = db_session.query(AgentType).filter_by(slug="test-agent").first()
|
||||||
|
|
||||||
|
assert retrieved is not None
|
||||||
|
assert retrieved.name == "Test Agent"
|
||||||
|
assert retrieved.slug == "test-agent"
|
||||||
|
assert retrieved.personality_prompt == "You are a helpful assistant."
|
||||||
|
assert retrieved.primary_model == "claude-opus-4-5-20251101"
|
||||||
|
assert retrieved.is_active is True # Default
|
||||||
|
assert retrieved.expertise == [] # Default empty list
|
||||||
|
assert retrieved.fallback_models == [] # Default empty list
|
||||||
|
assert retrieved.model_params == {} # Default empty dict
|
||||||
|
assert retrieved.mcp_servers == [] # Default empty list
|
||||||
|
assert retrieved.tool_permissions == {} # Default empty dict
|
||||||
|
|
||||||
|
def test_create_agent_type_with_all_fields(self, db_session):
|
||||||
|
"""Test creating an agent type with all optional fields."""
|
||||||
|
agent_type_id = uuid.uuid4()
|
||||||
|
|
||||||
|
agent_type = AgentType(
|
||||||
|
id=agent_type_id,
|
||||||
|
name="Full Agent Type",
|
||||||
|
slug="full-agent-type",
|
||||||
|
description="A fully configured agent type",
|
||||||
|
expertise=["python", "fastapi", "testing"],
|
||||||
|
personality_prompt="You are an expert Python developer...",
|
||||||
|
primary_model="claude-opus-4-5-20251101",
|
||||||
|
fallback_models=["claude-sonnet-4-20250514", "gpt-4o"],
|
||||||
|
model_params={"temperature": 0.7, "max_tokens": 4096},
|
||||||
|
mcp_servers=["gitea", "file-system", "slack"],
|
||||||
|
tool_permissions={"allowed": ["*"], "denied": ["dangerous_tool"]},
|
||||||
|
is_active=True,
|
||||||
|
)
|
||||||
|
db_session.add(agent_type)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
retrieved = db_session.query(AgentType).filter_by(id=agent_type_id).first()
|
||||||
|
|
||||||
|
assert retrieved.name == "Full Agent Type"
|
||||||
|
assert retrieved.description == "A fully configured agent type"
|
||||||
|
assert retrieved.expertise == ["python", "fastapi", "testing"]
|
||||||
|
assert retrieved.fallback_models == ["claude-sonnet-4-20250514", "gpt-4o"]
|
||||||
|
assert retrieved.model_params == {"temperature": 0.7, "max_tokens": 4096}
|
||||||
|
assert retrieved.mcp_servers == ["gitea", "file-system", "slack"]
|
||||||
|
assert retrieved.tool_permissions == {
|
||||||
|
"allowed": ["*"],
|
||||||
|
"denied": ["dangerous_tool"],
|
||||||
|
}
|
||||||
|
assert retrieved.is_active is True
|
||||||
|
|
||||||
|
def test_agent_type_unique_slug_constraint(self, db_session):
|
||||||
|
"""Test that agent types cannot have duplicate slugs."""
|
||||||
|
agent_type1 = AgentType(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Agent One",
|
||||||
|
slug="duplicate-agent-slug",
|
||||||
|
personality_prompt="First agent",
|
||||||
|
primary_model="claude-opus-4-5-20251101",
|
||||||
|
)
|
||||||
|
db_session.add(agent_type1)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
agent_type2 = AgentType(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Agent Two",
|
||||||
|
slug="duplicate-agent-slug", # Same slug
|
||||||
|
personality_prompt="Second agent",
|
||||||
|
primary_model="claude-opus-4-5-20251101",
|
||||||
|
)
|
||||||
|
db_session.add(agent_type2)
|
||||||
|
|
||||||
|
with pytest.raises(IntegrityError):
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
db_session.rollback()
|
||||||
|
|
||||||
|
def test_agent_type_timestamps(self, db_session):
|
||||||
|
"""Test that timestamps are automatically set."""
|
||||||
|
agent_type = AgentType(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Timestamp Agent",
|
||||||
|
slug="timestamp-agent",
|
||||||
|
personality_prompt="Test",
|
||||||
|
primary_model="claude-opus-4-5-20251101",
|
||||||
|
)
|
||||||
|
db_session.add(agent_type)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
retrieved = (
|
||||||
|
db_session.query(AgentType).filter_by(slug="timestamp-agent").first()
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(retrieved.created_at, datetime)
|
||||||
|
assert isinstance(retrieved.updated_at, datetime)
|
||||||
|
|
||||||
|
def test_agent_type_update(self, db_session):
|
||||||
|
"""Test updating agent type fields."""
|
||||||
|
agent_type = AgentType(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Original Agent",
|
||||||
|
slug="original-agent",
|
||||||
|
personality_prompt="Original prompt",
|
||||||
|
primary_model="claude-opus-4-5-20251101",
|
||||||
|
is_active=True,
|
||||||
|
)
|
||||||
|
db_session.add(agent_type)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
original_created_at = agent_type.created_at
|
||||||
|
|
||||||
|
# Update fields
|
||||||
|
agent_type.name = "Updated Agent"
|
||||||
|
agent_type.is_active = False
|
||||||
|
agent_type.expertise = ["new", "skills"]
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
retrieved = db_session.query(AgentType).filter_by(slug="original-agent").first()
|
||||||
|
|
||||||
|
assert retrieved.name == "Updated Agent"
|
||||||
|
assert retrieved.is_active is False
|
||||||
|
assert retrieved.expertise == ["new", "skills"]
|
||||||
|
assert retrieved.created_at == original_created_at
|
||||||
|
assert retrieved.updated_at > original_created_at
|
||||||
|
|
||||||
|
def test_agent_type_delete(self, db_session):
|
||||||
|
"""Test deleting an agent type."""
|
||||||
|
agent_type_id = uuid.uuid4()
|
||||||
|
agent_type = AgentType(
|
||||||
|
id=agent_type_id,
|
||||||
|
name="Delete Me",
|
||||||
|
slug="delete-me-agent",
|
||||||
|
personality_prompt="Delete test",
|
||||||
|
primary_model="claude-opus-4-5-20251101",
|
||||||
|
)
|
||||||
|
db_session.add(agent_type)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
db_session.delete(agent_type)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
deleted = db_session.query(AgentType).filter_by(id=agent_type_id).first()
|
||||||
|
assert deleted is None
|
||||||
|
|
||||||
|
def test_agent_type_string_representation(self, db_session):
|
||||||
|
"""Test the string representation of an agent type."""
|
||||||
|
agent_type = AgentType(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Repr Agent",
|
||||||
|
slug="repr-agent",
|
||||||
|
personality_prompt="Test",
|
||||||
|
primary_model="claude-opus-4-5-20251101",
|
||||||
|
is_active=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert str(agent_type) == "<AgentType Repr Agent (repr-agent) active=True>"
|
||||||
|
assert repr(agent_type) == "<AgentType Repr Agent (repr-agent) active=True>"
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentTypeJsonFields:
|
||||||
|
"""Tests for AgentType JSON fields."""
|
||||||
|
|
||||||
|
def test_complex_expertise_list(self, db_session):
|
||||||
|
"""Test storing a list of expertise areas."""
|
||||||
|
expertise = ["python", "fastapi", "sqlalchemy", "postgresql", "redis", "docker"]
|
||||||
|
|
||||||
|
agent_type = AgentType(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Expert Agent",
|
||||||
|
slug="expert-agent",
|
||||||
|
personality_prompt="Prompt",
|
||||||
|
primary_model="claude-opus-4-5-20251101",
|
||||||
|
expertise=expertise,
|
||||||
|
)
|
||||||
|
db_session.add(agent_type)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
retrieved = db_session.query(AgentType).filter_by(slug="expert-agent").first()
|
||||||
|
assert retrieved.expertise == expertise
|
||||||
|
assert "python" in retrieved.expertise
|
||||||
|
assert len(retrieved.expertise) == 6
|
||||||
|
|
||||||
|
def test_complex_model_params(self, db_session):
|
||||||
|
"""Test storing complex model parameters."""
|
||||||
|
model_params = {
|
||||||
|
"temperature": 0.7,
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"top_p": 0.9,
|
||||||
|
"frequency_penalty": 0.1,
|
||||||
|
"presence_penalty": 0.1,
|
||||||
|
"stop_sequences": ["###", "END"],
|
||||||
|
}
|
||||||
|
|
||||||
|
agent_type = AgentType(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Params Agent",
|
||||||
|
slug="params-agent",
|
||||||
|
personality_prompt="Prompt",
|
||||||
|
primary_model="claude-opus-4-5-20251101",
|
||||||
|
model_params=model_params,
|
||||||
|
)
|
||||||
|
db_session.add(agent_type)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
retrieved = db_session.query(AgentType).filter_by(slug="params-agent").first()
|
||||||
|
assert retrieved.model_params == model_params
|
||||||
|
assert retrieved.model_params["temperature"] == 0.7
|
||||||
|
assert retrieved.model_params["stop_sequences"] == ["###", "END"]
|
||||||
|
|
||||||
|
def test_complex_tool_permissions(self, db_session):
|
||||||
|
"""Test storing complex tool permissions."""
|
||||||
|
tool_permissions = {
|
||||||
|
"allowed": ["file:read", "file:write", "git:commit"],
|
||||||
|
"denied": ["file:delete", "system:exec"],
|
||||||
|
"require_approval": ["git:push", "gitea:create_pr"],
|
||||||
|
"limits": {
|
||||||
|
"file:write": {"max_size_mb": 10},
|
||||||
|
"git:commit": {"require_message": True},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
agent_type = AgentType(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Permissions Agent",
|
||||||
|
slug="permissions-agent",
|
||||||
|
personality_prompt="Prompt",
|
||||||
|
primary_model="claude-opus-4-5-20251101",
|
||||||
|
tool_permissions=tool_permissions,
|
||||||
|
)
|
||||||
|
db_session.add(agent_type)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
retrieved = (
|
||||||
|
db_session.query(AgentType).filter_by(slug="permissions-agent").first()
|
||||||
|
)
|
||||||
|
assert retrieved.tool_permissions == tool_permissions
|
||||||
|
assert "file:read" in retrieved.tool_permissions["allowed"]
|
||||||
|
assert retrieved.tool_permissions["limits"]["file:write"]["max_size_mb"] == 10
|
||||||
|
|
||||||
|
def test_empty_json_fields_default(self, db_session):
|
||||||
|
"""Test that JSON fields default to empty structures."""
|
||||||
|
agent_type = AgentType(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Empty JSON Agent",
|
||||||
|
slug="empty-json-agent",
|
||||||
|
personality_prompt="Prompt",
|
||||||
|
primary_model="claude-opus-4-5-20251101",
|
||||||
|
)
|
||||||
|
db_session.add(agent_type)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
retrieved = (
|
||||||
|
db_session.query(AgentType).filter_by(slug="empty-json-agent").first()
|
||||||
|
)
|
||||||
|
assert retrieved.expertise == []
|
||||||
|
assert retrieved.fallback_models == []
|
||||||
|
assert retrieved.model_params == {}
|
||||||
|
assert retrieved.mcp_servers == []
|
||||||
|
assert retrieved.tool_permissions == {}
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentTypeIsActive:
|
||||||
|
"""Tests for AgentType is_active field."""
|
||||||
|
|
||||||
|
def test_default_is_active(self, db_session):
|
||||||
|
"""Test that is_active defaults to True."""
|
||||||
|
agent_type = AgentType(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Default Active",
|
||||||
|
slug="default-active",
|
||||||
|
personality_prompt="Prompt",
|
||||||
|
primary_model="claude-opus-4-5-20251101",
|
||||||
|
)
|
||||||
|
db_session.add(agent_type)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
retrieved = db_session.query(AgentType).filter_by(slug="default-active").first()
|
||||||
|
assert retrieved.is_active is True
|
||||||
|
|
||||||
|
def test_deactivate_agent_type(self, db_session):
|
||||||
|
"""Test deactivating an agent type."""
|
||||||
|
agent_type = AgentType(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Deactivate Me",
|
||||||
|
slug="deactivate-me",
|
||||||
|
personality_prompt="Prompt",
|
||||||
|
primary_model="claude-opus-4-5-20251101",
|
||||||
|
is_active=True,
|
||||||
|
)
|
||||||
|
db_session.add(agent_type)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
agent_type.is_active = False
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
retrieved = db_session.query(AgentType).filter_by(slug="deactivate-me").first()
|
||||||
|
assert retrieved.is_active is False
|
||||||
503
backend/tests/models/syndarix/test_issue.py
Normal file
503
backend/tests/models/syndarix/test_issue.py
Normal file
@@ -0,0 +1,503 @@
|
|||||||
|
# tests/models/syndarix/test_issue.py
|
||||||
|
"""
|
||||||
|
Unit tests for the Issue model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
|
||||||
|
from app.models.syndarix import (
|
||||||
|
AgentInstance,
|
||||||
|
AgentType,
|
||||||
|
Issue,
|
||||||
|
IssuePriority,
|
||||||
|
IssueStatus,
|
||||||
|
IssueType,
|
||||||
|
Project,
|
||||||
|
Sprint,
|
||||||
|
SprintStatus,
|
||||||
|
SyncStatus,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestIssueModel:
|
||||||
|
"""Tests for Issue model creation and fields."""
|
||||||
|
|
||||||
|
def test_create_issue_with_required_fields(self, db_session):
|
||||||
|
"""Test creating an issue with only required fields."""
|
||||||
|
project = Project(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Issue Project",
|
||||||
|
slug="issue-project",
|
||||||
|
)
|
||||||
|
db_session.add(project)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
issue = Issue(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
project_id=project.id,
|
||||||
|
title="Test Issue",
|
||||||
|
)
|
||||||
|
db_session.add(issue)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
retrieved = db_session.query(Issue).filter_by(title="Test Issue").first()
|
||||||
|
|
||||||
|
assert retrieved is not None
|
||||||
|
assert retrieved.title == "Test Issue"
|
||||||
|
assert retrieved.body == "" # Default empty string
|
||||||
|
assert retrieved.status == IssueStatus.OPEN # Default
|
||||||
|
assert retrieved.priority == IssuePriority.MEDIUM # Default
|
||||||
|
assert retrieved.labels == [] # Default empty list
|
||||||
|
assert retrieved.story_points is None
|
||||||
|
assert retrieved.assigned_agent_id is None
|
||||||
|
assert retrieved.human_assignee is None
|
||||||
|
assert retrieved.sprint_id is None
|
||||||
|
assert retrieved.sync_status == SyncStatus.SYNCED # Default
|
||||||
|
|
||||||
|
def test_create_issue_with_all_fields(self, db_session):
|
||||||
|
"""Test creating an issue with all optional fields."""
|
||||||
|
project = Project(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Full Issue Project",
|
||||||
|
slug="full-issue-project",
|
||||||
|
)
|
||||||
|
db_session.add(project)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
issue_id = uuid.uuid4()
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
|
||||||
|
issue = Issue(
|
||||||
|
id=issue_id,
|
||||||
|
project_id=project.id,
|
||||||
|
title="Full Issue",
|
||||||
|
body="A complete issue with all fields set",
|
||||||
|
type=IssueType.BUG,
|
||||||
|
status=IssueStatus.IN_PROGRESS,
|
||||||
|
priority=IssuePriority.CRITICAL,
|
||||||
|
labels=["bug", "security", "urgent"],
|
||||||
|
story_points=8,
|
||||||
|
human_assignee="john.doe@example.com",
|
||||||
|
external_tracker_type="gitea",
|
||||||
|
external_issue_id="gitea-123",
|
||||||
|
remote_url="https://gitea.example.com/issues/123",
|
||||||
|
external_issue_number=123,
|
||||||
|
sync_status=SyncStatus.SYNCED,
|
||||||
|
last_synced_at=now,
|
||||||
|
external_updated_at=now,
|
||||||
|
)
|
||||||
|
db_session.add(issue)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
retrieved = db_session.query(Issue).filter_by(id=issue_id).first()
|
||||||
|
|
||||||
|
assert retrieved.title == "Full Issue"
|
||||||
|
assert retrieved.body == "A complete issue with all fields set"
|
||||||
|
assert retrieved.type == IssueType.BUG
|
||||||
|
assert retrieved.status == IssueStatus.IN_PROGRESS
|
||||||
|
assert retrieved.priority == IssuePriority.CRITICAL
|
||||||
|
assert retrieved.labels == ["bug", "security", "urgent"]
|
||||||
|
assert retrieved.story_points == 8
|
||||||
|
assert retrieved.human_assignee == "john.doe@example.com"
|
||||||
|
assert retrieved.external_tracker_type == "gitea"
|
||||||
|
assert retrieved.external_issue_id == "gitea-123"
|
||||||
|
assert retrieved.external_issue_number == 123
|
||||||
|
assert retrieved.sync_status == SyncStatus.SYNCED
|
||||||
|
|
||||||
|
def test_issue_timestamps(self, db_session):
|
||||||
|
"""Test that timestamps are automatically set."""
|
||||||
|
project = Project(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Timestamp Issue Project",
|
||||||
|
slug="timestamp-issue-project",
|
||||||
|
)
|
||||||
|
db_session.add(project)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
issue = Issue(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
project_id=project.id,
|
||||||
|
title="Timestamp Issue",
|
||||||
|
)
|
||||||
|
db_session.add(issue)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
assert isinstance(issue.created_at, datetime)
|
||||||
|
assert isinstance(issue.updated_at, datetime)
|
||||||
|
|
||||||
|
def test_issue_string_representation(self, db_session):
|
||||||
|
"""Test the string representation of an issue."""
|
||||||
|
project = Project(
|
||||||
|
id=uuid.uuid4(), name="Repr Issue Project", slug="repr-issue-project"
|
||||||
|
)
|
||||||
|
db_session.add(project)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
issue = Issue(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
project_id=project.id,
|
||||||
|
title="This is a very long issue title that should be truncated in repr",
|
||||||
|
status=IssueStatus.OPEN,
|
||||||
|
priority=IssuePriority.HIGH,
|
||||||
|
)
|
||||||
|
|
||||||
|
repr_str = repr(issue)
|
||||||
|
assert "This is a very long issue tit" in repr_str # First 30 chars
|
||||||
|
assert "open" in repr_str
|
||||||
|
assert "high" in repr_str
|
||||||
|
|
||||||
|
|
||||||
|
class TestIssueStatus:
|
||||||
|
"""Tests for Issue status field."""
|
||||||
|
|
||||||
|
def test_all_issue_statuses(self, db_session):
|
||||||
|
"""Test that all issue statuses can be stored."""
|
||||||
|
project = Project(
|
||||||
|
id=uuid.uuid4(), name="Status Issue Project", slug="status-issue-project"
|
||||||
|
)
|
||||||
|
db_session.add(project)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
for status in IssueStatus:
|
||||||
|
issue = Issue(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
project_id=project.id,
|
||||||
|
title=f"Issue {status.value}",
|
||||||
|
status=status,
|
||||||
|
)
|
||||||
|
db_session.add(issue)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
retrieved = db_session.query(Issue).filter_by(id=issue.id).first()
|
||||||
|
assert retrieved.status == status
|
||||||
|
|
||||||
|
|
||||||
|
class TestIssuePriority:
|
||||||
|
"""Tests for Issue priority field."""
|
||||||
|
|
||||||
|
def test_all_issue_priorities(self, db_session):
|
||||||
|
"""Test that all issue priorities can be stored."""
|
||||||
|
project = Project(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Priority Issue Project",
|
||||||
|
slug="priority-issue-project",
|
||||||
|
)
|
||||||
|
db_session.add(project)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
for priority in IssuePriority:
|
||||||
|
issue = Issue(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
project_id=project.id,
|
||||||
|
title=f"Issue {priority.value}",
|
||||||
|
priority=priority,
|
||||||
|
)
|
||||||
|
db_session.add(issue)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
retrieved = db_session.query(Issue).filter_by(id=issue.id).first()
|
||||||
|
assert retrieved.priority == priority
|
||||||
|
|
||||||
|
|
||||||
|
class TestIssueSyncStatus:
|
||||||
|
"""Tests for Issue sync status field."""
|
||||||
|
|
||||||
|
def test_all_sync_statuses(self, db_session):
|
||||||
|
"""Test that all sync statuses can be stored."""
|
||||||
|
project = Project(
|
||||||
|
id=uuid.uuid4(), name="Sync Issue Project", slug="sync-issue-project"
|
||||||
|
)
|
||||||
|
db_session.add(project)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
for sync_status in SyncStatus:
|
||||||
|
issue = Issue(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
project_id=project.id,
|
||||||
|
title=f"Issue {sync_status.value}",
|
||||||
|
external_tracker_type="gitea",
|
||||||
|
external_issue_id=f"ext-{sync_status.value}",
|
||||||
|
sync_status=sync_status,
|
||||||
|
)
|
||||||
|
db_session.add(issue)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
retrieved = db_session.query(Issue).filter_by(id=issue.id).first()
|
||||||
|
assert retrieved.sync_status == sync_status
|
||||||
|
|
||||||
|
|
||||||
|
class TestIssueLabels:
|
||||||
|
"""Tests for Issue labels JSON field."""
|
||||||
|
|
||||||
|
def test_store_labels(self, db_session):
|
||||||
|
"""Test storing labels list."""
|
||||||
|
project = Project(
|
||||||
|
id=uuid.uuid4(), name="Labels Issue Project", slug="labels-issue-project"
|
||||||
|
)
|
||||||
|
db_session.add(project)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
labels = ["bug", "security", "high-priority", "needs-review"]
|
||||||
|
|
||||||
|
issue = Issue(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
project_id=project.id,
|
||||||
|
title="Issue with Labels",
|
||||||
|
labels=labels,
|
||||||
|
)
|
||||||
|
db_session.add(issue)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
retrieved = db_session.query(Issue).filter_by(title="Issue with Labels").first()
|
||||||
|
assert retrieved.labels == labels
|
||||||
|
assert "security" in retrieved.labels
|
||||||
|
|
||||||
|
def test_update_labels(self, db_session):
|
||||||
|
"""Test updating labels."""
|
||||||
|
project = Project(
|
||||||
|
id=uuid.uuid4(), name="Update Labels Project", slug="update-labels-project"
|
||||||
|
)
|
||||||
|
db_session.add(project)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
issue = Issue(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
project_id=project.id,
|
||||||
|
title="Update Labels Issue",
|
||||||
|
labels=["initial"],
|
||||||
|
)
|
||||||
|
db_session.add(issue)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
issue.labels = ["updated", "new-label"]
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
retrieved = (
|
||||||
|
db_session.query(Issue).filter_by(title="Update Labels Issue").first()
|
||||||
|
)
|
||||||
|
assert "initial" not in retrieved.labels
|
||||||
|
assert "updated" in retrieved.labels
|
||||||
|
|
||||||
|
|
||||||
|
class TestIssueAssignment:
|
||||||
|
"""Tests for Issue assignment fields."""
|
||||||
|
|
||||||
|
def test_assign_to_agent(self, db_session):
|
||||||
|
"""Test assigning an issue to an agent."""
|
||||||
|
project = Project(
|
||||||
|
id=uuid.uuid4(), name="Agent Assign Project", slug="agent-assign-project"
|
||||||
|
)
|
||||||
|
agent_type = AgentType(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Test Agent Type",
|
||||||
|
slug="test-agent-type-assign",
|
||||||
|
personality_prompt="Test",
|
||||||
|
primary_model="claude-opus-4-5-20251101",
|
||||||
|
)
|
||||||
|
db_session.add(project)
|
||||||
|
db_session.add(agent_type)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
agent_instance = AgentInstance(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
agent_type_id=agent_type.id,
|
||||||
|
project_id=project.id,
|
||||||
|
name="TaskBot",
|
||||||
|
)
|
||||||
|
db_session.add(agent_instance)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
issue = Issue(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
project_id=project.id,
|
||||||
|
title="Agent Assignment Issue",
|
||||||
|
assigned_agent_id=agent_instance.id,
|
||||||
|
)
|
||||||
|
db_session.add(issue)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
retrieved = (
|
||||||
|
db_session.query(Issue).filter_by(title="Agent Assignment Issue").first()
|
||||||
|
)
|
||||||
|
assert retrieved.assigned_agent_id == agent_instance.id
|
||||||
|
assert retrieved.human_assignee is None
|
||||||
|
|
||||||
|
def test_assign_to_human(self, db_session):
|
||||||
|
"""Test assigning an issue to a human."""
|
||||||
|
project = Project(
|
||||||
|
id=uuid.uuid4(), name="Human Assign Project", slug="human-assign-project"
|
||||||
|
)
|
||||||
|
db_session.add(project)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
issue = Issue(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
project_id=project.id,
|
||||||
|
title="Human Assignment Issue",
|
||||||
|
human_assignee="developer@example.com",
|
||||||
|
)
|
||||||
|
db_session.add(issue)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
retrieved = (
|
||||||
|
db_session.query(Issue).filter_by(title="Human Assignment Issue").first()
|
||||||
|
)
|
||||||
|
assert retrieved.human_assignee == "developer@example.com"
|
||||||
|
assert retrieved.assigned_agent_id is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestIssueSprintAssociation:
|
||||||
|
"""Tests for Issue sprint association."""
|
||||||
|
|
||||||
|
def test_assign_issue_to_sprint(self, db_session):
|
||||||
|
"""Test assigning an issue to a sprint."""
|
||||||
|
project = Project(
|
||||||
|
id=uuid.uuid4(), name="Sprint Assign Project", slug="sprint-assign-project"
|
||||||
|
)
|
||||||
|
db_session.add(project)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
from datetime import date
|
||||||
|
|
||||||
|
sprint = Sprint(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
project_id=project.id,
|
||||||
|
name="Sprint 1",
|
||||||
|
number=1,
|
||||||
|
start_date=date.today(),
|
||||||
|
end_date=date.today() + timedelta(days=14),
|
||||||
|
status=SprintStatus.ACTIVE,
|
||||||
|
)
|
||||||
|
db_session.add(sprint)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
issue = Issue(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
project_id=project.id,
|
||||||
|
title="Sprint Issue",
|
||||||
|
sprint_id=sprint.id,
|
||||||
|
)
|
||||||
|
db_session.add(issue)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
retrieved = db_session.query(Issue).filter_by(title="Sprint Issue").first()
|
||||||
|
assert retrieved.sprint_id == sprint.id
|
||||||
|
|
||||||
|
|
||||||
|
class TestIssueExternalTracker:
|
||||||
|
"""Tests for Issue external tracker integration."""
|
||||||
|
|
||||||
|
def test_gitea_integration(self, db_session):
|
||||||
|
"""Test Gitea external tracker fields."""
|
||||||
|
project = Project(id=uuid.uuid4(), name="Gitea Project", slug="gitea-project")
|
||||||
|
db_session.add(project)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
|
||||||
|
issue = Issue(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
project_id=project.id,
|
||||||
|
title="Gitea Synced Issue",
|
||||||
|
external_tracker_type="gitea",
|
||||||
|
external_issue_id="abc123xyz",
|
||||||
|
remote_url="https://gitea.example.com/org/repo/issues/42",
|
||||||
|
external_issue_number=42,
|
||||||
|
sync_status=SyncStatus.SYNCED,
|
||||||
|
last_synced_at=now,
|
||||||
|
external_updated_at=now,
|
||||||
|
)
|
||||||
|
db_session.add(issue)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
retrieved = (
|
||||||
|
db_session.query(Issue).filter_by(title="Gitea Synced Issue").first()
|
||||||
|
)
|
||||||
|
assert retrieved.external_tracker_type == "gitea"
|
||||||
|
assert retrieved.external_issue_id == "abc123xyz"
|
||||||
|
assert retrieved.external_issue_number == 42
|
||||||
|
assert "/issues/42" in retrieved.remote_url
|
||||||
|
|
||||||
|
def test_github_integration(self, db_session):
|
||||||
|
"""Test GitHub external tracker fields."""
|
||||||
|
project = Project(id=uuid.uuid4(), name="GitHub Project", slug="github-project")
|
||||||
|
db_session.add(project)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
issue = Issue(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
project_id=project.id,
|
||||||
|
title="GitHub Synced Issue",
|
||||||
|
external_tracker_type="github",
|
||||||
|
external_issue_id="gh-12345",
|
||||||
|
remote_url="https://github.com/org/repo/issues/100",
|
||||||
|
external_issue_number=100,
|
||||||
|
)
|
||||||
|
db_session.add(issue)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
retrieved = (
|
||||||
|
db_session.query(Issue).filter_by(title="GitHub Synced Issue").first()
|
||||||
|
)
|
||||||
|
assert retrieved.external_tracker_type == "github"
|
||||||
|
assert retrieved.external_issue_number == 100
|
||||||
|
|
||||||
|
|
||||||
|
class TestIssueLifecycle:
|
||||||
|
"""Tests for Issue lifecycle operations."""
|
||||||
|
|
||||||
|
def test_close_issue(self, db_session):
|
||||||
|
"""Test closing an issue."""
|
||||||
|
project = Project(
|
||||||
|
id=uuid.uuid4(), name="Close Issue Project", slug="close-issue-project"
|
||||||
|
)
|
||||||
|
db_session.add(project)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
issue = Issue(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
project_id=project.id,
|
||||||
|
title="Issue to Close",
|
||||||
|
status=IssueStatus.OPEN,
|
||||||
|
)
|
||||||
|
db_session.add(issue)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Close the issue
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
issue.status = IssueStatus.CLOSED
|
||||||
|
issue.closed_at = now
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
retrieved = db_session.query(Issue).filter_by(title="Issue to Close").first()
|
||||||
|
assert retrieved.status == IssueStatus.CLOSED
|
||||||
|
assert retrieved.closed_at is not None
|
||||||
|
|
||||||
|
def test_reopen_issue(self, db_session):
|
||||||
|
"""Test reopening a closed issue."""
|
||||||
|
project = Project(
|
||||||
|
id=uuid.uuid4(), name="Reopen Issue Project", slug="reopen-issue-project"
|
||||||
|
)
|
||||||
|
db_session.add(project)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
issue = Issue(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
project_id=project.id,
|
||||||
|
title="Issue to Reopen",
|
||||||
|
status=IssueStatus.CLOSED,
|
||||||
|
closed_at=now,
|
||||||
|
)
|
||||||
|
db_session.add(issue)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Reopen the issue
|
||||||
|
issue.status = IssueStatus.OPEN
|
||||||
|
issue.closed_at = None
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
retrieved = db_session.query(Issue).filter_by(title="Issue to Reopen").first()
|
||||||
|
assert retrieved.status == IssueStatus.OPEN
|
||||||
|
assert retrieved.closed_at is None
|
||||||
275
backend/tests/models/syndarix/test_project.py
Normal file
275
backend/tests/models/syndarix/test_project.py
Normal file
@@ -0,0 +1,275 @@
|
|||||||
|
# tests/models/syndarix/test_project.py
|
||||||
|
"""
|
||||||
|
Unit tests for the Project model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy.exc import IntegrityError
|
||||||
|
|
||||||
|
from app.models.syndarix import (
|
||||||
|
AutonomyLevel,
|
||||||
|
Project,
|
||||||
|
ProjectStatus,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestProjectModel:
|
||||||
|
"""Tests for Project model creation and fields."""
|
||||||
|
|
||||||
|
def test_create_project_with_required_fields(self, db_session):
|
||||||
|
"""Test creating a project with only required fields."""
|
||||||
|
project = Project(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Test Project",
|
||||||
|
slug="test-project",
|
||||||
|
)
|
||||||
|
db_session.add(project)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
retrieved = db_session.query(Project).filter_by(slug="test-project").first()
|
||||||
|
|
||||||
|
assert retrieved is not None
|
||||||
|
assert retrieved.name == "Test Project"
|
||||||
|
assert retrieved.slug == "test-project"
|
||||||
|
assert retrieved.autonomy_level == AutonomyLevel.MILESTONE # Default
|
||||||
|
assert retrieved.status == ProjectStatus.ACTIVE # Default
|
||||||
|
assert retrieved.settings == {} # Default empty dict
|
||||||
|
assert retrieved.description is None
|
||||||
|
assert retrieved.owner_id is None
|
||||||
|
|
||||||
|
def test_create_project_with_all_fields(self, db_session):
|
||||||
|
"""Test creating a project with all optional fields."""
|
||||||
|
project_id = uuid.uuid4()
|
||||||
|
owner_id = uuid.uuid4()
|
||||||
|
|
||||||
|
project = Project(
|
||||||
|
id=project_id,
|
||||||
|
name="Full Project",
|
||||||
|
slug="full-project",
|
||||||
|
description="A complete project with all fields",
|
||||||
|
autonomy_level=AutonomyLevel.AUTONOMOUS,
|
||||||
|
status=ProjectStatus.PAUSED,
|
||||||
|
settings={"webhook_url": "https://example.com/webhook"},
|
||||||
|
owner_id=owner_id,
|
||||||
|
)
|
||||||
|
db_session.add(project)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
retrieved = db_session.query(Project).filter_by(id=project_id).first()
|
||||||
|
|
||||||
|
assert retrieved.name == "Full Project"
|
||||||
|
assert retrieved.slug == "full-project"
|
||||||
|
assert retrieved.description == "A complete project with all fields"
|
||||||
|
assert retrieved.autonomy_level == AutonomyLevel.AUTONOMOUS
|
||||||
|
assert retrieved.status == ProjectStatus.PAUSED
|
||||||
|
assert retrieved.settings == {"webhook_url": "https://example.com/webhook"}
|
||||||
|
assert retrieved.owner_id == owner_id
|
||||||
|
|
||||||
|
def test_project_unique_slug_constraint(self, db_session):
|
||||||
|
"""Test that projects cannot have duplicate slugs."""
|
||||||
|
project1 = Project(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Project One",
|
||||||
|
slug="duplicate-slug",
|
||||||
|
)
|
||||||
|
db_session.add(project1)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
project2 = Project(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Project Two",
|
||||||
|
slug="duplicate-slug", # Same slug
|
||||||
|
)
|
||||||
|
db_session.add(project2)
|
||||||
|
|
||||||
|
with pytest.raises(IntegrityError):
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
db_session.rollback()
|
||||||
|
|
||||||
|
def test_project_timestamps(self, db_session):
|
||||||
|
"""Test that timestamps are automatically set."""
|
||||||
|
project = Project(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Timestamp Project",
|
||||||
|
slug="timestamp-project",
|
||||||
|
)
|
||||||
|
db_session.add(project)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
retrieved = (
|
||||||
|
db_session.query(Project).filter_by(slug="timestamp-project").first()
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(retrieved.created_at, datetime)
|
||||||
|
assert isinstance(retrieved.updated_at, datetime)
|
||||||
|
|
||||||
|
def test_project_update(self, db_session):
|
||||||
|
"""Test updating project fields."""
|
||||||
|
project = Project(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Original Name",
|
||||||
|
slug="original-slug",
|
||||||
|
status=ProjectStatus.ACTIVE,
|
||||||
|
)
|
||||||
|
db_session.add(project)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
original_created_at = project.created_at
|
||||||
|
|
||||||
|
# Update fields
|
||||||
|
project.name = "Updated Name"
|
||||||
|
project.status = ProjectStatus.COMPLETED
|
||||||
|
project.settings = {"new_setting": "value"}
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
retrieved = db_session.query(Project).filter_by(slug="original-slug").first()
|
||||||
|
|
||||||
|
assert retrieved.name == "Updated Name"
|
||||||
|
assert retrieved.status == ProjectStatus.COMPLETED
|
||||||
|
assert retrieved.settings == {"new_setting": "value"}
|
||||||
|
assert retrieved.created_at == original_created_at
|
||||||
|
assert retrieved.updated_at > original_created_at
|
||||||
|
|
||||||
|
def test_project_delete(self, db_session):
|
||||||
|
"""Test deleting a project."""
|
||||||
|
project_id = uuid.uuid4()
|
||||||
|
project = Project(
|
||||||
|
id=project_id,
|
||||||
|
name="Delete Me",
|
||||||
|
slug="delete-me",
|
||||||
|
)
|
||||||
|
db_session.add(project)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
db_session.delete(project)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
deleted = db_session.query(Project).filter_by(id=project_id).first()
|
||||||
|
assert deleted is None
|
||||||
|
|
||||||
|
def test_project_string_representation(self, db_session):
|
||||||
|
"""Test the string representation of a project."""
|
||||||
|
project = Project(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Repr Project",
|
||||||
|
slug="repr-project",
|
||||||
|
status=ProjectStatus.ACTIVE,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert str(project) == "<Project Repr Project (repr-project) status=active>"
|
||||||
|
assert repr(project) == "<Project Repr Project (repr-project) status=active>"
|
||||||
|
|
||||||
|
|
||||||
|
class TestProjectEnums:
|
||||||
|
"""Tests for Project enum fields."""
|
||||||
|
|
||||||
|
def test_all_autonomy_levels(self, db_session):
|
||||||
|
"""Test that all autonomy levels can be stored."""
|
||||||
|
for level in AutonomyLevel:
|
||||||
|
project = Project(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name=f"Project {level.value}",
|
||||||
|
slug=f"project-{level.value}",
|
||||||
|
autonomy_level=level,
|
||||||
|
)
|
||||||
|
db_session.add(project)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
retrieved = (
|
||||||
|
db_session.query(Project)
|
||||||
|
.filter_by(slug=f"project-{level.value}")
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
assert retrieved.autonomy_level == level
|
||||||
|
|
||||||
|
def test_all_project_statuses(self, db_session):
|
||||||
|
"""Test that all project statuses can be stored."""
|
||||||
|
for status in ProjectStatus:
|
||||||
|
project = Project(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name=f"Project {status.value}",
|
||||||
|
slug=f"project-status-{status.value}",
|
||||||
|
status=status,
|
||||||
|
)
|
||||||
|
db_session.add(project)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
retrieved = (
|
||||||
|
db_session.query(Project)
|
||||||
|
.filter_by(slug=f"project-status-{status.value}")
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
assert retrieved.status == status
|
||||||
|
|
||||||
|
|
||||||
|
class TestProjectSettings:
|
||||||
|
"""Tests for Project JSON settings field."""
|
||||||
|
|
||||||
|
def test_complex_json_settings(self, db_session):
|
||||||
|
"""Test storing complex JSON in settings."""
|
||||||
|
complex_settings = {
|
||||||
|
"mcp_servers": ["gitea", "slack", "file-system"],
|
||||||
|
"webhook_urls": {
|
||||||
|
"on_issue_created": "https://example.com/issue",
|
||||||
|
"on_sprint_completed": "https://example.com/sprint",
|
||||||
|
},
|
||||||
|
"notification_settings": {
|
||||||
|
"email": True,
|
||||||
|
"slack_channel": "#syndarix-updates",
|
||||||
|
},
|
||||||
|
"tags": ["important", "client-a"],
|
||||||
|
}
|
||||||
|
|
||||||
|
project = Project(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Complex Settings Project",
|
||||||
|
slug="complex-settings",
|
||||||
|
settings=complex_settings,
|
||||||
|
)
|
||||||
|
db_session.add(project)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
retrieved = db_session.query(Project).filter_by(slug="complex-settings").first()
|
||||||
|
|
||||||
|
assert retrieved.settings == complex_settings
|
||||||
|
assert retrieved.settings["mcp_servers"] == ["gitea", "slack", "file-system"]
|
||||||
|
assert (
|
||||||
|
retrieved.settings["webhook_urls"]["on_issue_created"]
|
||||||
|
== "https://example.com/issue"
|
||||||
|
)
|
||||||
|
assert "important" in retrieved.settings["tags"]
|
||||||
|
|
||||||
|
def test_empty_settings(self, db_session):
|
||||||
|
"""Test that empty settings defaults correctly."""
|
||||||
|
project = Project(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Empty Settings",
|
||||||
|
slug="empty-settings",
|
||||||
|
)
|
||||||
|
db_session.add(project)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
retrieved = db_session.query(Project).filter_by(slug="empty-settings").first()
|
||||||
|
assert retrieved.settings == {}
|
||||||
|
|
||||||
|
def test_update_settings(self, db_session):
|
||||||
|
"""Test updating settings field."""
|
||||||
|
project = Project(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Update Settings",
|
||||||
|
slug="update-settings",
|
||||||
|
settings={"initial": "value"},
|
||||||
|
)
|
||||||
|
db_session.add(project)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Update settings
|
||||||
|
project.settings = {"updated": "new_value", "additional": "data"}
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
retrieved = db_session.query(Project).filter_by(slug="update-settings").first()
|
||||||
|
assert retrieved.settings == {"updated": "new_value", "additional": "data"}
|
||||||
558
backend/tests/models/syndarix/test_sprint.py
Normal file
558
backend/tests/models/syndarix/test_sprint.py
Normal file
@@ -0,0 +1,558 @@
|
|||||||
|
# tests/models/syndarix/test_sprint.py
|
||||||
|
"""
|
||||||
|
Unit tests for the Sprint model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from datetime import date, datetime, timedelta
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.models.syndarix import (
|
||||||
|
Project,
|
||||||
|
Sprint,
|
||||||
|
SprintStatus,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSprintModel:
|
||||||
|
"""Tests for Sprint model creation and fields."""
|
||||||
|
|
||||||
|
def test_create_sprint_with_required_fields(self, db_session):
|
||||||
|
"""Test creating a sprint with only required fields."""
|
||||||
|
project = Project(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Sprint Project",
|
||||||
|
slug="sprint-project",
|
||||||
|
)
|
||||||
|
db_session.add(project)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
today = date.today()
|
||||||
|
sprint = Sprint(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
project_id=project.id,
|
||||||
|
name="Sprint 1",
|
||||||
|
number=1,
|
||||||
|
start_date=today,
|
||||||
|
end_date=today + timedelta(days=14),
|
||||||
|
)
|
||||||
|
db_session.add(sprint)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
retrieved = db_session.query(Sprint).filter_by(name="Sprint 1").first()
|
||||||
|
|
||||||
|
assert retrieved is not None
|
||||||
|
assert retrieved.name == "Sprint 1"
|
||||||
|
assert retrieved.number == 1
|
||||||
|
assert retrieved.start_date == today
|
||||||
|
assert retrieved.end_date == today + timedelta(days=14)
|
||||||
|
assert retrieved.status == SprintStatus.PLANNED # Default
|
||||||
|
assert retrieved.goal is None
|
||||||
|
assert retrieved.planned_points is None
|
||||||
|
assert retrieved.velocity is None
|
||||||
|
|
||||||
|
def test_create_sprint_with_all_fields(self, db_session):
|
||||||
|
"""Test creating a sprint with all optional fields."""
|
||||||
|
project = Project(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Full Sprint Project",
|
||||||
|
slug="full-sprint-project",
|
||||||
|
)
|
||||||
|
db_session.add(project)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
today = date.today()
|
||||||
|
sprint_id = uuid.uuid4()
|
||||||
|
|
||||||
|
sprint = Sprint(
|
||||||
|
id=sprint_id,
|
||||||
|
project_id=project.id,
|
||||||
|
name="Full Sprint",
|
||||||
|
number=5,
|
||||||
|
goal="Complete all authentication features",
|
||||||
|
start_date=today,
|
||||||
|
end_date=today + timedelta(days=14),
|
||||||
|
status=SprintStatus.ACTIVE,
|
||||||
|
planned_points=34,
|
||||||
|
velocity=21,
|
||||||
|
)
|
||||||
|
db_session.add(sprint)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
retrieved = db_session.query(Sprint).filter_by(id=sprint_id).first()
|
||||||
|
|
||||||
|
assert retrieved.name == "Full Sprint"
|
||||||
|
assert retrieved.number == 5
|
||||||
|
assert retrieved.goal == "Complete all authentication features"
|
||||||
|
assert retrieved.status == SprintStatus.ACTIVE
|
||||||
|
assert retrieved.planned_points == 34
|
||||||
|
assert retrieved.velocity == 21
|
||||||
|
|
||||||
|
def test_sprint_timestamps(self, db_session):
|
||||||
|
"""Test that timestamps are automatically set."""
|
||||||
|
project = Project(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Timestamp Sprint Project",
|
||||||
|
slug="timestamp-sprint-project",
|
||||||
|
)
|
||||||
|
db_session.add(project)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
today = date.today()
|
||||||
|
sprint = Sprint(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
project_id=project.id,
|
||||||
|
name="Timestamp Sprint",
|
||||||
|
number=1,
|
||||||
|
start_date=today,
|
||||||
|
end_date=today + timedelta(days=14),
|
||||||
|
)
|
||||||
|
db_session.add(sprint)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
assert isinstance(sprint.created_at, datetime)
|
||||||
|
assert isinstance(sprint.updated_at, datetime)
|
||||||
|
|
||||||
|
def test_sprint_string_representation(self, db_session):
|
||||||
|
"""Test the string representation of a sprint."""
|
||||||
|
project = Project(
|
||||||
|
id=uuid.uuid4(), name="Repr Sprint Project", slug="repr-sprint-project"
|
||||||
|
)
|
||||||
|
db_session.add(project)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
today = date.today()
|
||||||
|
sprint = Sprint(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
project_id=project.id,
|
||||||
|
name="Sprint Alpha",
|
||||||
|
number=3,
|
||||||
|
start_date=today,
|
||||||
|
end_date=today + timedelta(days=14),
|
||||||
|
status=SprintStatus.ACTIVE,
|
||||||
|
)
|
||||||
|
|
||||||
|
repr_str = repr(sprint)
|
||||||
|
assert "Sprint Alpha" in repr_str
|
||||||
|
assert "#3" in repr_str
|
||||||
|
assert str(project.id) in repr_str
|
||||||
|
assert "active" in repr_str
|
||||||
|
|
||||||
|
|
||||||
|
class TestSprintStatus:
|
||||||
|
"""Tests for Sprint status field."""
|
||||||
|
|
||||||
|
def test_all_sprint_statuses(self, db_session):
|
||||||
|
"""Test that all sprint statuses can be stored."""
|
||||||
|
project = Project(
|
||||||
|
id=uuid.uuid4(), name="Status Sprint Project", slug="status-sprint-project"
|
||||||
|
)
|
||||||
|
db_session.add(project)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
today = date.today()
|
||||||
|
for idx, status in enumerate(SprintStatus):
|
||||||
|
sprint = Sprint(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
project_id=project.id,
|
||||||
|
name=f"Sprint {status.value}",
|
||||||
|
number=idx + 1,
|
||||||
|
start_date=today,
|
||||||
|
end_date=today + timedelta(days=14),
|
||||||
|
status=status,
|
||||||
|
)
|
||||||
|
db_session.add(sprint)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
retrieved = db_session.query(Sprint).filter_by(id=sprint.id).first()
|
||||||
|
assert retrieved.status == status
|
||||||
|
|
||||||
|
|
||||||
|
class TestSprintLifecycle:
|
||||||
|
"""Tests for Sprint lifecycle operations."""
|
||||||
|
|
||||||
|
def test_start_sprint(self, db_session):
|
||||||
|
"""Test starting a planned sprint."""
|
||||||
|
project = Project(
|
||||||
|
id=uuid.uuid4(), name="Start Sprint Project", slug="start-sprint-project"
|
||||||
|
)
|
||||||
|
db_session.add(project)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
today = date.today()
|
||||||
|
sprint = Sprint(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
project_id=project.id,
|
||||||
|
name="Sprint to Start",
|
||||||
|
number=1,
|
||||||
|
start_date=today,
|
||||||
|
end_date=today + timedelta(days=14),
|
||||||
|
status=SprintStatus.PLANNED,
|
||||||
|
)
|
||||||
|
db_session.add(sprint)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Start the sprint
|
||||||
|
sprint.status = SprintStatus.ACTIVE
|
||||||
|
sprint.planned_points = 21
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
retrieved = db_session.query(Sprint).filter_by(name="Sprint to Start").first()
|
||||||
|
assert retrieved.status == SprintStatus.ACTIVE
|
||||||
|
assert retrieved.planned_points == 21
|
||||||
|
|
||||||
|
def test_complete_sprint(self, db_session):
|
||||||
|
"""Test completing an active sprint."""
|
||||||
|
project = Project(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Complete Sprint Project",
|
||||||
|
slug="complete-sprint-project",
|
||||||
|
)
|
||||||
|
db_session.add(project)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
today = date.today()
|
||||||
|
sprint = Sprint(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
project_id=project.id,
|
||||||
|
name="Sprint to Complete",
|
||||||
|
number=1,
|
||||||
|
start_date=today - timedelta(days=14),
|
||||||
|
end_date=today,
|
||||||
|
status=SprintStatus.ACTIVE,
|
||||||
|
planned_points=21,
|
||||||
|
)
|
||||||
|
db_session.add(sprint)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Complete the sprint
|
||||||
|
sprint.status = SprintStatus.COMPLETED
|
||||||
|
sprint.velocity = 18
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
retrieved = (
|
||||||
|
db_session.query(Sprint).filter_by(name="Sprint to Complete").first()
|
||||||
|
)
|
||||||
|
assert retrieved.status == SprintStatus.COMPLETED
|
||||||
|
assert retrieved.velocity == 18
|
||||||
|
|
||||||
|
def test_cancel_sprint(self, db_session):
|
||||||
|
"""Test cancelling a sprint."""
|
||||||
|
project = Project(
|
||||||
|
id=uuid.uuid4(), name="Cancel Sprint Project", slug="cancel-sprint-project"
|
||||||
|
)
|
||||||
|
db_session.add(project)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
today = date.today()
|
||||||
|
sprint = Sprint(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
project_id=project.id,
|
||||||
|
name="Sprint to Cancel",
|
||||||
|
number=1,
|
||||||
|
start_date=today,
|
||||||
|
end_date=today + timedelta(days=14),
|
||||||
|
status=SprintStatus.ACTIVE,
|
||||||
|
planned_points=21,
|
||||||
|
)
|
||||||
|
db_session.add(sprint)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Cancel the sprint
|
||||||
|
sprint.status = SprintStatus.CANCELLED
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
retrieved = db_session.query(Sprint).filter_by(name="Sprint to Cancel").first()
|
||||||
|
assert retrieved.status == SprintStatus.CANCELLED
|
||||||
|
|
||||||
|
|
||||||
|
class TestSprintDates:
|
||||||
|
"""Tests for Sprint date fields."""
|
||||||
|
|
||||||
|
def test_sprint_date_range(self, db_session):
|
||||||
|
"""Test storing sprint date range."""
|
||||||
|
project = Project(
|
||||||
|
id=uuid.uuid4(), name="Date Range Project", slug="date-range-project"
|
||||||
|
)
|
||||||
|
db_session.add(project)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
start = date(2024, 1, 1)
|
||||||
|
end = date(2024, 1, 14)
|
||||||
|
|
||||||
|
sprint = Sprint(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
project_id=project.id,
|
||||||
|
name="Date Range Sprint",
|
||||||
|
number=1,
|
||||||
|
start_date=start,
|
||||||
|
end_date=end,
|
||||||
|
)
|
||||||
|
db_session.add(sprint)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
retrieved = db_session.query(Sprint).filter_by(name="Date Range Sprint").first()
|
||||||
|
assert retrieved.start_date == start
|
||||||
|
assert retrieved.end_date == end
|
||||||
|
|
||||||
|
def test_one_day_sprint(self, db_session):
|
||||||
|
"""Test creating a one-day sprint."""
|
||||||
|
project = Project(
|
||||||
|
id=uuid.uuid4(), name="One Day Project", slug="one-day-project"
|
||||||
|
)
|
||||||
|
db_session.add(project)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
today = date.today()
|
||||||
|
sprint = Sprint(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
project_id=project.id,
|
||||||
|
name="One Day Sprint",
|
||||||
|
number=1,
|
||||||
|
start_date=today,
|
||||||
|
end_date=today, # Same day
|
||||||
|
)
|
||||||
|
db_session.add(sprint)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
retrieved = db_session.query(Sprint).filter_by(name="One Day Sprint").first()
|
||||||
|
assert retrieved.start_date == retrieved.end_date
|
||||||
|
|
||||||
|
def test_long_sprint(self, db_session):
|
||||||
|
"""Test creating a long sprint (e.g., 4 weeks)."""
|
||||||
|
project = Project(
|
||||||
|
id=uuid.uuid4(), name="Long Sprint Project", slug="long-sprint-project"
|
||||||
|
)
|
||||||
|
db_session.add(project)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
today = date.today()
|
||||||
|
sprint = Sprint(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
project_id=project.id,
|
||||||
|
name="Long Sprint",
|
||||||
|
number=1,
|
||||||
|
start_date=today,
|
||||||
|
end_date=today + timedelta(days=28), # 4 weeks
|
||||||
|
)
|
||||||
|
db_session.add(sprint)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
retrieved = db_session.query(Sprint).filter_by(name="Long Sprint").first()
|
||||||
|
delta = retrieved.end_date - retrieved.start_date
|
||||||
|
assert delta.days == 28
|
||||||
|
|
||||||
|
|
||||||
|
class TestSprintPoints:
|
||||||
|
"""Tests for Sprint story points fields."""
|
||||||
|
|
||||||
|
def test_sprint_with_zero_points(self, db_session):
|
||||||
|
"""Test sprint with zero planned points."""
|
||||||
|
project = Project(
|
||||||
|
id=uuid.uuid4(), name="Zero Points Project", slug="zero-points-project"
|
||||||
|
)
|
||||||
|
db_session.add(project)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
today = date.today()
|
||||||
|
sprint = Sprint(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
project_id=project.id,
|
||||||
|
name="Zero Points Sprint",
|
||||||
|
number=1,
|
||||||
|
start_date=today,
|
||||||
|
end_date=today + timedelta(days=14),
|
||||||
|
planned_points=0,
|
||||||
|
velocity=0,
|
||||||
|
)
|
||||||
|
db_session.add(sprint)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
retrieved = (
|
||||||
|
db_session.query(Sprint).filter_by(name="Zero Points Sprint").first()
|
||||||
|
)
|
||||||
|
assert retrieved.planned_points == 0
|
||||||
|
assert retrieved.velocity == 0
|
||||||
|
|
||||||
|
def test_sprint_velocity_calculation(self, db_session):
|
||||||
|
"""Test that we can calculate velocity from points."""
|
||||||
|
project = Project(
|
||||||
|
id=uuid.uuid4(), name="Velocity Project", slug="velocity-project"
|
||||||
|
)
|
||||||
|
db_session.add(project)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
today = date.today()
|
||||||
|
sprint = Sprint(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
project_id=project.id,
|
||||||
|
name="Velocity Sprint",
|
||||||
|
number=1,
|
||||||
|
start_date=today,
|
||||||
|
end_date=today + timedelta(days=14),
|
||||||
|
status=SprintStatus.COMPLETED,
|
||||||
|
planned_points=21,
|
||||||
|
velocity=18,
|
||||||
|
)
|
||||||
|
db_session.add(sprint)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
retrieved = db_session.query(Sprint).filter_by(name="Velocity Sprint").first()
|
||||||
|
|
||||||
|
# Calculate completion ratio from velocity
|
||||||
|
completion_ratio = retrieved.velocity / retrieved.planned_points
|
||||||
|
assert completion_ratio == pytest.approx(18 / 21, rel=0.01)
|
||||||
|
|
||||||
|
def test_sprint_overdelivery(self, db_session):
|
||||||
|
"""Test sprint where completed > planned (stretch goals)."""
|
||||||
|
project = Project(
|
||||||
|
id=uuid.uuid4(), name="Overdelivery Project", slug="overdelivery-project"
|
||||||
|
)
|
||||||
|
db_session.add(project)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
today = date.today()
|
||||||
|
sprint = Sprint(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
project_id=project.id,
|
||||||
|
name="Overdelivery Sprint",
|
||||||
|
number=1,
|
||||||
|
start_date=today,
|
||||||
|
end_date=today + timedelta(days=14),
|
||||||
|
status=SprintStatus.COMPLETED,
|
||||||
|
planned_points=20,
|
||||||
|
velocity=25, # Completed more than planned
|
||||||
|
)
|
||||||
|
db_session.add(sprint)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
retrieved = (
|
||||||
|
db_session.query(Sprint).filter_by(name="Overdelivery Sprint").first()
|
||||||
|
)
|
||||||
|
assert retrieved.velocity > retrieved.planned_points
|
||||||
|
|
||||||
|
|
||||||
|
class TestSprintNumber:
|
||||||
|
"""Tests for Sprint number field."""
|
||||||
|
|
||||||
|
def test_sequential_sprint_numbers(self, db_session):
|
||||||
|
"""Test creating sprints with sequential numbers."""
|
||||||
|
project = Project(
|
||||||
|
id=uuid.uuid4(), name="Sequential Project", slug="sequential-project"
|
||||||
|
)
|
||||||
|
db_session.add(project)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
today = date.today()
|
||||||
|
for i in range(1, 6):
|
||||||
|
sprint = Sprint(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
project_id=project.id,
|
||||||
|
name=f"Sprint {i}",
|
||||||
|
number=i,
|
||||||
|
start_date=today + timedelta(days=(i - 1) * 14),
|
||||||
|
end_date=today + timedelta(days=i * 14 - 1),
|
||||||
|
)
|
||||||
|
db_session.add(sprint)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
sprints = (
|
||||||
|
db_session.query(Sprint)
|
||||||
|
.filter_by(project_id=project.id)
|
||||||
|
.order_by(Sprint.number)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
assert len(sprints) == 5
|
||||||
|
for i, sprint in enumerate(sprints, 1):
|
||||||
|
assert sprint.number == i
|
||||||
|
|
||||||
|
def test_large_sprint_number(self, db_session):
|
||||||
|
"""Test sprint with large number (e.g., long-running project)."""
|
||||||
|
project = Project(
|
||||||
|
id=uuid.uuid4(), name="Large Number Project", slug="large-number-project"
|
||||||
|
)
|
||||||
|
db_session.add(project)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
today = date.today()
|
||||||
|
sprint = Sprint(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
project_id=project.id,
|
||||||
|
name="Sprint 100",
|
||||||
|
number=100,
|
||||||
|
start_date=today,
|
||||||
|
end_date=today + timedelta(days=14),
|
||||||
|
)
|
||||||
|
db_session.add(sprint)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
retrieved = db_session.query(Sprint).filter_by(name="Sprint 100").first()
|
||||||
|
assert retrieved.number == 100
|
||||||
|
|
||||||
|
|
||||||
|
class TestSprintUpdate:
|
||||||
|
"""Tests for Sprint update operations."""
|
||||||
|
|
||||||
|
def test_update_sprint_goal(self, db_session):
|
||||||
|
"""Test updating sprint goal."""
|
||||||
|
project = Project(
|
||||||
|
id=uuid.uuid4(), name="Update Goal Project", slug="update-goal-project"
|
||||||
|
)
|
||||||
|
db_session.add(project)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
today = date.today()
|
||||||
|
sprint = Sprint(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
project_id=project.id,
|
||||||
|
name="Update Goal Sprint",
|
||||||
|
number=1,
|
||||||
|
start_date=today,
|
||||||
|
end_date=today + timedelta(days=14),
|
||||||
|
goal="Original goal",
|
||||||
|
)
|
||||||
|
db_session.add(sprint)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
original_created_at = sprint.created_at
|
||||||
|
|
||||||
|
sprint.goal = "Updated goal with more detail"
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
retrieved = (
|
||||||
|
db_session.query(Sprint).filter_by(name="Update Goal Sprint").first()
|
||||||
|
)
|
||||||
|
assert retrieved.goal == "Updated goal with more detail"
|
||||||
|
assert retrieved.created_at == original_created_at
|
||||||
|
assert retrieved.updated_at > original_created_at
|
||||||
|
|
||||||
|
def test_update_sprint_dates(self, db_session):
|
||||||
|
"""Test updating sprint dates."""
|
||||||
|
project = Project(
|
||||||
|
id=uuid.uuid4(), name="Update Dates Project", slug="update-dates-project"
|
||||||
|
)
|
||||||
|
db_session.add(project)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
today = date.today()
|
||||||
|
sprint = Sprint(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
project_id=project.id,
|
||||||
|
name="Update Dates Sprint",
|
||||||
|
number=1,
|
||||||
|
start_date=today,
|
||||||
|
end_date=today + timedelta(days=14),
|
||||||
|
)
|
||||||
|
db_session.add(sprint)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Extend sprint by a week
|
||||||
|
sprint.end_date = today + timedelta(days=21)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
retrieved = (
|
||||||
|
db_session.query(Sprint).filter_by(name="Update Dates Sprint").first()
|
||||||
|
)
|
||||||
|
delta = retrieved.end_date - retrieved.start_date
|
||||||
|
assert delta.days == 21
|
||||||
2
backend/tests/schemas/syndarix/__init__.py
Normal file
2
backend/tests/schemas/syndarix/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
# tests/schemas/syndarix/__init__.py
|
||||||
|
"""Syndarix schema validation tests."""
|
||||||
69
backend/tests/schemas/syndarix/conftest.py
Normal file
69
backend/tests/schemas/syndarix/conftest.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
# tests/schemas/syndarix/conftest.py
|
||||||
|
"""
|
||||||
|
Shared fixtures for Syndarix schema tests.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from datetime import date, timedelta
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def valid_uuid():
|
||||||
|
"""Return a valid UUID for testing."""
|
||||||
|
return uuid.uuid4()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def valid_project_data():
|
||||||
|
"""Return valid project data for schema testing."""
|
||||||
|
return {
|
||||||
|
"name": "Test Project",
|
||||||
|
"slug": "test-project",
|
||||||
|
"description": "A test project",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def valid_agent_type_data():
|
||||||
|
"""Return valid agent type data for schema testing."""
|
||||||
|
return {
|
||||||
|
"name": "Backend Engineer",
|
||||||
|
"slug": "backend-engineer",
|
||||||
|
"personality_prompt": "You are an expert backend engineer.",
|
||||||
|
"primary_model": "claude-opus-4-5-20251101",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def valid_sprint_data(valid_uuid):
|
||||||
|
"""Return valid sprint data for schema testing."""
|
||||||
|
today = date.today()
|
||||||
|
return {
|
||||||
|
"project_id": valid_uuid,
|
||||||
|
"name": "Sprint 1",
|
||||||
|
"number": 1,
|
||||||
|
"start_date": today,
|
||||||
|
"end_date": today + timedelta(days=14),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def valid_issue_data(valid_uuid):
|
||||||
|
"""Return valid issue data for schema testing."""
|
||||||
|
return {
|
||||||
|
"project_id": valid_uuid,
|
||||||
|
"title": "Test Issue",
|
||||||
|
"body": "Issue description",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def valid_agent_instance_data(valid_uuid):
|
||||||
|
"""Return valid agent instance data for schema testing."""
|
||||||
|
return {
|
||||||
|
"agent_type_id": valid_uuid,
|
||||||
|
"project_id": valid_uuid,
|
||||||
|
"name": "TestAgent",
|
||||||
|
}
|
||||||
265
backend/tests/schemas/syndarix/test_agent_instance_schemas.py
Normal file
265
backend/tests/schemas/syndarix/test_agent_instance_schemas.py
Normal file
@@ -0,0 +1,265 @@
|
|||||||
|
# tests/schemas/syndarix/test_agent_instance_schemas.py
|
||||||
|
"""
|
||||||
|
Tests for AgentInstance schema validation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from decimal import Decimal
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from app.schemas.syndarix import (
|
||||||
|
AgentInstanceCreate,
|
||||||
|
AgentInstanceUpdate,
|
||||||
|
AgentStatus,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentInstanceCreateValidation:
|
||||||
|
"""Tests for AgentInstanceCreate schema validation."""
|
||||||
|
|
||||||
|
def test_valid_agent_instance_create(self, valid_agent_instance_data):
|
||||||
|
"""Test creating agent instance with valid data."""
|
||||||
|
instance = AgentInstanceCreate(**valid_agent_instance_data)
|
||||||
|
|
||||||
|
assert instance.agent_type_id is not None
|
||||||
|
assert instance.project_id is not None
|
||||||
|
|
||||||
|
def test_agent_instance_create_defaults(self, valid_agent_instance_data):
|
||||||
|
"""Test that defaults are applied correctly."""
|
||||||
|
instance = AgentInstanceCreate(**valid_agent_instance_data)
|
||||||
|
|
||||||
|
assert instance.status == AgentStatus.IDLE
|
||||||
|
assert instance.current_task is None
|
||||||
|
assert instance.short_term_memory == {}
|
||||||
|
assert instance.long_term_memory_ref is None
|
||||||
|
assert instance.session_id is None
|
||||||
|
|
||||||
|
def test_agent_instance_create_with_all_fields(self, valid_uuid):
|
||||||
|
"""Test creating agent instance with all optional fields."""
|
||||||
|
instance = AgentInstanceCreate(
|
||||||
|
agent_type_id=valid_uuid,
|
||||||
|
project_id=valid_uuid,
|
||||||
|
name="WorkingAgent",
|
||||||
|
status=AgentStatus.WORKING,
|
||||||
|
current_task="Processing feature request",
|
||||||
|
short_term_memory={"context": "working"},
|
||||||
|
long_term_memory_ref="project-123/agent-456",
|
||||||
|
session_id="session-abc",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert instance.status == AgentStatus.WORKING
|
||||||
|
assert instance.current_task == "Processing feature request"
|
||||||
|
assert instance.short_term_memory == {"context": "working"}
|
||||||
|
assert instance.long_term_memory_ref == "project-123/agent-456"
|
||||||
|
assert instance.session_id == "session-abc"
|
||||||
|
|
||||||
|
def test_agent_instance_create_agent_type_id_required(self, valid_uuid):
|
||||||
|
"""Test that agent_type_id is required."""
|
||||||
|
with pytest.raises(ValidationError) as exc_info:
|
||||||
|
AgentInstanceCreate(
|
||||||
|
project_id=valid_uuid,
|
||||||
|
name="TestAgent",
|
||||||
|
)
|
||||||
|
|
||||||
|
errors = exc_info.value.errors()
|
||||||
|
assert any("agent_type_id" in str(e).lower() for e in errors)
|
||||||
|
|
||||||
|
def test_agent_instance_create_project_id_required(self, valid_uuid):
|
||||||
|
"""Test that project_id is required."""
|
||||||
|
with pytest.raises(ValidationError) as exc_info:
|
||||||
|
AgentInstanceCreate(
|
||||||
|
agent_type_id=valid_uuid,
|
||||||
|
name="TestAgent",
|
||||||
|
)
|
||||||
|
|
||||||
|
errors = exc_info.value.errors()
|
||||||
|
assert any("project_id" in str(e).lower() for e in errors)
|
||||||
|
|
||||||
|
def test_agent_instance_create_name_required(self, valid_uuid):
|
||||||
|
"""Test that name is required."""
|
||||||
|
with pytest.raises(ValidationError) as exc_info:
|
||||||
|
AgentInstanceCreate(
|
||||||
|
agent_type_id=valid_uuid,
|
||||||
|
project_id=valid_uuid,
|
||||||
|
)
|
||||||
|
|
||||||
|
errors = exc_info.value.errors()
|
||||||
|
assert any("name" in str(e).lower() for e in errors)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentInstanceUpdateValidation:
|
||||||
|
"""Tests for AgentInstanceUpdate schema validation."""
|
||||||
|
|
||||||
|
def test_agent_instance_update_partial(self):
|
||||||
|
"""Test updating only some fields."""
|
||||||
|
update = AgentInstanceUpdate(
|
||||||
|
status=AgentStatus.WORKING,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert update.status == AgentStatus.WORKING
|
||||||
|
assert update.current_task is None
|
||||||
|
assert update.short_term_memory is None
|
||||||
|
|
||||||
|
def test_agent_instance_update_all_fields(self):
|
||||||
|
"""Test updating all fields."""
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
update = AgentInstanceUpdate(
|
||||||
|
status=AgentStatus.WORKING,
|
||||||
|
current_task="New task",
|
||||||
|
short_term_memory={"new": "context"},
|
||||||
|
long_term_memory_ref="new-ref",
|
||||||
|
session_id="new-session",
|
||||||
|
last_activity_at=now,
|
||||||
|
tasks_completed=5,
|
||||||
|
tokens_used=10000,
|
||||||
|
cost_incurred=Decimal("1.5000"),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert update.status == AgentStatus.WORKING
|
||||||
|
assert update.current_task == "New task"
|
||||||
|
assert update.tasks_completed == 5
|
||||||
|
assert update.tokens_used == 10000
|
||||||
|
assert update.cost_incurred == Decimal("1.5000")
|
||||||
|
|
||||||
|
def test_agent_instance_update_tasks_completed_negative_fails(self):
|
||||||
|
"""Test that negative tasks_completed raises ValidationError."""
|
||||||
|
with pytest.raises(ValidationError) as exc_info:
|
||||||
|
AgentInstanceUpdate(tasks_completed=-1)
|
||||||
|
|
||||||
|
errors = exc_info.value.errors()
|
||||||
|
assert any("tasks_completed" in str(e).lower() for e in errors)
|
||||||
|
|
||||||
|
def test_agent_instance_update_tokens_used_negative_fails(self):
|
||||||
|
"""Test that negative tokens_used raises ValidationError."""
|
||||||
|
with pytest.raises(ValidationError) as exc_info:
|
||||||
|
AgentInstanceUpdate(tokens_used=-1)
|
||||||
|
|
||||||
|
errors = exc_info.value.errors()
|
||||||
|
assert any("tokens_used" in str(e).lower() for e in errors)
|
||||||
|
|
||||||
|
def test_agent_instance_update_cost_incurred_negative_fails(self):
|
||||||
|
"""Test that negative cost_incurred raises ValidationError."""
|
||||||
|
with pytest.raises(ValidationError) as exc_info:
|
||||||
|
AgentInstanceUpdate(cost_incurred=Decimal("-0.01"))
|
||||||
|
|
||||||
|
errors = exc_info.value.errors()
|
||||||
|
assert any("cost_incurred" in str(e).lower() for e in errors)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentStatusEnum:
|
||||||
|
"""Tests for AgentStatus enum validation."""
|
||||||
|
|
||||||
|
def test_valid_agent_statuses(self, valid_uuid):
|
||||||
|
"""Test all valid agent statuses."""
|
||||||
|
for status in AgentStatus:
|
||||||
|
instance = AgentInstanceCreate(
|
||||||
|
agent_type_id=valid_uuid,
|
||||||
|
project_id=valid_uuid,
|
||||||
|
name=f"Agent{status.value}",
|
||||||
|
status=status,
|
||||||
|
)
|
||||||
|
assert instance.status == status
|
||||||
|
|
||||||
|
def test_invalid_agent_status(self, valid_uuid):
|
||||||
|
"""Test that invalid agent status raises ValidationError."""
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
AgentInstanceCreate(
|
||||||
|
agent_type_id=valid_uuid,
|
||||||
|
project_id=valid_uuid,
|
||||||
|
name="TestAgent",
|
||||||
|
status="invalid", # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentInstanceShortTermMemory:
|
||||||
|
"""Tests for AgentInstance short_term_memory validation."""
|
||||||
|
|
||||||
|
def test_short_term_memory_empty_dict(self, valid_uuid):
|
||||||
|
"""Test that empty short_term_memory is valid."""
|
||||||
|
instance = AgentInstanceCreate(
|
||||||
|
agent_type_id=valid_uuid,
|
||||||
|
project_id=valid_uuid,
|
||||||
|
name="TestAgent",
|
||||||
|
short_term_memory={},
|
||||||
|
)
|
||||||
|
assert instance.short_term_memory == {}
|
||||||
|
|
||||||
|
def test_short_term_memory_complex(self, valid_uuid):
|
||||||
|
"""Test complex short_term_memory structure."""
|
||||||
|
memory = {
|
||||||
|
"conversation_history": [
|
||||||
|
{"role": "user", "content": "Hello"},
|
||||||
|
{"role": "assistant", "content": "Hi there"},
|
||||||
|
],
|
||||||
|
"recent_files": ["file1.py", "file2.py"],
|
||||||
|
"decisions": {"key": "value"},
|
||||||
|
"context_tokens": 1024,
|
||||||
|
}
|
||||||
|
instance = AgentInstanceCreate(
|
||||||
|
agent_type_id=valid_uuid,
|
||||||
|
project_id=valid_uuid,
|
||||||
|
name="MemoryAgent",
|
||||||
|
short_term_memory=memory,
|
||||||
|
)
|
||||||
|
assert instance.short_term_memory == memory
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentInstanceStringFields:
|
||||||
|
"""Tests for AgentInstance string field validation."""
|
||||||
|
|
||||||
|
def test_long_term_memory_ref_max_length(self, valid_uuid):
|
||||||
|
"""Test long_term_memory_ref max length."""
|
||||||
|
long_ref = "a" * 500 # Max length is 500
|
||||||
|
|
||||||
|
instance = AgentInstanceCreate(
|
||||||
|
agent_type_id=valid_uuid,
|
||||||
|
project_id=valid_uuid,
|
||||||
|
name="TestAgent",
|
||||||
|
long_term_memory_ref=long_ref,
|
||||||
|
)
|
||||||
|
assert instance.long_term_memory_ref == long_ref
|
||||||
|
|
||||||
|
def test_long_term_memory_ref_too_long(self, valid_uuid):
|
||||||
|
"""Test that too long long_term_memory_ref raises ValidationError."""
|
||||||
|
too_long = "a" * 501
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError) as exc_info:
|
||||||
|
AgentInstanceCreate(
|
||||||
|
agent_type_id=valid_uuid,
|
||||||
|
project_id=valid_uuid,
|
||||||
|
name="TestAgent",
|
||||||
|
long_term_memory_ref=too_long,
|
||||||
|
)
|
||||||
|
|
||||||
|
errors = exc_info.value.errors()
|
||||||
|
assert any("long_term_memory_ref" in str(e).lower() for e in errors)
|
||||||
|
|
||||||
|
def test_session_id_max_length(self, valid_uuid):
|
||||||
|
"""Test session_id max length."""
|
||||||
|
long_session = "a" * 255 # Max length is 255
|
||||||
|
|
||||||
|
instance = AgentInstanceCreate(
|
||||||
|
agent_type_id=valid_uuid,
|
||||||
|
project_id=valid_uuid,
|
||||||
|
name="TestAgent",
|
||||||
|
session_id=long_session,
|
||||||
|
)
|
||||||
|
assert instance.session_id == long_session
|
||||||
|
|
||||||
|
def test_session_id_too_long(self, valid_uuid):
|
||||||
|
"""Test that too long session_id raises ValidationError."""
|
||||||
|
too_long = "a" * 256
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError) as exc_info:
|
||||||
|
AgentInstanceCreate(
|
||||||
|
agent_type_id=valid_uuid,
|
||||||
|
project_id=valid_uuid,
|
||||||
|
name="TestAgent",
|
||||||
|
session_id=too_long,
|
||||||
|
)
|
||||||
|
|
||||||
|
errors = exc_info.value.errors()
|
||||||
|
assert any("session_id" in str(e).lower() for e in errors)
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user