23 Commits
main ... main

Author SHA1 Message Date
Felipe Cardoso
a94e29d99c chore(frontend): remove unnecessary newline in overrides field of package.json 2026-03-01 19:40:11 +01:00
Felipe Cardoso
81e48c73ca fix(tests): handle missing schemathesis gracefully in API contract tests
- Replaced `pytest.mark.skipif` with `pytest.skip` to better manage scenarios where `schemathesis` is not installed.
- Added a fallback test function to ensure explicit handling for missing dependencies.
2026-03-01 19:32:49 +01:00
Felipe Cardoso
a3f78dc801 refactor(tests): replace crud references with repo across repository test files
- Updated import statements and test logic to align with `repositories` naming changes.
- Adjusted documentation and test names for consistency with the updated naming convention.
- Improved test descriptions to reflect the repository-based structure.
2026-03-01 19:22:16 +01:00
Felipe Cardoso
07309013d7 chore(frontend): update scripts and docs to use bun run test for consistency
- Replaced `bun test` with `bun run test` in all documentation and scripts for uniformity.
- Removed outdated `glob` override in package configurations.
2026-03-01 18:44:48 +01:00
Felipe Cardoso
846fc31190 feat(api): enhance KeyMap and FieldsConfig handling for improved flexibility
- Added support for unmapped fields in `KeyMap` definitions and parsing.
- Updated `buildKeyMap` to allow aliasing keys without transport layer mappings.
- Improved parameter assignment logic to handle optional `in` mappings.
- Enhanced handling of `allowExtra` fields for more concise and robust configurations.
2026-03-01 18:01:34 +01:00
Felipe Cardoso
ff7a67cb58 chore(frontend): migrate from npm to Bun for dependency management and scripts
- Updated README to replace npm commands with Bun equivalents.
- Added `bun.lock` file to track Bun-managed dependencies.
2026-03-01 18:00:43 +01:00
Felipe Cardoso
0760a8284d feat(tests): add comprehensive benchmarks for auth and performance-critical endpoints
- Introduced benchmarks for password hashing, verification, and JWT token operations.
- Added latency tests for `/register`, `/refresh`, `/sessions`, and `/users/me` endpoints.
- Updated `BENCHMARKS.md` with new tests, thresholds, and execution details.
2026-03-01 17:01:44 +01:00
Felipe Cardoso
ce4d0c7b0d feat(backend): enhance performance benchmarking with baseline detection and documentation
- Updated `make benchmark-check` in Makefile to detect and handle missing baselines, creating them if not found.
- Added `.benchmarks` directory to `.gitignore` for local baseline exclusions.
- Linked benchmarking documentation in `ARCHITECTURE.md` and added comprehensive `BENCHMARKS.md` guide.
2026-03-01 16:30:06 +01:00
Felipe Cardoso
4ceb8ad98c feat(backend): add performance benchmarks and API security tests
- Introduced `benchmark`, `benchmark-save`, and `benchmark-check` Makefile targets for performance testing.
- Added API security fuzzing through the `test-api-security` Makefile target, leveraging Schemathesis.
- Updated Dockerfiles to use Alpine for security and CVE mitigation.
- Enhanced security with `scan-image` and `scan-images` targets for Docker image vulnerability scanning via Trivy.
- Integrated `pytest-benchmark` for performance regression detection, with tests for key API endpoints.
- Extended `uv.lock` and `pyproject.toml` to include performance benchmarking dependencies.
2026-03-01 16:16:18 +01:00
Felipe Cardoso
f8aafb250d fix(backend): suppress license-check output in Makefile for cleaner logs
- Redirect pip-licenses output to `/dev/null` to reduce noise during license checks.
- Retain success and compliance messages for clear feedback.
2026-03-01 14:24:22 +01:00
Felipe Cardoso
4385d20ca6 fix(tests): simplify invalid token test logic in test_auth_security.py
- Removed unnecessary try-except block for JWT encoding failures.
- Adjusted test to directly verify `TokenInvalidError` during decoding.
- Clarified comment on HMAC algorithm compatibility (`HS384` vs. `HS256`).
2026-03-01 14:24:17 +01:00
Felipe Cardoso
1a36907f10 refactor(backend): replace python-jose and passlib with PyJWT and bcrypt for security and simplicity
- Migrated JWT token handling from `python-jose` to `PyJWT`, reducing dependencies and improving error clarity.
- Replaced `passlib` bcrypt integration with direct `bcrypt` usage for password hashing.
- Updated `Makefile`, removing unused CVE ignore based on the replaced dependencies.
- Reflected changes in `ARCHITECTURE.md` and adjusted function headers in `auth.py`.
- Cleaned up `uv.lock` and `pyproject.toml` to remove unused dependencies (`ecdsa`, `rsa`, etc.) and add `PyJWT`.
- Refactored tests and services to align with the updated libraries (`PyJWT` error handling, decoding, and validation).
2026-03-01 14:02:04 +01:00
Felipe Cardoso
0553a1fc53 refactor(logging): switch to parameterized logging for improved performance and clarity
- Replaced f-strings with parameterized logging calls across routes, services, and repositories to optimize log message evaluation.
- Improved exception handling by using `logger.exception` where appropriate for automatic traceback logging.
2026-03-01 13:38:15 +01:00
Felipe Cardoso
57e969ed67 chore(backend): extend Makefile with audit, validation, and security targets
- Added `dep-audit`, `license-check`, `audit`, `validate-all`, and `check` targets for security and quality checks.
- Updated `.PHONY` to include new targets.
- Enhanced `help` command documentation with descriptions of the new commands.
- Updated `ARCHITECTURE.md`, `CLAUDE.md`, and `uv.lock` to reflect related changes. Upgraded dependencies where necessary.
2026-03-01 12:03:34 +01:00
Felipe Cardoso
68275b1dd3 refactor(docs): update architecture to reflect repository migration
- Rename CRUD layer to Repository layer throughout architecture documentation.
- Update dependency injection examples to use repository classes.
- Add async SQLAlchemy pattern for Repository methods (`select()` and transactions).
- Replace CRUD references in FEATURE_EXAMPLE.md with Repository-focused implementation details.
- Highlight repository class responsibilities and remove outdated CRUD patterns.
2026-03-01 11:13:51 +01:00
Felipe Cardoso
80d2dc0cb2 fix(backend): clear VIRTUAL_ENV before invoking pyright
Prevents a spurious warning when the shell's VIRTUAL_ENV points to a
different project's venv. Pyright detects the mismatch and warns; clearing
the variable inline forces pyright to resolve the venv from pyrightconfig.json.
2026-02-28 19:48:33 +01:00
Felipe Cardoso
a8aa416ecb refactor(backend): migrate type checking from mypy to pyright
Replace mypy>=1.8.0 with pyright>=1.1.390. Remove all [tool.mypy] and
[tool.pydantic-mypy] sections from pyproject.toml and add
pyrightconfig.json (standard mode, SQLAlchemy false-positive rules
suppressed globally).

Fixes surfaced by pyright:
- Remove unreachable except AuthError clauses in login/login_oauth (same class as AuthenticationError)
- Fix Pydantic v2 list Field: min_items/max_items → min_length/max_length
- Split OAuthProviderConfig TypedDict into required + optional(email_url) inheritance
- Move JWTError/ExpiredSignatureError from lazy try-block imports to module level
- Add timezone-aware guard to UserSession.is_expired to match sibling models
- Fix is_active: bool → bool | None in three organization repo signatures
- Initialize search_filter = None before conditional block (possibly unbound fix)
- Add bool() casts to model is_expired and repo is_active/is_superuser returns
- Restructure except (JWTError, Exception) into separate except clauses
2026-02-28 19:12:40 +01:00
Felipe Cardoso
4c6bf55bcc Refactor(backend): improve formatting in services, repositories & tests
- Consistently format multi-line function headers, exception handling, and repository method calls for readability.
- Reorganize misplaced imports across modules (e.g., services & tests) into proper sorted order.
- Adjust indentation, line breaks, and spacing inconsistencies in tests and migration files.
- Cleanup unnecessary trailing newlines and reorganize `__all__` declarations for consistency.
2026-02-28 18:37:56 +01:00
Felipe Cardoso
98b455fdc3 refactor(backend): enforce route→service→repo layered architecture
- introduce custom repository exception hierarchy (DuplicateEntryError,
  IntegrityConstraintError, InvalidInputError) replacing raw ValueError
- eliminate all direct repository imports and raw SQL from route layer
- add UserService, SessionService, OrganizationService to service layer
- add get_stats/get_org_distribution service methods replacing admin inline SQL
- fix timing side-channel in authenticate_user via dummy bcrypt check
- replace SHA-256 client secret fallback with explicit InvalidClientError
- replace assert with InvalidGrantError in authorization code exchange
- replace N+1 token revocation loops with bulk UPDATE statements
- rename oauth account token fields (drop misleading 'encrypted' suffix)
- add Alembic migration 0003 for token field column rename
- add 45 new service/repository tests; 975 passing, 94% coverage
2026-02-27 09:32:57 +01:00
Felipe Cardoso
0646c96b19 Add semicolons to mockServiceWorker.js for consistent style compliance
- Updated `mockServiceWorker.js` by adding missing semicolons across the file for improved code consistency and adherence to style guidelines.
- Refactored multi-line logical expressions into single-line where applicable, maintaining readability.
2026-01-01 13:21:31 +01:00
Felipe Cardoso
62afb328fe Upgrade dependencies in package-lock.json
- Upgraded various dependencies across `@esbuild`, `@eslint`, `@hey-api`, and `@img` packages to their latest versions.
- Removed unused `json5` dependency under `@babel/core`.
- Ensured integrity hashes are updated to reflect changes.
2026-01-01 13:21:23 +01:00
Felipe Cardoso
b9a746bc16 Refactor component props formatting for consistency in extends usage across UI and documentation files 2026-01-01 13:19:36 +01:00
Felipe Cardoso
de8e18e97d Update GitHub repository URLs across components and tests
- Replaced all occurrences of the previous repository URL (`your-org/fast-next-template`) with the updated repository URL (`cardosofelipe/pragma-stack.git`) in both frontend components and test files.
- Adjusted related test assertions and documentation links accordingly.
2026-01-01 13:15:08 +01:00
654 changed files with 10491 additions and 174296 deletions

View File

@@ -1,22 +1,15 @@
# Common settings # Common settings
PROJECT_NAME=Syndarix PROJECT_NAME=App
VERSION=1.0.0 VERSION=1.0.0
# Database settings # Database settings
POSTGRES_USER=postgres POSTGRES_USER=postgres
POSTGRES_PASSWORD=postgres POSTGRES_PASSWORD=postgres
POSTGRES_DB=syndarix POSTGRES_DB=app
POSTGRES_HOST=db POSTGRES_HOST=db
POSTGRES_PORT=5432 POSTGRES_PORT=5432
DATABASE_URL=postgresql://${POSTGRES_USER}:${POSTGRES_PASSWORD}@${POSTGRES_HOST}:${POSTGRES_PORT}/${POSTGRES_DB} DATABASE_URL=postgresql://${POSTGRES_USER}:${POSTGRES_PASSWORD}@${POSTGRES_HOST}:${POSTGRES_PORT}/${POSTGRES_DB}
# Redis settings (cache, pub/sub, Celery broker)
REDIS_URL=redis://redis:6379/0
# Celery settings (optional - defaults to REDIS_URL if not set)
# CELERY_BROKER_URL=redis://redis:6379/0
# CELERY_RESULT_BACKEND=redis://redis:6379/0
# Backend settings # Backend settings
BACKEND_PORT=8000 BACKEND_PORT=8000
# CRITICAL: Generate a secure SECRET_KEY for production! # CRITICAL: Generate a secure SECRET_KEY for production!

View File

@@ -1,460 +0,0 @@
# Syndarix CI/CD Pipeline
# Gitea Actions workflow for continuous integration and deployment
#
# Pipeline Structure:
# - lint: Fast feedback (linting and type checking)
# - test: Run test suites (depends on lint)
# - build: Build Docker images (depends on test)
# - deploy: Deploy to production (depends on build, only on main)
name: CI/CD Pipeline
on:
push:
branches:
- main
- dev
- 'feature/**'
pull_request:
branches:
- main
- dev
env:
PYTHON_VERSION: "3.12"
NODE_VERSION: "20"
UV_VERSION: "0.4.x"
jobs:
# ===========================================================================
# LINT JOB - Fast feedback first
# ===========================================================================
lint:
name: Lint & Type Check
runs-on: ubuntu-latest
strategy:
matrix:
component: [backend, frontend]
steps:
- name: Checkout code
uses: actions/checkout@v4
# ----- Backend Linting -----
- name: Set up Python
if: matrix.component == 'backend'
uses: actions/setup-python@v5
with:
python-version: ${{ env.PYTHON_VERSION }}
- name: Install uv
if: matrix.component == 'backend'
uses: astral-sh/setup-uv@v4
with:
version: ${{ env.UV_VERSION }}
- name: Cache uv dependencies
if: matrix.component == 'backend'
uses: actions/cache@v4
with:
path: |
~/.cache/uv
backend/.venv
key: uv-${{ runner.os }}-${{ hashFiles('backend/uv.lock') }}
restore-keys: |
uv-${{ runner.os }}-
- name: Install backend dependencies
if: matrix.component == 'backend'
working-directory: backend
run: uv sync --extra dev --frozen
- name: Run ruff linting
if: matrix.component == 'backend'
working-directory: backend
run: uv run ruff check app
- name: Run ruff format check
if: matrix.component == 'backend'
working-directory: backend
run: uv run ruff format --check app
- name: Run mypy type checking
if: matrix.component == 'backend'
working-directory: backend
run: uv run mypy app
# ----- Frontend Linting -----
- name: Set up Node.js
if: matrix.component == 'frontend'
uses: actions/setup-node@v4
with:
node-version: ${{ env.NODE_VERSION }}
- name: Cache npm dependencies
if: matrix.component == 'frontend'
uses: actions/cache@v4
with:
path: |
~/.npm
frontend/node_modules
key: npm-${{ runner.os }}-${{ hashFiles('frontend/package-lock.json') }}
restore-keys: |
npm-${{ runner.os }}-
- name: Install frontend dependencies
if: matrix.component == 'frontend'
working-directory: frontend
run: npm ci
- name: Run ESLint
if: matrix.component == 'frontend'
working-directory: frontend
run: npm run lint
- name: Run TypeScript type check
if: matrix.component == 'frontend'
working-directory: frontend
run: npm run type-check
- name: Run Prettier format check
if: matrix.component == 'frontend'
working-directory: frontend
run: npm run format:check
# ===========================================================================
# TEST JOB - Run test suites
# ===========================================================================
test:
name: Test
runs-on: ubuntu-latest
needs: lint
strategy:
matrix:
component: [backend, frontend]
steps:
- name: Checkout code
uses: actions/checkout@v4
# ----- Backend Tests -----
- name: Set up Python
if: matrix.component == 'backend'
uses: actions/setup-python@v5
with:
python-version: ${{ env.PYTHON_VERSION }}
- name: Install uv
if: matrix.component == 'backend'
uses: astral-sh/setup-uv@v4
with:
version: ${{ env.UV_VERSION }}
- name: Cache uv dependencies
if: matrix.component == 'backend'
uses: actions/cache@v4
with:
path: |
~/.cache/uv
backend/.venv
key: uv-${{ runner.os }}-${{ hashFiles('backend/uv.lock') }}
restore-keys: |
uv-${{ runner.os }}-
- name: Install backend dependencies
if: matrix.component == 'backend'
working-directory: backend
run: uv sync --extra dev --frozen
- name: Run pytest with coverage
if: matrix.component == 'backend'
working-directory: backend
env:
IS_TEST: "True"
run: |
uv run pytest --cov=app --cov-report=xml --cov-report=term-missing --cov-fail-under=90
- name: Upload backend coverage report
if: matrix.component == 'backend'
uses: actions/upload-artifact@v4
with:
name: backend-coverage
path: backend/coverage.xml
retention-days: 7
# ----- Frontend Tests -----
- name: Set up Node.js
if: matrix.component == 'frontend'
uses: actions/setup-node@v4
with:
node-version: ${{ env.NODE_VERSION }}
- name: Cache npm dependencies
if: matrix.component == 'frontend'
uses: actions/cache@v4
with:
path: |
~/.npm
frontend/node_modules
key: npm-${{ runner.os }}-${{ hashFiles('frontend/package-lock.json') }}
restore-keys: |
npm-${{ runner.os }}-
- name: Install frontend dependencies
if: matrix.component == 'frontend'
working-directory: frontend
run: npm ci
- name: Run Jest unit tests
if: matrix.component == 'frontend'
working-directory: frontend
run: npm test -- --coverage --passWithNoTests
- name: Upload frontend coverage report
if: matrix.component == 'frontend'
uses: actions/upload-artifact@v4
with:
name: frontend-coverage
path: frontend/coverage/
retention-days: 7
# ===========================================================================
# BUILD JOB - Build Docker images
# ===========================================================================
build:
name: Build
runs-on: ubuntu-latest
needs: test
strategy:
matrix:
component: [backend, frontend]
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Cache Docker layers
uses: actions/cache@v4
with:
path: /tmp/.buildx-cache
key: docker-${{ matrix.component }}-${{ github.sha }}
restore-keys: |
docker-${{ matrix.component }}-
- name: Build backend Docker image
if: matrix.component == 'backend'
uses: docker/build-push-action@v5
with:
context: ./backend
file: ./backend/Dockerfile
target: production
push: false
tags: syndarix-backend:${{ github.sha }}
cache-from: type=local,src=/tmp/.buildx-cache
cache-to: type=local,dest=/tmp/.buildx-cache-new,mode=max
- name: Build frontend Docker image
if: matrix.component == 'frontend'
uses: docker/build-push-action@v5
with:
context: ./frontend
file: ./frontend/Dockerfile
target: runner
push: false
tags: syndarix-frontend:${{ github.sha }}
build-args: |
NEXT_PUBLIC_API_URL=http://localhost:8000
cache-from: type=local,src=/tmp/.buildx-cache
cache-to: type=local,dest=/tmp/.buildx-cache-new,mode=max
# Prevent cache from growing indefinitely
- name: Move cache
run: |
rm -rf /tmp/.buildx-cache
mv /tmp/.buildx-cache-new /tmp/.buildx-cache
# ===========================================================================
# DEPLOY JOB - Deploy to production (only on main branch)
# ===========================================================================
deploy:
name: Deploy
runs-on: ubuntu-latest
needs: build
if: github.ref == 'refs/heads/main' && github.event_name == 'push'
environment: production
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Deploy notification
run: |
echo "Deployment to production would happen here"
echo "Branch: ${{ github.ref }}"
echo "Commit: ${{ github.sha }}"
echo "Actor: ${{ github.actor }}"
# TODO: Add actual deployment steps when infrastructure is ready
# Options:
# - SSH to production server and run docker-compose pull && docker-compose up -d
# - Use Kubernetes deployment
# - Use cloud provider deployment (AWS ECS, GCP Cloud Run, etc.)
# - Trigger webhook to deployment orchestrator
# ===========================================================================
# SECURITY SCAN JOB - Run on main and dev branches
# ===========================================================================
security:
name: Security Scan
runs-on: ubuntu-latest
needs: lint
if: github.ref == 'refs/heads/main' || github.ref == 'refs/heads/dev'
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: ${{ env.PYTHON_VERSION }}
- name: Install uv
uses: astral-sh/setup-uv@v4
with:
version: ${{ env.UV_VERSION }}
- name: Install backend dependencies
working-directory: backend
run: uv sync --extra dev --frozen
- name: Run Bandit security scan (via ruff)
working-directory: backend
run: |
# Ruff includes flake8-bandit (S rules) for security scanning
# Run with explicit security rules only
uv run ruff check app --select=S --ignore=S101,S104,S105,S106,S603,S607
- name: Run pip-audit for dependency vulnerabilities
working-directory: backend
run: |
# pip-audit checks for known vulnerabilities in Python dependencies
uv run pip-audit --require-hashes --disable-pip -r <(uv pip compile pyproject.toml) || true
# Note: Using || true temporarily while setting up proper remediation
- name: Check for secrets in code
run: |
# Basic check for common secret patterns
# In production, use tools like gitleaks or trufflehog
echo "Checking for potential hardcoded secrets..."
! grep -rn --include="*.py" --include="*.ts" --include="*.tsx" --include="*.js" \
-E "(api_key|apikey|secret_key|secretkey|password|passwd|token)\s*=\s*['\"][^'\"]{8,}['\"]" \
backend/app frontend/src || echo "No obvious secrets found"
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: ${{ env.NODE_VERSION }}
- name: Install frontend dependencies
working-directory: frontend
run: npm ci
- name: Run npm audit
working-directory: frontend
run: |
npm audit --audit-level=high || true
# Note: Using || true to not fail on moderate vulnerabilities
# In production, consider stricter settings
# ===========================================================================
# E2E TEST JOB - Run end-to-end tests with Playwright
# ===========================================================================
e2e-tests:
name: E2E Tests
runs-on: ubuntu-latest
needs: [lint, test]
if: github.ref == 'refs/heads/main' || github.ref == 'refs/heads/dev' || github.event_name == 'pull_request'
services:
postgres:
image: pgvector/pgvector:pg17
env:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: syndarix_test
ports:
- 5432:5432
options: >-
--health-cmd "pg_isready -U postgres"
--health-interval 10s
--health-timeout 5s
--health-retries 5
redis:
image: redis:7-alpine
ports:
- 6379:6379
options: >-
--health-cmd "redis-cli ping"
--health-interval 10s
--health-timeout 5s
--health-retries 5
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: ${{ env.PYTHON_VERSION }}
- name: Install uv
uses: astral-sh/setup-uv@v4
with:
version: ${{ env.UV_VERSION }}
- name: Install backend dependencies
working-directory: backend
run: uv sync --extra dev --frozen
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: ${{ env.NODE_VERSION }}
- name: Install frontend dependencies
working-directory: frontend
run: npm ci
- name: Install Playwright browsers
working-directory: frontend
run: npx playwright install --with-deps chromium
- name: Start backend server
working-directory: backend
env:
DATABASE_URL: postgresql://postgres:postgres@localhost:5432/syndarix_test
REDIS_URL: redis://localhost:6379/0
SECRET_KEY: test-secret-key-for-e2e-tests-only
ENVIRONMENT: test
IS_TEST: "True"
run: |
# Run migrations
uv run python -c "from app.database import create_tables; import asyncio; asyncio.run(create_tables())" || true
# Start backend in background
uv run uvicorn app.main:app --host 0.0.0.0 --port 8000 &
# Wait for backend to be ready
sleep 10
- name: Run Playwright E2E tests
working-directory: frontend
env:
NEXT_PUBLIC_API_URL: http://localhost:8000
run: |
npm run build
npm run test:e2e -- --project=chromium
- name: Upload Playwright report
uses: actions/upload-artifact@v4
if: always()
with:
name: playwright-report
path: frontend/playwright-report/
retention-days: 7

View File

@@ -1,61 +0,0 @@
#!/bin/bash
# Pre-commit hook to enforce validation before commits on protected branches
# Install: git config core.hooksPath .githooks
set -e
# Get the current branch name
BRANCH=$(git rev-parse --abbrev-ref HEAD)
# Protected branches that require validation
PROTECTED_BRANCHES="main dev"
# Check if we're on a protected branch
is_protected() {
for branch in $PROTECTED_BRANCHES; do
if [ "$BRANCH" = "$branch" ]; then
return 0
fi
done
return 1
}
if is_protected; then
echo "🔒 Committing to protected branch '$BRANCH' - running validation..."
# Check if we have backend changes
if git diff --cached --name-only | grep -q "^backend/"; then
echo "📦 Backend changes detected - running make validate..."
cd backend
if ! make validate; then
echo ""
echo "❌ Backend validation failed!"
echo " Please fix the issues and try again."
echo " Run 'cd backend && make validate' to see errors."
exit 1
fi
cd ..
echo "✅ Backend validation passed!"
fi
# Check if we have frontend changes
if git diff --cached --name-only | grep -q "^frontend/"; then
echo "🎨 Frontend changes detected - running npm run validate..."
cd frontend
if ! npm run validate 2>/dev/null; then
echo ""
echo "❌ Frontend validation failed!"
echo " Please fix the issues and try again."
echo " Run 'cd frontend && npm run validate' to see errors."
exit 1
fi
cd ..
echo "✅ Frontend validation passed!"
fi
echo "🎉 All validations passed! Proceeding with commit..."
else
echo "📝 Committing to feature branch '$BRANCH' - skipping validation (run manually if needed)"
fi
exit 0

View File

@@ -41,7 +41,7 @@ To enable CI/CD workflows:
- Runs on: Push to main/develop, PRs affecting frontend code - Runs on: Push to main/develop, PRs affecting frontend code
- Tests: Frontend unit tests (Jest) - Tests: Frontend unit tests (Jest)
- Coverage: Uploads to Codecov - Coverage: Uploads to Codecov
- Fast: Uses npm cache - Fast: Uses bun cache
### `e2e-tests.yml` ### `e2e-tests.yml`
- Runs on: All pushes and PRs - Runs on: All pushes and PRs

2
.gitignore vendored
View File

@@ -187,7 +187,7 @@ coverage.xml
.hypothesis/ .hypothesis/
.pytest_cache/ .pytest_cache/
cover/ cover/
backend/.benchmarks
# Translations # Translations
*.mo *.mo
*.pot *.pot

View File

@@ -13,10 +13,10 @@ uv run uvicorn app.main:app --reload # Start dev server
# Frontend (Node.js) # Frontend (Node.js)
cd frontend cd frontend
npm install # Install dependencies bun install # Install dependencies
npm run dev # Start dev server bun run dev # Start dev server
npm run generate:api # Generate API client from OpenAPI bun run generate:api # Generate API client from OpenAPI
npm run test:e2e # Run E2E tests bun run test:e2e # Run E2E tests
``` ```
**Access points:** **Access points:**
@@ -37,7 +37,7 @@ Default superuser (change in production):
│ ├── app/ │ ├── app/
│ │ ├── api/ # API routes (auth, users, organizations, admin) │ │ ├── api/ # API routes (auth, users, organizations, admin)
│ │ ├── core/ # Core functionality (auth, config, database) │ │ ├── core/ # Core functionality (auth, config, database)
│ │ ├── crud/ # Database CRUD operations │ │ ├── repositories/ # Repository pattern (database operations)
│ │ ├── models/ # SQLAlchemy ORM models │ │ ├── models/ # SQLAlchemy ORM models
│ │ ├── schemas/ # Pydantic request/response schemas │ │ ├── schemas/ # Pydantic request/response schemas
│ │ ├── services/ # Business logic layer │ │ ├── services/ # Business logic layer
@@ -113,7 +113,7 @@ OAUTH_ISSUER=https://api.yourdomain.com # JWT issuer URL (must be HTTPS in
### Database Pattern ### Database Pattern
- **Async SQLAlchemy 2.0** with PostgreSQL - **Async SQLAlchemy 2.0** with PostgreSQL
- **Connection pooling**: 20 base connections, 50 max overflow - **Connection pooling**: 20 base connections, 50 max overflow
- **CRUD base class**: `crud/base.py` with common operations - **Repository base class**: `repositories/base.py` with common operations
- **Migrations**: Alembic with helper script `migrate.py` - **Migrations**: Alembic with helper script `migrate.py`
- `python migrate.py auto "message"` - Generate and apply - `python migrate.py auto "message"` - Generate and apply
- `python migrate.py list` - View history - `python migrate.py list` - View history
@@ -121,7 +121,7 @@ OAUTH_ISSUER=https://api.yourdomain.com # JWT issuer URL (must be HTTPS in
### Frontend State Management ### Frontend State Management
- **Zustand stores**: Lightweight state management - **Zustand stores**: Lightweight state management
- **TanStack Query**: API data fetching/caching - **TanStack Query**: API data fetching/caching
- **Auto-generated client**: From OpenAPI spec via `npm run generate:api` - **Auto-generated client**: From OpenAPI spec via `bun run generate:api`
- **Dependency Injection**: ALWAYS use `useAuth()` from `AuthContext`, NEVER import `useAuthStore` directly - **Dependency Injection**: ALWAYS use `useAuth()` from `AuthContext`, NEVER import `useAuthStore` directly
### Internationalization (i18n) ### Internationalization (i18n)
@@ -165,21 +165,25 @@ Permission dependencies in `api/dependencies/permissions.py`:
**Frontend Unit Tests (Jest):** **Frontend Unit Tests (Jest):**
- 97% coverage - 97% coverage
- Component, hook, and utility testing - Component, hook, and utility testing
- Run: `npm test` - Run: `bun run test`
- Coverage: `npm run test:coverage` - Coverage: `bun run test:coverage`
**Frontend E2E Tests (Playwright):** **Frontend E2E Tests (Playwright):**
- 56 passing, 1 skipped (zero flaky tests) - 56 passing, 1 skipped (zero flaky tests)
- Complete user flows (auth, navigation, settings) - Complete user flows (auth, navigation, settings)
- Run: `npm run test:e2e` - Run: `bun run test:e2e`
- UI mode: `npm run test:e2e:ui` - UI mode: `bun run test:e2e:ui`
### Development Tooling ### Development Tooling
**Backend:** **Backend:**
- **uv**: Modern Python package manager (10-100x faster than pip) - **uv**: Modern Python package manager (10-100x faster than pip)
- **Ruff**: All-in-one linting/formatting (replaces Black, Flake8, isort) - **Ruff**: All-in-one linting/formatting (replaces Black, Flake8, isort)
- **mypy**: Type checking with Pydantic plugin - **Pyright**: Static type checking (strict mode)
- **pip-audit**: Dependency vulnerability scanning (OSV database)
- **detect-secrets**: Hardcoded secrets detection
- **pip-licenses**: License compliance checking
- **pre-commit**: Git hook framework (Ruff, detect-secrets, standard checks)
- **Makefile**: `make help` for all commands - **Makefile**: `make help` for all commands
**Frontend:** **Frontend:**
@@ -218,11 +222,11 @@ NEXT_PUBLIC_API_URL=http://localhost:8000/api/v1
### Adding a New API Endpoint ### Adding a New API Endpoint
1. **Define schema** in `backend/app/schemas/` 1. **Define schema** in `backend/app/schemas/`
2. **Create CRUD operations** in `backend/app/crud/` 2. **Create repository** in `backend/app/repositories/`
3. **Implement route** in `backend/app/api/routes/` 3. **Implement route** in `backend/app/api/routes/`
4. **Register router** in `backend/app/api/main.py` 4. **Register router** in `backend/app/api/main.py`
5. **Write tests** in `backend/tests/api/` 5. **Write tests** in `backend/tests/api/`
6. **Generate frontend client**: `npm run generate:api` 6. **Generate frontend client**: `bun run generate:api`
### Database Migrations ### Database Migrations
@@ -239,7 +243,7 @@ python migrate.py auto "description" # Generate + apply
2. **Follow design system** (see `frontend/docs/design-system/`) 2. **Follow design system** (see `frontend/docs/design-system/`)
3. **Use dependency injection** for auth (`useAuth()` not `useAuthStore`) 3. **Use dependency injection** for auth (`useAuth()` not `useAuthStore`)
4. **Write tests** in `frontend/tests/` or `__tests__/` 4. **Write tests** in `frontend/tests/` or `__tests__/`
5. **Run type check**: `npm run type-check` 5. **Run type check**: `bun run type-check`
## Security Features ## Security Features
@@ -249,6 +253,10 @@ python migrate.py auto "description" # Generate + apply
- **CSRF protection**: Built into FastAPI - **CSRF protection**: Built into FastAPI
- **Session revocation**: Database-backed session tracking - **Session revocation**: Database-backed session tracking
- **Comprehensive security tests**: JWT algorithm attacks, session hijacking, privilege escalation - **Comprehensive security tests**: JWT algorithm attacks, session hijacking, privilege escalation
- **Dependency vulnerability scanning**: `make dep-audit` (pip-audit against OSV database)
- **License compliance**: `make license-check` (blocks GPL-3.0/AGPL)
- **Secrets detection**: Pre-commit hook blocks hardcoded secrets
- **Unified security pipeline**: `make audit` (all security checks), `make check` (quality + security + tests)
## Docker Deployment ## Docker Deployment
@@ -281,7 +289,7 @@ docker-compose exec backend python -c "from app.init_db import init_db; import a
- Authentication system (JWT with refresh tokens, OAuth/social login) - Authentication system (JWT with refresh tokens, OAuth/social login)
- **OAuth Provider Mode (MCP-ready)**: Full OAuth 2.0 Authorization Server - **OAuth Provider Mode (MCP-ready)**: Full OAuth 2.0 Authorization Server
- Session management (device tracking, revocation) - Session management (device tracking, revocation)
- User management (CRUD, password change) - User management (full lifecycle, password change)
- Organization system (multi-tenant with RBAC) - Organization system (multi-tenant with RBAC)
- Admin panel (user/org management, bulk operations) - Admin panel (user/org management, bulk operations)
- **Internationalization (i18n)** with English and Italian - **Internationalization (i18n)** with English and Italian

357
CLAUDE.md
View File

@@ -1,204 +1,253 @@
# CLAUDE.md # CLAUDE.md
Claude Code context for **Syndarix** - AI-Powered Software Consulting Agency. Claude Code context for FastAPI + Next.js Full-Stack Template.
**Built on PragmaStack.** See [AGENTS.md](./AGENTS.md) for base template context. **See [AGENTS.md](./AGENTS.md) for project context, architecture, and development commands.**
---
## Syndarix Project Context
### Vision
Syndarix is an autonomous platform that orchestrates specialized AI agents to deliver complete software solutions with minimal human intervention. It acts as a virtual consulting agency with AI agents playing roles like Product Owner, Architect, Engineers, QA, etc.
### Repository
- **URL:** https://gitea.pragmazest.com/cardosofelipe/syndarix
- **Issue Tracker:** Gitea Issues (primary)
- **CI/CD:** Gitea Actions
### Core Concepts
**Agent Types & Instances:**
- Agent Type = Template (base model, failover, expertise, personality)
- Agent Instance = Spawned from type, assigned to project
- Multiple instances of same type can work together
**Project Workflow:**
1. Requirements discovery with Product Owner agent
2. Architecture spike (PO + BA + Architect brainstorm)
3. Implementation planning and backlog creation
4. Autonomous sprint execution with checkpoints
5. Demo and client feedback
**Autonomy Levels:**
- `FULL_CONTROL`: Approve every action
- `MILESTONE`: Approve sprint boundaries
- `AUTONOMOUS`: Only major decisions
**MCP-First Architecture:**
All integrations via Model Context Protocol servers with explicit scoping:
```python
# All tools take project_id for scoping
search_knowledge(project_id="proj-123", query="auth flow")
create_issue(project_id="proj-123", title="Add login")
```
### Directory Structure
```
docs/
├── development/ # Workflow and coding standards
├── requirements/ # Requirements documents
├── architecture/ # Architecture documentation
├── adrs/ # Architecture Decision Records
└── spikes/ # Spike research documents
```
### Current Phase
**Backlog Population** - Creating detailed issues for Phase 0-1 implementation.
---
## Development Standards
**CRITICAL: These rules are mandatory. See linked docs for full details.**
### Quick Reference
| Topic | Documentation |
|-------|---------------|
| **Workflow & Branching** | [docs/development/WORKFLOW.md](./docs/development/WORKFLOW.md) |
| **Coding Standards** | [docs/development/CODING_STANDARDS.md](./docs/development/CODING_STANDARDS.md) |
| **Design System** | [frontend/docs/design-system/](./frontend/docs/design-system/) |
| **Backend E2E Testing** | [backend/docs/E2E_TESTING.md](./backend/docs/E2E_TESTING.md) |
| **Demo Mode** | [frontend/docs/DEMO_MODE.md](./frontend/docs/DEMO_MODE.md) |
### Essential Rules Summary
1. **Issue-Driven Development**: Every piece of work MUST have an issue first
2. **Branch per Feature**: `feature/<issue-number>-<description>`, single branch for design+implementation
3. **Testing Required**: All code must be tested, aim for >90% coverage
4. **Code Review**: Must pass multi-agent review before merge
5. **No Direct Commits**: Never commit directly to `main` or `dev`
6. **Stack Verification**: ALWAYS run the full stack before considering work done (see below)
### CRITICAL: Stack Verification Before Merge
**This is NON-NEGOTIABLE. A feature with 100% test coverage that crashes on startup is WORTHLESS.**
Before considering ANY issue complete:
```bash
# 1. Start the dev stack
make dev
# 2. Wait for backend to be healthy, check logs
docker compose -f docker-compose.dev.yml logs backend --tail=100
# 3. Start frontend
cd frontend && npm run dev
# 4. Verify both are running without errors
```
**The issue is NOT done if:**
- Backend crashes on startup (import errors, missing dependencies)
- Frontend fails to compile or render
- Health checks fail
- Any error appears in logs
**Why this matters:**
- Tests run in isolation and may pass despite broken imports
- Docker builds cache layers and may hide dependency issues
- A single `ModuleNotFoundError` renders all test coverage meaningless
### Common Commands
```bash
# Backend
IS_TEST=True uv run pytest # Run tests
uv run ruff check src/ # Lint
uv run mypy src/ # Type check
python migrate.py auto "message" # Database migration
# Frontend
npm test # Unit tests
npm run lint # Lint
npm run type-check # Type check
npm run generate:api # Regenerate API client
```
---
## Claude Code-Specific Guidance ## Claude Code-Specific Guidance
### Critical User Preferences ### Critical User Preferences
**File Operations:** #### File Operations - NEVER Use Heredoc/Cat Append
- ALWAYS use Read/Write/Edit tools instead of `cat >> file << EOF` **ALWAYS use Read/Write/Edit tools instead of `cat >> file << EOF` commands.**
- Never use heredoc - it triggers manual approval dialogs
**Work Style:** This triggers manual approval dialogs and disrupts workflow.
```bash
# WRONG ❌
cat >> file.txt << EOF
content
EOF
# CORRECT ✅ - Use Read, then Write tools
```
#### Work Style
- User prefers autonomous operation without frequent interruptions - User prefers autonomous operation without frequent interruptions
- Ask for batch permissions upfront for long work sessions - Ask for batch permissions upfront for long work sessions
- Work independently, document decisions clearly - Work independently, document decisions clearly
- Only use emojis if the user explicitly requests it - Only use emojis if the user explicitly requests it
### Critical Pattern: Auth Store DI ### When Working with This Stack
**Dependency Management:**
- Backend uses **uv** (modern Python package manager), not pip
- Always use `uv run` prefix: `IS_TEST=True uv run pytest`
- Or use Makefile commands: `make test`, `make install-dev`
- Add dependencies: `uv add <package>` or `uv add --dev <package>`
**Database Migrations:**
- Use the `migrate.py` helper script, not Alembic directly
- Generate + apply: `python migrate.py auto "message"`
- Never commit migrations without testing them first
- Check current state: `python migrate.py current`
**Frontend API Client Generation:**
- Run `bun run generate:api` after backend schema changes
- Client is auto-generated from OpenAPI spec
- Located in `frontend/src/lib/api/generated/`
- NEVER manually edit generated files
**Testing Commands:**
- Backend unit/integration: `IS_TEST=True uv run pytest` (always prefix with `IS_TEST=True`)
- Backend E2E (requires Docker): `make test-e2e`
- Frontend unit: `bun run test`
- Frontend E2E: `bun run test:e2e`
- Use `make test` or `make test-cov` in backend for convenience
**Security & Quality Commands (Backend):**
- `make validate` — lint + format + type checks
- `make audit` — dependency vulnerabilities + license compliance
- `make validate-all` — quality + security checks
- `make check`**full pipeline**: quality + security + tests
**Backend E2E Testing (requires Docker):**
- Install deps: `make install-e2e`
- Run all E2E tests: `make test-e2e`
- Run schema tests only: `make test-e2e-schema`
- Run all tests: `make test-all` (unit + E2E)
- Uses Testcontainers (real PostgreSQL) + Schemathesis (OpenAPI contract testing)
- Markers: `@pytest.mark.e2e`, `@pytest.mark.postgres`, `@pytest.mark.schemathesis`
- See: `backend/docs/E2E_TESTING.md` for complete guide
### 🔴 CRITICAL: Auth Store Dependency Injection Pattern
**ALWAYS use `useAuth()` from `AuthContext`, NEVER import `useAuthStore` directly!** **ALWAYS use `useAuth()` from `AuthContext`, NEVER import `useAuthStore` directly!**
```typescript ```typescript
// ❌ WRONG // ❌ WRONG - Bypasses dependency injection
import { useAuthStore } from '@/lib/stores/authStore'; import { useAuthStore } from '@/lib/stores/authStore';
const { user, isAuthenticated } = useAuthStore();
// ✅ CORRECT // ✅ CORRECT - Uses dependency injection
import { useAuth } from '@/lib/auth/AuthContext'; import { useAuth } from '@/lib/auth/AuthContext';
const { user, isAuthenticated } = useAuth();
``` ```
See [CODING_STANDARDS.md](./docs/development/CODING_STANDARDS.md#auth-store-dependency-injection) for details. **Why This Matters:**
- E2E tests inject mock stores via `window.__TEST_AUTH_STORE__`
- Unit tests inject via `<AuthProvider store={mockStore}>`
- Direct `useAuthStore` imports bypass this injection → **tests fail**
- ESLint will catch violations (added Nov 2025)
**Exceptions:**
1. `AuthContext.tsx` - DI boundary, legitimately needs real store
2. `client.ts` - Non-React context, uses dynamic import + `__TEST_AUTH_STORE__` check
### E2E Test Best Practices
When writing or fixing Playwright tests:
**Navigation Pattern:**
```typescript
// ✅ CORRECT - Use Promise.all for Next.js Link clicks
await Promise.all([
page.waitForURL('/target', { timeout: 10000 }),
link.click()
]);
```
**Selectors:**
- Use ID-based selectors for validation errors: `#email-error`
- Error IDs use dashes not underscores: `#new-password-error`
- Target `.border-destructive[role="alert"]` to avoid Next.js route announcer conflicts
- Avoid generic `[role="alert"]` which matches multiple elements
**URL Assertions:**
```typescript
// ✅ Use regex to handle query params
await expect(page).toHaveURL(/\/auth\/login/);
// ❌ Don't use exact strings (fails with query params)
await expect(page).toHaveURL('/auth/login');
```
**Configuration:**
- Uses 12 workers in non-CI mode (`playwright.config.ts`)
- Reduces to 2 workers in CI for stability
- Tests are designed to be non-flaky with proper waits
### Important Implementation Details
**Authentication Testing:**
- Backend fixtures in `tests/conftest.py`:
- `async_test_db`: Fresh SQLite per test
- `async_test_user` / `async_test_superuser`: Pre-created users
- `user_token` / `superuser_token`: Access tokens for API calls
- Always use `@pytest.mark.asyncio` for async tests
- Use `@pytest_asyncio.fixture` for async fixtures
**Database Testing:**
```python
# Mock database exceptions correctly
from unittest.mock import patch, AsyncMock
async def mock_commit():
raise OperationalError("Connection lost", {}, Exception())
with patch.object(session, 'commit', side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
with pytest.raises(OperationalError):
await repo_method(session, obj_in=data)
mock_rollback.assert_called_once()
```
**Frontend Component Development:**
- Follow design system docs in `frontend/docs/design-system/`
- Read `08-ai-guidelines.md` for AI code generation rules
- Use parent-controlled spacing (see `04-spacing-philosophy.md`)
- WCAG AA compliance required (see `07-accessibility.md`)
**Security Considerations:**
- Backend has comprehensive security tests (JWT attacks, session hijacking)
- Never skip security headers in production
- Rate limiting is configured in route decorators: `@limiter.limit("10/minute")`
- Session revocation is database-backed, not just JWT expiry
- Run `make audit` to check for dependency vulnerabilities and license compliance
- Run `make check` for the full pipeline: quality + security + tests
- Pre-commit hooks enforce Ruff lint/format and detect-secrets on every commit
- Setup hooks: `cd backend && uv run pre-commit install`
### Common Workflows Guidance
**When Adding a New Feature:**
1. Start with backend schema and repository
2. Implement API route with proper authorization
3. Write backend tests (aim for >90% coverage)
4. Generate frontend API client: `bun run generate:api`
5. Implement frontend components
6. Write frontend unit tests
7. Add E2E tests for critical flows
8. Update relevant documentation
**When Fixing Tests:**
- Backend: Check test database isolation and async fixture usage
- Frontend unit: Verify mocking of `useAuth()` not `useAuthStore`
- E2E: Use `Promise.all()` pattern and regex URL assertions
**When Debugging:**
- Backend: Check `IS_TEST=True` environment variable is set
- Frontend: Run `bun run type-check` first
- E2E: Use `bun run test:e2e:debug` for step-by-step debugging
- Check logs: Backend has detailed error logging
**Demo Mode (Frontend-Only Showcase):**
- Enable: `echo "NEXT_PUBLIC_DEMO_MODE=true" > frontend/.env.local`
- Uses MSW (Mock Service Worker) to intercept API calls in browser
- Zero backend required - perfect for Vercel deployments
- **Fully Automated**: MSW handlers auto-generated from OpenAPI spec
- Run `bun run generate:api` → updates both API client AND MSW handlers
- No manual synchronization needed!
- Demo credentials (any password ≥8 chars works):
- User: `demo@example.com` / `DemoPass123`
- Admin: `admin@example.com` / `AdminPass123`
- **Safe**: MSW never runs during tests (Jest or Playwright)
- **Coverage**: Mock files excluded from linting and coverage
- **Documentation**: `frontend/docs/DEMO_MODE.md` for complete guide
### Tool Usage Preferences ### Tool Usage Preferences
**Prefer specialized tools over bash:** **Prefer specialized tools over bash:**
- Use Read/Write/Edit tools for file operations - Use Read/Write/Edit tools for file operations
- Never use `cat`, `echo >`, or heredoc for file manipulation
- Use Task tool with `subagent_type=Explore` for codebase exploration - Use Task tool with `subagent_type=Explore` for codebase exploration
- Use Grep tool for code search, not bash `grep` - Use Grep tool for code search, not bash `grep`
**Parallel tool calls for:** **When to use parallel tool calls:**
- Independent git commands - Independent git commands: `git status`, `git diff`, `git log`
- Reading multiple unrelated files - Reading multiple unrelated files
- Running multiple test suites - Running multiple test suites simultaneously
- Independent validation steps - Independent validation steps
--- ## Custom Skills
## Key Extensions (from PragmaStack base) No Claude Code Skills installed yet. To create one, invoke the built-in "skill-creator" skill.
- Celery + Redis for agent job queue **Potential skill ideas for this project:**
- WebSocket/SSE for real-time updates - API endpoint generator workflow (schema → repository → route → tests → frontend client)
- pgvector for RAG knowledge base - Component generator with design system compliance
- MCP server integration layer - Database migration troubleshooting helper
- Test coverage analyzer and improvement suggester
--- - E2E test generator for new features
## Additional Resources ## Additional Resources
**Documentation:** **Comprehensive Documentation:**
- [AGENTS.md](./AGENTS.md) - Framework-agnostic AI assistant context - [AGENTS.md](./AGENTS.md) - Framework-agnostic AI assistant context
- [README.md](./README.md) - User-facing project overview - [README.md](./README.md) - User-facing project overview
- [docs/development/](./docs/development/) - Development workflow and standards - `backend/docs/` - Backend architecture, coding standards, common pitfalls
- [backend/docs/](./backend/docs/) - Backend architecture and guides - `frontend/docs/design-system/` - Complete design system guide
- [frontend/docs/design-system/](./frontend/docs/design-system/) - Complete design system
**API Documentation (when running):** **API Documentation (when running):**
- Swagger UI: http://localhost:8000/docs - Swagger UI: http://localhost:8000/docs
- ReDoc: http://localhost:8000/redoc - ReDoc: http://localhost:8000/redoc
- OpenAPI JSON: http://localhost:8000/api/v1/openapi.json - OpenAPI JSON: http://localhost:8000/api/v1/openapi.json
**Testing Documentation:**
- Backend tests: `backend/tests/` (97% coverage)
- Frontend E2E: `frontend/e2e/README.md`
- Design system: `frontend/docs/design-system/08-ai-guidelines.md`
--- ---
**For project architecture, development commands, and general context, see [AGENTS.md](./AGENTS.md).** **For project architecture, development commands, and general context, see [AGENTS.md](./AGENTS.md).**

View File

@@ -91,7 +91,10 @@ Ready to write some code? Awesome!
cd backend cd backend
# Install dependencies (uv manages virtual environment automatically) # Install dependencies (uv manages virtual environment automatically)
uv sync make install-dev
# Setup pre-commit hooks
uv run pre-commit install
# Setup environment # Setup environment
cp .env.example .env cp .env.example .env
@@ -100,8 +103,14 @@ cp .env.example .env
# Run migrations # Run migrations
python migrate.py apply python migrate.py apply
# Run quality + security checks
make validate-all
# Run tests # Run tests
IS_TEST=True uv run pytest make test
# Run full pipeline (quality + security + tests)
make check
# Start dev server # Start dev server
uvicorn app.main:app --reload uvicorn app.main:app --reload
@@ -113,20 +122,20 @@ uvicorn app.main:app --reload
cd frontend cd frontend
# Install dependencies # Install dependencies
npm install bun install
# Setup environment # Setup environment
cp .env.local.example .env.local cp .env.local.example .env.local
# Generate API client # Generate API client
npm run generate:api bun run generate:api
# Run tests # Run tests
npm test bun run test
npm run test:e2e:ui bun run test:e2e:ui
# Start dev server # Start dev server
npm run dev bun run dev
``` ```
--- ---
@@ -195,7 +204,7 @@ export function UserProfile({ userId }: UserProfileProps) {
### Key Patterns ### Key Patterns
- **Backend**: Use CRUD pattern, keep routes thin, business logic in services - **Backend**: Use repository pattern, keep routes thin, business logic in services
- **Frontend**: Use React Query for server state, Zustand for client state - **Frontend**: Use React Query for server state, Zustand for client state
- **Both**: Handle errors gracefully, log appropriately, write tests - **Both**: Handle errors gracefully, log appropriately, write tests
@@ -316,7 +325,7 @@ Fixed stuff
### Before Submitting ### Before Submitting
- [ ] Code follows project style guidelines - [ ] Code follows project style guidelines
- [ ] All tests pass locally - [ ] `make check` passes (quality + security + tests) in backend
- [ ] New tests added for new features - [ ] New tests added for new features
- [ ] Documentation updated if needed - [ ] Documentation updated if needed
- [ ] No merge conflicts with `main` - [ ] No merge conflicts with `main`

117
Makefile
View File

@@ -1,31 +1,18 @@
.PHONY: help dev dev-full prod down logs logs-dev clean clean-slate drop-db reset-db push-images deploy .PHONY: help dev dev-full prod down logs logs-dev clean clean-slate drop-db reset-db push-images deploy scan-images
.PHONY: test test-backend test-mcp test-frontend test-all test-cov test-integration validate validate-all
VERSION ?= latest VERSION ?= latest
REGISTRY ?= ghcr.io/cardosofelipe/pragma-stack REGISTRY ?= ghcr.io/cardosofelipe/pragma-stack
# Default target # Default target
help: help:
@echo "Syndarix - AI-Powered Software Consulting Agency" @echo "FastAPI + Next.js Full-Stack Template"
@echo "" @echo ""
@echo "Development:" @echo "Development:"
@echo " make dev - Start backend + db + MCP servers (frontend runs separately)" @echo " make dev - Start backend + db (frontend runs separately)"
@echo " make dev-full - Start all services including frontend" @echo " make dev-full - Start all services including frontend"
@echo " make down - Stop all services" @echo " make down - Stop all services"
@echo " make logs-dev - Follow dev container logs" @echo " make logs-dev - Follow dev container logs"
@echo "" @echo ""
@echo "Testing:"
@echo " make test - Run all tests (backend + MCP servers)"
@echo " make test-backend - Run backend tests only"
@echo " make test-mcp - Run MCP server tests only"
@echo " make test-frontend - Run frontend tests only"
@echo " make test-cov - Run all tests with coverage reports"
@echo " make test-integration - Run MCP integration tests (requires running stack)"
@echo ""
@echo "Validation:"
@echo " make validate - Validate backend + MCP servers (lint, type-check, test)"
@echo " make validate-all - Validate everything including frontend"
@echo ""
@echo "Database:" @echo "Database:"
@echo " make drop-db - Drop and recreate empty database" @echo " make drop-db - Drop and recreate empty database"
@echo " make reset-db - Drop database and apply all migrations" @echo " make reset-db - Drop database and apply all migrations"
@@ -34,6 +21,7 @@ help:
@echo " make prod - Start production stack" @echo " make prod - Start production stack"
@echo " make deploy - Pull and deploy latest images" @echo " make deploy - Pull and deploy latest images"
@echo " make push-images - Build and push images to registry" @echo " make push-images - Build and push images to registry"
@echo " make scan-images - Scan production images for CVEs (requires trivy)"
@echo " make logs - Follow production container logs" @echo " make logs - Follow production container logs"
@echo "" @echo ""
@echo "Cleanup:" @echo "Cleanup:"
@@ -41,10 +29,8 @@ help:
@echo " make clean-slate - Stop containers AND delete volumes (DATA LOSS!)" @echo " make clean-slate - Stop containers AND delete volumes (DATA LOSS!)"
@echo "" @echo ""
@echo "Subdirectory commands:" @echo "Subdirectory commands:"
@echo " cd backend && make help - Backend-specific commands" @echo " cd backend && make help - Backend-specific commands"
@echo " cd mcp-servers/llm-gateway && make - LLM Gateway commands" @echo " cd frontend && npm run - Frontend-specific commands"
@echo " cd mcp-servers/knowledge-base && make - Knowledge Base commands"
@echo " cd frontend && npm run - Frontend-specific commands"
# ============================================================================ # ============================================================================
# Development # Development
@@ -104,6 +90,28 @@ push-images:
docker push $(REGISTRY)/backend:$(VERSION) docker push $(REGISTRY)/backend:$(VERSION)
docker push $(REGISTRY)/frontend:$(VERSION) docker push $(REGISTRY)/frontend:$(VERSION)
scan-images:
@docker info > /dev/null 2>&1 || (echo "❌ Docker is not running!"; exit 1)
@echo "🐳 Building and scanning production images for CVEs..."
docker build -t $(REGISTRY)/backend:scan --target production ./backend
docker build -t $(REGISTRY)/frontend:scan --target runner ./frontend
@echo ""
@echo "=== Backend Image Scan ==="
@if command -v trivy > /dev/null 2>&1; then \
trivy image --severity HIGH,CRITICAL --exit-code 1 $(REGISTRY)/backend:scan; \
else \
echo " Trivy not found locally, using Docker to run Trivy..."; \
docker run --rm -v /var/run/docker.sock:/var/run/docker.sock aquasec/trivy image --severity HIGH,CRITICAL --exit-code 1 $(REGISTRY)/backend:scan; \
fi
@echo ""
@echo "=== Frontend Image Scan ==="
@if command -v trivy > /dev/null 2>&1; then \
trivy image --severity HIGH,CRITICAL --exit-code 1 $(REGISTRY)/frontend:scan; \
else \
docker run --rm -v /var/run/docker.sock:/var/run/docker.sock aquasec/trivy image --severity HIGH,CRITICAL --exit-code 1 $(REGISTRY)/frontend:scan; \
fi
@echo "✅ No HIGH/CRITICAL CVEs found in production images!"
# ============================================================================ # ============================================================================
# Cleanup # Cleanup
# ============================================================================ # ============================================================================
@@ -114,72 +122,3 @@ clean:
# WARNING! THIS REMOVES CONTAINERS AND VOLUMES AS WELL - DO NOT USE THIS UNLESS YOU WANT TO START OVER WITH DATA AND ALL # WARNING! THIS REMOVES CONTAINERS AND VOLUMES AS WELL - DO NOT USE THIS UNLESS YOU WANT TO START OVER WITH DATA AND ALL
clean-slate: clean-slate:
docker compose -f docker-compose.dev.yml down -v --remove-orphans docker compose -f docker-compose.dev.yml down -v --remove-orphans
# ============================================================================
# Testing
# ============================================================================
test: test-backend test-mcp
@echo ""
@echo "All tests passed!"
test-backend:
@echo "Running backend tests..."
@cd backend && IS_TEST=True uv run pytest tests/ -v
test-mcp:
@echo "Running MCP server tests..."
@echo ""
@echo "=== LLM Gateway ==="
@cd mcp-servers/llm-gateway && uv run pytest tests/ -v
@echo ""
@echo "=== Knowledge Base ==="
@cd mcp-servers/knowledge-base && uv run pytest tests/ -v
test-frontend:
@echo "Running frontend tests..."
@cd frontend && npm test
test-all: test test-frontend
@echo ""
@echo "All tests (backend + MCP + frontend) passed!"
test-cov:
@echo "Running all tests with coverage..."
@echo ""
@echo "=== Backend Coverage ==="
@cd backend && IS_TEST=True uv run pytest tests/ -v --cov=app --cov-report=term-missing
@echo ""
@echo "=== LLM Gateway Coverage ==="
@cd mcp-servers/llm-gateway && uv run pytest tests/ -v --cov=. --cov-report=term-missing
@echo ""
@echo "=== Knowledge Base Coverage ==="
@cd mcp-servers/knowledge-base && uv run pytest tests/ -v --cov=. --cov-report=term-missing
test-integration:
@echo "Running MCP integration tests..."
@echo "Note: Requires running stack (make dev first)"
@cd backend && RUN_INTEGRATION_TESTS=true IS_TEST=True uv run pytest tests/integration/ -v
# ============================================================================
# Validation (lint + type-check + test)
# ============================================================================
validate:
@echo "Validating backend..."
@cd backend && make validate
@echo ""
@echo "Validating LLM Gateway..."
@cd mcp-servers/llm-gateway && make validate
@echo ""
@echo "Validating Knowledge Base..."
@cd mcp-servers/knowledge-base && make validate
@echo ""
@echo "All validations passed!"
validate-all: validate
@echo ""
@echo "Validating frontend..."
@cd frontend && npm run validate
@echo ""
@echo "Full validation passed!"

724
README.md
View File

@@ -1,175 +1,659 @@
# Syndarix # <img src="frontend/public/logo.svg" alt="PragmaStack" width="32" height="32" style="vertical-align: middle" /> PragmaStack
> **Your AI-Powered Software Consulting Agency** > **The Pragmatic Full-Stack Template. Production-ready, security-first, and opinionated.**
>
> An autonomous platform that orchestrates specialized AI agents to deliver complete software solutions with minimal human intervention.
[![Built on PragmaStack](https://img.shields.io/badge/Built_on-PragmaStack-blue)](https://gitea.pragmazest.com/cardosofelipe/fast-next-template) [![Backend Coverage](https://img.shields.io/badge/backend_coverage-97%25-brightgreen)](./backend/tests)
[![Frontend Coverage](https://img.shields.io/badge/frontend_coverage-97%25-brightgreen)](./frontend/tests)
[![E2E Tests](https://img.shields.io/badge/e2e_tests-passing-success)](./frontend/e2e)
[![License: MIT](https://img.shields.io/badge/License-MIT-blue.svg)](./LICENSE) [![License: MIT](https://img.shields.io/badge/License-MIT-blue.svg)](./LICENSE)
[![PRs Welcome](https://img.shields.io/badge/PRs-welcome-brightgreen.svg)](./CONTRIBUTING.md)
![Landing Page](docs/images/landing.png)
--- ---
## Vision ## Why PragmaStack?
Syndarix transforms the software development lifecycle by providing a **virtual consulting team** of AI agents that collaboratively plan, design, implement, test, and deliver complete software solutions. Building a modern full-stack application often leads to "analysis paralysis" or "boilerplate fatigue". You spend weeks setting up authentication, testing, and linting before writing a single line of business logic.
**The Problem:** Even with AI coding assistants, developers spend as much time managing AI as doing the work themselves. Context switching, babysitting, and knowledge fragmentation limit productivity. **PragmaStack cuts through the noise.**
**The Solution:** A structured, autonomous agency where specialized AI agents handle different roles (Product Owner, Architect, Engineers, QA, etc.) with proper workflows, reviews, and quality gates. We provide a **pragmatic**, opinionated foundation that prioritizes:
- **Speed**: Ship features, not config files.
- **Robustness**: Security and testing are not optional.
- **Clarity**: Code that is easy to read and maintain.
Whether you're building a SaaS, an internal tool, or a side project, PragmaStack gives you a solid starting point without the bloat.
--- ---
## Key Features ## Features
### Multi-Agent Orchestration ### 🔐 **Authentication & Security**
- Configurable agent **types** with base model, failover, expertise, and personality - JWT-based authentication with access + refresh tokens
- Spawn multiple **instances** from the same type (e.g., Dave, Ellis, Kate as Software Developers) - **OAuth/Social Login** (Google, GitHub) with PKCE support
- Agent-to-agent communication and collaboration - **OAuth 2.0 Authorization Server** (MCP-ready) for third-party integrations
- Per-instance customization with domain-specific knowledge - Session management with device tracking and revocation
- Password reset flow (email integration ready)
- Secure password hashing (bcrypt)
- CSRF protection, rate limiting, and security headers
- Comprehensive security tests (JWT algorithm attacks, session hijacking, privilege escalation)
### Complete SDLC Support ### 🔌 **OAuth Provider Mode (MCP Integration)**
- **Requirements Discovery** → **Architecture Spike****Implementation Planning** Full OAuth 2.0 Authorization Server for Model Context Protocol (MCP) and third-party clients:
- **Sprint Management** with automated ceremonies - **RFC 7636**: Authorization Code Flow with PKCE (S256 only)
- **Issue Tracking** with Epic/Story/Task hierarchy - **RFC 8414**: Server metadata discovery at `/.well-known/oauth-authorization-server`
- **Git Integration** with proper branch/PR workflows - **RFC 7662**: Token introspection endpoint
- **CI/CD Pipelines** with automated testing - **RFC 7009**: Token revocation endpoint
- **JWT access tokens**: Self-contained, configurable lifetime
- **Opaque refresh tokens**: Secure rotation, database-backed revocation
- **Consent management**: Users can review and revoke app permissions
- **Client management**: Admin endpoints for registering OAuth clients
- **Scopes**: `openid`, `profile`, `email`, `read:users`, `write:users`, `admin`
### Configurable Autonomy ### 👥 **Multi-Tenancy & Organizations**
- From `FULL_CONTROL` (approve everything) to `AUTONOMOUS` (only major milestones) - Full organization system with role-based access control (Owner, Admin, Member)
- Client can intervene at any point - Invite/remove members, manage permissions
- Transparent progress visibility - Organization-scoped data access
- User can belong to multiple organizations
### MCP-First Architecture ### 🛠️ **Admin Panel**
- All integrations via **Model Context Protocol (MCP)** servers - Complete user management (full lifecycle, activate/deactivate, bulk operations)
- Unified Knowledge Base with project/agent scoping - Organization management (create, edit, delete, member management)
- Git providers (Gitea, GitHub, GitLab) via MCP - Session monitoring across all users
- Extensible through custom MCP tools - Real-time statistics dashboard
- Admin-only routes with proper authorization
### Project Complexity Wizard ### 🎨 **Modern Frontend**
- **Script** → Minimal process, no repo needed - Next.js 16 with App Router and React 19
- **Simple** → Single sprint, basic backlog - **PragmaStack Design System** built on shadcn/ui + TailwindCSS
- **Medium/Complex** → Full AGILE workflow with multiple sprints - Pre-configured theme with dark mode support (coming soon)
- Responsive, accessible components (WCAG AA compliant)
- Rich marketing landing page with animated components
- Live component showcase and documentation at `/dev`
### 🌍 **Internationalization (i18n)**
- Built-in multi-language support with next-intl v4
- Locale-based routing (`/en/*`, `/it/*`)
- Seamless language switching with LocaleSwitcher component
- SEO-friendly URLs and metadata per locale
- Translation files for English and Italian (easily extensible)
- Type-safe translations throughout the app
### 🎯 **Content & UX Features**
- **Toast notifications** with Sonner for elegant user feedback
- **Smooth animations** powered by Framer Motion
- **Markdown rendering** with syntax highlighting (GitHub Flavored Markdown)
- **Charts and visualizations** ready with Recharts
- **SEO optimization** with dynamic sitemap and robots.txt generation
- **Session tracking UI** with device information and revocation controls
### 🧪 **Comprehensive Testing**
- **Backend Testing**: ~97% unit test coverage
- Unit, integration, and security tests
- Async database testing with SQLAlchemy
- API endpoint testing with fixtures
- Security vulnerability tests (JWT attacks, session hijacking, privilege escalation)
- **Frontend Unit Tests**: ~97% coverage with Jest
- Component testing
- Hook testing
- Utility function testing
- **End-to-End Tests**: Playwright with zero flaky tests
- Complete user flows (auth, navigation, settings)
- Parallel execution for speed
- Visual regression testing ready
### 📚 **Developer Experience**
- Auto-generated TypeScript API client from OpenAPI spec
- Interactive API documentation (Swagger + ReDoc)
- Database migrations with Alembic helper script
- Hot reload in development for both frontend and backend
- Comprehensive code documentation and design system docs
- Live component playground at `/dev` with code examples
- Docker support for easy deployment
- VSCode workspace settings included
### 📊 **Ready for Production**
- Docker + docker-compose setup
- Environment-based configuration
- Database connection pooling
- Error handling and logging
- Health check endpoints
- Production security headers
- Rate limiting on sensitive endpoints
- SEO optimization with dynamic sitemaps and robots.txt
- Multi-language SEO with locale-specific metadata
- Performance monitoring and bundle analysis
--- ---
## Technology Stack ## 📸 Screenshots
Built on [PragmaStack](https://gitea.pragmazest.com/cardosofelipe/fast-next-template): <details>
<summary>Click to view screenshots</summary>
| Component | Technology | ### Landing Page
|-----------|------------| ![Landing Page](docs/images/landing.png)
| Backend | FastAPI 0.115+ (Python 3.11+) |
| Frontend | Next.js 16 (React 19) |
| Database | PostgreSQL 15+ with pgvector |
| ORM | SQLAlchemy 2.0 |
| State Management | Zustand + TanStack Query |
| UI | shadcn/ui + Tailwind 4 |
| Auth | JWT dual-token + OAuth 2.0 |
| Testing | pytest + Jest + Playwright |
### Syndarix Extensions
| Component | Technology |
|-----------|------------| ### Authentication
| Task Queue | Celery + Redis | ![Login Page](docs/images/login.png)
| Real-time | FastAPI WebSocket / SSE |
| Vector DB | pgvector (PostgreSQL extension) |
| MCP SDK | Anthropic MCP SDK |
### Admin Dashboard
![Admin Dashboard](docs/images/admin-dashboard.png)
### Design System
![Components](docs/images/design-system.png)
</details>
--- ---
## Project Status ## 🎭 Demo Mode
**Phase:** Architecture & Planning **Try the frontend without a backend!** Perfect for:
- **Free deployment** on Vercel (no backend costs)
See [docs/requirements/](./docs/requirements/) for the comprehensive requirements document. - **Portfolio showcasing** with live demos
- **Client presentations** without infrastructure setup
### Current Milestones
- [x] Fork PragmaStack as foundation
- [x] Create requirements document
- [ ] Execute architecture spikes
- [ ] Create ADRs for key decisions
- [ ] Begin MVP implementation
---
## Documentation
- [Requirements Document](./docs/requirements/SYNDARIX_REQUIREMENTS.md)
- [Architecture Decisions](./docs/adrs/) (coming soon)
- [Spike Research](./docs/spikes/) (coming soon)
- [Architecture Overview](./docs/architecture/) (coming soon)
---
## Getting Started
### Prerequisites
- Docker & Docker Compose
- Node.js 20+
- Python 3.11+
- PostgreSQL 15+ (or use Docker)
### Quick Start ### Quick Start
```bash
cd frontend
echo "NEXT_PUBLIC_DEMO_MODE=true" > .env.local
bun run dev
```
**Demo Credentials:**
- Regular user: `demo@example.com` / `DemoPass123`
- Admin user: `admin@example.com` / `AdminPass123`
Demo mode uses [Mock Service Worker (MSW)](https://mswjs.io/) to intercept API calls in the browser. Your code remains unchanged - the same components work with both real and mocked backends.
**Key Features:**
- ✅ Zero backend required
- ✅ All features functional (auth, admin, stats)
- ✅ Realistic network delays and errors
- ✅ Does NOT interfere with tests (97%+ coverage maintained)
- ✅ One-line toggle: `NEXT_PUBLIC_DEMO_MODE=true`
📖 **[Complete Demo Mode Documentation](./frontend/docs/DEMO_MODE.md)**
---
## 🚀 Tech Stack
### Backend
- **[FastAPI](https://fastapi.tiangolo.com/)** - Modern async Python web framework
- **[SQLAlchemy 2.0](https://www.sqlalchemy.org/)** - Powerful ORM with async support
- **[PostgreSQL](https://www.postgresql.org/)** - Robust relational database
- **[Alembic](https://alembic.sqlalchemy.org/)** - Database migrations
- **[Pydantic v2](https://docs.pydantic.dev/)** - Data validation with type hints
- **[pytest](https://pytest.org/)** - Testing framework with async support
### Frontend
- **[Next.js 16](https://nextjs.org/)** - React framework with App Router
- **[React 19](https://react.dev/)** - UI library
- **[TypeScript](https://www.typescriptlang.org/)** - Type-safe JavaScript
- **[TailwindCSS](https://tailwindcss.com/)** - Utility-first CSS framework
- **[shadcn/ui](https://ui.shadcn.com/)** - Beautiful, accessible component library
- **[next-intl](https://next-intl.dev/)** - Internationalization (i18n) with type safety
- **[TanStack Query](https://tanstack.com/query)** - Powerful data fetching/caching
- **[Zustand](https://zustand-demo.pmnd.rs/)** - Lightweight state management
- **[Framer Motion](https://www.framer.com/motion/)** - Production-ready animation library
- **[Sonner](https://sonner.emilkowal.ski/)** - Beautiful toast notifications
- **[Recharts](https://recharts.org/)** - Composable charting library
- **[React Markdown](https://github.com/remarkjs/react-markdown)** - Markdown rendering with GFM support
- **[Playwright](https://playwright.dev/)** - End-to-end testing
### DevOps
- **[Docker](https://www.docker.com/)** - Containerization
- **[docker-compose](https://docs.docker.com/compose/)** - Multi-container orchestration
- **GitHub Actions** (coming soon) - CI/CD pipelines
---
## 📋 Prerequisites
- **Docker & Docker Compose** (recommended) - [Install Docker](https://docs.docker.com/get-docker/)
- **OR manually:**
- Python 3.12+
- Node.js 18+ (Node 20+ recommended)
- PostgreSQL 15+
---
## 🏃 Quick Start (Docker)
The fastest way to get started is with Docker:
```bash ```bash
# Clone the repository # Clone the repository
git clone https://gitea.pragmazest.com/cardosofelipe/syndarix.git git clone https://github.com/cardosofelipe/pragma-stack.git
cd syndarix cd fast-next-template
# Copy environment template # Copy environment file
cp .env.template .env cp .env.template .env
# Start development environment # Start all services (backend, frontend, database)
docker-compose -f docker-compose.dev.yml up -d docker-compose up
# Run database migrations # In another terminal, run database migrations
make migrate docker-compose exec backend alembic upgrade head
# Start the development servers # Create first superuser (optional)
make dev docker-compose exec backend python -c "from app.init_db import init_db; import asyncio; asyncio.run(init_db())"
```
**That's it! 🎉**
- Frontend: http://localhost:3000
- Backend API: http://localhost:8000
- API Docs: http://localhost:8000/docs
Default superuser credentials:
- Email: `admin@example.com`
- Password: `admin123`
**⚠️ Change these immediately in production!**
---
## 🛠️ Manual Setup (Development)
### Backend Setup
```bash
cd backend
# Create virtual environment
python -m venv .venv
source .venv/bin/activate # On Windows: .venv\Scripts\activate
# Install dependencies
pip install -r requirements.txt
# Setup environment
cp .env.example .env
# Edit .env with your database credentials
# Run migrations
alembic upgrade head
# Initialize database with first superuser
python -c "from app.init_db import init_db; import asyncio; asyncio.run(init_db())"
# Start development server
uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
```
### Frontend Setup
```bash
cd frontend
# Install dependencies
bun install
# Setup environment
cp .env.local.example .env.local
# Edit .env.local with your backend URL
# Generate API client
bun run generate:api
# Start development server
bun run dev
```
Visit http://localhost:3000 to see your app!
---
## 📂 Project Structure
```
├── backend/ # FastAPI backend
│ ├── app/
│ │ ├── api/ # API routes and dependencies
│ │ ├── core/ # Core functionality (auth, config, database)
│ │ ├── repositories/ # Repository pattern (database operations)
│ │ ├── models/ # SQLAlchemy models
│ │ ├── schemas/ # Pydantic schemas
│ │ ├── services/ # Business logic
│ │ └── utils/ # Utilities
│ ├── tests/ # Backend tests (97% coverage)
│ ├── alembic/ # Database migrations
│ └── docs/ # Backend documentation
├── frontend/ # Next.js frontend
│ ├── src/
│ │ ├── app/ # Next.js App Router pages
│ │ ├── components/ # React components
│ │ ├── lib/ # Libraries and utilities
│ │ │ ├── api/ # API client (auto-generated)
│ │ │ └── stores/ # Zustand stores
│ │ └── hooks/ # Custom React hooks
│ ├── e2e/ # Playwright E2E tests
│ ├── tests/ # Unit tests (Jest)
│ └── docs/ # Frontend documentation
│ └── design-system/ # Comprehensive design system docs
├── docker-compose.yml # Docker orchestration
├── docker-compose.dev.yml # Development with hot reload
└── README.md # You are here!
``` ```
--- ---
## Architecture Overview ## 🧪 Testing
This template takes testing seriously with comprehensive coverage across all layers:
### Backend Unit & Integration Tests
**High coverage (~97%)** across all critical paths including security-focused tests.
```bash
cd backend
# Run all tests
IS_TEST=True pytest
# Run with coverage report
IS_TEST=True pytest --cov=app --cov-report=term-missing
# Run specific test file
IS_TEST=True pytest tests/api/test_auth.py -v
# Generate HTML coverage report
IS_TEST=True pytest --cov=app --cov-report=html
open htmlcov/index.html
``` ```
+====================================================================+
| SYNDARIX CORE | **Test types:**
+====================================================================+ - **Unit tests**: Repository operations, utilities, business logic
| +------------------+ +------------------+ +------------------+ | - **Integration tests**: API endpoints with database
| | Agent Orchestrator| | Project Manager | | Workflow Engine | | - **Security tests**: JWT algorithm attacks, session hijacking, privilege escalation
| +------------------+ +------------------+ +------------------+ | - **Error handling tests**: Database failures, validation errors
+====================================================================+
| ### Frontend Unit Tests
v
+====================================================================+ **High coverage (~97%)** with Jest and React Testing Library.
| MCP ORCHESTRATION LAYER |
| All integrations via unified MCP servers with project scoping | ```bash
+====================================================================+ cd frontend
|
+------------------------+------------------------+ # Run unit tests
| | | bun run test
+----v----+ +----v----+ +----v----+ +----v----+ +----v----+
| LLM | | Git | |Knowledge| | File | | Code | # Run with coverage
| Providers| | MCP | |Base MCP | |Sys. MCP | |Analysis | bun run test:coverage
+---------+ +---------+ +---------+ +---------+ +---------+
# Watch mode
bun run test:watch
```
**Test types:**
- Component rendering and interactions
- Custom hooks behavior
- State management
- Utility functions
- API integration mocks
### End-to-End Tests
**Zero flaky tests** with Playwright covering complete user journeys.
```bash
cd frontend
# Run E2E tests
bun run test:e2e
# Run E2E tests in UI mode (recommended for development)
bun run test:e2e:ui
# Run specific test file
npx playwright test auth-login.spec.ts
# Generate test report
npx playwright show-report
```
**Test coverage:**
- Complete authentication flows
- Navigation and routing
- Form submissions and validation
- Settings and profile management
- Session management
- Admin panel workflows (in progress)
---
## 🤖 AI-Friendly Documentation
This project includes comprehensive documentation designed for AI coding assistants:
- **[AGENTS.md](./AGENTS.md)** - Framework-agnostic AI assistant context for PragmaStack
- **[CLAUDE.md](./CLAUDE.md)** - Claude Code-specific guidance
These files provide AI assistants with the **PragmaStack** architecture, patterns, and best practices.
---
## 🗄️ Database Migrations
The template uses Alembic for database migrations:
```bash
cd backend
# Generate migration from model changes
python migrate.py generate "description of changes"
# Apply migrations
python migrate.py apply
# Or do both in one command
python migrate.py auto "description"
# View migration history
python migrate.py list
# Check current revision
python migrate.py current
``` ```
--- ---
## Contributing ## 📖 Documentation
See [CONTRIBUTING.md](./CONTRIBUTING.md) for guidelines. ### AI Assistant Documentation
- **[AGENTS.md](./AGENTS.md)** - Framework-agnostic AI coding assistant context
- **[CLAUDE.md](./CLAUDE.md)** - Claude Code-specific guidance and preferences
### Backend Documentation
- **[ARCHITECTURE.md](./backend/docs/ARCHITECTURE.md)** - System architecture and design patterns
- **[CODING_STANDARDS.md](./backend/docs/CODING_STANDARDS.md)** - Code quality standards
- **[COMMON_PITFALLS.md](./backend/docs/COMMON_PITFALLS.md)** - Common mistakes to avoid
- **[FEATURE_EXAMPLE.md](./backend/docs/FEATURE_EXAMPLE.md)** - Step-by-step feature guide
### Frontend Documentation
- **[PragmaStack Design System](./frontend/docs/design-system/)** - Complete design system guide
- Quick start, foundations (colors, typography, spacing)
- Component library guide
- Layout patterns, spacing philosophy
- Forms, accessibility, AI guidelines
- **[E2E Testing Guide](./frontend/e2e/README.md)** - E2E testing setup and best practices
### API Documentation
When the backend is running:
- **Swagger UI**: http://localhost:8000/docs
- **ReDoc**: http://localhost:8000/redoc
- **OpenAPI JSON**: http://localhost:8000/api/v1/openapi.json
--- ---
## License ## 🚢 Deployment
MIT License - see [LICENSE](./LICENSE) for details. ### Docker Production Deployment
```bash
# Build and start all services
docker-compose up -d
# Run migrations
docker-compose exec backend alembic upgrade head
# View logs
docker-compose logs -f
# Stop services
docker-compose down
```
### Production Checklist
- [ ] Change default superuser credentials
- [ ] Set strong `SECRET_KEY` in backend `.env`
- [ ] Configure production database (PostgreSQL)
- [ ] Set `ENVIRONMENT=production` in backend
- [ ] Configure CORS origins for your domain
- [ ] Setup SSL/TLS certificates
- [ ] Configure email service for password resets
- [ ] Setup monitoring and logging
- [ ] Configure backup strategy
- [ ] Review and adjust rate limits
- [ ] Test security headers
--- ---
## Acknowledgments ## 🛣️ Roadmap & Status
- Built on [PragmaStack](https://gitea.pragmazest.com/cardosofelipe/fast-next-template) ### ✅ Completed
- Powered by Claude and the Anthropic API - [x] Authentication system (JWT, refresh tokens, session management, OAuth)
- [x] User management (full lifecycle, profile, password change)
- [x] Organization system with RBAC (Owner, Admin, Member)
- [x] Admin panel (users, organizations, sessions, statistics)
- [x] **Internationalization (i18n)** with next-intl (English + Italian)
- [x] Backend testing infrastructure (~97% coverage)
- [x] Frontend unit testing infrastructure (~97% coverage)
- [x] Frontend E2E testing (Playwright, zero flaky tests)
- [x] Design system documentation
- [x] **Marketing landing page** with animated components
- [x] **`/dev` documentation portal** with live component examples
- [x] **Toast notifications** system (Sonner)
- [x] **Charts and visualizations** (Recharts)
- [x] **Animation system** (Framer Motion)
- [x] **Markdown rendering** with syntax highlighting
- [x] **SEO optimization** (sitemap, robots.txt, locale-aware metadata)
- [x] Database migrations with helper script
- [x] Docker deployment
- [x] API documentation (OpenAPI/Swagger)
### 🚧 In Progress
- [ ] Email integration (templates ready, SMTP pending)
### 🔮 Planned
- [ ] GitHub Actions CI/CD pipelines
- [ ] Dynamic test coverage badges from CI
- [ ] E2E test coverage reporting
- [ ] OAuth token encryption at rest (security hardening)
- [ ] Additional languages (Spanish, French, German, etc.)
- [ ] SSO/SAML authentication
- [ ] Real-time notifications with WebSockets
- [ ] Webhook system
- [ ] File upload/storage (S3-compatible)
- [ ] Audit logging system
- [ ] API versioning example
---
## 🤝 Contributing
Contributions are welcome! Whether you're fixing bugs, improving documentation, or proposing new features, we'd love your help.
### How to Contribute
1. **Fork the repository**
2. **Create a feature branch** (`git checkout -b feature/amazing-feature`)
3. **Make your changes**
- Follow existing code style
- Add tests for new features
- Update documentation as needed
4. **Run tests** to ensure everything works
5. **Commit your changes** (`git commit -m 'Add amazing feature'`)
6. **Push to your branch** (`git push origin feature/amazing-feature`)
7. **Open a Pull Request**
### Development Guidelines
- Write tests for new features (aim for >90% coverage)
- Follow the existing architecture patterns
- Update documentation when adding features
- Keep commits atomic and well-described
- Be respectful and constructive in discussions
### Reporting Issues
Found a bug? Have a suggestion? [Open an issue](https://github.com/cardosofelipe/pragma-stack/issues)!
Please include:
- Clear description of the issue/suggestion
- Steps to reproduce (for bugs)
- Expected vs. actual behavior
- Environment details (OS, Python/Node version, etc.)
---
## 📄 License
This project is licensed under the **MIT License** - see the [LICENSE](./LICENSE) file for details.
**TL;DR**: You can use this template for any purpose, commercial or non-commercial. Attribution is appreciated but not required!
---
## 🙏 Acknowledgments
This template is built on the shoulders of giants:
- [FastAPI](https://fastapi.tiangolo.com/) by Sebastián Ramírez
- [Next.js](https://nextjs.org/) by Vercel
- [shadcn/ui](https://ui.shadcn.com/) by shadcn
- [TanStack Query](https://tanstack.com/query) by Tanner Linsley
- [Playwright](https://playwright.dev/) by Microsoft
- And countless other open-source projects that make modern development possible
---
## 💬 Questions?
- **Documentation**: Check the `/docs` folders in backend and frontend
- **Issues**: [GitHub Issues](https://github.com/cardosofelipe/pragma-stack/issues)
- **Discussions**: [GitHub Discussions](https://github.com/cardosofelipe/pragma-stack/discussions)
---
## ⭐ Star This Repo
If this template saves you time, consider giving it a star! It helps others discover the project and motivates continued development.
**Happy coding! 🚀**
---
<div align="center">
Made with ❤️ by a developer who got tired of rebuilding the same boilerplate
</div>

View File

@@ -11,7 +11,7 @@ omit =
app/utils/auth_test_utils.py app/utils/auth_test_utils.py
# Async implementations not yet in use # Async implementations not yet in use
app/crud/base_async.py app/repositories/base_async.py
app/core/database_async.py app/core/database_async.py
# CLI scripts - run manually, not tested # CLI scripts - run manually, not tested
@@ -23,7 +23,7 @@ omit =
app/api/routes/__init__.py app/api/routes/__init__.py
app/api/dependencies/__init__.py app/api/dependencies/__init__.py
app/core/__init__.py app/core/__init__.py
app/crud/__init__.py app/repositories/__init__.py
app/models/__init__.py app/models/__init__.py
app/schemas/__init__.py app/schemas/__init__.py
app/services/__init__.py app/services/__init__.py

View File

@@ -0,0 +1,44 @@
# Pre-commit hooks for backend quality and security checks.
#
# Install:
# cd backend && uv run pre-commit install
#
# Run manually on all files:
# cd backend && uv run pre-commit run --all-files
#
# Skip hooks temporarily:
# git commit --no-verify
#
repos:
# ── Code Quality ──────────────────────────────────────────────────────────
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.14.4
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
- id: ruff-format
# ── General File Hygiene ──────────────────────────────────────────────────
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-toml
- id: check-merge-conflict
- id: check-added-large-files
args: [--maxkb=500]
- id: debug-statements
# ── Security ──────────────────────────────────────────────────────────────
- repo: https://github.com/Yelp/detect-secrets
rev: v1.5.0
hooks:
- id: detect-secrets
args: ['--baseline', '.secrets.baseline']
exclude: |
(?x)^(
.*\.lock$|
.*\.svg$
)$

1073
backend/.secrets.baseline Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -7,10 +7,7 @@ ENV PYTHONDONTWRITEBYTECODE=1 \
PYTHONPATH=/app \ PYTHONPATH=/app \
UV_COMPILE_BYTECODE=1 \ UV_COMPILE_BYTECODE=1 \
UV_LINK_MODE=copy \ UV_LINK_MODE=copy \
UV_NO_CACHE=1 \ UV_NO_CACHE=1
UV_PROJECT_ENVIRONMENT=/opt/venv \
VIRTUAL_ENV=/opt/venv \
PATH="/opt/venv/bin:$PATH"
# Install system dependencies and uv # Install system dependencies and uv
RUN apt-get update && \ RUN apt-get update && \
@@ -23,7 +20,7 @@ RUN apt-get update && \
# Copy dependency files # Copy dependency files
COPY pyproject.toml uv.lock ./ COPY pyproject.toml uv.lock ./
# Install dependencies using uv into /opt/venv (outside /app to survive bind mounts) # Install dependencies using uv (development mode with dev dependencies)
RUN uv sync --extra dev --frozen RUN uv sync --extra dev --frozen
# Copy application code # Copy application code
@@ -36,11 +33,11 @@ RUN chmod +x /usr/local/bin/entrypoint.sh
ENTRYPOINT ["/usr/local/bin/entrypoint.sh"] ENTRYPOINT ["/usr/local/bin/entrypoint.sh"]
# Production stage # Production stage — Alpine eliminates glibc CVEs (e.g. CVE-2026-0861)
FROM python:3.12-slim AS production FROM python:3.12-alpine AS production
# Create non-root user # Create non-root user
RUN groupadd -r appuser && useradd -r -g appuser appuser RUN addgroup -S appuser && adduser -S -G appuser appuser
WORKDIR /app WORKDIR /app
ENV PYTHONDONTWRITEBYTECODE=1 \ ENV PYTHONDONTWRITEBYTECODE=1 \
@@ -48,24 +45,21 @@ ENV PYTHONDONTWRITEBYTECODE=1 \
PYTHONPATH=/app \ PYTHONPATH=/app \
UV_COMPILE_BYTECODE=1 \ UV_COMPILE_BYTECODE=1 \
UV_LINK_MODE=copy \ UV_LINK_MODE=copy \
UV_NO_CACHE=1 \ UV_NO_CACHE=1
UV_PROJECT_ENVIRONMENT=/opt/venv \
VIRTUAL_ENV=/opt/venv \
PATH="/opt/venv/bin:$PATH"
# Install system dependencies and uv # Install system dependencies and uv
RUN apt-get update && \ RUN apk add --no-cache postgresql-client curl ca-certificates && \
apt-get install -y --no-install-recommends postgresql-client curl ca-certificates && \
curl -LsSf https://astral.sh/uv/install.sh | sh && \ curl -LsSf https://astral.sh/uv/install.sh | sh && \
mv /root/.local/bin/uv* /usr/local/bin/ && \ mv /root/.local/bin/uv* /usr/local/bin/
apt-get clean && \
rm -rf /var/lib/apt/lists/*
# Copy dependency files # Copy dependency files
COPY pyproject.toml uv.lock ./ COPY pyproject.toml uv.lock ./
# Install only production dependencies using uv into /opt/venv # Install build dependencies, compile Python packages, then remove build deps
RUN uv sync --frozen --no-dev RUN apk add --no-cache --virtual .build-deps \
gcc g++ musl-dev python3-dev linux-headers libffi-dev openssl-dev && \
uv sync --frozen --no-dev && \
apk del .build-deps
# Copy application code # Copy application code
COPY . . COPY . .
@@ -73,7 +67,7 @@ COPY entrypoint.sh /usr/local/bin/
RUN chmod +x /usr/local/bin/entrypoint.sh RUN chmod +x /usr/local/bin/entrypoint.sh
# Set ownership to non-root user # Set ownership to non-root user
RUN chown -R appuser:appuser /app /opt/venv RUN chown -R appuser:appuser /app
# Switch to non-root user # Switch to non-root user
USER appuser USER appuser
@@ -83,4 +77,4 @@ HEALTHCHECK --interval=30s --timeout=10s --start-period=40s --retries=3 \
CMD curl -f http://localhost:8000/health || exit 1 CMD curl -f http://localhost:8000/health || exit 1
ENTRYPOINT ["/usr/local/bin/entrypoint.sh"] ENTRYPOINT ["/usr/local/bin/entrypoint.sh"]
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"] CMD ["uv", "run", "uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]

View File

@@ -1,4 +1,7 @@
.PHONY: help lint lint-fix format format-check type-check test test-cov validate clean install-dev sync check-docker install-e2e test-e2e test-e2e-schema test-all test-integration .PHONY: help lint lint-fix format format-check type-check test test-cov validate clean install-dev sync check-docker install-e2e test-e2e test-e2e-schema test-all dep-audit license-check audit validate-all check benchmark benchmark-check benchmark-save scan-image test-api-security
# Prevent a stale VIRTUAL_ENV in the caller's shell from confusing uv
unexport VIRTUAL_ENV
# Default target # Default target
help: help:
@@ -14,17 +17,30 @@ help:
@echo " make lint-fix - Run Ruff linter with auto-fix" @echo " make lint-fix - Run Ruff linter with auto-fix"
@echo " make format - Format code with Ruff" @echo " make format - Format code with Ruff"
@echo " make format-check - Check if code is formatted" @echo " make format-check - Check if code is formatted"
@echo " make type-check - Run mypy type checking" @echo " make type-check - Run pyright type checking"
@echo " make validate - Run all checks (lint + format + types)" @echo " make validate - Run all checks (lint + format + types + schema fuzz)"
@echo ""
@echo "Performance:"
@echo " make benchmark - Run performance benchmarks"
@echo " make benchmark-save - Run benchmarks and save as baseline"
@echo " make benchmark-check - Run benchmarks and compare against baseline"
@echo ""
@echo "Security & Audit:"
@echo " make dep-audit - Scan dependencies for known vulnerabilities"
@echo " make license-check - Check dependency license compliance"
@echo " make audit - Run all security audits (deps + licenses)"
@echo " make scan-image - Scan Docker image for CVEs (requires trivy)"
@echo " make validate-all - Run all quality + security checks"
@echo " make check - Full pipeline: quality + security + tests"
@echo "" @echo ""
@echo "Testing:" @echo "Testing:"
@echo " make test - Run pytest (unit/integration, SQLite)" @echo " make test - Run pytest (unit/integration, SQLite)"
@echo " make test-cov - Run pytest with coverage report" @echo " make test-cov - Run pytest with coverage report"
@echo " make test-e2e - Run E2E tests (PostgreSQL, requires Docker)" @echo " make test-e2e - Run E2E tests (PostgreSQL, requires Docker)"
@echo " make test-e2e-schema - Run Schemathesis API schema tests" @echo " make test-e2e-schema - Run Schemathesis API schema tests"
@echo " make test-integration - Run MCP integration tests (requires running stack)"
@echo " make test-all - Run all tests (unit + E2E)" @echo " make test-all - Run all tests (unit + E2E)"
@echo " make check-docker - Check if Docker is available" @echo " make check-docker - Check if Docker is available"
@echo " make check - Full pipeline: quality + security + tests"
@echo "" @echo ""
@echo "Cleanup:" @echo "Cleanup:"
@echo " make clean - Remove cache and build artifacts" @echo " make clean - Remove cache and build artifacts"
@@ -64,12 +80,52 @@ format-check:
@uv run ruff format --check app/ tests/ @uv run ruff format --check app/ tests/
type-check: type-check:
@echo "🔎 Running mypy type checking..." @echo "🔎 Running pyright type checking..."
@uv run mypy app/ @uv run pyright app/
validate: lint format-check type-check validate: lint format-check type-check test-api-security
@echo "✅ All quality checks passed!" @echo "✅ All quality checks passed!"
# API Security Testing (Schemathesis property-based fuzzing)
test-api-security: check-docker
@echo "🔐 Running Schemathesis API security fuzzing..."
@IS_TEST=True PYTHONPATH=. uv run pytest tests/e2e/ -v -m "schemathesis" --tb=short -n 0
@echo "✅ API schema security tests passed!"
# ============================================================================
# Security & Audit
# ============================================================================
dep-audit:
@echo "🔒 Scanning dependencies for known vulnerabilities..."
@uv run pip-audit --desc --skip-editable
@echo "✅ No known vulnerabilities found!"
license-check:
@echo "📜 Checking dependency license compliance..."
@uv run pip-licenses --fail-on="GPL-3.0-or-later;AGPL-3.0-or-later" --format=plain > /dev/null
@echo "✅ All dependency licenses are compliant!"
audit: dep-audit license-check
@echo "✅ All security audits passed!"
scan-image: check-docker
@echo "🐳 Scanning Docker image for OS-level CVEs with Trivy..."
@docker build -t pragma-backend:scan -q --target production .
@if command -v trivy > /dev/null 2>&1; then \
trivy image --severity HIGH,CRITICAL --exit-code 1 pragma-backend:scan; \
else \
echo " Trivy not found locally, using Docker to run Trivy..."; \
docker run --rm -v /var/run/docker.sock:/var/run/docker.sock aquasec/trivy image --severity HIGH,CRITICAL --exit-code 1 pragma-backend:scan; \
fi
@echo "✅ No HIGH/CRITICAL CVEs found in Docker image!"
validate-all: validate audit
@echo "✅ All quality + security checks passed!"
check: validate-all test
@echo "✅ Full validation pipeline complete!"
# ============================================================================ # ============================================================================
# Testing # Testing
# ============================================================================ # ============================================================================
@@ -83,15 +139,6 @@ test-cov:
@IS_TEST=True PYTHONPATH=. uv run pytest --cov=app --cov-report=term-missing --cov-report=html -n 16 @IS_TEST=True PYTHONPATH=. uv run pytest --cov=app --cov-report=term-missing --cov-report=html -n 16
@echo "📊 Coverage report generated in htmlcov/index.html" @echo "📊 Coverage report generated in htmlcov/index.html"
# ============================================================================
# Integration Testing (requires running stack: make dev)
# ============================================================================
test-integration:
@echo "🧪 Running MCP integration tests..."
@echo "Note: Requires running stack (make dev from project root)"
@RUN_INTEGRATION_TESTS=true IS_TEST=True PYTHONPATH=. uv run pytest tests/integration/ -v
# ============================================================================ # ============================================================================
# E2E Testing (requires Docker) # E2E Testing (requires Docker)
# ============================================================================ # ============================================================================
@@ -124,6 +171,31 @@ test-e2e-schema: check-docker
@echo "🧪 Running Schemathesis API schema tests..." @echo "🧪 Running Schemathesis API schema tests..."
@IS_TEST=True PYTHONPATH=. uv run pytest tests/e2e/ -v -m "schemathesis" --tb=short -n 0 @IS_TEST=True PYTHONPATH=. uv run pytest tests/e2e/ -v -m "schemathesis" --tb=short -n 0
# ============================================================================
# Performance Benchmarks
# ============================================================================
benchmark:
@echo "⏱️ Running performance benchmarks..."
@IS_TEST=True PYTHONPATH=. uv run pytest tests/benchmarks/ -v --benchmark-only --benchmark-sort=mean -p no:xdist --override-ini='addopts='
benchmark-save:
@echo "⏱️ Running benchmarks and saving baseline..."
@IS_TEST=True PYTHONPATH=. uv run pytest tests/benchmarks/ -v --benchmark-only --benchmark-save=baseline --benchmark-sort=mean -p no:xdist --override-ini='addopts='
@echo "✅ Benchmark baseline saved to .benchmarks/"
benchmark-check:
@echo "⏱️ Running benchmarks and comparing against baseline..."
@if find .benchmarks -name '*_baseline*' -print -quit 2>/dev/null | grep -q .; then \
IS_TEST=True PYTHONPATH=. uv run pytest tests/benchmarks/ -v --benchmark-only --benchmark-compare=0001_baseline --benchmark-sort=mean --benchmark-compare-fail=mean:200% -p no:xdist --override-ini='addopts='; \
echo "✅ No performance regressions detected!"; \
else \
echo "⚠️ No benchmark baseline found. Run 'make benchmark-save' first to create one."; \
echo " Running benchmarks without comparison..."; \
IS_TEST=True PYTHONPATH=. uv run pytest tests/benchmarks/ -v --benchmark-only --benchmark-save=baseline --benchmark-sort=mean -p no:xdist --override-ini='addopts='; \
echo "✅ Benchmark baseline created. Future runs of 'make benchmark-check' will compare against it."; \
fi
test-all: test-all:
@echo "🧪 Running ALL tests (unit + E2E)..." @echo "🧪 Running ALL tests (unit + E2E)..."
@$(MAKE) test @$(MAKE) test
@@ -137,7 +209,7 @@ clean:
@echo "🧹 Cleaning up..." @echo "🧹 Cleaning up..."
@find . -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true @find . -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true
@find . -type d -name ".pytest_cache" -exec rm -rf {} + 2>/dev/null || true @find . -type d -name ".pytest_cache" -exec rm -rf {} + 2>/dev/null || true
@find . -type d -name ".mypy_cache" -exec rm -rf {} + 2>/dev/null || true @find . -type d -name ".pyright" -exec rm -rf {} + 2>/dev/null || true
@find . -type d -name ".ruff_cache" -exec rm -rf {} + 2>/dev/null || true @find . -type d -name ".ruff_cache" -exec rm -rf {} + 2>/dev/null || true
@find . -type d -name "*.egg-info" -exec rm -rf {} + 2>/dev/null || true @find . -type d -name "*.egg-info" -exec rm -rf {} + 2>/dev/null || true
@find . -type d -name "htmlcov" -exec rm -rf {} + 2>/dev/null || true @find . -type d -name "htmlcov" -exec rm -rf {} + 2>/dev/null || true

View File

@@ -1,6 +1,6 @@
# Syndarix Backend API # PragmaStack Backend API
> The pragmatic, production-ready FastAPI backend for Syndarix. > The pragmatic, production-ready FastAPI backend for PragmaStack.
## Overview ## Overview
@@ -14,7 +14,9 @@ Features:
- **Multi-tenancy**: Organization-based access control with roles (Owner/Admin/Member) - **Multi-tenancy**: Organization-based access control with roles (Owner/Admin/Member)
- **Testing**: 97%+ coverage with security-focused test suite - **Testing**: 97%+ coverage with security-focused test suite
- **Performance**: Async throughout, connection pooling, optimized queries - **Performance**: Async throughout, connection pooling, optimized queries
- **Modern Tooling**: uv for dependencies, Ruff for linting/formatting, mypy for type checking - **Modern Tooling**: uv for dependencies, Ruff for linting/formatting, Pyright for type checking
- **Security Auditing**: Automated dependency vulnerability scanning, license compliance, secrets detection
- **Pre-commit Hooks**: Ruff, detect-secrets, and standard checks on every commit
## Quick Start ## Quick Start
@@ -149,7 +151,7 @@ uv pip list --outdated
# Run any Python command via uv (no activation needed) # Run any Python command via uv (no activation needed)
uv run python script.py uv run python script.py
uv run pytest uv run pytest
uv run mypy app/ uv run pyright app/
# Or activate the virtual environment # Or activate the virtual environment
source .venv/bin/activate source .venv/bin/activate
@@ -171,12 +173,22 @@ make lint # Run Ruff linter (check only)
make lint-fix # Run Ruff with auto-fix make lint-fix # Run Ruff with auto-fix
make format # Format code with Ruff make format # Format code with Ruff
make format-check # Check if code is formatted make format-check # Check if code is formatted
make type-check # Run mypy type checking make type-check # Run Pyright type checking
make validate # Run all checks (lint + format + types) make validate # Run all checks (lint + format + types)
# Security & Audit
make dep-audit # Scan dependencies for known vulnerabilities (CVEs)
make license-check # Check dependency license compliance
make audit # Run all security audits (deps + licenses)
make validate-all # Run all quality + security checks
make check # Full pipeline: quality + security + tests
# Testing # Testing
make test # Run all tests make test # Run all tests
make test-cov # Run tests with coverage report make test-cov # Run tests with coverage report
make test-e2e # Run E2E tests (PostgreSQL, requires Docker)
make test-e2e-schema # Run Schemathesis API schema tests
make test-all # Run all tests (unit + E2E)
# Utilities # Utilities
make clean # Remove cache and build artifacts make clean # Remove cache and build artifacts
@@ -252,7 +264,7 @@ app/
│ ├── database.py # Database engine setup │ ├── database.py # Database engine setup
│ ├── auth.py # JWT token handling │ ├── auth.py # JWT token handling
│ └── exceptions.py # Custom exceptions │ └── exceptions.py # Custom exceptions
├── crud/ # Database operations ├── repositories/ # Repository pattern (database operations)
├── models/ # SQLAlchemy ORM models ├── models/ # SQLAlchemy ORM models
├── schemas/ # Pydantic request/response schemas ├── schemas/ # Pydantic request/response schemas
├── services/ # Business logic layer ├── services/ # Business logic layer
@@ -352,18 +364,29 @@ open htmlcov/index.html
# Using Makefile (recommended) # Using Makefile (recommended)
make lint # Ruff linting make lint # Ruff linting
make format # Ruff formatting make format # Ruff formatting
make type-check # mypy type checking make type-check # Pyright type checking
make validate # All checks at once make validate # All checks at once
# Security audits
make dep-audit # Scan dependencies for CVEs
make license-check # Check license compliance
make audit # All security audits
make validate-all # Quality + security checks
make check # Full pipeline: quality + security + tests
# Using uv directly # Using uv directly
uv run ruff check app/ tests/ uv run ruff check app/ tests/
uv run ruff format app/ tests/ uv run ruff format app/ tests/
uv run mypy app/ uv run pyright app/
``` ```
**Tools:** **Tools:**
- **Ruff**: All-in-one linting, formatting, and import sorting (replaces Black, Flake8, isort) - **Ruff**: All-in-one linting, formatting, and import sorting (replaces Black, Flake8, isort)
- **mypy**: Static type checking with Pydantic plugin - **Pyright**: Static type checking (strict mode)
- **pip-audit**: Dependency vulnerability scanning against the OSV database
- **pip-licenses**: Dependency license compliance checking
- **detect-secrets**: Hardcoded secrets/credentials detection
- **pre-commit**: Git hook framework for automated checks on every commit
All configurations are in `pyproject.toml`. All configurations are in `pyproject.toml`.
@@ -439,7 +462,7 @@ See [docs/FEATURE_EXAMPLE.md](docs/FEATURE_EXAMPLE.md) for step-by-step guide.
Quick overview: Quick overview:
1. Create Pydantic schemas in `app/schemas/` 1. Create Pydantic schemas in `app/schemas/`
2. Create CRUD operations in `app/crud/` 2. Create repository in `app/repositories/`
3. Create route in `app/api/routes/` 3. Create route in `app/api/routes/`
4. Register router in `app/api/main.py` 4. Register router in `app/api/main.py`
5. Write tests in `tests/api/` 5. Write tests in `tests/api/`
@@ -589,13 +612,42 @@ Configured in `app/core/config.py`:
- **Security Headers**: CSP, HSTS, X-Frame-Options, etc. - **Security Headers**: CSP, HSTS, X-Frame-Options, etc.
- **Input Validation**: Pydantic schemas, SQL injection prevention (ORM) - **Input Validation**: Pydantic schemas, SQL injection prevention (ORM)
### Security Auditing
Automated, deterministic security checks are built into the development workflow:
```bash
# Scan dependencies for known vulnerabilities (CVEs)
make dep-audit
# Check dependency license compliance (blocks GPL-3.0/AGPL)
make license-check
# Run all security audits
make audit
# Full pipeline: quality + security + tests
make check
```
**Pre-commit hooks** automatically run on every commit:
- **Ruff** lint + format checks
- **detect-secrets** blocks commits containing hardcoded secrets
- **Standard checks**: trailing whitespace, YAML/TOML validation, merge conflict detection, large file prevention
Setup pre-commit hooks:
```bash
uv run pre-commit install
```
### Security Best Practices ### Security Best Practices
1. **Never commit secrets**: Use `.env` files (git-ignored) 1. **Never commit secrets**: Use `.env` files (git-ignored), enforced by detect-secrets pre-commit hook
2. **Strong SECRET_KEY**: Min 32 chars, cryptographically random 2. **Strong SECRET_KEY**: Min 32 chars, cryptographically random
3. **HTTPS in production**: Required for token security 3. **HTTPS in production**: Required for token security
4. **Regular updates**: Keep dependencies current (`uv sync --upgrade`) 4. **Regular updates**: Keep dependencies current (`uv sync --upgrade`), run `make dep-audit` to check for CVEs
5. **Audit logs**: Monitor authentication events 5. **Audit logs**: Monitor authentication events
6. **Run `make check` before pushing**: Validates quality, security, and tests in one command
--- ---
@@ -645,7 +697,11 @@ logging.basicConfig(level=logging.INFO)
**Built with modern Python tooling:** **Built with modern Python tooling:**
- 🚀 **uv** - 10-100x faster dependency management - 🚀 **uv** - 10-100x faster dependency management
-**Ruff** - 10-100x faster linting & formatting -**Ruff** - 10-100x faster linting & formatting
- 🔍 **mypy** - Static type checking - 🔍 **Pyright** - Static type checking (strict mode)
-**pytest** - Comprehensive test suite -**pytest** - Comprehensive test suite
- 🔒 **pip-audit** - Dependency vulnerability scanning
- 🔑 **detect-secrets** - Hardcoded secrets detection
- 📜 **pip-licenses** - License compliance checking
- 🪝 **pre-commit** - Automated git hooks
**All configured in a single `pyproject.toml` file!** **All configured in a single `pyproject.toml` file!**

View File

@@ -1,66 +0,0 @@
"""Enable pgvector extension
Revision ID: 0003
Revises: 0002
Create Date: 2025-12-30
This migration enables the pgvector extension for PostgreSQL, which provides
vector similarity search capabilities required for the RAG (Retrieval-Augmented
Generation) knowledge base system.
Vector Dimension Reference (per ADR-008 and SPIKE-006):
---------------------------------------------------------
The dimension size depends on the embedding model used:
| Model | Dimensions | Use Case |
|----------------------------|------------|------------------------------|
| text-embedding-3-small | 1536 | General docs, conversations |
| text-embedding-3-large | 256-3072 | High accuracy (configurable) |
| voyage-code-3 | 1024 | Code files (Python, JS, etc) |
| voyage-3-large | 1024 | High quality general purpose |
| nomic-embed-text (Ollama) | 768 | Local/fallback embedding |
Recommended defaults for Syndarix:
- Documentation/conversations: 1536 (text-embedding-3-small)
- Code files: 1024 (voyage-code-3)
Prerequisites:
--------------
This migration requires PostgreSQL with the pgvector extension installed.
The Docker Compose configuration uses `pgvector/pgvector:pg17` which includes
the extension pre-installed.
References:
-----------
- ADR-008: Knowledge Base and RAG Architecture
- SPIKE-006: Knowledge Base with pgvector for RAG System
- https://github.com/pgvector/pgvector
"""
from collections.abc import Sequence
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "0003"
down_revision: str | None = "0002"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Enable the pgvector extension.
The CREATE EXTENSION IF NOT EXISTS statement is idempotent - it will
succeed whether the extension already exists or not.
"""
op.execute("CREATE EXTENSION IF NOT EXISTS vector")
def downgrade() -> None:
"""Drop the pgvector extension.
Note: This will fail if any tables with vector columns exist.
Future migrations that create vector columns should be downgraded first.
"""
op.execute("DROP EXTENSION IF EXISTS vector")

View File

@@ -0,0 +1,35 @@
"""rename oauth account token fields drop encrypted suffix
Revision ID: 0003
Revises: 0002
Create Date: 2026-02-27 01:03:18.869178
"""
from collections.abc import Sequence
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "0003"
down_revision: str | None = "0002"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
op.alter_column(
"oauth_accounts", "access_token_encrypted", new_column_name="access_token"
)
op.alter_column(
"oauth_accounts", "refresh_token_encrypted", new_column_name="refresh_token"
)
def downgrade() -> None:
op.alter_column(
"oauth_accounts", "access_token", new_column_name="access_token_encrypted"
)
op.alter_column(
"oauth_accounts", "refresh_token", new_column_name="refresh_token_encrypted"
)

View File

@@ -1,507 +0,0 @@
"""Add Syndarix models
Revision ID: 0004
Revises: 0003
Create Date: 2025-12-31
This migration creates the core Syndarix domain tables:
- projects: Client engagement projects
- agent_types: Agent template configurations
- agent_instances: Spawned agent instances assigned to projects
- sprints: Sprint containers for issues
- issues: Work items (epics, stories, tasks, bugs)
"""
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = "0004"
down_revision: str | None = "0003"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Create Syndarix domain tables."""
# =========================================================================
# Create projects table
# Note: ENUM types are created automatically by sa.Enum() during table creation
# =========================================================================
op.create_table(
"projects",
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("name", sa.String(255), nullable=False),
sa.Column("slug", sa.String(255), nullable=False),
sa.Column("description", sa.Text(), nullable=True),
sa.Column(
"autonomy_level",
sa.Enum(
"full_control",
"milestone",
"autonomous",
name="autonomy_level",
),
nullable=False,
server_default="milestone",
),
sa.Column(
"status",
sa.Enum(
"active",
"paused",
"completed",
"archived",
name="project_status",
),
nullable=False,
server_default="active",
),
sa.Column(
"complexity",
sa.Enum(
"script",
"simple",
"medium",
"complex",
name="project_complexity",
),
nullable=False,
server_default="medium",
),
sa.Column(
"client_mode",
sa.Enum("technical", "auto", name="client_mode"),
nullable=False,
server_default="auto",
),
sa.Column(
"settings",
postgresql.JSONB(astext_type=sa.Text()),
nullable=False,
server_default="{}",
),
sa.Column("owner_id", postgresql.UUID(as_uuid=True), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text("now()"),
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text("now()"),
),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["owner_id"], ["users.id"], ondelete="SET NULL"),
sa.UniqueConstraint("slug"),
)
# Single column indexes
op.create_index("ix_projects_name", "projects", ["name"])
op.create_index("ix_projects_slug", "projects", ["slug"])
op.create_index("ix_projects_status", "projects", ["status"])
op.create_index("ix_projects_autonomy_level", "projects", ["autonomy_level"])
op.create_index("ix_projects_complexity", "projects", ["complexity"])
op.create_index("ix_projects_client_mode", "projects", ["client_mode"])
op.create_index("ix_projects_owner_id", "projects", ["owner_id"])
# Composite indexes
op.create_index("ix_projects_slug_status", "projects", ["slug", "status"])
op.create_index("ix_projects_owner_status", "projects", ["owner_id", "status"])
op.create_index(
"ix_projects_autonomy_status", "projects", ["autonomy_level", "status"]
)
op.create_index(
"ix_projects_complexity_status", "projects", ["complexity", "status"]
)
# =========================================================================
# Create agent_types table
# =========================================================================
op.create_table(
"agent_types",
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("name", sa.String(255), nullable=False),
sa.Column("slug", sa.String(255), nullable=False),
sa.Column("description", sa.Text(), nullable=True),
# Areas of expertise (e.g., ["python", "fastapi", "databases"])
sa.Column(
"expertise",
postgresql.JSONB(astext_type=sa.Text()),
nullable=False,
server_default="[]",
),
# System prompt defining personality and behavior (required)
sa.Column("personality_prompt", sa.Text(), nullable=False),
# LLM model configuration
sa.Column("primary_model", sa.String(100), nullable=False),
sa.Column(
"fallback_models",
postgresql.JSONB(astext_type=sa.Text()),
nullable=False,
server_default="[]",
),
# Model parameters (temperature, max_tokens, etc.)
sa.Column(
"model_params",
postgresql.JSONB(astext_type=sa.Text()),
nullable=False,
server_default="{}",
),
# MCP servers this agent can connect to
sa.Column(
"mcp_servers",
postgresql.JSONB(astext_type=sa.Text()),
nullable=False,
server_default="[]",
),
# Tool permissions configuration
sa.Column(
"tool_permissions",
postgresql.JSONB(astext_type=sa.Text()),
nullable=False,
server_default="{}",
),
sa.Column("is_active", sa.Boolean(), nullable=False, server_default="true"),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text("now()"),
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text("now()"),
),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("slug"),
)
# Single column indexes
op.create_index("ix_agent_types_name", "agent_types", ["name"])
op.create_index("ix_agent_types_slug", "agent_types", ["slug"])
op.create_index("ix_agent_types_is_active", "agent_types", ["is_active"])
# Composite indexes
op.create_index("ix_agent_types_slug_active", "agent_types", ["slug", "is_active"])
op.create_index("ix_agent_types_name_active", "agent_types", ["name", "is_active"])
# =========================================================================
# Create agent_instances table
# =========================================================================
op.create_table(
"agent_instances",
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("agent_type_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("project_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("name", sa.String(100), nullable=False),
sa.Column(
"status",
sa.Enum(
"idle",
"working",
"waiting",
"paused",
"terminated",
name="agent_status",
),
nullable=False,
server_default="idle",
),
sa.Column("current_task", sa.Text(), nullable=True),
# Short-term memory (conversation context, recent decisions)
sa.Column(
"short_term_memory",
postgresql.JSONB(astext_type=sa.Text()),
nullable=False,
server_default="{}",
),
# Reference to long-term memory in vector store
sa.Column("long_term_memory_ref", sa.String(500), nullable=True),
# Session ID for active MCP connections
sa.Column("session_id", sa.String(255), nullable=True),
# Activity tracking
sa.Column("last_activity_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("terminated_at", sa.DateTime(timezone=True), nullable=True),
# Usage metrics
sa.Column("tasks_completed", sa.Integer(), nullable=False, server_default="0"),
sa.Column("tokens_used", sa.BigInteger(), nullable=False, server_default="0"),
sa.Column(
"cost_incurred",
sa.Numeric(precision=10, scale=4),
nullable=False,
server_default="0",
),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text("now()"),
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text("now()"),
),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(
["agent_type_id"], ["agent_types.id"], ondelete="RESTRICT"
),
sa.ForeignKeyConstraint(["project_id"], ["projects.id"], ondelete="CASCADE"),
)
# Single column indexes
op.create_index("ix_agent_instances_name", "agent_instances", ["name"])
op.create_index("ix_agent_instances_status", "agent_instances", ["status"])
op.create_index(
"ix_agent_instances_agent_type_id", "agent_instances", ["agent_type_id"]
)
op.create_index("ix_agent_instances_project_id", "agent_instances", ["project_id"])
op.create_index("ix_agent_instances_session_id", "agent_instances", ["session_id"])
op.create_index(
"ix_agent_instances_last_activity_at", "agent_instances", ["last_activity_at"]
)
op.create_index(
"ix_agent_instances_terminated_at", "agent_instances", ["terminated_at"]
)
# Composite indexes
op.create_index(
"ix_agent_instances_project_status",
"agent_instances",
["project_id", "status"],
)
op.create_index(
"ix_agent_instances_type_status",
"agent_instances",
["agent_type_id", "status"],
)
op.create_index(
"ix_agent_instances_project_type",
"agent_instances",
["project_id", "agent_type_id"],
)
# =========================================================================
# Create sprints table (before issues for FK reference)
# =========================================================================
op.create_table(
"sprints",
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("project_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("name", sa.String(255), nullable=False),
sa.Column("number", sa.Integer(), nullable=False),
sa.Column("goal", sa.Text(), nullable=True),
sa.Column("start_date", sa.Date(), nullable=False),
sa.Column("end_date", sa.Date(), nullable=False),
sa.Column(
"status",
sa.Enum(
"planned",
"active",
"in_review",
"completed",
"cancelled",
name="sprint_status",
),
nullable=False,
server_default="planned",
),
sa.Column("planned_points", sa.Integer(), nullable=True),
sa.Column("velocity", sa.Integer(), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text("now()"),
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text("now()"),
),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["project_id"], ["projects.id"], ondelete="CASCADE"),
sa.UniqueConstraint("project_id", "number", name="uq_sprint_project_number"),
)
# Single column indexes
op.create_index("ix_sprints_project_id", "sprints", ["project_id"])
op.create_index("ix_sprints_status", "sprints", ["status"])
op.create_index("ix_sprints_start_date", "sprints", ["start_date"])
op.create_index("ix_sprints_end_date", "sprints", ["end_date"])
# Composite indexes
op.create_index("ix_sprints_project_status", "sprints", ["project_id", "status"])
op.create_index("ix_sprints_project_number", "sprints", ["project_id", "number"])
op.create_index("ix_sprints_date_range", "sprints", ["start_date", "end_date"])
# =========================================================================
# Create issues table
# =========================================================================
op.create_table(
"issues",
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("project_id", postgresql.UUID(as_uuid=True), nullable=False),
# Parent issue for hierarchy (Epic -> Story -> Task)
sa.Column("parent_id", postgresql.UUID(as_uuid=True), nullable=True),
# Issue type (epic, story, task, bug)
sa.Column(
"type",
sa.Enum(
"epic",
"story",
"task",
"bug",
name="issue_type",
),
nullable=False,
server_default="task",
),
# Reporter (who created this issue)
sa.Column("reporter_id", postgresql.UUID(as_uuid=True), nullable=True),
# Issue content
sa.Column("title", sa.String(500), nullable=False),
sa.Column("body", sa.Text(), nullable=False, server_default=""),
# Status and priority
sa.Column(
"status",
sa.Enum(
"open",
"in_progress",
"in_review",
"blocked",
"closed",
name="issue_status",
),
nullable=False,
server_default="open",
),
sa.Column(
"priority",
sa.Enum(
"low",
"medium",
"high",
"critical",
name="issue_priority",
),
nullable=False,
server_default="medium",
),
# Labels for categorization
sa.Column(
"labels",
postgresql.JSONB(astext_type=sa.Text()),
nullable=False,
server_default="[]",
),
# Assignment - agent or human (mutually exclusive)
sa.Column("assigned_agent_id", postgresql.UUID(as_uuid=True), nullable=True),
sa.Column("human_assignee", sa.String(255), nullable=True),
# Sprint association
sa.Column("sprint_id", postgresql.UUID(as_uuid=True), nullable=True),
# Estimation
sa.Column("story_points", sa.Integer(), nullable=True),
sa.Column("due_date", sa.Date(), nullable=True),
# External tracker integration (String for flexibility)
sa.Column("external_tracker_type", sa.String(50), nullable=True),
sa.Column("external_issue_id", sa.String(255), nullable=True),
sa.Column("remote_url", sa.String(1000), nullable=True),
sa.Column("external_issue_number", sa.Integer(), nullable=True),
# Sync status
sa.Column(
"sync_status",
sa.Enum(
"synced",
"pending",
"conflict",
"error",
name="sync_status",
),
nullable=False,
server_default="synced",
),
sa.Column("last_synced_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("external_updated_at", sa.DateTime(timezone=True), nullable=True),
# Lifecycle
sa.Column("closed_at", sa.DateTime(timezone=True), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text("now()"),
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text("now()"),
),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["project_id"], ["projects.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(["parent_id"], ["issues.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(["sprint_id"], ["sprints.id"], ondelete="SET NULL"),
sa.ForeignKeyConstraint(
["assigned_agent_id"], ["agent_instances.id"], ondelete="SET NULL"
),
)
# Single column indexes
op.create_index("ix_issues_project_id", "issues", ["project_id"])
op.create_index("ix_issues_parent_id", "issues", ["parent_id"])
op.create_index("ix_issues_type", "issues", ["type"])
op.create_index("ix_issues_reporter_id", "issues", ["reporter_id"])
op.create_index("ix_issues_status", "issues", ["status"])
op.create_index("ix_issues_priority", "issues", ["priority"])
op.create_index("ix_issues_assigned_agent_id", "issues", ["assigned_agent_id"])
op.create_index("ix_issues_human_assignee", "issues", ["human_assignee"])
op.create_index("ix_issues_sprint_id", "issues", ["sprint_id"])
op.create_index("ix_issues_due_date", "issues", ["due_date"])
op.create_index(
"ix_issues_external_tracker_type", "issues", ["external_tracker_type"]
)
op.create_index("ix_issues_sync_status", "issues", ["sync_status"])
op.create_index("ix_issues_closed_at", "issues", ["closed_at"])
# Composite indexes
op.create_index("ix_issues_project_status", "issues", ["project_id", "status"])
op.create_index("ix_issues_project_priority", "issues", ["project_id", "priority"])
op.create_index("ix_issues_project_sprint", "issues", ["project_id", "sprint_id"])
op.create_index("ix_issues_project_type", "issues", ["project_id", "type"])
op.create_index(
"ix_issues_project_agent", "issues", ["project_id", "assigned_agent_id"]
)
op.create_index(
"ix_issues_project_status_priority",
"issues",
["project_id", "status", "priority"],
)
op.create_index(
"ix_issues_external_tracker_id",
"issues",
["external_tracker_type", "external_issue_id"],
)
def downgrade() -> None:
"""Drop Syndarix domain tables."""
# Drop tables in reverse order (respecting FK constraints)
op.drop_table("issues")
op.drop_table("sprints")
op.drop_table("agent_instances")
op.drop_table("agent_types")
op.drop_table("projects")
# Drop ENUM types
op.execute("DROP TYPE IF EXISTS sprint_status")
op.execute("DROP TYPE IF EXISTS sync_status")
op.execute("DROP TYPE IF EXISTS issue_priority")
op.execute("DROP TYPE IF EXISTS issue_status")
op.execute("DROP TYPE IF EXISTS issue_type")
op.execute("DROP TYPE IF EXISTS agent_status")
op.execute("DROP TYPE IF EXISTS client_mode")
op.execute("DROP TYPE IF EXISTS project_complexity")
op.execute("DROP TYPE IF EXISTS project_status")
op.execute("DROP TYPE IF EXISTS autonomy_level")

View File

@@ -1,12 +1,12 @@
from fastapi import Depends, Header, HTTPException, status from fastapi import Depends, Header, HTTPException, status
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from fastapi.security.utils import get_authorization_scheme_param from fastapi.security.utils import get_authorization_scheme_param
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.core.auth import TokenExpiredError, TokenInvalidError, get_token_data from app.core.auth import TokenExpiredError, TokenInvalidError, get_token_data
from app.core.database import get_db from app.core.database import get_db
from app.models.user import User from app.models.user import User
from app.repositories.user import user_repo
# OAuth2 configuration # OAuth2 configuration
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
@@ -32,9 +32,8 @@ async def get_current_user(
# Decode token and get user ID # Decode token and get user ID
token_data = get_token_data(token) token_data = get_token_data(token)
# Get user from database # Get user from database via repository
result = await db.execute(select(User).where(User.id == token_data.user_id)) user = await user_repo.get(db, id=str(token_data.user_id))
user = result.scalar_one_or_none()
if not user: if not user:
raise HTTPException( raise HTTPException(
@@ -144,90 +143,9 @@ async def get_optional_current_user(
try: try:
token_data = get_token_data(token) token_data = get_token_data(token)
result = await db.execute(select(User).where(User.id == token_data.user_id)) user = await user_repo.get(db, id=str(token_data.user_id))
user = result.scalar_one_or_none()
if not user or not user.is_active: if not user or not user.is_active:
return None return None
return user return user
except (TokenExpiredError, TokenInvalidError): except (TokenExpiredError, TokenInvalidError):
return None return None
async def get_current_user_sse(
db: AsyncSession = Depends(get_db),
authorization: str | None = Header(None),
token: str | None = None, # Query parameter - passed directly from route
) -> User:
"""
Get the current authenticated user for SSE endpoints.
SSE (Server-Sent Events) via EventSource API doesn't support custom headers,
so this dependency accepts tokens from either:
1. Authorization header (preferred, for non-EventSource clients)
2. Query parameter 'token' (fallback for EventSource compatibility)
Security note: Query parameter tokens appear in server logs and browser history.
Consider implementing short-lived SSE-specific tokens for production if this
is a concern. The current approach is acceptable for internal/trusted networks.
Args:
db: Database session
authorization: Authorization header (Bearer token)
token: Query parameter token (fallback for EventSource)
Returns:
User: The authenticated user
Raises:
HTTPException: If authentication fails
"""
# Try Authorization header first (preferred)
auth_token = None
if authorization:
scheme, param = get_authorization_scheme_param(authorization)
if scheme.lower() == "bearer" and param:
auth_token = param
# Fall back to query parameter if no header token
if not auth_token and token:
auth_token = token
if not auth_token:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers={"WWW-Authenticate": "Bearer"},
)
try:
# Decode token and get user ID
token_data = get_token_data(auth_token)
# Get user from database
result = await db.execute(select(User).where(User.id == token_data.user_id))
user = result.scalar_one_or_none()
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
)
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user"
)
return user
except TokenExpiredError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token expired",
headers={"WWW-Authenticate": "Bearer"},
)
except TokenInvalidError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)

View File

@@ -1,36 +0,0 @@
"""
Event bus dependency for FastAPI routes.
This module provides the FastAPI dependency for injecting the EventBus
into route handlers. The event bus is a singleton that maintains
Redis pub/sub connections for real-time event streaming.
"""
from app.services.event_bus import (
EventBus,
get_connected_event_bus as _get_connected_event_bus,
)
async def get_event_bus() -> EventBus:
"""
FastAPI dependency that provides a connected EventBus instance.
The EventBus is a singleton that maintains Redis pub/sub connections.
It's lazily initialized and connected on first access, and should be
closed during application shutdown via close_event_bus().
Usage:
@router.get("/events/stream")
async def stream_events(
event_bus: EventBus = Depends(get_event_bus)
):
...
Returns:
EventBus: The global connected event bus instance
Raises:
EventBusConnectionError: If connection to Redis fails
"""
return await _get_connected_event_bus()

View File

@@ -15,9 +15,9 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.api.dependencies.auth import get_current_user from app.api.dependencies.auth import get_current_user
from app.core.database import get_db from app.core.database import get_db
from app.crud.organization import organization as organization_crud
from app.models.user import User from app.models.user import User
from app.models.user_organization import OrganizationRole from app.models.user_organization import OrganizationRole
from app.services.organization_service import organization_service
def require_superuser(current_user: User = Depends(get_current_user)) -> User: def require_superuser(current_user: User = Depends(get_current_user)) -> User:
@@ -81,7 +81,7 @@ class OrganizationPermission:
return current_user return current_user
# Get user's role in organization # Get user's role in organization
user_role = await organization_crud.get_user_role_in_org( user_role = await organization_service.get_user_role_in_org(
db, user_id=current_user.id, organization_id=organization_id db, user_id=current_user.id, organization_id=organization_id
) )
@@ -123,7 +123,7 @@ async def require_org_membership(
if current_user.is_superuser: if current_user.is_superuser:
return current_user return current_user
user_role = await organization_crud.get_user_role_in_org( user_role = await organization_service.get_user_role_in_org(
db, user_id=current_user.id, organization_id=organization_id db, user_id=current_user.id, organization_id=organization_id
) )

View File

@@ -0,0 +1,41 @@
# app/api/dependencies/services.py
"""FastAPI dependency functions for service singletons."""
from app.services import oauth_provider_service
from app.services.auth_service import AuthService
from app.services.oauth_service import OAuthService
from app.services.organization_service import OrganizationService, organization_service
from app.services.session_service import SessionService, session_service
from app.services.user_service import UserService, user_service
def get_auth_service() -> AuthService:
"""Return the AuthService singleton for dependency injection."""
from app.services.auth_service import AuthService as _AuthService
return _AuthService()
def get_user_service() -> UserService:
"""Return the UserService singleton for dependency injection."""
return user_service
def get_organization_service() -> OrganizationService:
"""Return the OrganizationService singleton for dependency injection."""
return organization_service
def get_session_service() -> SessionService:
"""Return the SessionService singleton for dependency injection."""
return session_service
def get_oauth_service() -> OAuthService:
"""Return OAuthService for dependency injection."""
return OAuthService()
def get_oauth_provider_service():
"""Return the oauth_provider_service module for dependency injection."""
return oauth_provider_service

View File

@@ -2,19 +2,11 @@ from fastapi import APIRouter
from app.api.routes import ( from app.api.routes import (
admin, admin,
agent_types,
agents,
auth, auth,
context,
events,
issues,
mcp,
oauth, oauth,
oauth_provider, oauth_provider,
organizations, organizations,
projects,
sessions, sessions,
sprints,
users, users,
) )
@@ -30,25 +22,3 @@ api_router.include_router(admin.router, prefix="/admin", tags=["Admin"])
api_router.include_router( api_router.include_router(
organizations.router, prefix="/organizations", tags=["Organizations"] organizations.router, prefix="/organizations", tags=["Organizations"]
) )
# SSE events router - no prefix, routes define full paths
api_router.include_router(events.router, tags=["Events"])
# MCP (Model Context Protocol) router
api_router.include_router(mcp.router, prefix="/mcp", tags=["MCP"])
# Context Management Engine router
api_router.include_router(context.router, prefix="/context", tags=["Context"])
# Syndarix domain routers
api_router.include_router(projects.router, prefix="/projects", tags=["Projects"])
api_router.include_router(
agent_types.router, prefix="/agent-types", tags=["Agent Types"]
)
# Issues router - routes include /projects/{project_id}/issues paths
api_router.include_router(issues.router, tags=["Issues"])
# Agents router - routes include /projects/{project_id}/agents paths
api_router.include_router(agents.router, tags=["Agents"])
# Sprints router - routes need prefix as they use /projects/{project_id}/sprints paths
api_router.include_router(
sprints.router, prefix="/projects/{project_id}/sprints", tags=["Sprints"]
)

View File

@@ -14,7 +14,6 @@ from uuid import UUID
from fastapi import APIRouter, Depends, Query, status from fastapi import APIRouter, Depends, Query, status
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.api.dependencies.permissions import require_superuser from app.api.dependencies.permissions import require_superuser
@@ -25,12 +24,9 @@ from app.core.exceptions import (
ErrorCode, ErrorCode,
NotFoundError, NotFoundError,
) )
from app.crud.organization import organization as organization_crud from app.core.repository_exceptions import DuplicateEntryError
from app.crud.session import session as session_crud
from app.crud.user import user as user_crud
from app.models.organization import Organization
from app.models.user import User from app.models.user import User
from app.models.user_organization import OrganizationRole, UserOrganization from app.models.user_organization import OrganizationRole
from app.schemas.common import ( from app.schemas.common import (
MessageResponse, MessageResponse,
PaginatedResponse, PaginatedResponse,
@@ -46,6 +42,9 @@ from app.schemas.organizations import (
) )
from app.schemas.sessions import AdminSessionResponse from app.schemas.sessions import AdminSessionResponse
from app.schemas.users import UserCreate, UserResponse, UserUpdate from app.schemas.users import UserCreate, UserResponse, UserUpdate
from app.services.organization_service import organization_service
from app.services.session_service import session_service
from app.services.user_service import user_service
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -66,7 +65,7 @@ class BulkUserAction(BaseModel):
action: BulkAction = Field(..., description="Action to perform on selected users") action: BulkAction = Field(..., description="Action to perform on selected users")
user_ids: list[UUID] = Field( user_ids: list[UUID] = Field(
..., min_items=1, max_items=100, description="List of user IDs (max 100)" ..., min_length=1, max_length=100, description="List of user IDs (max 100)"
) )
@@ -178,38 +177,29 @@ async def admin_get_stats(
"""Get admin dashboard statistics with real data from database.""" """Get admin dashboard statistics with real data from database."""
from app.core.config import settings from app.core.config import settings
# Check if we have any data stats = await user_service.get_stats(db)
total_users_query = select(func.count()).select_from(User) total_users = stats["total_users"]
total_users = (await db.execute(total_users_query)).scalar() or 0 active_count = stats["active_count"]
inactive_count = stats["inactive_count"]
all_users = stats["all_users"]
# If database is essentially empty (only admin user), return demo data # If database is essentially empty (only admin user), return demo data
if total_users <= 1 and settings.DEMO_MODE: # pragma: no cover if total_users <= 1 and settings.DEMO_MODE: # pragma: no cover
logger.info("Returning demo stats data (empty database in demo mode)") logger.info("Returning demo stats data (empty database in demo mode)")
return _generate_demo_stats() return _generate_demo_stats()
# 1. User Growth (Last 30 days) - Improved calculation # 1. User Growth (Last 30 days)
datetime.now(UTC) - timedelta(days=30)
# Get all users with their creation dates
all_users_query = select(User).order_by(User.created_at)
result = await db.execute(all_users_query)
all_users = result.scalars().all()
# Build cumulative counts per day
user_growth = [] user_growth = []
for i in range(29, -1, -1): for i in range(29, -1, -1):
date = datetime.now(UTC) - timedelta(days=i) date = datetime.now(UTC) - timedelta(days=i)
date_start = date.replace(hour=0, minute=0, second=0, microsecond=0, tzinfo=UTC) date_start = date.replace(hour=0, minute=0, second=0, microsecond=0, tzinfo=UTC)
date_end = date_start + timedelta(days=1) date_end = date_start + timedelta(days=1)
# Count all users created before end of this day
# Make comparison timezone-aware
total_users_on_date = sum( total_users_on_date = sum(
1 1
for u in all_users for u in all_users
if u.created_at and u.created_at.replace(tzinfo=UTC) < date_end if u.created_at and u.created_at.replace(tzinfo=UTC) < date_end
) )
# Count active users created before end of this day
active_users_on_date = sum( active_users_on_date = sum(
1 1
for u in all_users for u in all_users
@@ -227,27 +217,16 @@ async def admin_get_stats(
) )
# 2. Organization Distribution - Top 6 organizations by member count # 2. Organization Distribution - Top 6 organizations by member count
org_query = ( org_rows = await organization_service.get_org_distribution(db, limit=6)
select(Organization.name, func.count(UserOrganization.user_id).label("count")) org_dist = [OrgDistributionData(name=r["name"], value=r["value"]) for r in org_rows]
.join(UserOrganization, Organization.id == UserOrganization.organization_id)
.group_by(Organization.name)
.order_by(func.count(UserOrganization.user_id).desc())
.limit(6)
)
result = await db.execute(org_query)
org_dist = [
OrgDistributionData(name=row.name, value=row.count) for row in result.all()
]
# 3. User Registration Activity (Last 14 days) - NEW # 3. User Registration Activity (Last 14 days)
registration_activity = [] registration_activity = []
for i in range(13, -1, -1): for i in range(13, -1, -1):
date = datetime.now(UTC) - timedelta(days=i) date = datetime.now(UTC) - timedelta(days=i)
date_start = date.replace(hour=0, minute=0, second=0, microsecond=0, tzinfo=UTC) date_start = date.replace(hour=0, minute=0, second=0, microsecond=0, tzinfo=UTC)
date_end = date_start + timedelta(days=1) date_end = date_start + timedelta(days=1)
# Count users created on this specific day
# Make comparison timezone-aware
day_registrations = sum( day_registrations = sum(
1 1
for u in all_users for u in all_users
@@ -263,16 +242,8 @@ async def admin_get_stats(
) )
# 4. User Status - Active vs Inactive # 4. User Status - Active vs Inactive
active_query = select(func.count()).select_from(User).where(User.is_active)
inactive_query = (
select(func.count()).select_from(User).where(User.is_active.is_(False))
)
active_count = (await db.execute(active_query)).scalar() or 0
inactive_count = (await db.execute(inactive_query)).scalar() or 0
logger.info( logger.info(
f"User status counts - Active: {active_count}, Inactive: {inactive_count}" "User status counts - Active: %s, Inactive: %s", active_count, inactive_count
) )
user_status = [ user_status = [
@@ -321,7 +292,7 @@ async def admin_list_users(
filters["is_superuser"] = is_superuser filters["is_superuser"] = is_superuser
# Get users with search # Get users with search
users, total = await user_crud.get_multi_with_total( users, total = await user_service.list_users(
db, db,
skip=pagination.offset, skip=pagination.offset,
limit=pagination.limit, limit=pagination.limit,
@@ -341,7 +312,7 @@ async def admin_list_users(
return PaginatedResponse(data=users, pagination=pagination_meta) return PaginatedResponse(data=users, pagination=pagination_meta)
except Exception as e: except Exception as e:
logger.error(f"Error listing users (admin): {e!s}", exc_info=True) logger.exception("Error listing users (admin): %s", e)
raise raise
@@ -364,14 +335,14 @@ async def admin_create_user(
Allows setting is_superuser and other fields. Allows setting is_superuser and other fields.
""" """
try: try:
user = await user_crud.create(db, obj_in=user_in) user = await user_service.create_user(db, user_in)
logger.info(f"Admin {admin.email} created user {user.email}") logger.info("Admin %s created user %s", admin.email, user.email)
return user return user
except ValueError as e: except DuplicateEntryError as e:
logger.warning(f"Failed to create user: {e!s}") logger.warning("Failed to create user: %s", e)
raise NotFoundError(message=str(e), error_code=ErrorCode.USER_ALREADY_EXISTS) raise DuplicateError(message=str(e), error_code=ErrorCode.USER_ALREADY_EXISTS)
except Exception as e: except Exception as e:
logger.error(f"Error creating user (admin): {e!s}", exc_info=True) logger.exception("Error creating user (admin): %s", e)
raise raise
@@ -388,11 +359,7 @@ async def admin_get_user(
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
) -> Any: ) -> Any:
"""Get detailed information about a specific user.""" """Get detailed information about a specific user."""
user = await user_crud.get(db, id=user_id) user = await user_service.get_user(db, str(user_id))
if not user:
raise NotFoundError(
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
)
return user return user
@@ -411,20 +378,13 @@ async def admin_update_user(
) -> Any: ) -> Any:
"""Update user information with admin privileges.""" """Update user information with admin privileges."""
try: try:
user = await user_crud.get(db, id=user_id) user = await user_service.get_user(db, str(user_id))
if not user: updated_user = await user_service.update_user(db, user=user, obj_in=user_in)
raise NotFoundError( logger.info("Admin %s updated user %s", admin.email, updated_user.email)
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
)
updated_user = await user_crud.update(db, db_obj=user, obj_in=user_in)
logger.info(f"Admin {admin.email} updated user {updated_user.email}")
return updated_user return updated_user
except NotFoundError:
raise
except Exception as e: except Exception as e:
logger.error(f"Error updating user (admin): {e!s}", exc_info=True) logger.exception("Error updating user (admin): %s", e)
raise raise
@@ -442,11 +402,7 @@ async def admin_delete_user(
) -> Any: ) -> Any:
"""Soft delete a user (sets deleted_at timestamp).""" """Soft delete a user (sets deleted_at timestamp)."""
try: try:
user = await user_crud.get(db, id=user_id) user = await user_service.get_user(db, str(user_id))
if not user:
raise NotFoundError(
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
)
# Prevent deleting yourself # Prevent deleting yourself
if user.id == admin.id: if user.id == admin.id:
@@ -456,17 +412,15 @@ async def admin_delete_user(
error_code=ErrorCode.OPERATION_FORBIDDEN, error_code=ErrorCode.OPERATION_FORBIDDEN,
) )
await user_crud.soft_delete(db, id=user_id) await user_service.soft_delete_user(db, str(user_id))
logger.info(f"Admin {admin.email} deleted user {user.email}") logger.info("Admin %s deleted user %s", admin.email, user.email)
return MessageResponse( return MessageResponse(
success=True, message=f"User {user.email} has been deleted" success=True, message=f"User {user.email} has been deleted"
) )
except NotFoundError:
raise
except Exception as e: except Exception as e:
logger.error(f"Error deleting user (admin): {e!s}", exc_info=True) logger.exception("Error deleting user (admin): %s", e)
raise raise
@@ -484,23 +438,16 @@ async def admin_activate_user(
) -> Any: ) -> Any:
"""Activate a user account.""" """Activate a user account."""
try: try:
user = await user_crud.get(db, id=user_id) user = await user_service.get_user(db, str(user_id))
if not user: await user_service.update_user(db, user=user, obj_in={"is_active": True})
raise NotFoundError( logger.info("Admin %s activated user %s", admin.email, user.email)
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
)
await user_crud.update(db, db_obj=user, obj_in={"is_active": True})
logger.info(f"Admin {admin.email} activated user {user.email}")
return MessageResponse( return MessageResponse(
success=True, message=f"User {user.email} has been activated" success=True, message=f"User {user.email} has been activated"
) )
except NotFoundError:
raise
except Exception as e: except Exception as e:
logger.error(f"Error activating user (admin): {e!s}", exc_info=True) logger.exception("Error activating user (admin): %s", e)
raise raise
@@ -518,11 +465,7 @@ async def admin_deactivate_user(
) -> Any: ) -> Any:
"""Deactivate a user account.""" """Deactivate a user account."""
try: try:
user = await user_crud.get(db, id=user_id) user = await user_service.get_user(db, str(user_id))
if not user:
raise NotFoundError(
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
)
# Prevent deactivating yourself # Prevent deactivating yourself
if user.id == admin.id: if user.id == admin.id:
@@ -532,17 +475,15 @@ async def admin_deactivate_user(
error_code=ErrorCode.OPERATION_FORBIDDEN, error_code=ErrorCode.OPERATION_FORBIDDEN,
) )
await user_crud.update(db, db_obj=user, obj_in={"is_active": False}) await user_service.update_user(db, user=user, obj_in={"is_active": False})
logger.info(f"Admin {admin.email} deactivated user {user.email}") logger.info("Admin %s deactivated user %s", admin.email, user.email)
return MessageResponse( return MessageResponse(
success=True, message=f"User {user.email} has been deactivated" success=True, message=f"User {user.email} has been deactivated"
) )
except NotFoundError:
raise
except Exception as e: except Exception as e:
logger.error(f"Error deactivating user (admin): {e!s}", exc_info=True) logger.exception("Error deactivating user (admin): %s", e)
raise raise
@@ -567,16 +508,16 @@ async def admin_bulk_user_action(
try: try:
# Use efficient bulk operations instead of loop # Use efficient bulk operations instead of loop
if bulk_action.action == BulkAction.ACTIVATE: if bulk_action.action == BulkAction.ACTIVATE:
affected_count = await user_crud.bulk_update_status( affected_count = await user_service.bulk_update_status(
db, user_ids=bulk_action.user_ids, is_active=True db, user_ids=bulk_action.user_ids, is_active=True
) )
elif bulk_action.action == BulkAction.DEACTIVATE: elif bulk_action.action == BulkAction.DEACTIVATE:
affected_count = await user_crud.bulk_update_status( affected_count = await user_service.bulk_update_status(
db, user_ids=bulk_action.user_ids, is_active=False db, user_ids=bulk_action.user_ids, is_active=False
) )
elif bulk_action.action == BulkAction.DELETE: elif bulk_action.action == BulkAction.DELETE:
# bulk_soft_delete automatically excludes the admin user # bulk_soft_delete automatically excludes the admin user
affected_count = await user_crud.bulk_soft_delete( affected_count = await user_service.bulk_soft_delete(
db, user_ids=bulk_action.user_ids, exclude_user_id=admin.id db, user_ids=bulk_action.user_ids, exclude_user_id=admin.id
) )
else: # pragma: no cover else: # pragma: no cover
@@ -587,8 +528,11 @@ async def admin_bulk_user_action(
failed_count = requested_count - affected_count failed_count = requested_count - affected_count
logger.info( logger.info(
f"Admin {admin.email} performed bulk {bulk_action.action.value} " "Admin %s performed bulk %s on %s users (%s skipped/failed)",
f"on {affected_count} users ({failed_count} skipped/failed)" admin.email,
bulk_action.action.value,
affected_count,
failed_count,
) )
return BulkActionResult( return BulkActionResult(
@@ -600,7 +544,7 @@ async def admin_bulk_user_action(
) )
except Exception as e: # pragma: no cover except Exception as e: # pragma: no cover
logger.error(f"Error in bulk user action: {e!s}", exc_info=True) logger.exception("Error in bulk user action: %s", e)
raise raise
@@ -624,7 +568,7 @@ async def admin_list_organizations(
"""List all organizations with filtering and search.""" """List all organizations with filtering and search."""
try: try:
# Use optimized method that gets member counts in single query (no N+1) # Use optimized method that gets member counts in single query (no N+1)
orgs_with_data, total = await organization_crud.get_multi_with_member_counts( orgs_with_data, total = await organization_service.get_multi_with_member_counts(
db, db,
skip=pagination.offset, skip=pagination.offset,
limit=pagination.limit, limit=pagination.limit,
@@ -661,7 +605,7 @@ async def admin_list_organizations(
return PaginatedResponse(data=orgs_with_count, pagination=pagination_meta) return PaginatedResponse(data=orgs_with_count, pagination=pagination_meta)
except Exception as e: except Exception as e:
logger.error(f"Error listing organizations (admin): {e!s}", exc_info=True) logger.exception("Error listing organizations (admin): %s", e)
raise raise
@@ -680,8 +624,8 @@ async def admin_create_organization(
) -> Any: ) -> Any:
"""Create a new organization.""" """Create a new organization."""
try: try:
org = await organization_crud.create(db, obj_in=org_in) org = await organization_service.create_organization(db, obj_in=org_in)
logger.info(f"Admin {admin.email} created organization {org.name}") logger.info("Admin %s created organization %s", admin.email, org.name)
# Add member count # Add member count
org_dict = { org_dict = {
@@ -697,11 +641,11 @@ async def admin_create_organization(
} }
return OrganizationResponse(**org_dict) return OrganizationResponse(**org_dict)
except ValueError as e: except DuplicateEntryError as e:
logger.warning(f"Failed to create organization: {e!s}") logger.warning("Failed to create organization: %s", e)
raise NotFoundError(message=str(e), error_code=ErrorCode.ALREADY_EXISTS) raise DuplicateError(message=str(e), error_code=ErrorCode.ALREADY_EXISTS)
except Exception as e: except Exception as e:
logger.error(f"Error creating organization (admin): {e!s}", exc_info=True) logger.exception("Error creating organization (admin): %s", e)
raise raise
@@ -718,12 +662,7 @@ async def admin_get_organization(
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
) -> Any: ) -> Any:
"""Get detailed information about a specific organization.""" """Get detailed information about a specific organization."""
org = await organization_crud.get(db, id=org_id) org = await organization_service.get_organization(db, str(org_id))
if not org:
raise NotFoundError(
message=f"Organization {org_id} not found", error_code=ErrorCode.NOT_FOUND
)
org_dict = { org_dict = {
"id": org.id, "id": org.id,
"name": org.name, "name": org.name,
@@ -733,7 +672,7 @@ async def admin_get_organization(
"settings": org.settings, "settings": org.settings,
"created_at": org.created_at, "created_at": org.created_at,
"updated_at": org.updated_at, "updated_at": org.updated_at,
"member_count": await organization_crud.get_member_count( "member_count": await organization_service.get_member_count(
db, organization_id=org.id db, organization_id=org.id
), ),
} }
@@ -755,15 +694,11 @@ async def admin_update_organization(
) -> Any: ) -> Any:
"""Update organization information.""" """Update organization information."""
try: try:
org = await organization_crud.get(db, id=org_id) org = await organization_service.get_organization(db, str(org_id))
if not org: updated_org = await organization_service.update_organization(
raise NotFoundError( db, org=org, obj_in=org_in
message=f"Organization {org_id} not found", )
error_code=ErrorCode.NOT_FOUND, logger.info("Admin %s updated organization %s", admin.email, updated_org.name)
)
updated_org = await organization_crud.update(db, db_obj=org, obj_in=org_in)
logger.info(f"Admin {admin.email} updated organization {updated_org.name}")
org_dict = { org_dict = {
"id": updated_org.id, "id": updated_org.id,
@@ -774,16 +709,14 @@ async def admin_update_organization(
"settings": updated_org.settings, "settings": updated_org.settings,
"created_at": updated_org.created_at, "created_at": updated_org.created_at,
"updated_at": updated_org.updated_at, "updated_at": updated_org.updated_at,
"member_count": await organization_crud.get_member_count( "member_count": await organization_service.get_member_count(
db, organization_id=updated_org.id db, organization_id=updated_org.id
), ),
} }
return OrganizationResponse(**org_dict) return OrganizationResponse(**org_dict)
except NotFoundError:
raise
except Exception as e: except Exception as e:
logger.error(f"Error updating organization (admin): {e!s}", exc_info=True) logger.exception("Error updating organization (admin): %s", e)
raise raise
@@ -801,24 +734,16 @@ async def admin_delete_organization(
) -> Any: ) -> Any:
"""Delete an organization and all its relationships.""" """Delete an organization and all its relationships."""
try: try:
org = await organization_crud.get(db, id=org_id) org = await organization_service.get_organization(db, str(org_id))
if not org: await organization_service.remove_organization(db, str(org_id))
raise NotFoundError( logger.info("Admin %s deleted organization %s", admin.email, org.name)
message=f"Organization {org_id} not found",
error_code=ErrorCode.NOT_FOUND,
)
await organization_crud.remove(db, id=org_id)
logger.info(f"Admin {admin.email} deleted organization {org.name}")
return MessageResponse( return MessageResponse(
success=True, message=f"Organization {org.name} has been deleted" success=True, message=f"Organization {org.name} has been deleted"
) )
except NotFoundError:
raise
except Exception as e: except Exception as e:
logger.error(f"Error deleting organization (admin): {e!s}", exc_info=True) logger.exception("Error deleting organization (admin): %s", e)
raise raise
@@ -838,14 +763,8 @@ async def admin_list_organization_members(
) -> Any: ) -> Any:
"""List all members of an organization.""" """List all members of an organization."""
try: try:
org = await organization_crud.get(db, id=org_id) await organization_service.get_organization(db, str(org_id)) # validates exists
if not org: members, total = await organization_service.get_organization_members(
raise NotFoundError(
message=f"Organization {org_id} not found",
error_code=ErrorCode.NOT_FOUND,
)
members, total = await organization_crud.get_organization_members(
db, db,
organization_id=org_id, organization_id=org_id,
skip=pagination.offset, skip=pagination.offset,
@@ -868,9 +787,7 @@ async def admin_list_organization_members(
except NotFoundError: except NotFoundError:
raise raise
except Exception as e: except Exception as e:
logger.error( logger.exception("Error listing organization members (admin): %s", e)
f"Error listing organization members (admin): {e!s}", exc_info=True
)
raise raise
@@ -898,45 +815,32 @@ async def admin_add_organization_member(
) -> Any: ) -> Any:
"""Add a user to an organization.""" """Add a user to an organization."""
try: try:
org = await organization_crud.get(db, id=org_id) org = await organization_service.get_organization(db, str(org_id))
if not org: user = await user_service.get_user(db, str(request.user_id))
raise NotFoundError(
message=f"Organization {org_id} not found",
error_code=ErrorCode.NOT_FOUND,
)
user = await user_crud.get(db, id=request.user_id) await organization_service.add_member(
if not user:
raise NotFoundError(
message=f"User {request.user_id} not found",
error_code=ErrorCode.USER_NOT_FOUND,
)
await organization_crud.add_user(
db, organization_id=org_id, user_id=request.user_id, role=request.role db, organization_id=org_id, user_id=request.user_id, role=request.role
) )
logger.info( logger.info(
f"Admin {admin.email} added user {user.email} to organization {org.name} " "Admin %s added user %s to organization %s with role %s",
f"with role {request.role.value}" admin.email,
user.email,
org.name,
request.role.value,
) )
return MessageResponse( return MessageResponse(
success=True, message=f"User {user.email} added to organization {org.name}" success=True, message=f"User {user.email} added to organization {org.name}"
) )
except ValueError as e: except DuplicateEntryError as e:
logger.warning(f"Failed to add user to organization: {e!s}") logger.warning("Failed to add user to organization: %s", e)
# Use DuplicateError for "already exists" scenarios
raise DuplicateError( raise DuplicateError(
message=str(e), error_code=ErrorCode.USER_ALREADY_EXISTS, field="user_id" message=str(e), error_code=ErrorCode.USER_ALREADY_EXISTS, field="user_id"
) )
except NotFoundError:
raise
except Exception as e: except Exception as e:
logger.error( logger.exception("Error adding member to organization (admin): %s", e)
f"Error adding member to organization (admin): {e!s}", exc_info=True
)
raise raise
@@ -955,20 +859,10 @@ async def admin_remove_organization_member(
) -> Any: ) -> Any:
"""Remove a user from an organization.""" """Remove a user from an organization."""
try: try:
org = await organization_crud.get(db, id=org_id) org = await organization_service.get_organization(db, str(org_id))
if not org: user = await user_service.get_user(db, str(user_id))
raise NotFoundError(
message=f"Organization {org_id} not found",
error_code=ErrorCode.NOT_FOUND,
)
user = await user_crud.get(db, id=user_id) success = await organization_service.remove_member(
if not user:
raise NotFoundError(
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
)
success = await organization_crud.remove_user(
db, organization_id=org_id, user_id=user_id db, organization_id=org_id, user_id=user_id
) )
@@ -979,7 +873,10 @@ async def admin_remove_organization_member(
) )
logger.info( logger.info(
f"Admin {admin.email} removed user {user.email} from organization {org.name}" "Admin %s removed user %s from organization %s",
admin.email,
user.email,
org.name,
) )
return MessageResponse( return MessageResponse(
@@ -990,9 +887,7 @@ async def admin_remove_organization_member(
except NotFoundError: except NotFoundError:
raise raise
except Exception as e: # pragma: no cover except Exception as e: # pragma: no cover
logger.error( logger.exception("Error removing member from organization (admin): %s", e)
f"Error removing member from organization (admin): {e!s}", exc_info=True
)
raise raise
@@ -1022,7 +917,7 @@ async def admin_list_sessions(
"""List all sessions across all users with filtering and pagination.""" """List all sessions across all users with filtering and pagination."""
try: try:
# Get sessions with user info (eager loaded to prevent N+1) # Get sessions with user info (eager loaded to prevent N+1)
sessions, total = await session_crud.get_all_sessions( sessions, total = await session_service.get_all_sessions(
db, db,
skip=pagination.offset, skip=pagination.offset,
limit=pagination.limit, limit=pagination.limit,
@@ -1061,7 +956,10 @@ async def admin_list_sessions(
session_responses.append(session_response) session_responses.append(session_response)
logger.info( logger.info(
f"Admin {admin.email} listed {len(session_responses)} sessions (total: {total})" "Admin %s listed %s sessions (total: %s)",
admin.email,
len(session_responses),
total,
) )
pagination_meta = create_pagination_meta( pagination_meta = create_pagination_meta(
@@ -1074,5 +972,5 @@ async def admin_list_sessions(
return PaginatedResponse(data=session_responses, pagination=pagination_meta) return PaginatedResponse(data=session_responses, pagination=pagination_meta)
except Exception as e: # pragma: no cover except Exception as e: # pragma: no cover
logger.error(f"Error listing sessions (admin): {e!s}", exc_info=True) logger.exception("Error listing sessions (admin): %s", e)
raise raise

View File

@@ -1,462 +0,0 @@
# app/api/routes/agent_types.py
"""
AgentType configuration API endpoints.
Provides CRUD operations for managing AI agent type templates.
Agent types define the base configuration (model, personality, expertise)
from which agent instances are spawned for projects.
Authorization:
- Read endpoints: Any authenticated user
- Write endpoints (create, update, delete): Superusers only
"""
import logging
import os
from typing import Any
from uuid import UUID
from fastapi import APIRouter, Depends, Query, Request, status
from slowapi import Limiter
from slowapi.util import get_remote_address
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.dependencies.auth import get_current_user
from app.api.dependencies.permissions import require_superuser
from app.core.database import get_db
from app.core.exceptions import (
DuplicateError,
ErrorCode,
NotFoundError,
)
from app.crud.syndarix.agent_type import agent_type as agent_type_crud
from app.models.user import User
from app.schemas.common import (
MessageResponse,
PaginatedResponse,
PaginationParams,
create_pagination_meta,
)
from app.schemas.syndarix import (
AgentTypeCreate,
AgentTypeResponse,
AgentTypeUpdate,
)
router = APIRouter()
logger = logging.getLogger(__name__)
# Initialize limiter for this router
limiter = Limiter(key_func=get_remote_address)
# Use higher rate limits in test environment
IS_TEST = os.getenv("IS_TEST", "False") == "True"
RATE_MULTIPLIER = 100 if IS_TEST else 1
def _build_agent_type_response(
agent_type: Any,
instance_count: int = 0,
) -> AgentTypeResponse:
"""
Build an AgentTypeResponse from a database model.
Args:
agent_type: AgentType model instance
instance_count: Number of agent instances for this type
Returns:
AgentTypeResponse schema
"""
return AgentTypeResponse(
id=agent_type.id,
name=agent_type.name,
slug=agent_type.slug,
description=agent_type.description,
expertise=agent_type.expertise,
personality_prompt=agent_type.personality_prompt,
primary_model=agent_type.primary_model,
fallback_models=agent_type.fallback_models,
model_params=agent_type.model_params,
mcp_servers=agent_type.mcp_servers,
tool_permissions=agent_type.tool_permissions,
is_active=agent_type.is_active,
created_at=agent_type.created_at,
updated_at=agent_type.updated_at,
instance_count=instance_count,
)
# ===== Write Endpoints (Admin Only) =====
@router.post(
"",
response_model=AgentTypeResponse,
status_code=status.HTTP_201_CREATED,
summary="Create Agent Type",
description="Create a new agent type configuration (admin only)",
operation_id="create_agent_type",
)
@limiter.limit(f"{20 * RATE_MULTIPLIER}/minute")
async def create_agent_type(
request: Request,
agent_type_in: AgentTypeCreate,
admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Create a new agent type configuration.
Agent types define templates for AI agents including:
- Model configuration (primary model, fallback models, parameters)
- Personality and expertise areas
- MCP server integrations and tool permissions
Requires superuser privileges.
Args:
request: FastAPI request object
agent_type_in: Agent type creation data
admin: Authenticated superuser
db: Database session
Returns:
The created agent type configuration
Raises:
DuplicateError: If slug already exists
"""
try:
agent_type = await agent_type_crud.create(db, obj_in=agent_type_in)
logger.info(
f"Admin {admin.email} created agent type: {agent_type.name} "
f"(slug: {agent_type.slug})"
)
return _build_agent_type_response(agent_type, instance_count=0)
except ValueError as e:
logger.warning(f"Failed to create agent type: {e!s}")
raise DuplicateError(
message=str(e),
error_code=ErrorCode.ALREADY_EXISTS,
field="slug",
)
except Exception as e:
logger.error(f"Error creating agent type: {e!s}", exc_info=True)
raise
@router.patch(
"/{agent_type_id}",
response_model=AgentTypeResponse,
summary="Update Agent Type",
description="Update an existing agent type configuration (admin only)",
operation_id="update_agent_type",
)
@limiter.limit(f"{30 * RATE_MULTIPLIER}/minute")
async def update_agent_type(
request: Request,
agent_type_id: UUID,
agent_type_in: AgentTypeUpdate,
admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Update an existing agent type configuration.
Partial updates are supported - only provided fields will be updated.
Requires superuser privileges.
Args:
request: FastAPI request object
agent_type_id: UUID of the agent type to update
agent_type_in: Agent type update data
admin: Authenticated superuser
db: Database session
Returns:
The updated agent type configuration
Raises:
NotFoundError: If agent type not found
DuplicateError: If new slug already exists
"""
try:
# Verify agent type exists
result = await agent_type_crud.get_with_instance_count(
db, agent_type_id=agent_type_id
)
if not result:
raise NotFoundError(
message=f"Agent type {agent_type_id} not found",
error_code=ErrorCode.NOT_FOUND,
)
existing_type = result["agent_type"]
instance_count = result["instance_count"]
# Perform update
updated_type = await agent_type_crud.update(
db, db_obj=existing_type, obj_in=agent_type_in
)
logger.info(
f"Admin {admin.email} updated agent type: {updated_type.name} "
f"(id: {agent_type_id})"
)
return _build_agent_type_response(updated_type, instance_count=instance_count)
except NotFoundError:
raise
except ValueError as e:
logger.warning(f"Failed to update agent type {agent_type_id}: {e!s}")
raise DuplicateError(
message=str(e),
error_code=ErrorCode.ALREADY_EXISTS,
field="slug",
)
except Exception as e:
logger.error(f"Error updating agent type {agent_type_id}: {e!s}", exc_info=True)
raise
@router.delete(
"/{agent_type_id}",
response_model=MessageResponse,
summary="Deactivate Agent Type",
description="Deactivate an agent type (soft delete, admin only)",
operation_id="deactivate_agent_type",
)
@limiter.limit(f"{10 * RATE_MULTIPLIER}/minute")
async def deactivate_agent_type(
request: Request,
agent_type_id: UUID,
admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Deactivate an agent type (soft delete).
This sets is_active=False rather than deleting the record,
preserving referential integrity with existing agent instances.
Requires superuser privileges.
Args:
request: FastAPI request object
agent_type_id: UUID of the agent type to deactivate
admin: Authenticated superuser
db: Database session
Returns:
Success message
Raises:
NotFoundError: If agent type not found
"""
try:
deactivated = await agent_type_crud.deactivate(db, agent_type_id=agent_type_id)
if not deactivated:
raise NotFoundError(
message=f"Agent type {agent_type_id} not found",
error_code=ErrorCode.NOT_FOUND,
)
logger.info(
f"Admin {admin.email} deactivated agent type: {deactivated.name} "
f"(id: {agent_type_id})"
)
return MessageResponse(
success=True,
message=f"Agent type '{deactivated.name}' has been deactivated",
)
except NotFoundError:
raise
except Exception as e:
logger.error(
f"Error deactivating agent type {agent_type_id}: {e!s}", exc_info=True
)
raise
# ===== Read Endpoints (Authenticated Users) =====
@router.get(
"",
response_model=PaginatedResponse[AgentTypeResponse],
summary="List Agent Types",
description="Get paginated list of active agent types",
operation_id="list_agent_types",
)
@limiter.limit(f"{60 * RATE_MULTIPLIER}/minute")
async def list_agent_types(
request: Request,
pagination: PaginationParams = Depends(),
is_active: bool = Query(True, description="Filter by active status"),
search: str | None = Query(None, description="Search by name, slug, description"),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> Any:
"""
List all agent types with pagination and filtering.
By default, returns only active agent types. Set is_active=false
to include deactivated types (useful for admin views).
Args:
request: FastAPI request object
pagination: Pagination parameters (page, limit)
is_active: Filter by active status (default: True)
search: Optional search term for name, slug, description
current_user: Authenticated user
db: Database session
Returns:
Paginated list of agent types with instance counts
"""
try:
# Get agent types with instance counts
results, total = await agent_type_crud.get_multi_with_instance_counts(
db,
skip=pagination.offset,
limit=pagination.limit,
is_active=is_active,
search=search,
)
# Build response objects
agent_types_response = [
_build_agent_type_response(
item["agent_type"],
instance_count=item["instance_count"],
)
for item in results
]
pagination_meta = create_pagination_meta(
total=total,
page=pagination.page,
limit=pagination.limit,
items_count=len(agent_types_response),
)
return PaginatedResponse(data=agent_types_response, pagination=pagination_meta)
except Exception as e:
logger.error(f"Error listing agent types: {e!s}", exc_info=True)
raise
@router.get(
"/{agent_type_id}",
response_model=AgentTypeResponse,
summary="Get Agent Type",
description="Get agent type details by ID",
operation_id="get_agent_type",
)
@limiter.limit(f"{100 * RATE_MULTIPLIER}/minute")
async def get_agent_type(
request: Request,
agent_type_id: UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Get detailed information about a specific agent type.
Args:
request: FastAPI request object
agent_type_id: UUID of the agent type
current_user: Authenticated user
db: Database session
Returns:
Agent type details with instance count
Raises:
NotFoundError: If agent type not found
"""
try:
result = await agent_type_crud.get_with_instance_count(
db, agent_type_id=agent_type_id
)
if not result:
raise NotFoundError(
message=f"Agent type {agent_type_id} not found",
error_code=ErrorCode.NOT_FOUND,
)
return _build_agent_type_response(
result["agent_type"],
instance_count=result["instance_count"],
)
except NotFoundError:
raise
except Exception as e:
logger.error(f"Error getting agent type {agent_type_id}: {e!s}", exc_info=True)
raise
@router.get(
"/slug/{slug}",
response_model=AgentTypeResponse,
summary="Get Agent Type by Slug",
description="Get agent type details by slug",
operation_id="get_agent_type_by_slug",
)
@limiter.limit(f"{100 * RATE_MULTIPLIER}/minute")
async def get_agent_type_by_slug(
request: Request,
slug: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Get detailed information about an agent type by its slug.
Slugs are human-readable identifiers like "product-owner" or "backend-engineer".
Useful for referencing agent types in configuration files or APIs.
Args:
request: FastAPI request object
slug: Slug identifier of the agent type
current_user: Authenticated user
db: Database session
Returns:
Agent type details with instance count
Raises:
NotFoundError: If agent type not found
"""
try:
agent_type = await agent_type_crud.get_by_slug(db, slug=slug)
if not agent_type:
raise NotFoundError(
message=f"Agent type with slug '{slug}' not found",
error_code=ErrorCode.NOT_FOUND,
)
# Get instance count separately
result = await agent_type_crud.get_with_instance_count(
db, agent_type_id=agent_type.id
)
instance_count = result["instance_count"] if result else 0
return _build_agent_type_response(agent_type, instance_count=instance_count)
except NotFoundError:
raise
except Exception as e:
logger.error(f"Error getting agent type by slug '{slug}': {e!s}", exc_info=True)
raise

View File

@@ -1,984 +0,0 @@
# app/api/routes/agents.py
"""
Agent Instance management endpoints for Syndarix projects.
These endpoints allow project owners and superusers to manage AI agent instances
within their projects, including spawning, pausing, resuming, and terminating agents.
"""
import logging
import os
from typing import Any
from uuid import UUID
from fastapi import APIRouter, Depends, Query, Request, status
from slowapi import Limiter
from slowapi.util import get_remote_address
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.dependencies.auth import get_current_user
from app.core.database import get_db
from app.core.exceptions import (
AuthorizationError,
NotFoundError,
ValidationException,
)
from app.crud.syndarix.agent_instance import agent_instance as agent_instance_crud
from app.crud.syndarix.agent_type import agent_type as agent_type_crud
from app.crud.syndarix.project import project as project_crud
from app.models.syndarix import AgentInstance, Project
from app.models.syndarix.enums import AgentStatus
from app.models.user import User
from app.schemas.common import (
MessageResponse,
PaginatedResponse,
PaginationParams,
create_pagination_meta,
)
from app.schemas.errors import ErrorCode
from app.schemas.syndarix.agent_instance import (
AgentInstanceCreate,
AgentInstanceMetrics,
AgentInstanceResponse,
AgentInstanceUpdate,
)
router = APIRouter()
logger = logging.getLogger(__name__)
# Initialize limiter for this router
limiter = Limiter(key_func=get_remote_address)
# Use higher rate limits in test environment
IS_TEST = os.getenv("IS_TEST", "False") == "True"
RATE_MULTIPLIER = 100 if IS_TEST else 1
# Valid status transitions for agent lifecycle management
VALID_STATUS_TRANSITIONS: dict[AgentStatus, set[AgentStatus]] = {
AgentStatus.IDLE: {AgentStatus.WORKING, AgentStatus.PAUSED, AgentStatus.TERMINATED},
AgentStatus.WORKING: {
AgentStatus.IDLE,
AgentStatus.WAITING,
AgentStatus.PAUSED,
AgentStatus.TERMINATED,
},
AgentStatus.WAITING: {
AgentStatus.IDLE,
AgentStatus.WORKING,
AgentStatus.PAUSED,
AgentStatus.TERMINATED,
},
AgentStatus.PAUSED: {AgentStatus.IDLE, AgentStatus.TERMINATED},
AgentStatus.TERMINATED: set(), # Terminal state, no transitions allowed
}
async def verify_project_access(
db: AsyncSession,
project_id: UUID,
user: User,
) -> Project:
"""
Verify user has access to a project.
Args:
db: Database session
project_id: UUID of the project to verify
user: Current authenticated user
Returns:
Project: The project if access is granted
Raises:
NotFoundError: If the project does not exist
AuthorizationError: If the user does not have access to the project
"""
project = await project_crud.get(db, id=project_id)
if not project:
raise NotFoundError(
message=f"Project {project_id} not found",
error_code=ErrorCode.NOT_FOUND,
)
if not user.is_superuser and project.owner_id != user.id:
raise AuthorizationError(
message="You do not have access to this project",
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS,
)
return project
def validate_status_transition(
current_status: AgentStatus,
target_status: AgentStatus,
) -> None:
"""
Validate that a status transition is allowed.
Args:
current_status: The agent's current status
target_status: The desired target status
Raises:
ValidationException: If the transition is not allowed
"""
valid_targets = VALID_STATUS_TRANSITIONS.get(current_status, set())
if target_status not in valid_targets:
raise ValidationException(
message=f"Cannot transition from {current_status.value} to {target_status.value}",
error_code=ErrorCode.VALIDATION_ERROR,
field="status",
)
def build_agent_response(
agent: AgentInstance,
agent_type_name: str | None = None,
agent_type_slug: str | None = None,
project_name: str | None = None,
project_slug: str | None = None,
assigned_issues_count: int = 0,
) -> AgentInstanceResponse:
"""
Build an AgentInstanceResponse from an AgentInstance model.
Args:
agent: The agent instance model
agent_type_name: Name of the agent type
agent_type_slug: Slug of the agent type
project_name: Name of the project
project_slug: Slug of the project
assigned_issues_count: Number of issues assigned to this agent
Returns:
AgentInstanceResponse: The response schema
"""
return AgentInstanceResponse(
id=agent.id,
agent_type_id=agent.agent_type_id,
project_id=agent.project_id,
name=agent.name,
status=agent.status,
current_task=agent.current_task,
short_term_memory=agent.short_term_memory or {},
long_term_memory_ref=agent.long_term_memory_ref,
session_id=agent.session_id,
last_activity_at=agent.last_activity_at,
terminated_at=agent.terminated_at,
tasks_completed=agent.tasks_completed,
tokens_used=agent.tokens_used,
cost_incurred=agent.cost_incurred,
created_at=agent.created_at,
updated_at=agent.updated_at,
agent_type_name=agent_type_name,
agent_type_slug=agent_type_slug,
project_name=project_name,
project_slug=project_slug,
assigned_issues_count=assigned_issues_count,
)
# ===== Agent Instance Management Endpoints =====
@router.post(
"/projects/{project_id}/agents",
response_model=AgentInstanceResponse,
status_code=status.HTTP_201_CREATED,
summary="Spawn Agent Instance",
description="Spawn a new agent instance in a project. Requires project ownership or superuser.",
operation_id="spawn_agent",
)
@limiter.limit(f"{20 * RATE_MULTIPLIER}/minute")
async def spawn_agent(
request: Request,
project_id: UUID,
agent_in: AgentInstanceCreate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Spawn a new agent instance in a project.
Creates a new agent instance from an agent type template and assigns it
to the specified project. The agent starts in IDLE status by default.
Args:
request: FastAPI request object (for rate limiting)
project_id: UUID of the project to spawn the agent in
agent_in: Agent instance creation data
current_user: Current authenticated user
db: Database session
Returns:
AgentInstanceResponse: The newly created agent instance
Raises:
NotFoundError: If the project is not found
AuthorizationError: If the user lacks access to the project
ValidationException: If the agent creation data is invalid
"""
try:
# Verify project access
project = await verify_project_access(db, project_id, current_user)
# Ensure the agent is being created for the correct project
if agent_in.project_id != project_id:
raise ValidationException(
message="Agent project_id must match the URL project_id",
error_code=ErrorCode.VALIDATION_ERROR,
field="project_id",
)
# Validate that the agent type exists and is active
agent_type = await agent_type_crud.get(db, id=agent_in.agent_type_id)
if not agent_type:
raise NotFoundError(
message=f"Agent type {agent_in.agent_type_id} not found",
error_code=ErrorCode.NOT_FOUND,
)
if not agent_type.is_active:
raise ValidationException(
message=f"Agent type '{agent_type.name}' is inactive and cannot be used",
error_code=ErrorCode.VALIDATION_ERROR,
field="agent_type_id",
)
# Create the agent instance
agent = await agent_instance_crud.create(db, obj_in=agent_in)
logger.info(
f"User {current_user.email} spawned agent '{agent.name}' "
f"(id={agent.id}) in project {project.slug}"
)
# Get agent details for response
details = await agent_instance_crud.get_with_details(db, instance_id=agent.id)
if details:
return build_agent_response(
agent=details["instance"],
agent_type_name=details.get("agent_type_name"),
agent_type_slug=details.get("agent_type_slug"),
project_name=details.get("project_name"),
project_slug=details.get("project_slug"),
assigned_issues_count=details.get("assigned_issues_count", 0),
)
return build_agent_response(agent)
except (NotFoundError, AuthorizationError, ValidationException):
raise
except ValueError as e:
logger.warning(f"Failed to spawn agent: {e!s}")
raise ValidationException(
message=str(e),
error_code=ErrorCode.VALIDATION_ERROR,
)
except Exception as e:
logger.error(f"Error spawning agent: {e!s}", exc_info=True)
raise
@router.get(
"/projects/{project_id}/agents",
response_model=PaginatedResponse[AgentInstanceResponse],
summary="List Project Agents",
description="List all agent instances in a project with optional filtering.",
operation_id="list_project_agents",
)
@limiter.limit(f"{60 * RATE_MULTIPLIER}/minute")
async def list_project_agents(
request: Request,
project_id: UUID,
pagination: PaginationParams = Depends(),
status_filter: AgentStatus | None = Query(
None, alias="status", description="Filter by agent status"
),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> Any:
"""
List all agent instances in a project.
Returns a paginated list of agents with optional status filtering.
Results are ordered by creation date (newest first).
Args:
request: FastAPI request object (for rate limiting)
project_id: UUID of the project
pagination: Pagination parameters
status_filter: Optional filter by agent status
current_user: Current authenticated user
db: Database session
Returns:
PaginatedResponse[AgentInstanceResponse]: Paginated list of agents
Raises:
NotFoundError: If the project is not found
AuthorizationError: If the user lacks access to the project
"""
try:
# Verify project access
project = await verify_project_access(db, project_id, current_user)
# Get agents for the project
agents, total = await agent_instance_crud.get_by_project(
db,
project_id=project_id,
status=status_filter,
skip=pagination.offset,
limit=pagination.limit,
)
# Build response objects
agent_responses = []
for agent in agents:
# Get details for each agent (could be optimized with bulk query)
details = await agent_instance_crud.get_with_details(
db, instance_id=agent.id
)
if details:
agent_responses.append(
build_agent_response(
agent=details["instance"],
agent_type_name=details.get("agent_type_name"),
agent_type_slug=details.get("agent_type_slug"),
project_name=details.get("project_name"),
project_slug=details.get("project_slug"),
assigned_issues_count=details.get("assigned_issues_count", 0),
)
)
else:
agent_responses.append(build_agent_response(agent))
pagination_meta = create_pagination_meta(
total=total,
page=pagination.page,
limit=pagination.limit,
items_count=len(agent_responses),
)
logger.debug(
f"User {current_user.email} listed {len(agent_responses)} agents "
f"in project {project.slug}"
)
return PaginatedResponse(data=agent_responses, pagination=pagination_meta)
except (NotFoundError, AuthorizationError):
raise
except Exception as e:
logger.error(f"Error listing project agents: {e!s}", exc_info=True)
raise
# ===== Project Agent Metrics Endpoint =====
# NOTE: This endpoint MUST be defined before /{agent_id} routes
# to prevent FastAPI from trying to parse "metrics" as a UUID
@router.get(
"/projects/{project_id}/agents/metrics",
response_model=AgentInstanceMetrics,
summary="Get Project Agent Metrics",
description="Get aggregated usage metrics for all agents in a project.",
operation_id="get_project_agent_metrics",
)
@limiter.limit(f"{60 * RATE_MULTIPLIER}/minute")
async def get_project_agent_metrics(
request: Request,
project_id: UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Get aggregated usage metrics for all agents in a project.
Returns aggregated metrics across all agents including total
tasks completed, tokens used, and cost incurred.
Args:
request: FastAPI request object (for rate limiting)
project_id: UUID of the project
current_user: Current authenticated user
db: Database session
Returns:
AgentInstanceMetrics: Aggregated project agent metrics
Raises:
NotFoundError: If the project is not found
AuthorizationError: If the user lacks access to the project
"""
try:
# Verify project access
project = await verify_project_access(db, project_id, current_user)
# Get aggregated metrics for the project
metrics = await agent_instance_crud.get_project_metrics(
db, project_id=project_id
)
logger.debug(
f"User {current_user.email} retrieved project metrics for {project.slug}"
)
return AgentInstanceMetrics(
total_instances=metrics["total_instances"],
active_instances=metrics["active_instances"],
idle_instances=metrics["idle_instances"],
total_tasks_completed=metrics["total_tasks_completed"],
total_tokens_used=metrics["total_tokens_used"],
total_cost_incurred=metrics["total_cost_incurred"],
)
except (NotFoundError, AuthorizationError):
raise
except Exception as e:
logger.error(f"Error getting project agent metrics: {e!s}", exc_info=True)
raise
@router.get(
"/projects/{project_id}/agents/{agent_id}",
response_model=AgentInstanceResponse,
summary="Get Agent Details",
description="Get detailed information about a specific agent instance.",
operation_id="get_agent",
)
@limiter.limit(f"{60 * RATE_MULTIPLIER}/minute")
async def get_agent(
request: Request,
project_id: UUID,
agent_id: UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Get detailed information about a specific agent instance.
Returns full agent details including related entity information
(agent type name, project name) and assigned issues count.
Args:
request: FastAPI request object (for rate limiting)
project_id: UUID of the project
agent_id: UUID of the agent instance
current_user: Current authenticated user
db: Database session
Returns:
AgentInstanceResponse: The agent instance details
Raises:
NotFoundError: If the project or agent is not found
AuthorizationError: If the user lacks access to the project
"""
try:
# Verify project access
await verify_project_access(db, project_id, current_user)
# Get agent with full details
details = await agent_instance_crud.get_with_details(db, instance_id=agent_id)
if not details:
raise NotFoundError(
message=f"Agent {agent_id} not found",
error_code=ErrorCode.NOT_FOUND,
)
agent = details["instance"]
# Verify agent belongs to the specified project
if agent.project_id != project_id:
raise NotFoundError(
message=f"Agent {agent_id} not found in project {project_id}",
error_code=ErrorCode.NOT_FOUND,
)
logger.debug(
f"User {current_user.email} retrieved agent {agent.name} (id={agent_id})"
)
return build_agent_response(
agent=agent,
agent_type_name=details.get("agent_type_name"),
agent_type_slug=details.get("agent_type_slug"),
project_name=details.get("project_name"),
project_slug=details.get("project_slug"),
assigned_issues_count=details.get("assigned_issues_count", 0),
)
except (NotFoundError, AuthorizationError):
raise
except Exception as e:
logger.error(f"Error getting agent details: {e!s}", exc_info=True)
raise
@router.patch(
"/projects/{project_id}/agents/{agent_id}",
response_model=AgentInstanceResponse,
summary="Update Agent",
description="Update an agent instance's configuration and state.",
operation_id="update_agent",
)
@limiter.limit(f"{30 * RATE_MULTIPLIER}/minute")
async def update_agent(
request: Request,
project_id: UUID,
agent_id: UUID,
agent_in: AgentInstanceUpdate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Update an agent instance's configuration and state.
Allows updating agent status, current task, memory, and other
configurable fields. Status transitions are validated according
to the agent lifecycle state machine.
Args:
request: FastAPI request object (for rate limiting)
project_id: UUID of the project
agent_id: UUID of the agent instance
agent_in: Agent update data
current_user: Current authenticated user
db: Database session
Returns:
AgentInstanceResponse: The updated agent instance
Raises:
NotFoundError: If the project or agent is not found
AuthorizationError: If the user lacks access to the project
ValidationException: If the status transition is invalid
"""
try:
# Verify project access
await verify_project_access(db, project_id, current_user)
# Get current agent
agent = await agent_instance_crud.get(db, id=agent_id)
if not agent:
raise NotFoundError(
message=f"Agent {agent_id} not found",
error_code=ErrorCode.NOT_FOUND,
)
# Verify agent belongs to the specified project
if agent.project_id != project_id:
raise NotFoundError(
message=f"Agent {agent_id} not found in project {project_id}",
error_code=ErrorCode.NOT_FOUND,
)
# Validate status transition if status is being changed
if agent_in.status is not None and agent_in.status != agent.status:
validate_status_transition(agent.status, agent_in.status)
# Update the agent
updated_agent = await agent_instance_crud.update(
db, db_obj=agent, obj_in=agent_in
)
logger.info(
f"User {current_user.email} updated agent {updated_agent.name} "
f"(id={agent_id})"
)
# Get updated details
details = await agent_instance_crud.get_with_details(
db, instance_id=updated_agent.id
)
if details:
return build_agent_response(
agent=details["instance"],
agent_type_name=details.get("agent_type_name"),
agent_type_slug=details.get("agent_type_slug"),
project_name=details.get("project_name"),
project_slug=details.get("project_slug"),
assigned_issues_count=details.get("assigned_issues_count", 0),
)
return build_agent_response(updated_agent)
except (NotFoundError, AuthorizationError, ValidationException):
raise
except Exception as e:
logger.error(f"Error updating agent: {e!s}", exc_info=True)
raise
@router.post(
"/projects/{project_id}/agents/{agent_id}/pause",
response_model=AgentInstanceResponse,
summary="Pause Agent",
description="Pause an agent instance, temporarily stopping its work.",
operation_id="pause_agent",
)
@limiter.limit(f"{20 * RATE_MULTIPLIER}/minute")
async def pause_agent(
request: Request,
project_id: UUID,
agent_id: UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Pause an agent instance.
Transitions the agent to PAUSED status, temporarily stopping
its work. The agent can be resumed later with the resume endpoint.
Args:
request: FastAPI request object (for rate limiting)
project_id: UUID of the project
agent_id: UUID of the agent instance
current_user: Current authenticated user
db: Database session
Returns:
AgentInstanceResponse: The paused agent instance
Raises:
NotFoundError: If the project or agent is not found
AuthorizationError: If the user lacks access to the project
ValidationException: If the agent cannot be paused from its current state
"""
try:
# Verify project access
await verify_project_access(db, project_id, current_user)
# Get current agent
agent = await agent_instance_crud.get(db, id=agent_id)
if not agent:
raise NotFoundError(
message=f"Agent {agent_id} not found",
error_code=ErrorCode.NOT_FOUND,
)
# Verify agent belongs to the specified project
if agent.project_id != project_id:
raise NotFoundError(
message=f"Agent {agent_id} not found in project {project_id}",
error_code=ErrorCode.NOT_FOUND,
)
# Validate the transition to PAUSED
validate_status_transition(agent.status, AgentStatus.PAUSED)
# Update status to PAUSED
paused_agent = await agent_instance_crud.update_status(
db,
instance_id=agent_id,
status=AgentStatus.PAUSED,
)
if not paused_agent:
raise NotFoundError(
message=f"Agent {agent_id} not found",
error_code=ErrorCode.NOT_FOUND,
)
logger.info(
f"User {current_user.email} paused agent {paused_agent.name} "
f"(id={agent_id})"
)
# Get updated details
details = await agent_instance_crud.get_with_details(
db, instance_id=paused_agent.id
)
if details:
return build_agent_response(
agent=details["instance"],
agent_type_name=details.get("agent_type_name"),
agent_type_slug=details.get("agent_type_slug"),
project_name=details.get("project_name"),
project_slug=details.get("project_slug"),
assigned_issues_count=details.get("assigned_issues_count", 0),
)
return build_agent_response(paused_agent)
except (NotFoundError, AuthorizationError, ValidationException):
raise
except Exception as e:
logger.error(f"Error pausing agent: {e!s}", exc_info=True)
raise
@router.post(
"/projects/{project_id}/agents/{agent_id}/resume",
response_model=AgentInstanceResponse,
summary="Resume Agent",
description="Resume a paused agent instance.",
operation_id="resume_agent",
)
@limiter.limit(f"{20 * RATE_MULTIPLIER}/minute")
async def resume_agent(
request: Request,
project_id: UUID,
agent_id: UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Resume a paused agent instance.
Transitions the agent from PAUSED back to IDLE status,
allowing it to accept new work.
Args:
request: FastAPI request object (for rate limiting)
project_id: UUID of the project
agent_id: UUID of the agent instance
current_user: Current authenticated user
db: Database session
Returns:
AgentInstanceResponse: The resumed agent instance
Raises:
NotFoundError: If the project or agent is not found
AuthorizationError: If the user lacks access to the project
ValidationException: If the agent cannot be resumed from its current state
"""
try:
# Verify project access
await verify_project_access(db, project_id, current_user)
# Get current agent
agent = await agent_instance_crud.get(db, id=agent_id)
if not agent:
raise NotFoundError(
message=f"Agent {agent_id} not found",
error_code=ErrorCode.NOT_FOUND,
)
# Verify agent belongs to the specified project
if agent.project_id != project_id:
raise NotFoundError(
message=f"Agent {agent_id} not found in project {project_id}",
error_code=ErrorCode.NOT_FOUND,
)
# Validate the transition to IDLE (resume)
validate_status_transition(agent.status, AgentStatus.IDLE)
# Update status to IDLE
resumed_agent = await agent_instance_crud.update_status(
db,
instance_id=agent_id,
status=AgentStatus.IDLE,
)
if not resumed_agent:
raise NotFoundError(
message=f"Agent {agent_id} not found",
error_code=ErrorCode.NOT_FOUND,
)
logger.info(
f"User {current_user.email} resumed agent {resumed_agent.name} "
f"(id={agent_id})"
)
# Get updated details
details = await agent_instance_crud.get_with_details(
db, instance_id=resumed_agent.id
)
if details:
return build_agent_response(
agent=details["instance"],
agent_type_name=details.get("agent_type_name"),
agent_type_slug=details.get("agent_type_slug"),
project_name=details.get("project_name"),
project_slug=details.get("project_slug"),
assigned_issues_count=details.get("assigned_issues_count", 0),
)
return build_agent_response(resumed_agent)
except (NotFoundError, AuthorizationError, ValidationException):
raise
except Exception as e:
logger.error(f"Error resuming agent: {e!s}", exc_info=True)
raise
@router.delete(
"/projects/{project_id}/agents/{agent_id}",
response_model=MessageResponse,
summary="Terminate Agent",
description="Terminate an agent instance, permanently stopping it.",
operation_id="terminate_agent",
)
@limiter.limit(f"{10 * RATE_MULTIPLIER}/minute")
async def terminate_agent(
request: Request,
project_id: UUID,
agent_id: UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Terminate an agent instance.
Permanently terminates the agent, setting its status to TERMINATED.
This action cannot be undone - a new agent must be spawned if needed.
The agent's session and current task are cleared.
Args:
request: FastAPI request object (for rate limiting)
project_id: UUID of the project
agent_id: UUID of the agent instance
current_user: Current authenticated user
db: Database session
Returns:
MessageResponse: Confirmation message
Raises:
NotFoundError: If the project or agent is not found
AuthorizationError: If the user lacks access to the project
ValidationException: If the agent is already terminated
"""
try:
# Verify project access
await verify_project_access(db, project_id, current_user)
# Get current agent
agent = await agent_instance_crud.get(db, id=agent_id)
if not agent:
raise NotFoundError(
message=f"Agent {agent_id} not found",
error_code=ErrorCode.NOT_FOUND,
)
# Verify agent belongs to the specified project
if agent.project_id != project_id:
raise NotFoundError(
message=f"Agent {agent_id} not found in project {project_id}",
error_code=ErrorCode.NOT_FOUND,
)
# Check if already terminated
if agent.status == AgentStatus.TERMINATED:
raise ValidationException(
message="Agent is already terminated",
error_code=ErrorCode.VALIDATION_ERROR,
field="status",
)
# Validate the transition to TERMINATED
validate_status_transition(agent.status, AgentStatus.TERMINATED)
agent_name = agent.name
# Terminate the agent
terminated_agent = await agent_instance_crud.terminate(db, instance_id=agent_id)
if not terminated_agent:
raise NotFoundError(
message=f"Agent {agent_id} not found",
error_code=ErrorCode.NOT_FOUND,
)
logger.info(
f"User {current_user.email} terminated agent {agent_name} (id={agent_id})"
)
return MessageResponse(
success=True,
message=f"Agent '{agent_name}' has been terminated",
)
except (NotFoundError, AuthorizationError, ValidationException):
raise
except Exception as e:
logger.error(f"Error terminating agent: {e!s}", exc_info=True)
raise
@router.get(
"/projects/{project_id}/agents/{agent_id}/metrics",
response_model=AgentInstanceMetrics,
summary="Get Agent Metrics",
description="Get usage metrics for a specific agent instance.",
operation_id="get_agent_metrics",
)
@limiter.limit(f"{60 * RATE_MULTIPLIER}/minute")
async def get_agent_metrics(
request: Request,
project_id: UUID,
agent_id: UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Get usage metrics for a specific agent instance.
Returns metrics including tasks completed, tokens used,
and cost incurred for the specified agent.
Args:
request: FastAPI request object (for rate limiting)
project_id: UUID of the project
agent_id: UUID of the agent instance
current_user: Current authenticated user
db: Database session
Returns:
AgentInstanceMetrics: Agent usage metrics
Raises:
NotFoundError: If the project or agent is not found
AuthorizationError: If the user lacks access to the project
"""
try:
# Verify project access
await verify_project_access(db, project_id, current_user)
# Get agent
agent = await agent_instance_crud.get(db, id=agent_id)
if not agent:
raise NotFoundError(
message=f"Agent {agent_id} not found",
error_code=ErrorCode.NOT_FOUND,
)
# Verify agent belongs to the specified project
if agent.project_id != project_id:
raise NotFoundError(
message=f"Agent {agent_id} not found in project {project_id}",
error_code=ErrorCode.NOT_FOUND,
)
# Calculate metrics for this single agent
# For a single agent, we report its individual metrics
is_active = agent.status == AgentStatus.WORKING
is_idle = agent.status == AgentStatus.IDLE
logger.debug(
f"User {current_user.email} retrieved metrics for agent {agent.name} "
f"(id={agent_id})"
)
return AgentInstanceMetrics(
total_instances=1,
active_instances=1 if is_active else 0,
idle_instances=1 if is_idle else 0,
total_tasks_completed=agent.tasks_completed,
total_tokens_used=agent.tokens_used,
total_cost_incurred=agent.cost_incurred,
)
except (NotFoundError, AuthorizationError):
raise
except Exception as e:
logger.error(f"Error getting agent metrics: {e!s}", exc_info=True)
raise

View File

@@ -15,16 +15,14 @@ from app.core.auth import (
TokenExpiredError, TokenExpiredError,
TokenInvalidError, TokenInvalidError,
decode_token, decode_token,
get_password_hash,
) )
from app.core.database import get_db from app.core.database import get_db
from app.core.exceptions import ( from app.core.exceptions import (
AuthenticationError as AuthError, AuthenticationError as AuthError,
DatabaseError, DatabaseError,
DuplicateError,
ErrorCode, ErrorCode,
) )
from app.crud.session import session as session_crud
from app.crud.user import user as user_crud
from app.models.user import User from app.models.user import User
from app.schemas.common import MessageResponse from app.schemas.common import MessageResponse
from app.schemas.sessions import LogoutRequest, SessionCreate from app.schemas.sessions import LogoutRequest, SessionCreate
@@ -39,6 +37,8 @@ from app.schemas.users import (
) )
from app.services.auth_service import AuthenticationError, AuthService from app.services.auth_service import AuthenticationError, AuthService
from app.services.email_service import email_service from app.services.email_service import email_service
from app.services.session_service import session_service
from app.services.user_service import user_service
from app.utils.device import extract_device_info from app.utils.device import extract_device_info
from app.utils.security import create_password_reset_token, verify_password_reset_token from app.utils.security import create_password_reset_token, verify_password_reset_token
@@ -91,17 +91,18 @@ async def _create_login_session(
location_country=device_info.location_country, location_country=device_info.location_country,
) )
await session_crud.create_session(db, obj_in=session_data) await session_service.create_session(db, obj_in=session_data)
logger.info( logger.info(
f"{login_type.capitalize()} successful: {user.email} from {device_info.device_name} " "%s successful: %s from %s (IP: %s)",
f"(IP: {device_info.ip_address})" login_type.capitalize(),
user.email,
device_info.device_name,
device_info.ip_address,
) )
except Exception as session_err: except Exception as session_err:
# Log but don't fail login if session creation fails # Log but don't fail login if session creation fails
logger.error( logger.exception("Failed to create session for %s: %s", user.email, session_err)
f"Failed to create session for {user.email}: {session_err!s}", exc_info=True
)
@router.post( @router.post(
@@ -123,15 +124,21 @@ async def register_user(
try: try:
user = await AuthService.create_user(db, user_data) user = await AuthService.create_user(db, user_data)
return user return user
except AuthenticationError as e: except DuplicateError:
# SECURITY: Don't reveal if email exists - generic error message # SECURITY: Don't reveal if email exists - generic error message
logger.warning(f"Registration failed: {e!s}") logger.warning("Registration failed: duplicate email %s", user_data.email)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Registration failed. Please check your information and try again.",
)
except AuthError as e:
logger.warning("Registration failed: %s", e)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail="Registration failed. Please check your information and try again.", detail="Registration failed. Please check your information and try again.",
) )
except Exception as e: except Exception as e:
logger.error(f"Unexpected error during registration: {e!s}", exc_info=True) logger.exception("Unexpected error during registration: %s", e)
raise DatabaseError( raise DatabaseError(
message="An unexpected error occurred. Please try again later.", message="An unexpected error occurred. Please try again later.",
error_code=ErrorCode.INTERNAL_ERROR, error_code=ErrorCode.INTERNAL_ERROR,
@@ -159,7 +166,7 @@ async def login(
# Explicitly check for None result and raise correct exception # Explicitly check for None result and raise correct exception
if user is None: if user is None:
logger.warning(f"Invalid login attempt for: {login_data.email}") logger.warning("Invalid login attempt for: %s", login_data.email)
raise AuthError( raise AuthError(
message="Invalid email or password", message="Invalid email or password",
error_code=ErrorCode.INVALID_CREDENTIALS, error_code=ErrorCode.INVALID_CREDENTIALS,
@@ -175,14 +182,11 @@ async def login(
except AuthenticationError as e: except AuthenticationError as e:
# Handle specific authentication errors like inactive accounts # Handle specific authentication errors like inactive accounts
logger.warning(f"Authentication failed: {e!s}") logger.warning("Authentication failed: %s", e)
raise AuthError(message=str(e), error_code=ErrorCode.INVALID_CREDENTIALS) raise AuthError(message=str(e), error_code=ErrorCode.INVALID_CREDENTIALS)
except AuthError:
# Re-raise custom auth exceptions without modification
raise
except Exception as e: except Exception as e:
# Handle unexpected errors # Handle unexpected errors
logger.error(f"Unexpected error during login: {e!s}", exc_info=True) logger.exception("Unexpected error during login: %s", e)
raise DatabaseError( raise DatabaseError(
message="An unexpected error occurred. Please try again later.", message="An unexpected error occurred. Please try again later.",
error_code=ErrorCode.INTERNAL_ERROR, error_code=ErrorCode.INTERNAL_ERROR,
@@ -224,13 +228,10 @@ async def login_oauth(
# Return full token response with user data # Return full token response with user data
return tokens return tokens
except AuthenticationError as e: except AuthenticationError as e:
logger.warning(f"OAuth authentication failed: {e!s}") logger.warning("OAuth authentication failed: %s", e)
raise AuthError(message=str(e), error_code=ErrorCode.INVALID_CREDENTIALS) raise AuthError(message=str(e), error_code=ErrorCode.INVALID_CREDENTIALS)
except AuthError:
# Re-raise custom auth exceptions without modification
raise
except Exception as e: except Exception as e:
logger.error(f"Unexpected error during OAuth login: {e!s}", exc_info=True) logger.exception("Unexpected error during OAuth login: %s", e)
raise DatabaseError( raise DatabaseError(
message="An unexpected error occurred. Please try again later.", message="An unexpected error occurred. Please try again later.",
error_code=ErrorCode.INTERNAL_ERROR, error_code=ErrorCode.INTERNAL_ERROR,
@@ -259,11 +260,12 @@ async def refresh_token(
) )
# Check if session exists and is active # Check if session exists and is active
session = await session_crud.get_active_by_jti(db, jti=refresh_payload.jti) session = await session_service.get_active_by_jti(db, jti=refresh_payload.jti)
if not session: if not session:
logger.warning( logger.warning(
f"Refresh token used for inactive or non-existent session: {refresh_payload.jti}" "Refresh token used for inactive or non-existent session: %s",
refresh_payload.jti,
) )
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@@ -279,16 +281,14 @@ async def refresh_token(
# Update session with new refresh token JTI and expiration # Update session with new refresh token JTI and expiration
try: try:
await session_crud.update_refresh_token( await session_service.update_refresh_token(
db, db,
session=session, session=session,
new_jti=new_refresh_payload.jti, new_jti=new_refresh_payload.jti,
new_expires_at=datetime.fromtimestamp(new_refresh_payload.exp, tz=UTC), new_expires_at=datetime.fromtimestamp(new_refresh_payload.exp, tz=UTC),
) )
except Exception as session_err: except Exception as session_err:
logger.error( logger.exception("Failed to update session %s: %s", session.id, session_err)
f"Failed to update session {session.id}: {session_err!s}", exc_info=True
)
# Continue anyway - tokens are already issued # Continue anyway - tokens are already issued
return tokens return tokens
@@ -311,7 +311,7 @@ async def refresh_token(
# Re-raise HTTP exceptions (like session revoked) # Re-raise HTTP exceptions (like session revoked)
raise raise
except Exception as e: except Exception as e:
logger.error(f"Unexpected error during token refresh: {e!s}") logger.error("Unexpected error during token refresh: %s", e)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="An unexpected error occurred. Please try again later.", detail="An unexpected error occurred. Please try again later.",
@@ -347,7 +347,7 @@ async def request_password_reset(
""" """
try: try:
# Look up user by email # Look up user by email
user = await user_crud.get_by_email(db, email=reset_request.email) user = await user_service.get_by_email(db, email=reset_request.email)
# Only send email if user exists and is active # Only send email if user exists and is active
if user and user.is_active: if user and user.is_active:
@@ -358,11 +358,12 @@ async def request_password_reset(
await email_service.send_password_reset_email( await email_service.send_password_reset_email(
to_email=user.email, reset_token=reset_token, user_name=user.first_name to_email=user.email, reset_token=reset_token, user_name=user.first_name
) )
logger.info(f"Password reset requested for {user.email}") logger.info("Password reset requested for %s", user.email)
else: else:
# Log attempt but don't reveal if email exists # Log attempt but don't reveal if email exists
logger.warning( logger.warning(
f"Password reset requested for non-existent or inactive email: {reset_request.email}" "Password reset requested for non-existent or inactive email: %s",
reset_request.email,
) )
# Always return success to prevent email enumeration # Always return success to prevent email enumeration
@@ -371,7 +372,7 @@ async def request_password_reset(
message="If your email is registered, you will receive a password reset link shortly", message="If your email is registered, you will receive a password reset link shortly",
) )
except Exception as e: except Exception as e:
logger.error(f"Error processing password reset request: {e!s}", exc_info=True) logger.exception("Error processing password reset request: %s", e)
# Still return success to prevent information leakage # Still return success to prevent information leakage
return MessageResponse( return MessageResponse(
success=True, success=True,
@@ -412,40 +413,34 @@ async def confirm_password_reset(
detail="Invalid or expired password reset token", detail="Invalid or expired password reset token",
) )
# Look up user # Reset password via service (validates user exists and is active)
user = await user_crud.get_by_email(db, email=email) try:
user = await AuthService.reset_password(
if not user: db, email=email, new_password=reset_confirm.new_password
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
) )
except AuthenticationError as e:
if not user.is_active: err_msg = str(e)
raise HTTPException( if "inactive" in err_msg.lower():
status_code=status.HTTP_400_BAD_REQUEST, raise HTTPException(
detail="User account is inactive", status_code=status.HTTP_400_BAD_REQUEST, detail=err_msg
) )
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=err_msg)
# Update password
user.password_hash = get_password_hash(reset_confirm.new_password)
db.add(user)
await db.commit()
# SECURITY: Invalidate all existing sessions after password reset # SECURITY: Invalidate all existing sessions after password reset
# This prevents stolen sessions from being used after password change # This prevents stolen sessions from being used after password change
from app.crud.session import session as session_crud
try: try:
deactivated_count = await session_crud.deactivate_all_user_sessions( deactivated_count = await session_service.deactivate_all_user_sessions(
db, user_id=str(user.id) db, user_id=str(user.id)
) )
logger.info( logger.info(
f"Password reset successful for {user.email}, invalidated {deactivated_count} sessions" "Password reset successful for %s, invalidated %s sessions",
user.email,
deactivated_count,
) )
except Exception as session_error: except Exception as session_error:
# Log but don't fail password reset if session invalidation fails # Log but don't fail password reset if session invalidation fails
logger.error( logger.error(
f"Failed to invalidate sessions after password reset: {session_error!s}" "Failed to invalidate sessions after password reset: %s", session_error
) )
return MessageResponse( return MessageResponse(
@@ -456,7 +451,7 @@ async def confirm_password_reset(
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
logger.error(f"Error confirming password reset: {e!s}", exc_info=True) logger.exception("Error confirming password reset: %s", e)
await db.rollback() await db.rollback()
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
@@ -506,19 +501,21 @@ async def logout(
) )
except (TokenExpiredError, TokenInvalidError) as e: except (TokenExpiredError, TokenInvalidError) as e:
# Even if token is expired/invalid, try to deactivate session # Even if token is expired/invalid, try to deactivate session
logger.warning(f"Logout with invalid/expired token: {e!s}") logger.warning("Logout with invalid/expired token: %s", e)
# Don't fail - return success anyway # Don't fail - return success anyway
return MessageResponse(success=True, message="Logged out successfully") return MessageResponse(success=True, message="Logged out successfully")
# Find the session by JTI # Find the session by JTI
session = await session_crud.get_by_jti(db, jti=refresh_payload.jti) session = await session_service.get_by_jti(db, jti=refresh_payload.jti)
if session: if session:
# Verify session belongs to current user (security check) # Verify session belongs to current user (security check)
if str(session.user_id) != str(current_user.id): if str(session.user_id) != str(current_user.id):
logger.warning( logger.warning(
f"User {current_user.id} attempted to logout session {session.id} " "User %s attempted to logout session %s belonging to user %s",
f"belonging to user {session.user_id}" current_user.id,
session.id,
session.user_id,
) )
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
@@ -526,17 +523,20 @@ async def logout(
) )
# Deactivate the session # Deactivate the session
await session_crud.deactivate(db, session_id=str(session.id)) await session_service.deactivate(db, session_id=str(session.id))
logger.info( logger.info(
f"User {current_user.id} logged out from {session.device_name} " "User %s logged out from %s (session %s)",
f"(session {session.id})" current_user.id,
session.device_name,
session.id,
) )
else: else:
# Session not found - maybe already deleted or never existed # Session not found - maybe already deleted or never existed
# Return success anyway (idempotent) # Return success anyway (idempotent)
logger.info( logger.info(
f"Logout requested for non-existent session (JTI: {refresh_payload.jti})" "Logout requested for non-existent session (JTI: %s)",
refresh_payload.jti,
) )
return MessageResponse(success=True, message="Logged out successfully") return MessageResponse(success=True, message="Logged out successfully")
@@ -544,9 +544,7 @@ async def logout(
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
logger.error( logger.exception("Error during logout for user %s: %s", current_user.id, e)
f"Error during logout for user {current_user.id}: {e!s}", exc_info=True
)
# Don't expose error details # Don't expose error details
return MessageResponse(success=True, message="Logged out successfully") return MessageResponse(success=True, message="Logged out successfully")
@@ -584,12 +582,12 @@ async def logout_all(
""" """
try: try:
# Deactivate all sessions for this user # Deactivate all sessions for this user
count = await session_crud.deactivate_all_user_sessions( count = await session_service.deactivate_all_user_sessions(
db, user_id=str(current_user.id) db, user_id=str(current_user.id)
) )
logger.info( logger.info(
f"User {current_user.id} logged out from all devices ({count} sessions)" "User %s logged out from all devices (%s sessions)", current_user.id, count
) )
return MessageResponse( return MessageResponse(
@@ -598,9 +596,7 @@ async def logout_all(
) )
except Exception as e: except Exception as e:
logger.error( logger.exception("Error during logout-all for user %s: %s", current_user.id, e)
f"Error during logout-all for user {current_user.id}: {e!s}", exc_info=True
)
await db.rollback() await db.rollback()
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,

View File

@@ -1,411 +0,0 @@
"""
Context Management API Endpoints.
Provides REST endpoints for context assembly and optimization
for LLM requests using the ContextEngine.
"""
import logging
from typing import Annotated, Any
from fastapi import APIRouter, Depends, HTTPException, Query, status
from pydantic import BaseModel, Field
from app.api.dependencies.permissions import require_superuser
from app.models.user import User
from app.services.context import (
AssemblyTimeoutError,
BudgetExceededError,
ContextEngine,
ContextSettings,
create_context_engine,
get_context_settings,
)
from app.services.mcp import MCPClientManager, get_mcp_client
logger = logging.getLogger(__name__)
router = APIRouter()
# ============================================================================
# Singleton Engine Management
# ============================================================================
_context_engine: ContextEngine | None = None
def _get_or_create_engine(
mcp: MCPClientManager,
settings: ContextSettings | None = None,
) -> ContextEngine:
"""Get or create the singleton ContextEngine."""
global _context_engine
if _context_engine is None:
_context_engine = create_context_engine(
mcp_manager=mcp,
redis=None, # Optional: add Redis caching later
settings=settings or get_context_settings(),
)
logger.info("ContextEngine initialized")
else:
# Ensure MCP manager is up to date
_context_engine.set_mcp_manager(mcp)
return _context_engine
async def get_context_engine(
mcp: MCPClientManager = Depends(get_mcp_client),
) -> ContextEngine:
"""FastAPI dependency to get the ContextEngine."""
return _get_or_create_engine(mcp)
# ============================================================================
# Request/Response Schemas
# ============================================================================
class ConversationTurn(BaseModel):
"""A single conversation turn."""
role: str = Field(..., description="Role: 'user' or 'assistant'")
content: str = Field(..., description="Message content")
class ToolResult(BaseModel):
"""A tool execution result."""
tool_name: str = Field(..., description="Name of the tool")
content: str | dict[str, Any] = Field(..., description="Tool result content")
status: str = Field(default="success", description="Execution status")
class AssembleContextRequest(BaseModel):
"""Request to assemble context for an LLM request."""
project_id: str = Field(..., description="Project identifier")
agent_id: str = Field(..., description="Agent identifier")
query: str = Field(..., description="User's query or current request")
model: str = Field(
default="claude-3-sonnet",
description="Target model name",
)
max_tokens: int | None = Field(
None,
description="Maximum context tokens (uses model default if None)",
)
system_prompt: str | None = Field(
None,
description="System prompt/instructions",
)
task_description: str | None = Field(
None,
description="Current task description",
)
knowledge_query: str | None = Field(
None,
description="Query for knowledge base search",
)
knowledge_limit: int = Field(
default=10,
ge=1,
le=50,
description="Max number of knowledge results",
)
conversation_history: list[ConversationTurn] | None = Field(
None,
description="Previous conversation turns",
)
tool_results: list[ToolResult] | None = Field(
None,
description="Tool execution results to include",
)
compress: bool = Field(
default=True,
description="Whether to apply compression",
)
use_cache: bool = Field(
default=True,
description="Whether to use caching",
)
class AssembledContextResponse(BaseModel):
"""Response containing assembled context."""
content: str = Field(..., description="Assembled context content")
total_tokens: int = Field(..., description="Total token count")
context_count: int = Field(..., description="Number of context items included")
compressed: bool = Field(..., description="Whether compression was applied")
budget_used_percent: float = Field(
...,
description="Percentage of token budget used",
)
metadata: dict[str, Any] = Field(
default_factory=dict,
description="Additional metadata",
)
class TokenCountRequest(BaseModel):
"""Request to count tokens in content."""
content: str = Field(..., description="Content to count tokens in")
model: str | None = Field(
None,
description="Model for model-specific tokenization",
)
class TokenCountResponse(BaseModel):
"""Response containing token count."""
token_count: int = Field(..., description="Number of tokens")
model: str | None = Field(None, description="Model used for counting")
class BudgetInfoResponse(BaseModel):
"""Response containing budget information for a model."""
model: str = Field(..., description="Model name")
total_tokens: int = Field(..., description="Total token budget")
system_tokens: int = Field(..., description="Tokens reserved for system")
knowledge_tokens: int = Field(..., description="Tokens for knowledge")
conversation_tokens: int = Field(..., description="Tokens for conversation")
tool_tokens: int = Field(..., description="Tokens for tool results")
response_reserve: int = Field(..., description="Tokens reserved for response")
class ContextEngineStatsResponse(BaseModel):
"""Response containing engine statistics."""
cache: dict[str, Any] = Field(..., description="Cache statistics")
settings: dict[str, Any] = Field(..., description="Current settings")
class HealthResponse(BaseModel):
"""Health check response."""
status: str = Field(..., description="Health status")
mcp_connected: bool = Field(..., description="Whether MCP is connected")
cache_enabled: bool = Field(..., description="Whether caching is enabled")
# ============================================================================
# Endpoints
# ============================================================================
@router.get(
"/health",
response_model=HealthResponse,
summary="Context Engine Health",
description="Check health status of the context engine.",
)
async def health_check(
engine: ContextEngine = Depends(get_context_engine),
) -> HealthResponse:
"""Check context engine health."""
stats = await engine.get_stats()
return HealthResponse(
status="healthy",
mcp_connected=engine._mcp is not None,
cache_enabled=stats.get("settings", {}).get("cache_enabled", False),
)
@router.post(
"/assemble",
response_model=AssembledContextResponse,
summary="Assemble Context",
description="Assemble optimized context for an LLM request.",
)
async def assemble_context(
request: AssembleContextRequest,
current_user: User = Depends(require_superuser),
engine: ContextEngine = Depends(get_context_engine),
) -> AssembledContextResponse:
"""
Assemble optimized context for an LLM request.
This endpoint gathers context from various sources, scores and ranks them,
compresses if needed, and formats for the target model.
"""
logger.info(
"Context assembly for project=%s agent=%s by user=%s",
request.project_id,
request.agent_id,
current_user.id,
)
# Convert conversation history to dict format
conversation_history = None
if request.conversation_history:
conversation_history = [
{"role": turn.role, "content": turn.content}
for turn in request.conversation_history
]
# Convert tool results to dict format
tool_results = None
if request.tool_results:
tool_results = [
{
"tool_name": tr.tool_name,
"content": tr.content,
"status": tr.status,
}
for tr in request.tool_results
]
try:
result = await engine.assemble_context(
project_id=request.project_id,
agent_id=request.agent_id,
query=request.query,
model=request.model,
max_tokens=request.max_tokens,
system_prompt=request.system_prompt,
task_description=request.task_description,
knowledge_query=request.knowledge_query,
knowledge_limit=request.knowledge_limit,
conversation_history=conversation_history,
tool_results=tool_results,
compress=request.compress,
use_cache=request.use_cache,
)
# Calculate budget usage percentage
budget = await engine.get_budget_for_model(request.model, request.max_tokens)
budget_used_percent = (result.total_tokens / budget.total) * 100
# Check if compression was applied (from metadata if available)
was_compressed = result.metadata.get("compressed_contexts", 0) > 0
return AssembledContextResponse(
content=result.content,
total_tokens=result.total_tokens,
context_count=result.context_count,
compressed=was_compressed,
budget_used_percent=round(budget_used_percent, 2),
metadata={
"model": request.model,
"query": request.query,
"knowledge_included": bool(request.knowledge_query),
"conversation_turns": len(request.conversation_history or []),
"excluded_count": result.excluded_count,
"assembly_time_ms": result.assembly_time_ms,
},
)
except AssemblyTimeoutError as e:
raise HTTPException(
status_code=status.HTTP_504_GATEWAY_TIMEOUT,
detail=f"Context assembly timed out: {e}",
) from e
except BudgetExceededError as e:
raise HTTPException(
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
detail=f"Token budget exceeded: {e}",
) from e
except Exception as e:
logger.exception("Context assembly failed")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Context assembly failed: {e}",
) from e
@router.post(
"/count-tokens",
response_model=TokenCountResponse,
summary="Count Tokens",
description="Count tokens in content using the LLM Gateway.",
)
async def count_tokens(
request: TokenCountRequest,
engine: ContextEngine = Depends(get_context_engine),
) -> TokenCountResponse:
"""Count tokens in content."""
try:
count = await engine.count_tokens(
content=request.content,
model=request.model,
)
return TokenCountResponse(
token_count=count,
model=request.model,
)
except Exception as e:
logger.warning(f"Token counting failed: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Token counting failed: {e}",
) from e
@router.get(
"/budget/{model}",
response_model=BudgetInfoResponse,
summary="Get Token Budget",
description="Get token budget allocation for a specific model.",
)
async def get_budget(
model: str,
max_tokens: Annotated[int | None, Query(description="Custom max tokens")] = None,
engine: ContextEngine = Depends(get_context_engine),
) -> BudgetInfoResponse:
"""Get token budget information for a model."""
budget = await engine.get_budget_for_model(model, max_tokens)
return BudgetInfoResponse(
model=model,
total_tokens=budget.total,
system_tokens=budget.system,
knowledge_tokens=budget.knowledge,
conversation_tokens=budget.conversation,
tool_tokens=budget.tools,
response_reserve=budget.response_reserve,
)
@router.get(
"/stats",
response_model=ContextEngineStatsResponse,
summary="Engine Statistics",
description="Get context engine statistics and configuration.",
)
async def get_stats(
current_user: User = Depends(require_superuser),
engine: ContextEngine = Depends(get_context_engine),
) -> ContextEngineStatsResponse:
"""Get engine statistics."""
stats = await engine.get_stats()
return ContextEngineStatsResponse(
cache=stats.get("cache", {}),
settings=stats.get("settings", {}),
)
@router.post(
"/cache/invalidate",
status_code=status.HTTP_204_NO_CONTENT,
summary="Invalidate Cache (Admin Only)",
description="Invalidate context cache entries.",
)
async def invalidate_cache(
project_id: Annotated[
str | None, Query(description="Project to invalidate")
] = None,
pattern: Annotated[str | None, Query(description="Pattern to match")] = None,
current_user: User = Depends(require_superuser),
engine: ContextEngine = Depends(get_context_engine),
) -> None:
"""Invalidate cache entries."""
logger.info(
"Cache invalidation by user %s: project=%s pattern=%s",
current_user.id,
project_id,
pattern,
)
await engine.invalidate_cache(project_id=project_id, pattern=pattern)

View File

@@ -1,316 +0,0 @@
"""
SSE endpoint for real-time project event streaming.
This module provides Server-Sent Events (SSE) endpoints for streaming
project events to connected clients. Events are scoped to projects,
with authorization checks to ensure clients only receive events
for projects they have access to.
Features:
- Real-time event streaming via SSE
- Project-scoped authorization
- Automatic reconnection support (Last-Event-ID)
- Keepalive messages every 30 seconds
- Graceful connection cleanup
"""
import asyncio
import json
import logging
from typing import TYPE_CHECKING
from uuid import UUID
from fastapi import APIRouter, Depends, Header, Query, Request
from slowapi import Limiter
from slowapi.util import get_remote_address
from sse_starlette.sse import EventSourceResponse
from app.api.dependencies.auth import get_current_user, get_current_user_sse
from app.api.dependencies.event_bus import get_event_bus
from app.core.database import get_db
from app.core.exceptions import AuthorizationError
from app.models.user import User
from app.schemas.errors import ErrorCode
from app.schemas.events import EventType
from app.services.event_bus import EventBus
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession
logger = logging.getLogger(__name__)
router = APIRouter()
limiter = Limiter(key_func=get_remote_address)
# Keepalive interval in seconds
KEEPALIVE_INTERVAL = 30
async def check_project_access(
project_id: UUID,
user: User,
db: "AsyncSession",
) -> bool:
"""
Check if a user has access to a project's events.
Authorization rules:
- Superusers can access all projects
- Project owners can access their own projects
Args:
project_id: The project to check access for
user: The authenticated user
db: Database session for project lookup
Returns:
bool: True if user has access, False otherwise
"""
# Superusers can access all projects
if user.is_superuser:
logger.debug(
f"Project access granted for superuser {user.id} on project {project_id}"
)
return True
# Check if user owns the project
from app.crud.syndarix import project as project_crud
project = await project_crud.get(db, id=project_id)
if not project:
logger.debug(f"Project {project_id} not found for access check")
return False
has_access = bool(project.owner_id == user.id)
logger.debug(
f"Project access {'granted' if has_access else 'denied'} "
f"for user {user.id} on project {project_id} (owner: {project.owner_id})"
)
return has_access
async def event_generator(
project_id: UUID,
event_bus: EventBus,
last_event_id: str | None = None,
):
"""
Generate SSE events for a project.
This async generator yields SSE-formatted events from the event bus,
including keepalive comments to maintain the connection.
Args:
project_id: The project to stream events for
event_bus: The EventBus instance
last_event_id: Optional last received event ID for reconnection
Yields:
dict: SSE event data with 'event', 'data', and optional 'id' fields
"""
try:
async for event_data in event_bus.subscribe_sse(
project_id=project_id,
last_event_id=last_event_id,
keepalive_interval=KEEPALIVE_INTERVAL,
):
if event_data == "":
# Keepalive - yield SSE comment
yield {"comment": "keepalive"}
else:
# Parse event to extract type and id
try:
event_dict = json.loads(event_data)
event_type = event_dict.get("type", "message")
event_id = event_dict.get("id")
yield {
"event": event_type,
"data": event_data,
"id": event_id,
}
except json.JSONDecodeError:
# If we can't parse, send as generic message
yield {
"event": "message",
"data": event_data,
}
except asyncio.CancelledError:
logger.info(f"Event stream cancelled for project {project_id}")
raise
except Exception as e:
logger.error(f"Error in event stream for project {project_id}: {e}")
raise
@router.get(
"/projects/{project_id}/events/stream",
summary="Stream Project Events",
description="""
Stream real-time events for a project via Server-Sent Events (SSE).
**Authentication**: Required (Bearer token OR query parameter)
**Authorization**: Must have access to the project
**Authentication Methods**:
- Bearer token in Authorization header (preferred)
- Query parameter `token` (for EventSource compatibility)
Note: EventSource API doesn't support custom headers, so the query parameter
option is provided for browser-based SSE clients.
**SSE Event Format**:
```
event: agent.status_changed
id: 550e8400-e29b-41d4-a716-446655440000
data: {"id": "...", "type": "agent.status_changed", "project_id": "...", ...}
: keepalive
event: issue.created
id: 550e8400-e29b-41d4-a716-446655440001
data: {...}
```
**Reconnection**: Include the `Last-Event-ID` header with the last received
event ID to resume from where you left off.
**Keepalive**: The server sends a comment (`: keepalive`) every 30 seconds
to keep the connection alive.
**Rate Limit**: 10 connections/minute per IP
""",
response_class=EventSourceResponse,
responses={
200: {
"description": "SSE stream established",
"content": {"text/event-stream": {}},
},
401: {"description": "Not authenticated"},
403: {"description": "Not authorized to access this project"},
404: {"description": "Project not found"},
},
operation_id="stream_project_events",
)
@limiter.limit("10/minute")
async def stream_project_events(
request: Request,
project_id: UUID,
db: "AsyncSession" = Depends(get_db),
event_bus: EventBus = Depends(get_event_bus),
token: str | None = Query(
None, description="Auth token (for EventSource compatibility)"
),
authorization: str | None = Header(None, alias="Authorization"),
last_event_id: str | None = Header(None, alias="Last-Event-ID"),
):
"""
Stream real-time events for a project via SSE.
This endpoint establishes a persistent SSE connection that streams
project events to the client in real-time. The connection includes:
- Event streaming: All project events (agent updates, issues, etc.)
- Keepalive: Comment every 30 seconds to maintain connection
- Reconnection: Use Last-Event-ID header to resume after disconnect
The connection is automatically cleaned up when the client disconnects.
"""
# Authenticate user (supports both header and query param tokens)
current_user = await get_current_user_sse(
db=db, authorization=authorization, token=token
)
logger.info(
f"SSE connection request for project {project_id} "
f"by user {current_user.id} "
f"(last_event_id={last_event_id})"
)
# Check project access
has_access = await check_project_access(project_id, current_user, db)
if not has_access:
raise AuthorizationError(
message=f"You don't have access to project {project_id}",
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS,
)
# Return SSE response
return EventSourceResponse(
event_generator(
project_id=project_id,
event_bus=event_bus,
last_event_id=last_event_id,
),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no", # Disable nginx buffering
},
)
@router.post(
"/projects/{project_id}/events/test",
summary="Send Test Event (Development Only)",
description="""
Send a test event to a project's event stream. This endpoint is
intended for development and testing purposes.
**Authentication**: Required (Bearer token)
**Authorization**: Must have access to the project
**Note**: This endpoint should be disabled or restricted in production.
""",
response_model=dict,
responses={
200: {"description": "Test event sent"},
401: {"description": "Not authenticated"},
403: {"description": "Not authorized to access this project"},
},
operation_id="send_test_event",
)
async def send_test_event(
project_id: UUID,
current_user: User = Depends(get_current_user),
event_bus: EventBus = Depends(get_event_bus),
db: "AsyncSession" = Depends(get_db),
):
"""
Send a test event to the project's event stream.
This is useful for testing SSE connections during development.
"""
# Check project access
has_access = await check_project_access(project_id, current_user, db)
if not has_access:
raise AuthorizationError(
message=f"You don't have access to project {project_id}",
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS,
)
# Create and publish test event using the Event schema
event = EventBus.create_event(
event_type=EventType.AGENT_MESSAGE,
project_id=project_id,
actor_type="user",
actor_id=current_user.id,
payload={
"message": "Test event from SSE endpoint",
"message_type": "info",
},
)
channel = event_bus.get_project_channel(project_id)
await event_bus.publish(channel, event)
logger.info(f"Test event sent to project {project_id}: {event.id}")
return {
"success": True,
"event_id": event.id,
"event_type": event.type.value,
"message": "Test event sent successfully",
}

View File

@@ -1,968 +0,0 @@
# app/api/routes/issues.py
"""
Issue CRUD API endpoints for Syndarix projects.
Provides endpoints for managing issues within projects, including:
- Create, read, update, delete operations
- Filtering by status, priority, labels, sprint, assigned agent
- Search across title and body
- Assignment to agents
- External issue tracker sync triggers
"""
import logging
import os
from typing import Any
from uuid import UUID
from fastapi import APIRouter, Depends, Query, Request, status
from slowapi import Limiter
from slowapi.util import get_remote_address
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.dependencies.auth import get_current_user
from app.core.database import get_db
from app.core.exceptions import (
AuthorizationError,
NotFoundError,
ValidationException,
)
from app.crud.syndarix.agent_instance import agent_instance as agent_instance_crud
from app.crud.syndarix.issue import issue as issue_crud
from app.crud.syndarix.project import project as project_crud
from app.crud.syndarix.sprint import sprint as sprint_crud
from app.models.syndarix.enums import (
AgentStatus,
IssuePriority,
IssueStatus,
SprintStatus,
SyncStatus,
)
from app.models.user import User
from app.schemas.common import (
MessageResponse,
PaginatedResponse,
PaginationParams,
SortOrder,
create_pagination_meta,
)
from app.schemas.errors import ErrorCode
from app.schemas.syndarix.issue import (
IssueAssign,
IssueCreate,
IssueResponse,
IssueStats,
IssueUpdate,
)
router = APIRouter()
logger = logging.getLogger(__name__)
# Initialize limiter for this router
limiter = Limiter(key_func=get_remote_address)
# Use higher rate limits in test environment
IS_TEST = os.getenv("IS_TEST", "False") == "True"
RATE_MULTIPLIER = 100 if IS_TEST else 1
async def verify_project_ownership(
db: AsyncSession,
project_id: UUID,
user: User,
) -> None:
"""
Verify that the user owns the project or is a superuser.
Args:
db: Database session
project_id: Project UUID to verify
user: Current authenticated user
Raises:
NotFoundError: If project does not exist
AuthorizationError: If user does not own the project
"""
project = await project_crud.get(db, id=project_id)
if not project:
raise NotFoundError(
message=f"Project {project_id} not found",
error_code=ErrorCode.NOT_FOUND,
)
if not user.is_superuser and project.owner_id != user.id:
raise AuthorizationError(
message="You do not have access to this project",
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS,
)
def _build_issue_response(
issue: Any,
project_name: str | None = None,
project_slug: str | None = None,
sprint_name: str | None = None,
assigned_agent_type_name: str | None = None,
) -> IssueResponse:
"""
Build an IssueResponse from an Issue model instance.
Args:
issue: Issue model instance
project_name: Optional project name from relationship
project_slug: Optional project slug from relationship
sprint_name: Optional sprint name from relationship
assigned_agent_type_name: Optional agent type name from relationship
Returns:
IssueResponse schema instance
"""
return IssueResponse(
id=issue.id,
project_id=issue.project_id,
title=issue.title,
body=issue.body,
status=issue.status,
priority=issue.priority,
labels=issue.labels or [],
assigned_agent_id=issue.assigned_agent_id,
human_assignee=issue.human_assignee,
sprint_id=issue.sprint_id,
story_points=issue.story_points,
external_tracker_type=issue.external_tracker_type,
external_issue_id=issue.external_issue_id,
remote_url=issue.remote_url,
external_issue_number=issue.external_issue_number,
sync_status=issue.sync_status,
last_synced_at=issue.last_synced_at,
external_updated_at=issue.external_updated_at,
closed_at=issue.closed_at,
created_at=issue.created_at,
updated_at=issue.updated_at,
project_name=project_name,
project_slug=project_slug,
sprint_name=sprint_name,
assigned_agent_type_name=assigned_agent_type_name,
)
# ===== Issue CRUD Endpoints =====
@router.post(
"/projects/{project_id}/issues",
response_model=IssueResponse,
status_code=status.HTTP_201_CREATED,
summary="Create Issue",
description="Create a new issue in a project",
operation_id="create_issue",
)
@limiter.limit(f"{60 * RATE_MULTIPLIER}/minute")
async def create_issue(
request: Request,
project_id: UUID,
issue_in: IssueCreate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Create a new issue within a project.
The user must own the project or be a superuser.
The project_id in the path takes precedence over any project_id in the body.
Args:
request: FastAPI request object (for rate limiting)
project_id: UUID of the project to create the issue in
issue_in: Issue creation data
current_user: Authenticated user
db: Database session
Returns:
Created issue with full details
Raises:
NotFoundError: If project not found
AuthorizationError: If user lacks access
ValidationException: If assigned agent not in project
"""
# Verify project access
await verify_project_ownership(db, project_id, current_user)
# Override project_id from path
issue_in.project_id = project_id
# Validate assigned agent if provided
if issue_in.assigned_agent_id:
agent = await agent_instance_crud.get(db, id=issue_in.assigned_agent_id)
if not agent:
raise NotFoundError(
message=f"Agent instance {issue_in.assigned_agent_id} not found",
error_code=ErrorCode.NOT_FOUND,
)
if agent.project_id != project_id:
raise ValidationException(
message="Agent instance does not belong to this project",
error_code=ErrorCode.VALIDATION_ERROR,
field="assigned_agent_id",
)
if agent.status == AgentStatus.TERMINATED:
raise ValidationException(
message="Cannot assign issue to a terminated agent",
error_code=ErrorCode.VALIDATION_ERROR,
field="assigned_agent_id",
)
# Validate sprint if provided (IDOR prevention)
if issue_in.sprint_id:
sprint = await sprint_crud.get(db, id=issue_in.sprint_id)
if not sprint:
raise NotFoundError(
message=f"Sprint {issue_in.sprint_id} not found",
error_code=ErrorCode.NOT_FOUND,
)
if sprint.project_id != project_id:
raise ValidationException(
message="Sprint does not belong to this project",
error_code=ErrorCode.VALIDATION_ERROR,
field="sprint_id",
)
try:
issue = await issue_crud.create(db, obj_in=issue_in)
logger.info(
f"User {current_user.email} created issue '{issue.title}' "
f"in project {project_id}"
)
# Get project details for response
project = await project_crud.get(db, id=project_id)
return _build_issue_response(
issue,
project_name=project.name if project else None,
project_slug=project.slug if project else None,
)
except ValueError as e:
logger.warning(f"Failed to create issue: {e!s}")
raise ValidationException(
message=str(e),
error_code=ErrorCode.VALIDATION_ERROR,
)
except Exception as e:
logger.error(f"Error creating issue: {e!s}", exc_info=True)
raise
@router.get(
"/projects/{project_id}/issues",
response_model=PaginatedResponse[IssueResponse],
summary="List Issues",
description="Get paginated list of issues in a project with filtering",
operation_id="list_issues",
)
@limiter.limit(f"{120 * RATE_MULTIPLIER}/minute")
async def list_issues(
request: Request,
project_id: UUID,
pagination: PaginationParams = Depends(),
status_filter: IssueStatus | None = Query(
None, alias="status", description="Filter by issue status"
),
priority: IssuePriority | None = Query(None, description="Filter by priority"),
labels: list[str] | None = Query(
None, description="Filter by labels (comma-separated)"
),
sprint_id: UUID | None = Query(None, description="Filter by sprint ID"),
assigned_agent_id: UUID | None = Query(
None, description="Filter by assigned agent ID"
),
sync_status: SyncStatus | None = Query(None, description="Filter by sync status"),
search: str | None = Query(
None, min_length=1, max_length=100, description="Search in title and body"
),
sort_by: str = Query(
"created_at",
description="Field to sort by (created_at, updated_at, priority, status, title)",
),
sort_order: SortOrder = Query(SortOrder.DESC, description="Sort order"),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> Any:
"""
List issues in a project with comprehensive filtering options.
Supports filtering by:
- status: Issue status (open, in_progress, in_review, blocked, closed)
- priority: Issue priority (low, medium, high, critical)
- labels: Match issues containing any of the provided labels
- sprint_id: Issues in a specific sprint
- assigned_agent_id: Issues assigned to a specific agent
- sync_status: External tracker sync status
- search: Full-text search in title and body
Args:
request: FastAPI request object
project_id: Project UUID
pagination: Pagination parameters
status_filter: Optional status filter
priority: Optional priority filter
labels: Optional labels filter
sprint_id: Optional sprint filter
assigned_agent_id: Optional agent assignment filter
sync_status: Optional sync status filter
search: Optional search query
sort_by: Field to sort by
sort_order: Sort direction
current_user: Authenticated user
db: Database session
Returns:
Paginated list of issues matching filters
"""
# Verify project access
await verify_project_ownership(db, project_id, current_user)
try:
# Get filtered issues
issues, total = await issue_crud.get_by_project(
db,
project_id=project_id,
status=status_filter,
priority=priority,
sprint_id=sprint_id,
assigned_agent_id=assigned_agent_id,
labels=labels,
search=search,
skip=pagination.offset,
limit=pagination.limit,
sort_by=sort_by,
sort_order=sort_order.value,
)
# Build response objects
issue_responses = [_build_issue_response(issue) for issue in issues]
pagination_meta = create_pagination_meta(
total=total,
page=pagination.page,
limit=pagination.limit,
items_count=len(issue_responses),
)
return PaginatedResponse(data=issue_responses, pagination=pagination_meta)
except Exception as e:
logger.error(
f"Error listing issues for project {project_id}: {e!s}", exc_info=True
)
raise
# ===== Issue Statistics Endpoint =====
# NOTE: This endpoint MUST be defined before /{issue_id} routes
# to prevent FastAPI from trying to parse "stats" as a UUID
@router.get(
"/projects/{project_id}/issues/stats",
response_model=IssueStats,
summary="Get Issue Statistics",
description="Get aggregated issue statistics for a project",
operation_id="get_issue_stats",
)
@limiter.limit(f"{60 * RATE_MULTIPLIER}/minute")
async def get_issue_stats(
request: Request,
project_id: UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Get aggregated statistics for issues in a project.
Returns counts by status and priority, along with story point totals.
Args:
request: FastAPI request object
project_id: Project UUID
current_user: Authenticated user
db: Database session
Returns:
Issue statistics including counts by status/priority and story points
Raises:
NotFoundError: If project not found
AuthorizationError: If user lacks access
"""
# Verify project access
await verify_project_ownership(db, project_id, current_user)
try:
stats = await issue_crud.get_project_stats(db, project_id=project_id)
return IssueStats(**stats)
except Exception as e:
logger.error(
f"Error getting issue stats for project {project_id}: {e!s}",
exc_info=True,
)
raise
@router.get(
"/projects/{project_id}/issues/{issue_id}",
response_model=IssueResponse,
summary="Get Issue",
description="Get detailed information about a specific issue",
operation_id="get_issue",
)
@limiter.limit(f"{120 * RATE_MULTIPLIER}/minute")
async def get_issue(
request: Request,
project_id: UUID,
issue_id: UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Get detailed information about a specific issue.
Returns the issue with expanded relationship data including
project name, sprint name, and assigned agent type name.
Args:
request: FastAPI request object
project_id: Project UUID
issue_id: Issue UUID
current_user: Authenticated user
db: Database session
Returns:
Issue details with relationship data
Raises:
NotFoundError: If project or issue not found
AuthorizationError: If user lacks access
"""
# Verify project access
await verify_project_ownership(db, project_id, current_user)
# Get issue with details
issue_data = await issue_crud.get_with_details(db, issue_id=issue_id)
if not issue_data:
raise NotFoundError(
message=f"Issue {issue_id} not found",
error_code=ErrorCode.NOT_FOUND,
)
issue = issue_data["issue"]
# Verify issue belongs to the project
if issue.project_id != project_id:
raise NotFoundError(
message=f"Issue {issue_id} not found in project {project_id}",
error_code=ErrorCode.NOT_FOUND,
)
return _build_issue_response(
issue,
project_name=issue_data.get("project_name"),
project_slug=issue_data.get("project_slug"),
sprint_name=issue_data.get("sprint_name"),
assigned_agent_type_name=issue_data.get("assigned_agent_type_name"),
)
@router.patch(
"/projects/{project_id}/issues/{issue_id}",
response_model=IssueResponse,
summary="Update Issue",
description="Update an existing issue",
operation_id="update_issue",
)
@limiter.limit(f"{60 * RATE_MULTIPLIER}/minute")
async def update_issue(
request: Request,
project_id: UUID,
issue_id: UUID,
issue_in: IssueUpdate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Update an existing issue.
All fields are optional - only provided fields will be updated.
Validates that assigned agent belongs to the same project.
Args:
request: FastAPI request object
project_id: Project UUID
issue_id: Issue UUID
issue_in: Fields to update
current_user: Authenticated user
db: Database session
Returns:
Updated issue details
Raises:
NotFoundError: If project or issue not found
AuthorizationError: If user lacks access
ValidationException: If validation fails
"""
# Verify project access
await verify_project_ownership(db, project_id, current_user)
# Get existing issue
issue = await issue_crud.get(db, id=issue_id)
if not issue:
raise NotFoundError(
message=f"Issue {issue_id} not found",
error_code=ErrorCode.NOT_FOUND,
)
# Verify issue belongs to the project
if issue.project_id != project_id:
raise NotFoundError(
message=f"Issue {issue_id} not found in project {project_id}",
error_code=ErrorCode.NOT_FOUND,
)
# Validate assigned agent if being updated
if issue_in.assigned_agent_id is not None:
agent = await agent_instance_crud.get(db, id=issue_in.assigned_agent_id)
if not agent:
raise NotFoundError(
message=f"Agent instance {issue_in.assigned_agent_id} not found",
error_code=ErrorCode.NOT_FOUND,
)
if agent.project_id != project_id:
raise ValidationException(
message="Agent instance does not belong to this project",
error_code=ErrorCode.VALIDATION_ERROR,
field="assigned_agent_id",
)
if agent.status == AgentStatus.TERMINATED:
raise ValidationException(
message="Cannot assign issue to a terminated agent",
error_code=ErrorCode.VALIDATION_ERROR,
field="assigned_agent_id",
)
# Validate sprint if being updated (IDOR prevention and status validation)
if issue_in.sprint_id is not None:
sprint = await sprint_crud.get(db, id=issue_in.sprint_id)
if not sprint:
raise NotFoundError(
message=f"Sprint {issue_in.sprint_id} not found",
error_code=ErrorCode.NOT_FOUND,
)
if sprint.project_id != project_id:
raise ValidationException(
message="Sprint does not belong to this project",
error_code=ErrorCode.VALIDATION_ERROR,
field="sprint_id",
)
# Cannot add issues to completed or cancelled sprints
if sprint.status in [SprintStatus.COMPLETED, SprintStatus.CANCELLED]:
raise ValidationException(
message=f"Cannot add issues to sprint with status '{sprint.status.value}'",
error_code=ErrorCode.VALIDATION_ERROR,
field="sprint_id",
)
try:
updated_issue = await issue_crud.update(db, db_obj=issue, obj_in=issue_in)
logger.info(
f"User {current_user.email} updated issue {issue_id} in project {project_id}"
)
# Get full details for response
issue_data = await issue_crud.get_with_details(db, issue_id=issue_id)
return _build_issue_response(
updated_issue,
project_name=issue_data.get("project_name") if issue_data else None,
project_slug=issue_data.get("project_slug") if issue_data else None,
sprint_name=issue_data.get("sprint_name") if issue_data else None,
assigned_agent_type_name=issue_data.get("assigned_agent_type_name")
if issue_data
else None,
)
except ValueError as e:
logger.warning(f"Failed to update issue {issue_id}: {e!s}")
raise ValidationException(
message=str(e),
error_code=ErrorCode.VALIDATION_ERROR,
)
except Exception as e:
logger.error(f"Error updating issue {issue_id}: {e!s}", exc_info=True)
raise
@router.delete(
"/projects/{project_id}/issues/{issue_id}",
response_model=MessageResponse,
summary="Delete Issue",
description="Delete an issue permanently",
operation_id="delete_issue",
)
@limiter.limit(f"{30 * RATE_MULTIPLIER}/minute")
async def delete_issue(
request: Request,
project_id: UUID,
issue_id: UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Delete an issue permanently.
The issue will be permanently removed from the database.
Args:
request: FastAPI request object
project_id: Project UUID
issue_id: Issue UUID
current_user: Authenticated user
db: Database session
Returns:
Success message
Raises:
NotFoundError: If project or issue not found
AuthorizationError: If user lacks access
"""
# Verify project access
await verify_project_ownership(db, project_id, current_user)
# Get existing issue
issue = await issue_crud.get(db, id=issue_id)
if not issue:
raise NotFoundError(
message=f"Issue {issue_id} not found",
error_code=ErrorCode.NOT_FOUND,
)
# Verify issue belongs to the project
if issue.project_id != project_id:
raise NotFoundError(
message=f"Issue {issue_id} not found in project {project_id}",
error_code=ErrorCode.NOT_FOUND,
)
try:
issue_title = issue.title
await issue_crud.remove(db, id=issue_id)
logger.info(
f"User {current_user.email} deleted issue {issue_id} "
f"('{issue_title}') from project {project_id}"
)
return MessageResponse(
success=True,
message=f"Issue '{issue_title}' has been deleted",
)
except Exception as e:
logger.error(f"Error deleting issue {issue_id}: {e!s}", exc_info=True)
raise
# ===== Issue Assignment Endpoint =====
@router.post(
"/projects/{project_id}/issues/{issue_id}/assign",
response_model=IssueResponse,
summary="Assign Issue",
description="Assign an issue to an agent or human",
operation_id="assign_issue",
)
@limiter.limit(f"{60 * RATE_MULTIPLIER}/minute")
async def assign_issue(
request: Request,
project_id: UUID,
issue_id: UUID,
assignment: IssueAssign,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Assign an issue to an agent or human.
Only one type of assignment is allowed at a time:
- assigned_agent_id: Assign to an AI agent instance
- human_assignee: Assign to a human (name/email string)
To unassign, pass both as null/None.
Args:
request: FastAPI request object
project_id: Project UUID
issue_id: Issue UUID
assignment: Assignment data
current_user: Authenticated user
db: Database session
Returns:
Updated issue with assignment
Raises:
NotFoundError: If project, issue, or agent not found
AuthorizationError: If user lacks access
ValidationException: If agent not in project
"""
# Verify project access
await verify_project_ownership(db, project_id, current_user)
# Get existing issue
issue = await issue_crud.get(db, id=issue_id)
if not issue:
raise NotFoundError(
message=f"Issue {issue_id} not found",
error_code=ErrorCode.NOT_FOUND,
)
# Verify issue belongs to the project
if issue.project_id != project_id:
raise NotFoundError(
message=f"Issue {issue_id} not found in project {project_id}",
error_code=ErrorCode.NOT_FOUND,
)
# Process assignment based on type
if assignment.assigned_agent_id:
# Validate agent exists and belongs to project
agent = await agent_instance_crud.get(db, id=assignment.assigned_agent_id)
if not agent:
raise NotFoundError(
message=f"Agent instance {assignment.assigned_agent_id} not found",
error_code=ErrorCode.NOT_FOUND,
)
if agent.project_id != project_id:
raise ValidationException(
message="Agent instance does not belong to this project",
error_code=ErrorCode.VALIDATION_ERROR,
field="assigned_agent_id",
)
if agent.status == AgentStatus.TERMINATED:
raise ValidationException(
message="Cannot assign issue to a terminated agent",
error_code=ErrorCode.VALIDATION_ERROR,
field="assigned_agent_id",
)
updated_issue = await issue_crud.assign_to_agent(
db, issue_id=issue_id, agent_id=assignment.assigned_agent_id
)
logger.info(
f"User {current_user.email} assigned issue {issue_id} to agent {agent.name}"
)
elif assignment.human_assignee:
updated_issue = await issue_crud.assign_to_human(
db, issue_id=issue_id, human_assignee=assignment.human_assignee
)
logger.info(
f"User {current_user.email} assigned issue {issue_id} "
f"to human '{assignment.human_assignee}'"
)
else:
# Unassign - clear both agent and human
updated_issue = await issue_crud.assign_to_agent(
db, issue_id=issue_id, agent_id=None
)
logger.info(f"User {current_user.email} unassigned issue {issue_id}")
if not updated_issue:
raise NotFoundError(
message=f"Issue {issue_id} not found",
error_code=ErrorCode.NOT_FOUND,
)
# Get full details for response
issue_data = await issue_crud.get_with_details(db, issue_id=issue_id)
return _build_issue_response(
updated_issue,
project_name=issue_data.get("project_name") if issue_data else None,
project_slug=issue_data.get("project_slug") if issue_data else None,
sprint_name=issue_data.get("sprint_name") if issue_data else None,
assigned_agent_type_name=issue_data.get("assigned_agent_type_name")
if issue_data
else None,
)
@router.delete(
"/projects/{project_id}/issues/{issue_id}/assignment",
response_model=IssueResponse,
summary="Unassign Issue",
description="""
Remove agent/human assignment from an issue.
**Authentication**: Required (Bearer token)
**Authorization**: Project owner or superuser
This clears both agent and human assignee fields.
**Rate Limit**: 60 requests/minute
""",
operation_id="unassign_issue",
)
@limiter.limit(f"{60 * RATE_MULTIPLIER}/minute")
async def unassign_issue(
request: Request,
project_id: UUID,
issue_id: UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Remove assignment from an issue.
Clears both assigned_agent_id and human_assignee fields.
"""
# Verify project access
await verify_project_ownership(db, project_id, current_user)
# Get existing issue
issue = await issue_crud.get(db, id=issue_id)
if not issue:
raise NotFoundError(
message=f"Issue {issue_id} not found",
error_code=ErrorCode.NOT_FOUND,
)
# Verify issue belongs to project (IDOR prevention)
if issue.project_id != project_id:
raise NotFoundError(
message=f"Issue {issue_id} not found in project {project_id}",
error_code=ErrorCode.NOT_FOUND,
)
# Unassign the issue
updated_issue = await issue_crud.unassign(db, issue_id=issue_id)
if not updated_issue:
raise NotFoundError(
message=f"Issue {issue_id} not found",
error_code=ErrorCode.NOT_FOUND,
)
logger.info(f"User {current_user.email} unassigned issue {issue_id}")
# Get full details for response
issue_data = await issue_crud.get_with_details(db, issue_id=issue_id)
return _build_issue_response(
updated_issue,
project_name=issue_data.get("project_name") if issue_data else None,
project_slug=issue_data.get("project_slug") if issue_data else None,
sprint_name=issue_data.get("sprint_name") if issue_data else None,
assigned_agent_type_name=issue_data.get("assigned_agent_type_name")
if issue_data
else None,
)
# ===== Issue Sync Endpoint =====
@router.post(
"/projects/{project_id}/issues/{issue_id}/sync",
response_model=MessageResponse,
summary="Trigger Issue Sync",
description="Trigger synchronization with external issue tracker",
operation_id="sync_issue",
)
@limiter.limit(f"{30 * RATE_MULTIPLIER}/minute")
async def sync_issue(
request: Request,
project_id: UUID,
issue_id: UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Trigger synchronization of an issue with its external tracker.
This endpoint queues a sync task for the issue. The actual synchronization
happens asynchronously via Celery.
Prerequisites:
- Issue must have external_tracker_type configured
- Project must have integration settings for the tracker
Args:
request: FastAPI request object
project_id: Project UUID
issue_id: Issue UUID
current_user: Authenticated user
db: Database session
Returns:
Message indicating sync has been triggered
Raises:
NotFoundError: If project or issue not found
AuthorizationError: If user lacks access
ValidationException: If issue has no external tracker
"""
# Verify project access
await verify_project_ownership(db, project_id, current_user)
# Get existing issue
issue = await issue_crud.get(db, id=issue_id)
if not issue:
raise NotFoundError(
message=f"Issue {issue_id} not found",
error_code=ErrorCode.NOT_FOUND,
)
# Verify issue belongs to the project
if issue.project_id != project_id:
raise NotFoundError(
message=f"Issue {issue_id} not found in project {project_id}",
error_code=ErrorCode.NOT_FOUND,
)
# Check if issue has external tracker configured
if not issue.external_tracker_type:
raise ValidationException(
message="Issue does not have an external tracker configured",
error_code=ErrorCode.VALIDATION_ERROR,
field="external_tracker_type",
)
# Update sync status to pending
await issue_crud.update_sync_status(
db,
issue_id=issue_id,
sync_status=SyncStatus.PENDING,
)
# TODO: Queue Celery task for actual sync
# When Celery is set up, this will be:
# from app.tasks.sync import sync_issue_task
# sync_issue_task.delay(str(issue_id))
logger.info(
f"User {current_user.email} triggered sync for issue {issue_id} "
f"(tracker: {issue.external_tracker_type})"
)
return MessageResponse(
success=True,
message=f"Sync triggered for issue '{issue.title}'. "
f"Status will update when complete.",
)

View File

@@ -1,446 +0,0 @@
"""
MCP (Model Context Protocol) API Endpoints
Provides REST endpoints for managing MCP server connections
and executing tool calls.
"""
import logging
import re
from typing import Annotated, Any
from fastapi import APIRouter, Depends, HTTPException, Path, status
from pydantic import BaseModel, Field
from app.api.dependencies.permissions import require_superuser
from app.models.user import User
from app.services.mcp import (
MCPCircuitOpenError,
MCPClientManager,
MCPConnectionError,
MCPError,
MCPServerNotFoundError,
MCPTimeoutError,
MCPToolError,
MCPToolNotFoundError,
get_mcp_client,
)
logger = logging.getLogger(__name__)
router = APIRouter()
# Server name validation pattern: alphanumeric, hyphens, underscores, 1-64 chars
SERVER_NAME_PATTERN = re.compile(r"^[a-zA-Z0-9_-]{1,64}$")
# Type alias for validated server name path parameter
ServerNamePath = Annotated[
str,
Path(
description="MCP server name",
min_length=1,
max_length=64,
pattern=r"^[a-zA-Z0-9_-]+$",
),
]
# ============================================================================
# Request/Response Schemas
# ============================================================================
class ServerInfo(BaseModel):
"""Information about an MCP server."""
name: str = Field(..., description="Server name")
url: str = Field(..., description="Server URL")
enabled: bool = Field(..., description="Whether server is enabled")
timeout: int = Field(..., description="Request timeout in seconds")
transport: str = Field(..., description="Transport type (http, stdio, sse)")
description: str | None = Field(None, description="Server description")
class ServerListResponse(BaseModel):
"""Response containing list of MCP servers."""
servers: list[ServerInfo]
total: int
class ToolInfoResponse(BaseModel):
"""Information about an MCP tool."""
name: str = Field(..., description="Tool name")
description: str | None = Field(None, description="Tool description")
server_name: str | None = Field(None, description="Server providing the tool")
input_schema: dict[str, Any] | None = Field(
None, description="JSON schema for input"
)
class ToolListResponse(BaseModel):
"""Response containing list of tools."""
tools: list[ToolInfoResponse]
total: int
class ServerHealthStatus(BaseModel):
"""Health status for a server."""
name: str
healthy: bool
state: str
url: str
error: str | None = None
tools_count: int = 0
class HealthCheckResponse(BaseModel):
"""Response containing health status of all servers."""
servers: dict[str, ServerHealthStatus]
healthy_count: int
unhealthy_count: int
total: int
class ToolCallRequest(BaseModel):
"""Request to execute a tool."""
server: str = Field(..., description="MCP server name")
tool: str = Field(..., description="Tool name to execute")
arguments: dict[str, Any] = Field(
default_factory=dict,
description="Tool arguments",
)
timeout: float | None = Field(
None,
description="Optional timeout override in seconds",
)
class ToolCallResponse(BaseModel):
"""Response from tool execution."""
success: bool
data: Any | None = None
error: str | None = None
error_code: str | None = None
tool_name: str | None = None
server_name: str | None = None
execution_time_ms: float = 0.0
request_id: str | None = None
class CircuitBreakerStatus(BaseModel):
"""Status of a circuit breaker."""
server_name: str
state: str
failure_count: int
class CircuitBreakerListResponse(BaseModel):
"""Response containing circuit breaker statuses."""
circuit_breakers: list[CircuitBreakerStatus]
# ============================================================================
# Endpoints
# ============================================================================
@router.get(
"/servers",
response_model=ServerListResponse,
summary="List MCP Servers",
description="Get list of all registered MCP servers with their configurations.",
)
async def list_servers(
mcp: MCPClientManager = Depends(get_mcp_client),
) -> ServerListResponse:
"""List all registered MCP servers."""
servers = []
for name in mcp.list_servers():
try:
config = mcp.get_server_config(name)
servers.append(
ServerInfo(
name=name,
url=config.url,
enabled=config.enabled,
timeout=config.timeout,
transport=config.transport.value,
description=config.description,
)
)
except MCPServerNotFoundError:
continue
return ServerListResponse(
servers=servers,
total=len(servers),
)
@router.get(
"/servers/{server_name}/tools",
response_model=ToolListResponse,
summary="List Server Tools",
description="Get list of tools available on a specific MCP server.",
)
async def list_server_tools(
server_name: ServerNamePath,
mcp: MCPClientManager = Depends(get_mcp_client),
) -> ToolListResponse:
"""List all tools available on a specific server."""
try:
tools = await mcp.list_tools(server_name)
return ToolListResponse(
tools=[
ToolInfoResponse(
name=t.name,
description=t.description,
server_name=t.server_name,
input_schema=t.input_schema,
)
for t in tools
],
total=len(tools),
)
except MCPServerNotFoundError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Server not found: {server_name}",
) from e
@router.get(
"/tools",
response_model=ToolListResponse,
summary="List All Tools",
description="Get list of all tools from all MCP servers.",
)
async def list_all_tools(
mcp: MCPClientManager = Depends(get_mcp_client),
) -> ToolListResponse:
"""List all tools from all servers."""
tools = await mcp.list_all_tools()
return ToolListResponse(
tools=[
ToolInfoResponse(
name=t.name,
description=t.description,
server_name=t.server_name,
input_schema=t.input_schema,
)
for t in tools
],
total=len(tools),
)
@router.get(
"/health",
response_model=HealthCheckResponse,
summary="Health Check",
description="Check health status of all MCP servers.",
)
async def health_check(
mcp: MCPClientManager = Depends(get_mcp_client),
) -> HealthCheckResponse:
"""Perform health check on all MCP servers."""
health_results = await mcp.health_check()
servers = {
name: ServerHealthStatus(
name=status.name,
healthy=status.healthy,
state=status.state,
url=status.url,
error=status.error,
tools_count=status.tools_count,
)
for name, status in health_results.items()
}
healthy_count = sum(1 for s in servers.values() if s.healthy)
unhealthy_count = len(servers) - healthy_count
return HealthCheckResponse(
servers=servers,
healthy_count=healthy_count,
unhealthy_count=unhealthy_count,
total=len(servers),
)
@router.post(
"/call",
response_model=ToolCallResponse,
summary="Execute Tool (Admin Only)",
description="Execute a tool on an MCP server. Requires superuser privileges.",
)
async def call_tool(
request: ToolCallRequest,
current_user: User = Depends(require_superuser),
mcp: MCPClientManager = Depends(get_mcp_client),
) -> ToolCallResponse:
"""
Execute a tool on an MCP server.
This endpoint is restricted to superusers for direct tool execution.
Normal tool execution should go through agent workflows.
"""
logger.info(
"Tool call by user %s: %s.%s",
current_user.id,
request.server,
request.tool,
)
try:
result = await mcp.call_tool(
server=request.server,
tool=request.tool,
args=request.arguments,
timeout=request.timeout,
)
return ToolCallResponse(
success=result.success,
data=result.data,
error=result.error,
error_code=result.error_code,
tool_name=result.tool_name,
server_name=result.server_name,
execution_time_ms=result.execution_time_ms,
request_id=result.request_id,
)
except MCPCircuitOpenError as e:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=f"Server temporarily unavailable: {e.server_name}",
) from e
except MCPToolNotFoundError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Tool not found: {e.tool_name}",
) from e
except MCPServerNotFoundError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Server not found: {e.server_name}",
) from e
except MCPTimeoutError as e:
raise HTTPException(
status_code=status.HTTP_504_GATEWAY_TIMEOUT,
detail=str(e),
) from e
except MCPConnectionError as e:
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail=str(e),
) from e
except MCPToolError as e:
# Tool errors are returned in the response, not as HTTP errors
return ToolCallResponse(
success=False,
error=str(e),
error_code=e.error_code,
tool_name=e.tool_name,
server_name=e.server_name,
)
except MCPError as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e),
) from e
@router.get(
"/circuit-breakers",
response_model=CircuitBreakerListResponse,
summary="List Circuit Breakers",
description="Get status of all circuit breakers.",
)
async def list_circuit_breakers(
mcp: MCPClientManager = Depends(get_mcp_client),
) -> CircuitBreakerListResponse:
"""Get status of all circuit breakers."""
status_dict = mcp.get_circuit_breaker_status()
return CircuitBreakerListResponse(
circuit_breakers=[
CircuitBreakerStatus(
server_name=name,
state=info.get("state", "unknown"),
failure_count=info.get("failure_count", 0),
)
for name, info in status_dict.items()
]
)
@router.post(
"/circuit-breakers/{server_name}/reset",
status_code=status.HTTP_204_NO_CONTENT,
summary="Reset Circuit Breaker (Admin Only)",
description="Manually reset a circuit breaker for a server.",
)
async def reset_circuit_breaker(
server_name: ServerNamePath,
current_user: User = Depends(require_superuser),
mcp: MCPClientManager = Depends(get_mcp_client),
) -> None:
"""Manually reset a circuit breaker."""
logger.info(
"Circuit breaker reset by user %s for server %s",
current_user.id,
server_name,
)
success = await mcp.reset_circuit_breaker(server_name)
if not success:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"No circuit breaker found for server: {server_name}",
)
@router.post(
"/servers/{server_name}/reconnect",
status_code=status.HTTP_204_NO_CONTENT,
summary="Reconnect to Server (Admin Only)",
description="Force reconnection to an MCP server.",
)
async def reconnect_server(
server_name: ServerNamePath,
current_user: User = Depends(require_superuser),
mcp: MCPClientManager = Depends(get_mcp_client),
) -> None:
"""Force reconnection to an MCP server."""
logger.info(
"Reconnect requested by user %s for server %s",
current_user.id,
server_name,
)
try:
await mcp.disconnect(server_name)
await mcp.connect(server_name)
except MCPServerNotFoundError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Server not found: {server_name}",
) from e
except MCPConnectionError as e:
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail=f"Failed to reconnect: {e}",
) from e

View File

@@ -25,8 +25,6 @@ from app.core.auth import decode_token
from app.core.config import settings from app.core.config import settings
from app.core.database import get_db from app.core.database import get_db
from app.core.exceptions import AuthenticationError as AuthError from app.core.exceptions import AuthenticationError as AuthError
from app.crud import oauth_account
from app.crud.session import session as session_crud
from app.models.user import User from app.models.user import User
from app.schemas.oauth import ( from app.schemas.oauth import (
OAuthAccountsListResponse, OAuthAccountsListResponse,
@@ -38,6 +36,7 @@ from app.schemas.oauth import (
from app.schemas.sessions import SessionCreate from app.schemas.sessions import SessionCreate
from app.schemas.users import Token from app.schemas.users import Token
from app.services.oauth_service import OAuthService from app.services.oauth_service import OAuthService
from app.services.session_service import session_service
from app.utils.device import extract_device_info from app.utils.device import extract_device_info
router = APIRouter() router = APIRouter()
@@ -82,17 +81,19 @@ async def _create_oauth_login_session(
location_country=device_info.location_country, location_country=device_info.location_country,
) )
await session_crud.create_session(db, obj_in=session_data) await session_service.create_session(db, obj_in=session_data)
logger.info( logger.info(
f"OAuth login successful: {user.email} via {provider} " "OAuth login successful: %s via %s from %s (IP: %s)",
f"from {device_info.device_name} (IP: {device_info.ip_address})" user.email,
provider,
device_info.device_name,
device_info.ip_address,
) )
except Exception as session_err: except Exception as session_err:
# Log but don't fail login if session creation fails # Log but don't fail login if session creation fails
logger.error( logger.exception(
f"Failed to create session for OAuth login {user.email}: {session_err!s}", "Failed to create session for OAuth login %s: %s", user.email, session_err
exc_info=True,
) )
@@ -177,13 +178,13 @@ async def get_authorization_url(
} }
except AuthError as e: except AuthError as e:
logger.warning(f"OAuth authorization failed: {e!s}") logger.warning("OAuth authorization failed: %s", e)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e), detail=str(e),
) )
except Exception as e: except Exception as e:
logger.error(f"OAuth authorization error: {e!s}", exc_info=True) logger.exception("OAuth authorization error: %s", e)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to create authorization URL", detail="Failed to create authorization URL",
@@ -251,13 +252,13 @@ async def handle_callback(
return result return result
except AuthError as e: except AuthError as e:
logger.warning(f"OAuth callback failed: {e!s}") logger.warning("OAuth callback failed: %s", e)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=str(e), detail=str(e),
) )
except Exception as e: except Exception as e:
logger.error(f"OAuth callback error: {e!s}", exc_info=True) logger.exception("OAuth callback error: %s", e)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="OAuth authentication failed", detail="OAuth authentication failed",
@@ -289,7 +290,7 @@ async def list_accounts(
Returns: Returns:
List of linked OAuth accounts List of linked OAuth accounts
""" """
accounts = await oauth_account.get_user_accounts(db, user_id=current_user.id) accounts = await OAuthService.get_user_accounts(db, user_id=current_user.id)
return OAuthAccountsListResponse(accounts=accounts) return OAuthAccountsListResponse(accounts=accounts)
@@ -338,13 +339,13 @@ async def unlink_account(
) )
except AuthError as e: except AuthError as e:
logger.warning(f"OAuth unlink failed for {current_user.email}: {e!s}") logger.warning("OAuth unlink failed for %s: %s", current_user.email, e)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e), detail=str(e),
) )
except Exception as e: except Exception as e:
logger.error(f"OAuth unlink error: {e!s}", exc_info=True) logger.exception("OAuth unlink error: %s", e)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to unlink OAuth account", detail="Failed to unlink OAuth account",
@@ -397,7 +398,7 @@ async def start_link(
) )
# Check if user already has this provider linked # Check if user already has this provider linked
existing = await oauth_account.get_user_account_by_provider( existing = await OAuthService.get_user_account_by_provider(
db, user_id=current_user.id, provider=provider db, user_id=current_user.id, provider=provider
) )
if existing: if existing:
@@ -420,13 +421,13 @@ async def start_link(
} }
except AuthError as e: except AuthError as e:
logger.warning(f"OAuth link authorization failed: {e!s}") logger.warning("OAuth link authorization failed: %s", e)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e), detail=str(e),
) )
except Exception as e: except Exception as e:
logger.error(f"OAuth link error: {e!s}", exc_info=True) logger.exception("OAuth link error: %s", e)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to create authorization URL", detail="Failed to create authorization URL",

View File

@@ -34,7 +34,6 @@ from app.api.dependencies.auth import (
) )
from app.core.config import settings from app.core.config import settings
from app.core.database import get_db from app.core.database import get_db
from app.crud import oauth_client as oauth_client_crud
from app.models.user import User from app.models.user import User
from app.schemas.oauth import ( from app.schemas.oauth import (
OAuthClientCreate, OAuthClientCreate,
@@ -453,7 +452,7 @@ async def token(
except Exception as e: except Exception as e:
# Log malformed Basic auth for security monitoring # Log malformed Basic auth for security monitoring
logger.warning( logger.warning(
f"Malformed Basic auth header in token request: {type(e).__name__}" "Malformed Basic auth header in token request: %s", type(e).__name__
) )
# Fall back to form body # Fall back to form body
@@ -564,7 +563,8 @@ async def revoke(
except Exception as e: except Exception as e:
# Log malformed Basic auth for security monitoring # Log malformed Basic auth for security monitoring
logger.warning( logger.warning(
f"Malformed Basic auth header in revoke request: {type(e).__name__}" "Malformed Basic auth header in revoke request: %s",
type(e).__name__,
) )
# Fall back to form body # Fall back to form body
@@ -586,7 +586,7 @@ async def revoke(
) )
except Exception as e: except Exception as e:
# Log but don't expose errors per RFC 7009 # Log but don't expose errors per RFC 7009
logger.warning(f"Token revocation error: {e}") logger.warning("Token revocation error: %s", e)
# Always return 200 OK per RFC 7009 # Always return 200 OK per RFC 7009
return {"status": "ok"} return {"status": "ok"}
@@ -635,7 +635,8 @@ async def introspect(
except Exception as e: except Exception as e:
# Log malformed Basic auth for security monitoring # Log malformed Basic auth for security monitoring
logger.warning( logger.warning(
f"Malformed Basic auth header in introspect request: {type(e).__name__}" "Malformed Basic auth header in introspect request: %s",
type(e).__name__,
) )
# Fall back to form body # Fall back to form body
@@ -655,8 +656,8 @@ async def introspect(
headers={"WWW-Authenticate": "Basic"}, headers={"WWW-Authenticate": "Basic"},
) )
except Exception as e: except Exception as e:
logger.warning(f"Token introspection error: {e}") logger.warning("Token introspection error: %s", e)
return OAuthTokenIntrospectionResponse(active=False) return OAuthTokenIntrospectionResponse(active=False) # pyright: ignore[reportCallIssue]
# ============================================================================ # ============================================================================
@@ -712,7 +713,7 @@ async def register_client(
client_type=client_type, client_type=client_type,
) )
client, secret = await oauth_client_crud.create_client(db, obj_in=client_data) client, secret = await provider_service.register_client(db, client_data)
# Update MCP server URL if provided # Update MCP server URL if provided
if mcp_server_url: if mcp_server_url:
@@ -750,7 +751,7 @@ async def list_clients(
current_user: User = Depends(get_current_superuser), current_user: User = Depends(get_current_superuser),
) -> list[OAuthClientResponse]: ) -> list[OAuthClientResponse]:
"""List all OAuth clients.""" """List all OAuth clients."""
clients = await oauth_client_crud.get_all_clients(db) clients = await provider_service.list_clients(db)
return [OAuthClientResponse.model_validate(c) for c in clients] return [OAuthClientResponse.model_validate(c) for c in clients]
@@ -776,7 +777,7 @@ async def delete_client(
detail="Client not found", detail="Client not found",
) )
await oauth_client_crud.delete_client(db, client_id=client_id) await provider_service.delete_client_by_id(db, client_id=client_id)
# ============================================================================ # ============================================================================
@@ -797,30 +798,7 @@ async def list_my_consents(
current_user: User = Depends(get_current_active_user), current_user: User = Depends(get_current_active_user),
) -> list[dict]: ) -> list[dict]:
"""List applications the user has authorized.""" """List applications the user has authorized."""
from sqlalchemy import select return await provider_service.list_user_consents(db, user_id=current_user.id)
from app.models.oauth_client import OAuthClient
from app.models.oauth_provider_token import OAuthConsent
result = await db.execute(
select(OAuthConsent, OAuthClient)
.join(OAuthClient, OAuthConsent.client_id == OAuthClient.client_id)
.where(OAuthConsent.user_id == current_user.id)
)
rows = result.all()
return [
{
"client_id": consent.client_id,
"client_name": client.client_name,
"client_description": client.client_description,
"granted_scopes": consent.granted_scopes.split()
if consent.granted_scopes
else [],
"granted_at": consent.created_at.isoformat(),
}
for consent, client in rows
]
@router.delete( @router.delete(

View File

@@ -15,8 +15,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.api.dependencies.auth import get_current_user from app.api.dependencies.auth import get_current_user
from app.api.dependencies.permissions import require_org_admin, require_org_membership from app.api.dependencies.permissions import require_org_admin, require_org_membership
from app.core.database import get_db from app.core.database import get_db
from app.core.exceptions import ErrorCode, NotFoundError
from app.crud.organization import organization as organization_crud
from app.models.user import User from app.models.user import User
from app.schemas.common import ( from app.schemas.common import (
PaginatedResponse, PaginatedResponse,
@@ -28,6 +26,7 @@ from app.schemas.organizations import (
OrganizationResponse, OrganizationResponse,
OrganizationUpdate, OrganizationUpdate,
) )
from app.services.organization_service import organization_service
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -54,7 +53,7 @@ async def get_my_organizations(
""" """
try: try:
# Get all org data in single query with JOIN and subquery # Get all org data in single query with JOIN and subquery
orgs_data = await organization_crud.get_user_organizations_with_details( orgs_data = await organization_service.get_user_organizations_with_details(
db, user_id=current_user.id, is_active=is_active db, user_id=current_user.id, is_active=is_active
) )
@@ -78,7 +77,7 @@ async def get_my_organizations(
return orgs_with_data return orgs_with_data
except Exception as e: except Exception as e:
logger.error(f"Error getting user organizations: {e!s}", exc_info=True) logger.exception("Error getting user organizations: %s", e)
raise raise
@@ -100,13 +99,7 @@ async def get_organization(
User must be a member of the organization. User must be a member of the organization.
""" """
try: try:
org = await organization_crud.get(db, id=organization_id) org = await organization_service.get_organization(db, str(organization_id))
if not org: # pragma: no cover - Permission check prevents this (see docs/UNREACHABLE_DEFENSIVE_CODE_ANALYSIS.md)
raise NotFoundError(
detail=f"Organization {organization_id} not found",
error_code=ErrorCode.NOT_FOUND,
)
org_dict = { org_dict = {
"id": org.id, "id": org.id,
"name": org.name, "name": org.name,
@@ -116,16 +109,14 @@ async def get_organization(
"settings": org.settings, "settings": org.settings,
"created_at": org.created_at, "created_at": org.created_at,
"updated_at": org.updated_at, "updated_at": org.updated_at,
"member_count": await organization_crud.get_member_count( "member_count": await organization_service.get_member_count(
db, organization_id=org.id db, organization_id=org.id
), ),
} }
return OrganizationResponse(**org_dict) return OrganizationResponse(**org_dict)
except NotFoundError: # pragma: no cover - See above
raise
except Exception as e: except Exception as e:
logger.error(f"Error getting organization: {e!s}", exc_info=True) logger.exception("Error getting organization: %s", e)
raise raise
@@ -149,7 +140,7 @@ async def get_organization_members(
User must be a member of the organization to view members. User must be a member of the organization to view members.
""" """
try: try:
members, total = await organization_crud.get_organization_members( members, total = await organization_service.get_organization_members(
db, db,
organization_id=organization_id, organization_id=organization_id,
skip=pagination.offset, skip=pagination.offset,
@@ -169,7 +160,7 @@ async def get_organization_members(
return PaginatedResponse(data=member_responses, pagination=pagination_meta) return PaginatedResponse(data=member_responses, pagination=pagination_meta)
except Exception as e: except Exception as e:
logger.error(f"Error getting organization members: {e!s}", exc_info=True) logger.exception("Error getting organization members: %s", e)
raise raise
@@ -192,16 +183,12 @@ async def update_organization(
Requires owner or admin role in the organization. Requires owner or admin role in the organization.
""" """
try: try:
org = await organization_crud.get(db, id=organization_id) org = await organization_service.get_organization(db, str(organization_id))
if not org: # pragma: no cover - Permission check prevents this (see docs/UNREACHABLE_DEFENSIVE_CODE_ANALYSIS.md) updated_org = await organization_service.update_organization(
raise NotFoundError( db, org=org, obj_in=org_in
detail=f"Organization {organization_id} not found", )
error_code=ErrorCode.NOT_FOUND,
)
updated_org = await organization_crud.update(db, db_obj=org, obj_in=org_in)
logger.info( logger.info(
f"User {current_user.email} updated organization {updated_org.name}" "User %s updated organization %s", current_user.email, updated_org.name
) )
org_dict = { org_dict = {
@@ -213,14 +200,12 @@ async def update_organization(
"settings": updated_org.settings, "settings": updated_org.settings,
"created_at": updated_org.created_at, "created_at": updated_org.created_at,
"updated_at": updated_org.updated_at, "updated_at": updated_org.updated_at,
"member_count": await organization_crud.get_member_count( "member_count": await organization_service.get_member_count(
db, organization_id=updated_org.id db, organization_id=updated_org.id
), ),
} }
return OrganizationResponse(**org_dict) return OrganizationResponse(**org_dict)
except NotFoundError: # pragma: no cover - See above
raise
except Exception as e: except Exception as e:
logger.error(f"Error updating organization: {e!s}", exc_info=True) logger.exception("Error updating organization: %s", e)
raise raise

View File

@@ -1,659 +0,0 @@
# app/api/routes/projects.py
"""
Project management API endpoints for Syndarix.
These endpoints allow users to manage their AI-powered software consulting projects.
Users can create, read, update, and manage the lifecycle of their projects.
"""
import logging
import os
from typing import Any
from uuid import UUID
from fastapi import APIRouter, Depends, Query, Request, status
from slowapi import Limiter
from slowapi.util import get_remote_address
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.dependencies.auth import get_current_user
from app.core.database import get_db
from app.core.exceptions import (
AuthorizationError,
DuplicateError,
ErrorCode,
NotFoundError,
ValidationException,
)
from app.crud.syndarix.project import project as project_crud
from app.models.syndarix.enums import ProjectStatus
from app.models.user import User
from app.schemas.common import (
MessageResponse,
PaginatedResponse,
PaginationParams,
create_pagination_meta,
)
from app.schemas.syndarix.project import (
ProjectCreate,
ProjectResponse,
ProjectUpdate,
)
router = APIRouter()
logger = logging.getLogger(__name__)
# Initialize rate limiter
limiter = Limiter(key_func=get_remote_address)
# Use higher rate limits in test environment
IS_TEST = os.getenv("IS_TEST", "False") == "True"
RATE_MULTIPLIER = 100 if IS_TEST else 1
def _build_project_response(project_data: dict[str, Any]) -> ProjectResponse:
"""
Build a ProjectResponse from project data dictionary.
Args:
project_data: Dictionary containing project and related counts
Returns:
ProjectResponse with all fields populated
"""
project = project_data["project"]
return ProjectResponse(
id=project.id,
name=project.name,
slug=project.slug,
description=project.description,
autonomy_level=project.autonomy_level,
status=project.status,
settings=project.settings,
owner_id=project.owner_id,
created_at=project.created_at,
updated_at=project.updated_at,
agent_count=project_data.get("agent_count", 0),
issue_count=project_data.get("issue_count", 0),
active_sprint_name=project_data.get("active_sprint_name"),
)
def _check_project_ownership(project: Any, current_user: User) -> None:
"""
Check if the current user owns the project or is a superuser.
Args:
project: The project to check ownership of
current_user: The authenticated user
Raises:
AuthorizationError: If user doesn't own the project and isn't a superuser
"""
if not current_user.is_superuser and project.owner_id != current_user.id:
raise AuthorizationError(
message="You do not have permission to access this project",
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS,
)
# =============================================================================
# Project CRUD Endpoints
# =============================================================================
@router.post(
"",
response_model=ProjectResponse,
status_code=status.HTTP_201_CREATED,
summary="Create Project",
description="""
Create a new project for the current user.
The project will be owned by the authenticated user.
A unique slug is required for URL-friendly project identification.
**Rate Limit**: 10 requests/minute
""",
operation_id="create_project",
)
@limiter.limit(f"{10 * RATE_MULTIPLIER}/minute")
async def create_project(
request: Request,
project_in: ProjectCreate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Create a new project.
The authenticated user becomes the owner of the project.
"""
try:
# Set the owner to the current user
project_data = ProjectCreate(
name=project_in.name,
slug=project_in.slug,
description=project_in.description,
autonomy_level=project_in.autonomy_level,
status=project_in.status,
settings=project_in.settings,
owner_id=current_user.id,
)
project = await project_crud.create(db, obj_in=project_data)
logger.info(f"User {current_user.email} created project {project.slug}")
return ProjectResponse(
id=project.id,
name=project.name,
slug=project.slug,
description=project.description,
autonomy_level=project.autonomy_level,
status=project.status,
settings=project.settings,
owner_id=project.owner_id,
created_at=project.created_at,
updated_at=project.updated_at,
agent_count=0,
issue_count=0,
active_sprint_name=None,
)
except ValueError as e:
error_msg = str(e)
if "already exists" in error_msg.lower():
logger.warning(f"Duplicate project slug attempted: {project_in.slug}")
raise DuplicateError(
message=error_msg,
error_code=ErrorCode.DUPLICATE_ENTRY,
field="slug",
)
logger.error(f"Error creating project: {error_msg}", exc_info=True)
raise
except Exception as e:
logger.error(f"Unexpected error creating project: {e!s}", exc_info=True)
raise
@router.get(
"",
response_model=PaginatedResponse[ProjectResponse],
summary="List Projects",
description="""
List projects for the current user with filtering and pagination.
Regular users see only their own projects.
Superusers can see all projects by setting `all_projects=true`.
**Rate Limit**: 30 requests/minute
""",
operation_id="list_projects",
)
@limiter.limit(f"{30 * RATE_MULTIPLIER}/minute")
async def list_projects(
request: Request,
pagination: PaginationParams = Depends(),
status_filter: ProjectStatus | None = Query(
None, alias="status", description="Filter by project status"
),
search: str | None = Query(
None, description="Search by name, slug, or description"
),
all_projects: bool = Query(False, description="Show all projects (superuser only)"),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> Any:
"""
List projects with filtering, search, and pagination.
Regular users only see their own projects.
Superusers can view all projects if all_projects is true.
"""
try:
# Determine owner filter based on user role and request
owner_id = (
None if (current_user.is_superuser and all_projects) else current_user.id
)
projects_data, total = await project_crud.get_multi_with_counts(
db,
skip=pagination.offset,
limit=pagination.limit,
status=status_filter,
owner_id=owner_id,
search=search,
)
# Build response objects
project_responses = [_build_project_response(data) for data in projects_data]
pagination_meta = create_pagination_meta(
total=total,
page=pagination.page,
limit=pagination.limit,
items_count=len(project_responses),
)
return PaginatedResponse(data=project_responses, pagination=pagination_meta)
except Exception as e:
logger.error(f"Error listing projects: {e!s}", exc_info=True)
raise
@router.get(
"/{project_id}",
response_model=ProjectResponse,
summary="Get Project",
description="""
Get detailed information about a specific project.
Users can only access their own projects unless they are superusers.
**Rate Limit**: 60 requests/minute
""",
operation_id="get_project",
)
@limiter.limit(f"{60 * RATE_MULTIPLIER}/minute")
async def get_project(
request: Request,
project_id: UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Get detailed information about a project by ID.
Includes agent count, issue count, and active sprint name.
"""
try:
project_data = await project_crud.get_with_counts(db, project_id=project_id)
if not project_data:
raise NotFoundError(
message=f"Project {project_id} not found",
error_code=ErrorCode.NOT_FOUND,
)
project = project_data["project"]
_check_project_ownership(project, current_user)
return _build_project_response(project_data)
except (NotFoundError, AuthorizationError):
raise
except Exception as e:
logger.error(f"Error getting project {project_id}: {e!s}", exc_info=True)
raise
@router.get(
"/slug/{slug}",
response_model=ProjectResponse,
summary="Get Project by Slug",
description="""
Get detailed information about a project by its slug.
Users can only access their own projects unless they are superusers.
**Rate Limit**: 60 requests/minute
""",
operation_id="get_project_by_slug",
)
@limiter.limit(f"{60 * RATE_MULTIPLIER}/minute")
async def get_project_by_slug(
request: Request,
slug: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Get detailed information about a project by slug.
Includes agent count, issue count, and active sprint name.
"""
try:
project = await project_crud.get_by_slug(db, slug=slug)
if not project:
raise NotFoundError(
message=f"Project with slug '{slug}' not found",
error_code=ErrorCode.NOT_FOUND,
)
_check_project_ownership(project, current_user)
# Get project with counts
project_data = await project_crud.get_with_counts(db, project_id=project.id)
if not project_data:
raise NotFoundError(
message=f"Project with slug '{slug}' not found",
error_code=ErrorCode.NOT_FOUND,
)
return _build_project_response(project_data)
except (NotFoundError, AuthorizationError):
raise
except Exception as e:
logger.error(f"Error getting project by slug {slug}: {e!s}", exc_info=True)
raise
@router.patch(
"/{project_id}",
response_model=ProjectResponse,
summary="Update Project",
description="""
Update an existing project.
Only the project owner or a superuser can update a project.
Only provided fields will be updated.
**Rate Limit**: 20 requests/minute
""",
operation_id="update_project",
)
@limiter.limit(f"{20 * RATE_MULTIPLIER}/minute")
async def update_project(
request: Request,
project_id: UUID,
project_in: ProjectUpdate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Update a project's information.
Only the project owner or superusers can perform updates.
"""
try:
project = await project_crud.get(db, id=project_id)
if not project:
raise NotFoundError(
message=f"Project {project_id} not found",
error_code=ErrorCode.NOT_FOUND,
)
_check_project_ownership(project, current_user)
# Update the project
updated_project = await project_crud.update(
db, db_obj=project, obj_in=project_in
)
logger.info(f"User {current_user.email} updated project {updated_project.slug}")
# Get updated project with counts
project_data = await project_crud.get_with_counts(
db, project_id=updated_project.id
)
if not project_data:
# This shouldn't happen, but handle gracefully
raise NotFoundError(
message=f"Project {project_id} not found after update",
error_code=ErrorCode.NOT_FOUND,
)
return _build_project_response(project_data)
except (NotFoundError, AuthorizationError):
raise
except ValueError as e:
error_msg = str(e)
if "already exists" in error_msg.lower():
logger.warning(f"Duplicate project slug attempted: {project_in.slug}")
raise DuplicateError(
message=error_msg,
error_code=ErrorCode.DUPLICATE_ENTRY,
field="slug",
)
logger.error(f"Error updating project: {error_msg}", exc_info=True)
raise
except Exception as e:
logger.error(f"Error updating project {project_id}: {e!s}", exc_info=True)
raise
@router.delete(
"/{project_id}",
response_model=MessageResponse,
summary="Archive Project",
description="""
Archive a project (soft delete).
Only the project owner or a superuser can archive a project.
Archived projects are not deleted but are no longer accessible for active work.
**Rate Limit**: 10 requests/minute
""",
operation_id="archive_project",
)
@limiter.limit(f"{10 * RATE_MULTIPLIER}/minute")
async def archive_project(
request: Request,
project_id: UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Archive a project by setting its status to ARCHIVED.
This is a soft delete operation. The project data is preserved.
"""
try:
project = await project_crud.get(db, id=project_id)
if not project:
raise NotFoundError(
message=f"Project {project_id} not found",
error_code=ErrorCode.NOT_FOUND,
)
_check_project_ownership(project, current_user)
# Check if project is already archived
if project.status == ProjectStatus.ARCHIVED:
return MessageResponse(
success=True,
message=f"Project '{project.name}' is already archived",
)
archived_project = await project_crud.archive_project(db, project_id=project_id)
if not archived_project:
raise NotFoundError(
message=f"Failed to archive project {project_id}",
error_code=ErrorCode.NOT_FOUND,
)
logger.info(f"User {current_user.email} archived project {project.slug}")
return MessageResponse(
success=True,
message=f"Project '{archived_project.name}' has been archived",
)
except (NotFoundError, AuthorizationError):
raise
except Exception as e:
logger.error(f"Error archiving project {project_id}: {e!s}", exc_info=True)
raise
# =============================================================================
# Project Lifecycle Endpoints
# =============================================================================
@router.post(
"/{project_id}/pause",
response_model=ProjectResponse,
summary="Pause Project",
description="""
Pause an active project.
Only ACTIVE projects can be paused.
Only the project owner or a superuser can pause a project.
**Rate Limit**: 10 requests/minute
""",
operation_id="pause_project",
)
@limiter.limit(f"{10 * RATE_MULTIPLIER}/minute")
async def pause_project(
request: Request,
project_id: UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Pause an active project.
Sets the project status to PAUSED. Only ACTIVE projects can be paused.
"""
try:
project = await project_crud.get(db, id=project_id)
if not project:
raise NotFoundError(
message=f"Project {project_id} not found",
error_code=ErrorCode.NOT_FOUND,
)
_check_project_ownership(project, current_user)
# Validate current status (business logic validation, not authorization)
if project.status == ProjectStatus.PAUSED:
raise ValidationException(
message="Project is already paused",
error_code=ErrorCode.VALIDATION_ERROR,
field="status",
)
if project.status == ProjectStatus.ARCHIVED:
raise ValidationException(
message="Cannot pause an archived project",
error_code=ErrorCode.VALIDATION_ERROR,
field="status",
)
if project.status == ProjectStatus.COMPLETED:
raise ValidationException(
message="Cannot pause a completed project",
error_code=ErrorCode.VALIDATION_ERROR,
field="status",
)
# Update status to PAUSED
updated_project = await project_crud.update(
db, db_obj=project, obj_in=ProjectUpdate(status=ProjectStatus.PAUSED)
)
logger.info(f"User {current_user.email} paused project {project.slug}")
# Get project with counts
project_data = await project_crud.get_with_counts(
db, project_id=updated_project.id
)
if not project_data:
raise NotFoundError(
message=f"Project {project_id} not found after update",
error_code=ErrorCode.NOT_FOUND,
)
return _build_project_response(project_data)
except (NotFoundError, AuthorizationError, ValidationException):
raise
except Exception as e:
logger.error(f"Error pausing project {project_id}: {e!s}", exc_info=True)
raise
@router.post(
"/{project_id}/resume",
response_model=ProjectResponse,
summary="Resume Project",
description="""
Resume a paused project.
Only PAUSED projects can be resumed.
Only the project owner or a superuser can resume a project.
**Rate Limit**: 10 requests/minute
""",
operation_id="resume_project",
)
@limiter.limit(f"{10 * RATE_MULTIPLIER}/minute")
async def resume_project(
request: Request,
project_id: UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Resume a paused project.
Sets the project status back to ACTIVE. Only PAUSED projects can be resumed.
"""
try:
project = await project_crud.get(db, id=project_id)
if not project:
raise NotFoundError(
message=f"Project {project_id} not found",
error_code=ErrorCode.NOT_FOUND,
)
_check_project_ownership(project, current_user)
# Validate current status (business logic validation, not authorization)
if project.status == ProjectStatus.ACTIVE:
raise ValidationException(
message="Project is already active",
error_code=ErrorCode.VALIDATION_ERROR,
field="status",
)
if project.status == ProjectStatus.ARCHIVED:
raise ValidationException(
message="Cannot resume an archived project",
error_code=ErrorCode.VALIDATION_ERROR,
field="status",
)
if project.status == ProjectStatus.COMPLETED:
raise ValidationException(
message="Cannot resume a completed project",
error_code=ErrorCode.VALIDATION_ERROR,
field="status",
)
# Update status to ACTIVE
updated_project = await project_crud.update(
db, db_obj=project, obj_in=ProjectUpdate(status=ProjectStatus.ACTIVE)
)
logger.info(f"User {current_user.email} resumed project {project.slug}")
# Get project with counts
project_data = await project_crud.get_with_counts(
db, project_id=updated_project.id
)
if not project_data:
raise NotFoundError(
message=f"Project {project_id} not found after update",
error_code=ErrorCode.NOT_FOUND,
)
return _build_project_response(project_data)
except (NotFoundError, AuthorizationError, ValidationException):
raise
except Exception as e:
logger.error(f"Error resuming project {project_id}: {e!s}", exc_info=True)
raise

View File

@@ -17,10 +17,10 @@ from app.api.dependencies.auth import get_current_user
from app.core.auth import decode_token from app.core.auth import decode_token
from app.core.database import get_db from app.core.database import get_db
from app.core.exceptions import AuthorizationError, ErrorCode, NotFoundError from app.core.exceptions import AuthorizationError, ErrorCode, NotFoundError
from app.crud.session import session as session_crud
from app.models.user import User from app.models.user import User
from app.schemas.common import MessageResponse from app.schemas.common import MessageResponse
from app.schemas.sessions import SessionListResponse, SessionResponse from app.schemas.sessions import SessionListResponse, SessionResponse
from app.services.session_service import session_service
router = APIRouter() router = APIRouter()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -60,7 +60,7 @@ async def list_my_sessions(
""" """
try: try:
# Get all active sessions for user # Get all active sessions for user
sessions = await session_crud.get_user_sessions( sessions = await session_service.get_user_sessions(
db, user_id=str(current_user.id), active_only=True db, user_id=str(current_user.id), active_only=True
) )
@@ -74,9 +74,7 @@ async def list_my_sessions(
# For now, we'll mark current based on most recent activity # For now, we'll mark current based on most recent activity
except Exception as e: except Exception as e:
# Optional token parsing - silently ignore failures # Optional token parsing - silently ignore failures
logger.debug( logger.debug("Failed to decode access token for session marking: %s", e)
f"Failed to decode access token for session marking: {e!s}"
)
# Convert to response format # Convert to response format
session_responses = [] session_responses = []
@@ -98,7 +96,7 @@ async def list_my_sessions(
session_responses.append(session_response) session_responses.append(session_response)
logger.info( logger.info(
f"User {current_user.id} listed {len(session_responses)} active sessions" "User %s listed %s active sessions", current_user.id, len(session_responses)
) )
return SessionListResponse( return SessionListResponse(
@@ -106,9 +104,7 @@ async def list_my_sessions(
) )
except Exception as e: except Exception as e:
logger.error( logger.exception("Error listing sessions for user %s: %s", current_user.id, e)
f"Error listing sessions for user {current_user.id}: {e!s}", exc_info=True
)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to retrieve sessions", detail="Failed to retrieve sessions",
@@ -150,7 +146,7 @@ async def revoke_session(
""" """
try: try:
# Get the session # Get the session
session = await session_crud.get(db, id=str(session_id)) session = await session_service.get_session(db, str(session_id))
if not session: if not session:
raise NotFoundError( raise NotFoundError(
@@ -161,8 +157,10 @@ async def revoke_session(
# Verify session belongs to current user # Verify session belongs to current user
if str(session.user_id) != str(current_user.id): if str(session.user_id) != str(current_user.id):
logger.warning( logger.warning(
f"User {current_user.id} attempted to revoke session {session_id} " "User %s attempted to revoke session %s belonging to user %s",
f"belonging to user {session.user_id}" current_user.id,
session_id,
session.user_id,
) )
raise AuthorizationError( raise AuthorizationError(
message="You can only revoke your own sessions", message="You can only revoke your own sessions",
@@ -170,11 +168,13 @@ async def revoke_session(
) )
# Deactivate the session # Deactivate the session
await session_crud.deactivate(db, session_id=str(session_id)) await session_service.deactivate(db, session_id=str(session_id))
logger.info( logger.info(
f"User {current_user.id} revoked session {session_id} " "User %s revoked session %s (%s)",
f"({session.device_name})" current_user.id,
session_id,
session.device_name,
) )
return MessageResponse( return MessageResponse(
@@ -185,7 +185,7 @@ async def revoke_session(
except (NotFoundError, AuthorizationError): except (NotFoundError, AuthorizationError):
raise raise
except Exception as e: except Exception as e:
logger.error(f"Error revoking session {session_id}: {e!s}", exc_info=True) logger.exception("Error revoking session %s: %s", session_id, e)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to revoke session", detail="Failed to revoke session",
@@ -224,12 +224,12 @@ async def cleanup_expired_sessions(
""" """
try: try:
# Use optimized bulk DELETE instead of N individual deletes # Use optimized bulk DELETE instead of N individual deletes
deleted_count = await session_crud.cleanup_expired_for_user( deleted_count = await session_service.cleanup_expired_for_user(
db, user_id=str(current_user.id) db, user_id=str(current_user.id)
) )
logger.info( logger.info(
f"User {current_user.id} cleaned up {deleted_count} expired sessions" "User %s cleaned up %s expired sessions", current_user.id, deleted_count
) )
return MessageResponse( return MessageResponse(
@@ -237,9 +237,8 @@ async def cleanup_expired_sessions(
) )
except Exception as e: except Exception as e:
logger.error( logger.exception(
f"Error cleaning up sessions for user {current_user.id}: {e!s}", "Error cleaning up sessions for user %s: %s", current_user.id, e
exc_info=True,
) )
await db.rollback() await db.rollback()
raise HTTPException( raise HTTPException(

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,5 @@
""" """
User management endpoints for CRUD operations. User management endpoints for database operations.
""" """
import logging import logging
@@ -13,8 +13,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.api.dependencies.auth import get_current_superuser, get_current_user from app.api.dependencies.auth import get_current_superuser, get_current_user
from app.core.database import get_db from app.core.database import get_db
from app.core.exceptions import AuthorizationError, ErrorCode, NotFoundError from app.core.exceptions import AuthorizationError, ErrorCode
from app.crud.user import user as user_crud
from app.models.user import User from app.models.user import User
from app.schemas.common import ( from app.schemas.common import (
MessageResponse, MessageResponse,
@@ -25,6 +24,7 @@ from app.schemas.common import (
) )
from app.schemas.users import PasswordChange, UserResponse, UserUpdate from app.schemas.users import PasswordChange, UserResponse, UserUpdate
from app.services.auth_service import AuthenticationError, AuthService from app.services.auth_service import AuthenticationError, AuthService
from app.services.user_service import user_service
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -71,7 +71,7 @@ async def list_users(
filters["is_superuser"] = is_superuser filters["is_superuser"] = is_superuser
# Get paginated users with total count # Get paginated users with total count
users, total = await user_crud.get_multi_with_total( users, total = await user_service.list_users(
db, db,
skip=pagination.offset, skip=pagination.offset,
limit=pagination.limit, limit=pagination.limit,
@@ -90,7 +90,7 @@ async def list_users(
return PaginatedResponse(data=users, pagination=pagination_meta) return PaginatedResponse(data=users, pagination=pagination_meta)
except Exception as e: except Exception as e:
logger.error(f"Error listing users: {e!s}", exc_info=True) logger.exception("Error listing users: %s", e)
raise raise
@@ -107,7 +107,9 @@ async def list_users(
""", """,
operation_id="get_current_user_profile", operation_id="get_current_user_profile",
) )
def get_current_user_profile(current_user: User = Depends(get_current_user)) -> Any: async def get_current_user_profile(
current_user: User = Depends(get_current_user),
) -> Any:
"""Get current user's profile.""" """Get current user's profile."""
return current_user return current_user
@@ -138,18 +140,16 @@ async def update_current_user(
Users cannot elevate their own permissions (protected by UserUpdate schema validator). Users cannot elevate their own permissions (protected by UserUpdate schema validator).
""" """
try: try:
updated_user = await user_crud.update( updated_user = await user_service.update_user(
db, db_obj=current_user, obj_in=user_update db, user=current_user, obj_in=user_update
) )
logger.info(f"User {current_user.id} updated their profile") logger.info("User %s updated their profile", current_user.id)
return updated_user return updated_user
except ValueError as e: except ValueError as e:
logger.error(f"Error updating user {current_user.id}: {e!s}") logger.error("Error updating user %s: %s", current_user.id, e)
raise raise
except Exception as e: except Exception as e:
logger.error( logger.exception("Unexpected error updating user %s: %s", current_user.id, e)
f"Unexpected error updating user {current_user.id}: {e!s}", exc_info=True
)
raise raise
@@ -182,7 +182,9 @@ async def get_user_by_id(
# Check permissions # Check permissions
if str(user_id) != str(current_user.id) and not current_user.is_superuser: if str(user_id) != str(current_user.id) and not current_user.is_superuser:
logger.warning( logger.warning(
f"User {current_user.id} attempted to access user {user_id} without permission" "User %s attempted to access user %s without permission",
current_user.id,
user_id,
) )
raise AuthorizationError( raise AuthorizationError(
message="Not enough permissions to view this user", message="Not enough permissions to view this user",
@@ -190,13 +192,7 @@ async def get_user_by_id(
) )
# Get user # Get user
user = await user_crud.get(db, id=str(user_id)) user = await user_service.get_user(db, str(user_id))
if not user:
raise NotFoundError(
message=f"User with id {user_id} not found",
error_code=ErrorCode.USER_NOT_FOUND,
)
return user return user
@@ -233,7 +229,9 @@ async def update_user(
if not is_own_profile and not current_user.is_superuser: if not is_own_profile and not current_user.is_superuser:
logger.warning( logger.warning(
f"User {current_user.id} attempted to update user {user_id} without permission" "User %s attempted to update user %s without permission",
current_user.id,
user_id,
) )
raise AuthorizationError( raise AuthorizationError(
message="Not enough permissions to update this user", message="Not enough permissions to update this user",
@@ -241,22 +239,17 @@ async def update_user(
) )
# Get user # Get user
user = await user_crud.get(db, id=str(user_id)) user = await user_service.get_user(db, str(user_id))
if not user:
raise NotFoundError(
message=f"User with id {user_id} not found",
error_code=ErrorCode.USER_NOT_FOUND,
)
try: try:
updated_user = await user_crud.update(db, db_obj=user, obj_in=user_update) updated_user = await user_service.update_user(db, user=user, obj_in=user_update)
logger.info(f"User {user_id} updated by {current_user.id}") logger.info("User %s updated by %s", user_id, current_user.id)
return updated_user return updated_user
except ValueError as e: except ValueError as e:
logger.error(f"Error updating user {user_id}: {e!s}") logger.error("Error updating user %s: %s", user_id, e)
raise raise
except Exception as e: except Exception as e:
logger.error(f"Unexpected error updating user {user_id}: {e!s}", exc_info=True) logger.exception("Unexpected error updating user %s: %s", user_id, e)
raise raise
@@ -296,19 +289,19 @@ async def change_current_user_password(
) )
if success: if success:
logger.info(f"User {current_user.id} changed their password") logger.info("User %s changed their password", current_user.id)
return MessageResponse( return MessageResponse(
success=True, message="Password changed successfully" success=True, message="Password changed successfully"
) )
except AuthenticationError as e: except AuthenticationError as e:
logger.warning( logger.warning(
f"Failed password change attempt for user {current_user.id}: {e!s}" "Failed password change attempt for user %s: %s", current_user.id, e
) )
raise AuthorizationError( raise AuthorizationError(
message=str(e), error_code=ErrorCode.INVALID_CREDENTIALS message=str(e), error_code=ErrorCode.INVALID_CREDENTIALS
) )
except Exception as e: except Exception as e:
logger.error(f"Error changing password for user {current_user.id}: {e!s}") logger.error("Error changing password for user %s: %s", current_user.id, e)
raise raise
@@ -346,24 +339,19 @@ async def delete_user(
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS, error_code=ErrorCode.INSUFFICIENT_PERMISSIONS,
) )
# Get user # Get user (raises NotFoundError if not found)
user = await user_crud.get(db, id=str(user_id)) await user_service.get_user(db, str(user_id))
if not user:
raise NotFoundError(
message=f"User with id {user_id} not found",
error_code=ErrorCode.USER_NOT_FOUND,
)
try: try:
# Use soft delete instead of hard delete # Use soft delete instead of hard delete
await user_crud.soft_delete(db, id=str(user_id)) await user_service.soft_delete_user(db, str(user_id))
logger.info(f"User {user_id} soft-deleted by {current_user.id}") logger.info("User %s soft-deleted by %s", user_id, current_user.id)
return MessageResponse( return MessageResponse(
success=True, message=f"User {user_id} deleted successfully" success=True, message=f"User {user_id} deleted successfully"
) )
except ValueError as e: except ValueError as e:
logger.error(f"Error deleting user {user_id}: {e!s}") logger.error("Error deleting user %s: %s", user_id, e)
raise raise
except Exception as e: except Exception as e:
logger.error(f"Unexpected error deleting user {user_id}: {e!s}", exc_info=True) logger.exception("Unexpected error deleting user %s: %s", user_id, e)
raise raise

View File

@@ -1,116 +0,0 @@
# app/celery_app.py
"""
Celery application configuration for Syndarix.
This module configures the Celery app for background task processing:
- Agent execution tasks (LLM calls, tool execution)
- Git operations (clone, commit, push, PR creation)
- Issue synchronization with external trackers
- Workflow state management
- Cost tracking and budget monitoring
Architecture:
- Redis as message broker and result backend
- Queue routing for task isolation
- JSON serialization for cross-language compatibility
- Beat scheduler for periodic tasks
"""
from celery import Celery
from app.core.config import settings
# Create Celery application instance
celery_app = Celery(
"syndarix",
broker=settings.celery_broker_url,
backend=settings.celery_result_backend,
)
# Define task queues with their own exchanges and routing keys
TASK_QUEUES = {
"agent": {"exchange": "agent", "routing_key": "agent"},
"git": {"exchange": "git", "routing_key": "git"},
"sync": {"exchange": "sync", "routing_key": "sync"},
"default": {"exchange": "default", "routing_key": "default"},
}
# Configure Celery
celery_app.conf.update(
# Serialization
task_serializer="json",
accept_content=["json"],
result_serializer="json",
# Timezone
timezone="UTC",
enable_utc=True,
# Task imports for auto-discovery
imports=("app.tasks",),
# Default queue
task_default_queue="default",
# Task queues configuration
task_queues=TASK_QUEUES,
# Task routing - route tasks to appropriate queues
task_routes={
"app.tasks.agent.*": {"queue": "agent"},
"app.tasks.git.*": {"queue": "git"},
"app.tasks.sync.*": {"queue": "sync"},
"app.tasks.*": {"queue": "default"},
},
# Time limits per ADR-003
task_soft_time_limit=300, # 5 minutes soft limit
task_time_limit=600, # 10 minutes hard limit
# Result expiration - 24 hours
result_expires=86400,
# Broker connection retry
broker_connection_retry_on_startup=True,
# Retry configuration per ADR-003 (built-in retry with backoff)
task_autoretry_for=(Exception,), # Retry on all exceptions
task_retry_kwargs={"max_retries": 3, "countdown": 5}, # Initial 5s delay
task_retry_backoff=True, # Enable exponential backoff
task_retry_backoff_max=600, # Max 10 minutes between retries
task_retry_jitter=True, # Add jitter to prevent thundering herd
# Beat schedule for periodic tasks
beat_schedule={
# Cost aggregation every hour per ADR-012
"aggregate-daily-costs": {
"task": "app.tasks.cost.aggregate_daily_costs",
"schedule": 3600.0, # 1 hour in seconds
},
# Reset daily budget counters at midnight UTC
"reset-daily-budget-counters": {
"task": "app.tasks.cost.reset_daily_budget_counters",
"schedule": 86400.0, # 24 hours in seconds
},
# Check for stale workflows every 5 minutes
"recover-stale-workflows": {
"task": "app.tasks.workflow.recover_stale_workflows",
"schedule": 300.0, # 5 minutes in seconds
},
# Incremental issue sync every minute per ADR-011
"sync-issues-incremental": {
"task": "app.tasks.sync.sync_issues_incremental",
"schedule": 60.0, # 1 minute in seconds
},
# Full issue reconciliation every 15 minutes per ADR-011
"sync-issues-full": {
"task": "app.tasks.sync.sync_issues_full",
"schedule": 900.0, # 15 minutes in seconds
},
},
# Task execution settings
task_acks_late=True, # Acknowledge tasks after execution
task_reject_on_worker_lost=True, # Reject tasks if worker dies
worker_prefetch_multiplier=1, # Fair task distribution
)
# Auto-discover tasks from task modules
celery_app.autodiscover_tasks(
[
"app.tasks.agent",
"app.tasks.git",
"app.tasks.sync",
"app.tasks.workflow",
"app.tasks.cost",
]
)

View File

@@ -1,23 +1,21 @@
import asyncio import asyncio
import logging
import uuid import uuid
from datetime import UTC, datetime, timedelta from datetime import UTC, datetime, timedelta
from functools import partial from functools import partial
from typing import Any from typing import Any
from jose import JWTError, jwt import bcrypt
from passlib.context import CryptContext import jwt
from jwt.exceptions import (
ExpiredSignatureError,
InvalidTokenError,
MissingRequiredClaimError,
)
from pydantic import ValidationError from pydantic import ValidationError
from app.core.config import settings from app.core.config import settings
from app.schemas.users import TokenData, TokenPayload from app.schemas.users import TokenData, TokenPayload
# Suppress passlib bcrypt warnings about ident
logging.getLogger("passlib").setLevel(logging.ERROR)
# Password hashing context
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
# Custom exceptions for auth # Custom exceptions for auth
class AuthError(Exception): class AuthError(Exception):
@@ -37,13 +35,16 @@ class TokenMissingClaimError(AuthError):
def verify_password(plain_password: str, hashed_password: str) -> bool: def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verify a password against a hash.""" """Verify a password against a bcrypt hash."""
return pwd_context.verify(plain_password, hashed_password) return bcrypt.checkpw(
plain_password.encode("utf-8"), hashed_password.encode("utf-8")
)
def get_password_hash(password: str) -> str: def get_password_hash(password: str) -> str:
"""Generate a password hash.""" """Generate a bcrypt password hash."""
return pwd_context.hash(password) salt = bcrypt.gensalt()
return bcrypt.hashpw(password.encode("utf-8"), salt).decode("utf-8")
async def verify_password_async(plain_password: str, hashed_password: str) -> bool: async def verify_password_async(plain_password: str, hashed_password: str) -> bool:
@@ -60,9 +61,9 @@ async def verify_password_async(plain_password: str, hashed_password: str) -> bo
Returns: Returns:
True if password matches, False otherwise True if password matches, False otherwise
""" """
loop = asyncio.get_event_loop() loop = asyncio.get_running_loop()
return await loop.run_in_executor( return await loop.run_in_executor(
None, partial(pwd_context.verify, plain_password, hashed_password) None, partial(verify_password, plain_password, hashed_password)
) )
@@ -80,8 +81,8 @@ async def get_password_hash_async(password: str) -> str:
Returns: Returns:
Hashed password string Hashed password string
""" """
loop = asyncio.get_event_loop() loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, pwd_context.hash, password) return await loop.run_in_executor(None, get_password_hash, password)
def create_access_token( def create_access_token(
@@ -121,11 +122,7 @@ def create_access_token(
to_encode.update(claims) to_encode.update(claims)
# Create the JWT # Create the JWT
encoded_jwt = jwt.encode( return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
)
return encoded_jwt
def create_refresh_token( def create_refresh_token(
@@ -154,11 +151,7 @@ def create_refresh_token(
"type": "refresh", "type": "refresh",
} }
encoded_jwt = jwt.encode( return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
)
return encoded_jwt
def decode_token(token: str, verify_type: str | None = None) -> TokenPayload: def decode_token(token: str, verify_type: str | None = None) -> TokenPayload:
@@ -198,7 +191,7 @@ def decode_token(token: str, verify_type: str | None = None) -> TokenPayload:
# Reject weak or unexpected algorithms # Reject weak or unexpected algorithms
# NOTE: These are defensive checks that provide defense-in-depth. # NOTE: These are defensive checks that provide defense-in-depth.
# The python-jose library rejects these tokens BEFORE we reach here, # PyJWT rejects these tokens BEFORE we reach here,
# but we keep these checks in case the library changes or is misconfigured. # but we keep these checks in case the library changes or is misconfigured.
# Coverage: Marked as pragma since library catches first (see tests/core/test_auth_security.py) # Coverage: Marked as pragma since library catches first (see tests/core/test_auth_security.py)
if token_algorithm == "NONE": # pragma: no cover if token_algorithm == "NONE": # pragma: no cover
@@ -219,10 +212,11 @@ def decode_token(token: str, verify_type: str | None = None) -> TokenPayload:
token_data = TokenPayload(**payload) token_data = TokenPayload(**payload)
return token_data return token_data
except JWTError as e: except ExpiredSignatureError:
# Check if the error is due to an expired token raise TokenExpiredError("Token has expired")
if "expired" in str(e).lower(): except MissingRequiredClaimError as e:
raise TokenExpiredError("Token has expired") raise TokenMissingClaimError(f"Token missing required claim: {e}")
except InvalidTokenError:
raise TokenInvalidError("Invalid authentication token") raise TokenInvalidError("Invalid authentication token")
except ValidationError: except ValidationError:
raise TokenInvalidError("Invalid token payload") raise TokenInvalidError("Invalid token payload")

View File

@@ -5,7 +5,7 @@ from pydantic_settings import BaseSettings
class Settings(BaseSettings): class Settings(BaseSettings):
PROJECT_NAME: str = "Syndarix" PROJECT_NAME: str = "PragmaStack"
VERSION: str = "1.0.0" VERSION: str = "1.0.0"
API_V1_STR: str = "/api/v1" API_V1_STR: str = "/api/v1"
@@ -39,32 +39,6 @@ class Settings(BaseSettings):
db_pool_timeout: int = 30 # Seconds to wait for a connection db_pool_timeout: int = 30 # Seconds to wait for a connection
db_pool_recycle: int = 3600 # Recycle connections after 1 hour db_pool_recycle: int = 3600 # Recycle connections after 1 hour
# Redis configuration (Syndarix: cache, pub/sub, Celery broker)
REDIS_URL: str = Field(
default="redis://localhost:6379/0",
description="Redis URL for cache, pub/sub, and Celery broker",
)
# Celery configuration (Syndarix: background task processing)
CELERY_BROKER_URL: str | None = Field(
default=None,
description="Celery broker URL (defaults to REDIS_URL if not set)",
)
CELERY_RESULT_BACKEND: str | None = Field(
default=None,
description="Celery result backend URL (defaults to REDIS_URL if not set)",
)
@property
def celery_broker_url(self) -> str:
"""Get Celery broker URL, defaulting to Redis."""
return self.CELERY_BROKER_URL or self.REDIS_URL
@property
def celery_result_backend(self) -> str:
"""Get Celery result backend URL, defaulting to Redis."""
return self.CELERY_RESULT_BACKEND or self.REDIS_URL
# SQL debugging (disable in production) # SQL debugging (disable in production)
sql_echo: bool = False # Log SQL statements sql_echo: bool = False # Log SQL statements
sql_echo_pool: bool = False # Log connection pool events sql_echo_pool: bool = False # Log connection pool events

View File

@@ -128,8 +128,8 @@ async def async_transaction_scope() -> AsyncGenerator[AsyncSession, None]:
Usage: Usage:
async with async_transaction_scope() as db: async with async_transaction_scope() as db:
user = await user_crud.create(db, obj_in=user_create) user = await user_repo.create(db, obj_in=user_create)
profile = await profile_crud.create(db, obj_in=profile_create) profile = await profile_repo.create(db, obj_in=profile_create)
# Both operations committed together # Both operations committed together
""" """
async with SessionLocal() as session: async with SessionLocal() as session:
@@ -139,7 +139,7 @@ async def async_transaction_scope() -> AsyncGenerator[AsyncSession, None]:
logger.debug("Async transaction committed successfully") logger.debug("Async transaction committed successfully")
except Exception as e: except Exception as e:
await session.rollback() await session.rollback()
logger.error(f"Async transaction failed, rolling back: {e!s}") logger.error("Async transaction failed, rolling back: %s", e)
raise raise
finally: finally:
await session.close() await session.close()
@@ -155,7 +155,7 @@ async def check_async_database_health() -> bool:
await db.execute(text("SELECT 1")) await db.execute(text("SELECT 1"))
return True return True
except Exception as e: except Exception as e:
logger.error(f"Async database health check failed: {e!s}") logger.error("Async database health check failed: %s", e)
return False return False

View File

@@ -143,8 +143,11 @@ async def api_exception_handler(request: Request, exc: APIException) -> JSONResp
Returns a standardized error response with error code and message. Returns a standardized error response with error code and message.
""" """
logger.warning( logger.warning(
f"API exception: {exc.error_code} - {exc.message} " "API exception: %s - %s (status: %s, path: %s)",
f"(status: {exc.status_code}, path: {request.url.path})" exc.error_code,
exc.message,
exc.status_code,
request.url.path,
) )
error_response = ErrorResponse( error_response = ErrorResponse(
@@ -186,7 +189,9 @@ async def validation_exception_handler(
) )
) )
logger.warning(f"Validation error: {len(errors)} errors (path: {request.url.path})") logger.warning(
"Validation error: %s errors (path: %s)", len(errors), request.url.path
)
error_response = ErrorResponse(errors=errors) error_response = ErrorResponse(errors=errors)
@@ -218,11 +223,14 @@ async def http_exception_handler(request: Request, exc: HTTPException) -> JSONRe
) )
logger.warning( logger.warning(
f"HTTP exception: {exc.status_code} - {exc.detail} (path: {request.url.path})" "HTTP exception: %s - %s (path: %s)",
exc.status_code,
exc.detail,
request.url.path,
) )
error_response = ErrorResponse( error_response = ErrorResponse(
errors=[ErrorDetail(code=error_code, message=str(exc.detail))] errors=[ErrorDetail(code=error_code, message=str(exc.detail), field=None)]
) )
return JSONResponse( return JSONResponse(
@@ -239,10 +247,11 @@ async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONR
Logs the full exception and returns a generic error response to avoid Logs the full exception and returns a generic error response to avoid
leaking sensitive information in production. leaking sensitive information in production.
""" """
logger.error( logger.exception(
f"Unhandled exception: {type(exc).__name__} - {exc!s} " "Unhandled exception: %s - %s (path: %s)",
f"(path: {request.url.path})", type(exc).__name__,
exc_info=True, exc,
request.url.path,
) )
# In production, don't expose internal error details # In production, don't expose internal error details
@@ -254,7 +263,7 @@ async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONR
message = f"{type(exc).__name__}: {exc!s}" message = f"{type(exc).__name__}: {exc!s}"
error_response = ErrorResponse( error_response = ErrorResponse(
errors=[ErrorDetail(code=ErrorCode.INTERNAL_ERROR, message=message)] errors=[ErrorDetail(code=ErrorCode.INTERNAL_ERROR, message=message, field=None)]
) )
return JSONResponse( return JSONResponse(

View File

@@ -1,474 +0,0 @@
# app/core/redis.py
"""
Redis client configuration for caching and pub/sub.
This module provides async Redis connectivity with connection pooling
for FastAPI endpoints and background tasks.
Features:
- Connection pooling for efficient resource usage
- Cache operations (get, set, delete, expire)
- Pub/sub operations (publish, subscribe)
- Health check for monitoring
"""
import asyncio
import json
import logging
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from typing import Any
from redis.asyncio import ConnectionPool, Redis
from redis.asyncio.client import PubSub
from redis.exceptions import ConnectionError, RedisError, TimeoutError
from app.core.config import settings
# Configure logging
logger = logging.getLogger(__name__)
# Default TTL for cache entries (1 hour)
DEFAULT_CACHE_TTL = 3600
# Connection pool settings
POOL_MAX_CONNECTIONS = 50
POOL_TIMEOUT = 10 # seconds
class RedisClient:
"""
Async Redis client with connection pooling.
Provides high-level operations for caching and pub/sub
with proper error handling and connection management.
"""
def __init__(self, url: str | None = None) -> None:
"""
Initialize Redis client.
Args:
url: Redis connection URL. Defaults to settings.REDIS_URL.
"""
self._url = url or settings.REDIS_URL
self._pool: ConnectionPool | None = None
self._client: Redis | None = None
self._lock = asyncio.Lock()
async def _ensure_pool(self) -> ConnectionPool:
"""Ensure connection pool is initialized (thread-safe)."""
if self._pool is None:
async with self._lock:
# Double-check after acquiring lock
if self._pool is None:
self._pool = ConnectionPool.from_url(
self._url,
max_connections=POOL_MAX_CONNECTIONS,
socket_timeout=POOL_TIMEOUT,
socket_connect_timeout=POOL_TIMEOUT,
decode_responses=True,
health_check_interval=30,
)
logger.info("Redis connection pool initialized")
return self._pool
async def _get_client(self) -> Redis:
"""Get Redis client instance from pool."""
pool = await self._ensure_pool()
if self._client is None:
self._client = Redis(connection_pool=pool)
return self._client
# =========================================================================
# Cache Operations
# =========================================================================
async def cache_get(self, key: str) -> str | None:
"""
Get a value from cache.
Args:
key: Cache key.
Returns:
Cached value or None if not found.
"""
try:
client = await self._get_client()
value = await client.get(key)
if value is not None:
logger.debug(f"Cache hit for key: {key}")
else:
logger.debug(f"Cache miss for key: {key}")
return value
except (ConnectionError, TimeoutError) as e:
logger.error(f"Redis cache_get failed for key '{key}': {e}")
return None
except RedisError as e:
logger.error(f"Redis error in cache_get for key '{key}': {e}")
return None
async def cache_get_json(self, key: str) -> Any | None:
"""
Get a JSON-serialized value from cache.
Args:
key: Cache key.
Returns:
Deserialized value or None if not found.
"""
value = await self.cache_get(key)
if value is not None:
try:
return json.loads(value)
except json.JSONDecodeError as e:
logger.error(f"Failed to decode JSON for key '{key}': {e}")
return None
return None
async def cache_set(
self,
key: str,
value: str,
ttl: int | None = None,
) -> bool:
"""
Set a value in cache.
Args:
key: Cache key.
value: Value to cache.
ttl: Time-to-live in seconds. Defaults to DEFAULT_CACHE_TTL.
Returns:
True if successful, False otherwise.
"""
try:
client = await self._get_client()
ttl = ttl if ttl is not None else DEFAULT_CACHE_TTL
await client.set(key, value, ex=ttl)
logger.debug(f"Cache set for key: {key} (TTL: {ttl}s)")
return True
except (ConnectionError, TimeoutError) as e:
logger.error(f"Redis cache_set failed for key '{key}': {e}")
return False
except RedisError as e:
logger.error(f"Redis error in cache_set for key '{key}': {e}")
return False
async def cache_set_json(
self,
key: str,
value: Any,
ttl: int | None = None,
) -> bool:
"""
Set a JSON-serialized value in cache.
Args:
key: Cache key.
value: Value to serialize and cache.
ttl: Time-to-live in seconds.
Returns:
True if successful, False otherwise.
"""
try:
serialized = json.dumps(value)
return await self.cache_set(key, serialized, ttl)
except (TypeError, ValueError) as e:
logger.error(f"Failed to serialize value for key '{key}': {e}")
return False
async def cache_delete(self, key: str) -> bool:
"""
Delete a key from cache.
Args:
key: Cache key to delete.
Returns:
True if key was deleted, False otherwise.
"""
try:
client = await self._get_client()
result = await client.delete(key)
logger.debug(f"Cache delete for key: {key} (deleted: {result > 0})")
return result > 0
except (ConnectionError, TimeoutError) as e:
logger.error(f"Redis cache_delete failed for key '{key}': {e}")
return False
except RedisError as e:
logger.error(f"Redis error in cache_delete for key '{key}': {e}")
return False
async def cache_delete_pattern(self, pattern: str) -> int:
"""
Delete all keys matching a pattern.
Args:
pattern: Glob-style pattern (e.g., "user:*").
Returns:
Number of keys deleted.
"""
try:
client = await self._get_client()
deleted = 0
async for key in client.scan_iter(pattern):
await client.delete(key)
deleted += 1
logger.debug(f"Cache delete pattern '{pattern}': {deleted} keys deleted")
return deleted
except (ConnectionError, TimeoutError) as e:
logger.error(f"Redis cache_delete_pattern failed for '{pattern}': {e}")
return 0
except RedisError as e:
logger.error(f"Redis error in cache_delete_pattern for '{pattern}': {e}")
return 0
async def cache_expire(self, key: str, ttl: int) -> bool:
"""
Set or update TTL for a key.
Args:
key: Cache key.
ttl: New TTL in seconds.
Returns:
True if TTL was set, False if key doesn't exist.
"""
try:
client = await self._get_client()
result = await client.expire(key, ttl)
logger.debug(
f"Cache expire for key: {key} (TTL: {ttl}s, success: {result})"
)
return result
except (ConnectionError, TimeoutError) as e:
logger.error(f"Redis cache_expire failed for key '{key}': {e}")
return False
except RedisError as e:
logger.error(f"Redis error in cache_expire for key '{key}': {e}")
return False
async def cache_exists(self, key: str) -> bool:
"""
Check if a key exists in cache.
Args:
key: Cache key.
Returns:
True if key exists, False otherwise.
"""
try:
client = await self._get_client()
result = await client.exists(key)
return result > 0
except (ConnectionError, TimeoutError) as e:
logger.error(f"Redis cache_exists failed for key '{key}': {e}")
return False
except RedisError as e:
logger.error(f"Redis error in cache_exists for key '{key}': {e}")
return False
async def cache_ttl(self, key: str) -> int:
"""
Get remaining TTL for a key.
Args:
key: Cache key.
Returns:
TTL in seconds, -1 if no TTL, -2 if key doesn't exist.
"""
try:
client = await self._get_client()
return await client.ttl(key)
except (ConnectionError, TimeoutError) as e:
logger.error(f"Redis cache_ttl failed for key '{key}': {e}")
return -2
except RedisError as e:
logger.error(f"Redis error in cache_ttl for key '{key}': {e}")
return -2
# =========================================================================
# Pub/Sub Operations
# =========================================================================
async def publish(self, channel: str, message: str | dict) -> int:
"""
Publish a message to a channel.
Args:
channel: Channel name.
message: Message to publish (string or dict for JSON serialization).
Returns:
Number of subscribers that received the message.
"""
try:
client = await self._get_client()
if isinstance(message, dict):
message = json.dumps(message)
result = await client.publish(channel, message)
logger.debug(f"Published to channel '{channel}': {result} subscribers")
return result
except (ConnectionError, TimeoutError) as e:
logger.error(f"Redis publish failed for channel '{channel}': {e}")
return 0
except RedisError as e:
logger.error(f"Redis error in publish for channel '{channel}': {e}")
return 0
@asynccontextmanager
async def subscribe(self, *channels: str) -> AsyncGenerator[PubSub, None]:
"""
Subscribe to one or more channels.
Usage:
async with redis_client.subscribe("channel1", "channel2") as pubsub:
async for message in pubsub.listen():
if message["type"] == "message":
print(message["data"])
Args:
channels: Channel names to subscribe to.
Yields:
PubSub instance for receiving messages.
"""
client = await self._get_client()
pubsub = client.pubsub()
try:
await pubsub.subscribe(*channels)
logger.debug(f"Subscribed to channels: {channels}")
yield pubsub
finally:
await pubsub.unsubscribe(*channels)
await pubsub.close()
logger.debug(f"Unsubscribed from channels: {channels}")
@asynccontextmanager
async def psubscribe(self, *patterns: str) -> AsyncGenerator[PubSub, None]:
"""
Subscribe to channels matching patterns.
Usage:
async with redis_client.psubscribe("user:*") as pubsub:
async for message in pubsub.listen():
if message["type"] == "pmessage":
print(message["pattern"], message["channel"], message["data"])
Args:
patterns: Glob-style patterns to subscribe to.
Yields:
PubSub instance for receiving messages.
"""
client = await self._get_client()
pubsub = client.pubsub()
try:
await pubsub.psubscribe(*patterns)
logger.debug(f"Pattern subscribed: {patterns}")
yield pubsub
finally:
await pubsub.punsubscribe(*patterns)
await pubsub.close()
logger.debug(f"Pattern unsubscribed: {patterns}")
# =========================================================================
# Health & Connection Management
# =========================================================================
async def health_check(self) -> bool:
"""
Check if Redis connection is healthy.
Returns:
True if connection is successful, False otherwise.
"""
try:
client = await self._get_client()
result = await client.ping()
return result is True
except (ConnectionError, TimeoutError) as e:
logger.error(f"Redis health check failed: {e}")
return False
except RedisError as e:
logger.error(f"Redis health check error: {e}")
return False
async def close(self) -> None:
"""
Close Redis connections and cleanup resources.
Should be called during application shutdown.
"""
if self._client:
await self._client.close()
self._client = None
logger.debug("Redis client closed")
if self._pool:
await self._pool.disconnect()
self._pool = None
logger.info("Redis connection pool closed")
async def get_pool_info(self) -> dict[str, Any]:
"""
Get connection pool statistics.
Returns:
Dictionary with pool information.
"""
if self._pool is None:
return {"status": "not_initialized"}
return {
"status": "active",
"max_connections": POOL_MAX_CONNECTIONS,
"url": self._url.split("@")[-1] if "@" in self._url else self._url,
}
# Global Redis client instance
redis_client = RedisClient()
# FastAPI dependency for Redis client
async def get_redis() -> AsyncGenerator[RedisClient, None]:
"""
FastAPI dependency that provides the Redis client.
Usage:
@router.get("/cached-data")
async def get_data(redis: RedisClient = Depends(get_redis)):
cached = await redis.cache_get("my-key")
...
"""
yield redis_client
# Health check function for use in /health endpoint
async def check_redis_health() -> bool:
"""
Check if Redis connection is healthy.
Returns:
True if connection is successful, False otherwise.
"""
return await redis_client.health_check()
# Cleanup function for application shutdown
async def close_redis() -> None:
"""
Close Redis connections.
Should be called during application shutdown.
"""
await redis_client.close()

View File

@@ -0,0 +1,26 @@
"""
Custom exceptions for the repository layer.
These exceptions allow services and routes to handle database-level errors
with proper semantics, without leaking SQLAlchemy internals.
"""
class RepositoryError(Exception):
"""Base for all repository-layer errors."""
class DuplicateEntryError(RepositoryError):
"""Raised on unique constraint violations. Maps to HTTP 409 Conflict."""
class IntegrityConstraintError(RepositoryError):
"""Raised on FK or check constraint violations."""
class RecordNotFoundError(RepositoryError):
"""Raised when an expected record doesn't exist."""
class InvalidInputError(RepositoryError):
"""Raised on bad pagination params, invalid UUIDs, or other invalid inputs."""

View File

@@ -1,14 +0,0 @@
# app/crud/__init__.py
from .oauth import oauth_account, oauth_client, oauth_state
from .organization import organization
from .session import session as session_crud
from .user import user
__all__ = [
"oauth_account",
"oauth_client",
"oauth_state",
"organization",
"session_crud",
"user",
]

View File

@@ -1,718 +0,0 @@
"""
Async CRUD operations for OAuth models using SQLAlchemy 2.0 patterns.
Provides operations for:
- OAuthAccount: Managing linked OAuth provider accounts
- OAuthState: CSRF protection state during OAuth flows
- OAuthClient: Registered OAuth clients (provider mode skeleton)
"""
import logging
import secrets
from datetime import UTC, datetime
from uuid import UUID
from pydantic import BaseModel
from sqlalchemy import and_, delete, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from app.crud.base import CRUDBase
from app.models.oauth_account import OAuthAccount
from app.models.oauth_client import OAuthClient
from app.models.oauth_state import OAuthState
from app.schemas.oauth import OAuthAccountCreate, OAuthClientCreate, OAuthStateCreate
logger = logging.getLogger(__name__)
# ============================================================================
# OAuth Account CRUD
# ============================================================================
class EmptySchema(BaseModel):
"""Placeholder schema for CRUD operations that don't need update schemas."""
class CRUDOAuthAccount(CRUDBase[OAuthAccount, OAuthAccountCreate, EmptySchema]):
"""CRUD operations for OAuth account links."""
async def get_by_provider_id(
self,
db: AsyncSession,
*,
provider: str,
provider_user_id: str,
) -> OAuthAccount | None:
"""
Get OAuth account by provider and provider user ID.
Args:
db: Database session
provider: OAuth provider name (google, github)
provider_user_id: User ID from the OAuth provider
Returns:
OAuthAccount if found, None otherwise
"""
try:
result = await db.execute(
select(OAuthAccount)
.where(
and_(
OAuthAccount.provider == provider,
OAuthAccount.provider_user_id == provider_user_id,
)
)
.options(joinedload(OAuthAccount.user))
)
return result.scalar_one_or_none()
except Exception as e: # pragma: no cover # pragma: no cover
logger.error(
f"Error getting OAuth account for {provider}:{provider_user_id}: {e!s}"
)
raise
async def get_by_provider_email(
self,
db: AsyncSession,
*,
provider: str,
email: str,
) -> OAuthAccount | None:
"""
Get OAuth account by provider and email.
Used for auto-linking existing accounts by email.
Args:
db: Database session
provider: OAuth provider name
email: Email address from the OAuth provider
Returns:
OAuthAccount if found, None otherwise
"""
try:
result = await db.execute(
select(OAuthAccount)
.where(
and_(
OAuthAccount.provider == provider,
OAuthAccount.provider_email == email,
)
)
.options(joinedload(OAuthAccount.user))
)
return result.scalar_one_or_none()
except Exception as e: # pragma: no cover # pragma: no cover
logger.error(
f"Error getting OAuth account for {provider} email {email}: {e!s}"
)
raise
async def get_user_accounts(
self,
db: AsyncSession,
*,
user_id: str | UUID,
) -> list[OAuthAccount]:
"""
Get all OAuth accounts linked to a user.
Args:
db: Database session
user_id: User ID
Returns:
List of OAuthAccount objects
"""
try:
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
result = await db.execute(
select(OAuthAccount)
.where(OAuthAccount.user_id == user_uuid)
.order_by(OAuthAccount.created_at.desc())
)
return list(result.scalars().all())
except Exception as e: # pragma: no cover
logger.error(f"Error getting OAuth accounts for user {user_id}: {e!s}")
raise
async def get_user_account_by_provider(
self,
db: AsyncSession,
*,
user_id: str | UUID,
provider: str,
) -> OAuthAccount | None:
"""
Get a specific OAuth account for a user and provider.
Args:
db: Database session
user_id: User ID
provider: OAuth provider name
Returns:
OAuthAccount if found, None otherwise
"""
try:
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
result = await db.execute(
select(OAuthAccount).where(
and_(
OAuthAccount.user_id == user_uuid,
OAuthAccount.provider == provider,
)
)
)
return result.scalar_one_or_none()
except Exception as e: # pragma: no cover
logger.error(
f"Error getting OAuth account for user {user_id}, provider {provider}: {e!s}"
)
raise
async def create_account(
self, db: AsyncSession, *, obj_in: OAuthAccountCreate
) -> OAuthAccount:
"""
Create a new OAuth account link.
Args:
db: Database session
obj_in: OAuth account creation data
Returns:
Created OAuthAccount
Raises:
ValueError: If account already exists or creation fails
"""
try:
db_obj = OAuthAccount(
user_id=obj_in.user_id,
provider=obj_in.provider,
provider_user_id=obj_in.provider_user_id,
provider_email=obj_in.provider_email,
access_token_encrypted=obj_in.access_token_encrypted,
refresh_token_encrypted=obj_in.refresh_token_encrypted,
token_expires_at=obj_in.token_expires_at,
)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
logger.info(
f"OAuth account created: {obj_in.provider} linked to user {obj_in.user_id}"
)
return db_obj
except IntegrityError as e: # pragma: no cover
await db.rollback()
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
if "uq_oauth_provider_user" in error_msg.lower():
logger.warning(
f"OAuth account already exists: {obj_in.provider}:{obj_in.provider_user_id}"
)
raise ValueError(
f"This {obj_in.provider} account is already linked to another user"
)
logger.error(f"Integrity error creating OAuth account: {error_msg}")
raise ValueError(f"Failed to create OAuth account: {error_msg}")
except Exception as e: # pragma: no cover
await db.rollback()
logger.error(f"Error creating OAuth account: {e!s}", exc_info=True)
raise
async def delete_account(
self,
db: AsyncSession,
*,
user_id: str | UUID,
provider: str,
) -> bool:
"""
Delete an OAuth account link.
Args:
db: Database session
user_id: User ID
provider: OAuth provider name
Returns:
True if deleted, False if not found
"""
try:
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
result = await db.execute(
delete(OAuthAccount).where(
and_(
OAuthAccount.user_id == user_uuid,
OAuthAccount.provider == provider,
)
)
)
await db.commit()
deleted = result.rowcount > 0
if deleted:
logger.info(
f"OAuth account deleted: {provider} unlinked from user {user_id}"
)
else:
logger.warning(
f"OAuth account not found for deletion: {provider} for user {user_id}"
)
return deleted
except Exception as e: # pragma: no cover
await db.rollback()
logger.error(
f"Error deleting OAuth account {provider} for user {user_id}: {e!s}"
)
raise
async def update_tokens(
self,
db: AsyncSession,
*,
account: OAuthAccount,
access_token_encrypted: str | None = None,
refresh_token_encrypted: str | None = None,
token_expires_at: datetime | None = None,
) -> OAuthAccount:
"""
Update OAuth tokens for an account.
Args:
db: Database session
account: OAuthAccount to update
access_token_encrypted: New encrypted access token
refresh_token_encrypted: New encrypted refresh token
token_expires_at: New token expiration time
Returns:
Updated OAuthAccount
"""
try:
if access_token_encrypted is not None:
account.access_token_encrypted = access_token_encrypted
if refresh_token_encrypted is not None:
account.refresh_token_encrypted = refresh_token_encrypted
if token_expires_at is not None:
account.token_expires_at = token_expires_at
db.add(account)
await db.commit()
await db.refresh(account)
return account
except Exception as e: # pragma: no cover
await db.rollback()
logger.error(f"Error updating OAuth tokens: {e!s}")
raise
# ============================================================================
# OAuth State CRUD
# ============================================================================
class CRUDOAuthState(CRUDBase[OAuthState, OAuthStateCreate, EmptySchema]):
"""CRUD operations for OAuth state (CSRF protection)."""
async def create_state(
self, db: AsyncSession, *, obj_in: OAuthStateCreate
) -> OAuthState:
"""
Create a new OAuth state for CSRF protection.
Args:
db: Database session
obj_in: OAuth state creation data
Returns:
Created OAuthState
"""
try:
db_obj = OAuthState(
state=obj_in.state,
code_verifier=obj_in.code_verifier,
nonce=obj_in.nonce,
provider=obj_in.provider,
redirect_uri=obj_in.redirect_uri,
user_id=obj_in.user_id,
expires_at=obj_in.expires_at,
)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
logger.debug(f"OAuth state created for {obj_in.provider}")
return db_obj
except IntegrityError as e: # pragma: no cover
await db.rollback()
# State collision (extremely rare with cryptographic random)
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
logger.error(f"OAuth state collision: {error_msg}")
raise ValueError("Failed to create OAuth state, please retry")
except Exception as e: # pragma: no cover
await db.rollback()
logger.error(f"Error creating OAuth state: {e!s}", exc_info=True)
raise
async def get_and_consume_state(
self, db: AsyncSession, *, state: str
) -> OAuthState | None:
"""
Get and delete OAuth state (consume it).
This ensures each state can only be used once (replay protection).
Args:
db: Database session
state: State string to look up
Returns:
OAuthState if found and valid, None otherwise
"""
try:
# Get the state
result = await db.execute(
select(OAuthState).where(OAuthState.state == state)
)
db_obj = result.scalar_one_or_none()
if db_obj is None:
logger.warning(f"OAuth state not found: {state[:8]}...")
return None
# Check expiration
# Handle both timezone-aware and timezone-naive datetimes
now = datetime.now(UTC)
expires_at = db_obj.expires_at
if expires_at.tzinfo is None:
# SQLite returns naive datetimes, assume UTC
expires_at = expires_at.replace(tzinfo=UTC)
if expires_at < now:
logger.warning(f"OAuth state expired: {state[:8]}...")
await db.delete(db_obj)
await db.commit()
return None
# Delete it (consume)
await db.delete(db_obj)
await db.commit()
logger.debug(f"OAuth state consumed: {state[:8]}...")
return db_obj
except Exception as e: # pragma: no cover
await db.rollback()
logger.error(f"Error consuming OAuth state: {e!s}")
raise
async def cleanup_expired(self, db: AsyncSession) -> int:
"""
Clean up expired OAuth states.
Should be called periodically to remove stale states.
Args:
db: Database session
Returns:
Number of states deleted
"""
try:
now = datetime.now(UTC)
stmt = delete(OAuthState).where(OAuthState.expires_at < now)
result = await db.execute(stmt)
await db.commit()
count = result.rowcount
if count > 0:
logger.info(f"Cleaned up {count} expired OAuth states")
return count
except Exception as e: # pragma: no cover
await db.rollback()
logger.error(f"Error cleaning up expired OAuth states: {e!s}")
raise
# ============================================================================
# OAuth Client CRUD (Provider Mode - Skeleton)
# ============================================================================
class CRUDOAuthClient(CRUDBase[OAuthClient, OAuthClientCreate, EmptySchema]):
"""
CRUD operations for OAuth clients (provider mode).
This is a skeleton implementation for MCP client registration.
Full implementation can be expanded when needed.
"""
async def get_by_client_id(
self, db: AsyncSession, *, client_id: str
) -> OAuthClient | None:
"""
Get OAuth client by client_id.
Args:
db: Database session
client_id: OAuth client ID
Returns:
OAuthClient if found, None otherwise
"""
try:
result = await db.execute(
select(OAuthClient).where(
and_(
OAuthClient.client_id == client_id,
OAuthClient.is_active == True, # noqa: E712
)
)
)
return result.scalar_one_or_none()
except Exception as e: # pragma: no cover
logger.error(f"Error getting OAuth client {client_id}: {e!s}")
raise
async def create_client(
self,
db: AsyncSession,
*,
obj_in: OAuthClientCreate,
owner_user_id: UUID | None = None,
) -> tuple[OAuthClient, str | None]:
"""
Create a new OAuth client.
Args:
db: Database session
obj_in: OAuth client creation data
owner_user_id: Optional owner user ID
Returns:
Tuple of (created OAuthClient, client_secret or None for public clients)
"""
try:
# Generate client_id
client_id = secrets.token_urlsafe(32)
# Generate client_secret for confidential clients
client_secret = None
client_secret_hash = None
if obj_in.client_type == "confidential":
client_secret = secrets.token_urlsafe(48)
# SECURITY: Use bcrypt for secret storage (not SHA-256)
# bcrypt is computationally expensive, making brute-force attacks infeasible
from app.core.auth import get_password_hash
client_secret_hash = get_password_hash(client_secret)
db_obj = OAuthClient(
client_id=client_id,
client_secret_hash=client_secret_hash,
client_name=obj_in.client_name,
client_description=obj_in.client_description,
client_type=obj_in.client_type,
redirect_uris=obj_in.redirect_uris,
allowed_scopes=obj_in.allowed_scopes,
owner_user_id=owner_user_id,
is_active=True,
)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
logger.info(
f"OAuth client created: {obj_in.client_name} ({client_id[:8]}...)"
)
return db_obj, client_secret
except IntegrityError as e: # pragma: no cover
await db.rollback()
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
logger.error(f"Error creating OAuth client: {error_msg}")
raise ValueError(f"Failed to create OAuth client: {error_msg}")
except Exception as e: # pragma: no cover
await db.rollback()
logger.error(f"Error creating OAuth client: {e!s}", exc_info=True)
raise
async def deactivate_client(
self, db: AsyncSession, *, client_id: str
) -> OAuthClient | None:
"""
Deactivate an OAuth client.
Args:
db: Database session
client_id: OAuth client ID
Returns:
Deactivated OAuthClient if found, None otherwise
"""
try:
client = await self.get_by_client_id(db, client_id=client_id)
if client is None:
return None
client.is_active = False
db.add(client)
await db.commit()
await db.refresh(client)
logger.info(f"OAuth client deactivated: {client.client_name}")
return client
except Exception as e: # pragma: no cover
await db.rollback()
logger.error(f"Error deactivating OAuth client {client_id}: {e!s}")
raise
async def validate_redirect_uri(
self, db: AsyncSession, *, client_id: str, redirect_uri: str
) -> bool:
"""
Validate that a redirect URI is allowed for a client.
Args:
db: Database session
client_id: OAuth client ID
redirect_uri: Redirect URI to validate
Returns:
True if valid, False otherwise
"""
try:
client = await self.get_by_client_id(db, client_id=client_id)
if client is None:
return False
return redirect_uri in (client.redirect_uris or [])
except Exception as e: # pragma: no cover
logger.error(f"Error validating redirect URI: {e!s}")
return False
async def verify_client_secret(
self, db: AsyncSession, *, client_id: str, client_secret: str
) -> bool:
"""
Verify client credentials.
Args:
db: Database session
client_id: OAuth client ID
client_secret: Client secret to verify
Returns:
True if valid, False otherwise
"""
try:
result = await db.execute(
select(OAuthClient).where(
and_(
OAuthClient.client_id == client_id,
OAuthClient.is_active == True, # noqa: E712
)
)
)
client = result.scalar_one_or_none()
if client is None or client.client_secret_hash is None:
return False
# SECURITY: Verify secret using bcrypt (not SHA-256)
# This supports both old SHA-256 hashes (for migration) and new bcrypt hashes
from app.core.auth import verify_password
stored_hash: str = str(client.client_secret_hash)
# Check if it's a bcrypt hash (starts with $2b$) or legacy SHA-256
if stored_hash.startswith("$2"):
# New bcrypt format
return verify_password(client_secret, stored_hash)
else:
# Legacy SHA-256 format - still support for migration
import hashlib
secret_hash = hashlib.sha256(client_secret.encode()).hexdigest()
return secrets.compare_digest(stored_hash, secret_hash)
except Exception as e: # pragma: no cover
logger.error(f"Error verifying client secret: {e!s}")
return False
async def get_all_clients(
self, db: AsyncSession, *, include_inactive: bool = False
) -> list[OAuthClient]:
"""
Get all OAuth clients.
Args:
db: Database session
include_inactive: Whether to include inactive clients
Returns:
List of OAuthClient objects
"""
try:
query = select(OAuthClient).order_by(OAuthClient.created_at.desc())
if not include_inactive:
query = query.where(OAuthClient.is_active == True) # noqa: E712
result = await db.execute(query)
return list(result.scalars().all())
except Exception as e: # pragma: no cover
logger.error(f"Error getting all OAuth clients: {e!s}")
raise
async def delete_client(self, db: AsyncSession, *, client_id: str) -> bool:
"""
Delete an OAuth client permanently.
Note: This will cascade delete related records (tokens, consents, etc.)
due to foreign key constraints.
Args:
db: Database session
client_id: OAuth client ID
Returns:
True if deleted, False if not found
"""
try:
result = await db.execute(
delete(OAuthClient).where(OAuthClient.client_id == client_id)
)
await db.commit()
deleted = result.rowcount > 0
if deleted:
logger.info(f"OAuth client deleted: {client_id}")
else:
logger.warning(f"OAuth client not found for deletion: {client_id}")
return deleted
except Exception as e: # pragma: no cover
await db.rollback()
logger.error(f"Error deleting OAuth client {client_id}: {e!s}")
raise
# ============================================================================
# Singleton instances
# ============================================================================
oauth_account = CRUDOAuthAccount(OAuthAccount)
oauth_state = CRUDOAuthState(OAuthState)
oauth_client = CRUDOAuthClient(OAuthClient)

View File

@@ -1,20 +0,0 @@
# app/crud/syndarix/__init__.py
"""
Syndarix CRUD operations.
This package contains CRUD operations for all Syndarix domain entities.
"""
from .agent_instance import agent_instance
from .agent_type import agent_type
from .issue import issue
from .project import project
from .sprint import sprint
__all__ = [
"agent_instance",
"agent_type",
"issue",
"project",
"sprint",
]

View File

@@ -1,394 +0,0 @@
# app/crud/syndarix/agent_instance.py
"""Async CRUD operations for AgentInstance model using SQLAlchemy 2.0 patterns."""
import logging
from datetime import UTC, datetime
from decimal import Decimal
from typing import Any
from uuid import UUID
from sqlalchemy import func, select, update
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from app.crud.base import CRUDBase
from app.models.syndarix import AgentInstance, Issue
from app.models.syndarix.enums import AgentStatus
from app.schemas.syndarix import AgentInstanceCreate, AgentInstanceUpdate
logger = logging.getLogger(__name__)
class CRUDAgentInstance(
CRUDBase[AgentInstance, AgentInstanceCreate, AgentInstanceUpdate]
):
"""Async CRUD operations for AgentInstance model."""
async def create(
self, db: AsyncSession, *, obj_in: AgentInstanceCreate
) -> AgentInstance:
"""Create a new agent instance with error handling."""
try:
db_obj = AgentInstance(
agent_type_id=obj_in.agent_type_id,
project_id=obj_in.project_id,
name=obj_in.name,
status=obj_in.status,
current_task=obj_in.current_task,
short_term_memory=obj_in.short_term_memory,
long_term_memory_ref=obj_in.long_term_memory_ref,
session_id=obj_in.session_id,
)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
return db_obj
except IntegrityError as e:
await db.rollback()
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
logger.error(f"Integrity error creating agent instance: {error_msg}")
raise ValueError(f"Database integrity error: {error_msg}")
except Exception as e:
await db.rollback()
logger.error(
f"Unexpected error creating agent instance: {e!s}", exc_info=True
)
raise
async def get_with_details(
self,
db: AsyncSession,
*,
instance_id: UUID,
) -> dict[str, Any] | None:
"""
Get an agent instance with full details including related entities.
Returns:
Dictionary with instance and related entity details
"""
try:
# Get instance with joined relationships
result = await db.execute(
select(AgentInstance)
.options(
joinedload(AgentInstance.agent_type),
joinedload(AgentInstance.project),
)
.where(AgentInstance.id == instance_id)
)
instance = result.scalar_one_or_none()
if not instance:
return None
# Get assigned issues count
issues_count_result = await db.execute(
select(func.count(Issue.id)).where(
Issue.assigned_agent_id == instance_id
)
)
assigned_issues_count = issues_count_result.scalar_one()
return {
"instance": instance,
"agent_type_name": instance.agent_type.name
if instance.agent_type
else None,
"agent_type_slug": instance.agent_type.slug
if instance.agent_type
else None,
"project_name": instance.project.name if instance.project else None,
"project_slug": instance.project.slug if instance.project else None,
"assigned_issues_count": assigned_issues_count,
}
except Exception as e:
logger.error(
f"Error getting agent instance with details {instance_id}: {e!s}",
exc_info=True,
)
raise
async def get_by_project(
self,
db: AsyncSession,
*,
project_id: UUID,
status: AgentStatus | None = None,
skip: int = 0,
limit: int = 100,
) -> tuple[list[AgentInstance], int]:
"""Get agent instances for a specific project."""
try:
query = select(AgentInstance).where(AgentInstance.project_id == project_id)
if status is not None:
query = query.where(AgentInstance.status == status)
# Get total count
count_query = select(func.count()).select_from(query.alias())
count_result = await db.execute(count_query)
total = count_result.scalar_one()
# Apply pagination
query = query.order_by(AgentInstance.created_at.desc())
query = query.offset(skip).limit(limit)
result = await db.execute(query)
instances = list(result.scalars().all())
return instances, total
except Exception as e:
logger.error(
f"Error getting instances by project {project_id}: {e!s}",
exc_info=True,
)
raise
async def get_by_agent_type(
self,
db: AsyncSession,
*,
agent_type_id: UUID,
status: AgentStatus | None = None,
) -> list[AgentInstance]:
"""Get all instances of a specific agent type."""
try:
query = select(AgentInstance).where(
AgentInstance.agent_type_id == agent_type_id
)
if status is not None:
query = query.where(AgentInstance.status == status)
query = query.order_by(AgentInstance.created_at.desc())
result = await db.execute(query)
return list(result.scalars().all())
except Exception as e:
logger.error(
f"Error getting instances by agent type {agent_type_id}: {e!s}",
exc_info=True,
)
raise
async def update_status(
self,
db: AsyncSession,
*,
instance_id: UUID,
status: AgentStatus,
current_task: str | None = None,
) -> AgentInstance | None:
"""Update the status of an agent instance."""
try:
result = await db.execute(
select(AgentInstance).where(AgentInstance.id == instance_id)
)
instance = result.scalar_one_or_none()
if not instance:
return None
instance.status = status
instance.last_activity_at = datetime.now(UTC)
if current_task is not None:
instance.current_task = current_task
await db.commit()
await db.refresh(instance)
return instance
except Exception as e:
await db.rollback()
logger.error(
f"Error updating instance status {instance_id}: {e!s}", exc_info=True
)
raise
async def terminate(
self,
db: AsyncSession,
*,
instance_id: UUID,
) -> AgentInstance | None:
"""Terminate an agent instance.
Also unassigns all issues from this agent to prevent orphaned assignments.
"""
try:
result = await db.execute(
select(AgentInstance).where(AgentInstance.id == instance_id)
)
instance = result.scalar_one_or_none()
if not instance:
return None
# Unassign all issues from this agent before terminating
await db.execute(
update(Issue)
.where(Issue.assigned_agent_id == instance_id)
.values(assigned_agent_id=None)
)
instance.status = AgentStatus.TERMINATED
instance.terminated_at = datetime.now(UTC)
instance.current_task = None
instance.session_id = None
await db.commit()
await db.refresh(instance)
return instance
except Exception as e:
await db.rollback()
logger.error(
f"Error terminating instance {instance_id}: {e!s}", exc_info=True
)
raise
async def record_task_completion(
self,
db: AsyncSession,
*,
instance_id: UUID,
tokens_used: int,
cost_incurred: Decimal,
) -> AgentInstance | None:
"""Record a completed task and update metrics.
Uses atomic SQL UPDATE to prevent lost updates under concurrent load.
This avoids the read-modify-write race condition that occurs when
multiple task completions happen simultaneously.
"""
try:
now = datetime.now(UTC)
# Use atomic SQL UPDATE to increment counters without race conditions
# This is safe for concurrent updates - no read-modify-write pattern
result = await db.execute(
update(AgentInstance)
.where(AgentInstance.id == instance_id)
.values(
tasks_completed=AgentInstance.tasks_completed + 1,
tokens_used=AgentInstance.tokens_used + tokens_used,
cost_incurred=AgentInstance.cost_incurred + cost_incurred,
last_activity_at=now,
updated_at=now,
)
.returning(AgentInstance)
)
instance = result.scalar_one_or_none()
if not instance:
return None
await db.commit()
return instance
except Exception as e:
await db.rollback()
logger.error(
f"Error recording task completion {instance_id}: {e!s}", exc_info=True
)
raise
async def get_project_metrics(
self,
db: AsyncSession,
*,
project_id: UUID,
) -> dict[str, Any]:
"""Get aggregated metrics for all agents in a project."""
try:
result = await db.execute(
select(
func.count(AgentInstance.id).label("total_instances"),
func.count(AgentInstance.id)
.filter(AgentInstance.status == AgentStatus.WORKING)
.label("active_instances"),
func.count(AgentInstance.id)
.filter(AgentInstance.status == AgentStatus.IDLE)
.label("idle_instances"),
func.sum(AgentInstance.tasks_completed).label("total_tasks"),
func.sum(AgentInstance.tokens_used).label("total_tokens"),
func.sum(AgentInstance.cost_incurred).label("total_cost"),
).where(AgentInstance.project_id == project_id)
)
row = result.one()
return {
"total_instances": row.total_instances or 0,
"active_instances": row.active_instances or 0,
"idle_instances": row.idle_instances or 0,
"total_tasks_completed": row.total_tasks or 0,
"total_tokens_used": row.total_tokens or 0,
"total_cost_incurred": row.total_cost or Decimal("0.0000"),
}
except Exception as e:
logger.error(
f"Error getting project metrics {project_id}: {e!s}", exc_info=True
)
raise
async def bulk_terminate_by_project(
self,
db: AsyncSession,
*,
project_id: UUID,
) -> int:
"""Terminate all active instances in a project.
Also unassigns all issues from these agents to prevent orphaned assignments.
"""
try:
# First, unassign all issues from agents in this project
# Get all agent IDs that will be terminated
agents_to_terminate = await db.execute(
select(AgentInstance.id).where(
AgentInstance.project_id == project_id,
AgentInstance.status != AgentStatus.TERMINATED,
)
)
agent_ids = [row[0] for row in agents_to_terminate.fetchall()]
# Unassign issues from these agents
if agent_ids:
await db.execute(
update(Issue)
.where(Issue.assigned_agent_id.in_(agent_ids))
.values(assigned_agent_id=None)
)
now = datetime.now(UTC)
stmt = (
update(AgentInstance)
.where(
AgentInstance.project_id == project_id,
AgentInstance.status != AgentStatus.TERMINATED,
)
.values(
status=AgentStatus.TERMINATED,
terminated_at=now,
current_task=None,
session_id=None,
updated_at=now,
)
)
result = await db.execute(stmt)
await db.commit()
terminated_count = result.rowcount
logger.info(
f"Bulk terminated {terminated_count} instances in project {project_id}"
)
return terminated_count
except Exception as e:
await db.rollback()
logger.error(
f"Error bulk terminating instances for project {project_id}: {e!s}",
exc_info=True,
)
raise
# Create a singleton instance for use across the application
agent_instance = CRUDAgentInstance(AgentInstance)

View File

@@ -1,265 +0,0 @@
# app/crud/syndarix/agent_type.py
"""Async CRUD operations for AgentType model using SQLAlchemy 2.0 patterns."""
import logging
from typing import Any
from uuid import UUID
from sqlalchemy import func, or_, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from app.crud.base import CRUDBase
from app.models.syndarix import AgentInstance, AgentType
from app.schemas.syndarix import AgentTypeCreate, AgentTypeUpdate
logger = logging.getLogger(__name__)
class CRUDAgentType(CRUDBase[AgentType, AgentTypeCreate, AgentTypeUpdate]):
"""Async CRUD operations for AgentType model."""
async def get_by_slug(self, db: AsyncSession, *, slug: str) -> AgentType | None:
"""Get agent type by slug."""
try:
result = await db.execute(select(AgentType).where(AgentType.slug == slug))
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error getting agent type by slug {slug}: {e!s}")
raise
async def create(self, db: AsyncSession, *, obj_in: AgentTypeCreate) -> AgentType:
"""Create a new agent type with error handling."""
try:
db_obj = AgentType(
name=obj_in.name,
slug=obj_in.slug,
description=obj_in.description,
expertise=obj_in.expertise,
personality_prompt=obj_in.personality_prompt,
primary_model=obj_in.primary_model,
fallback_models=obj_in.fallback_models,
model_params=obj_in.model_params,
mcp_servers=obj_in.mcp_servers,
tool_permissions=obj_in.tool_permissions,
is_active=obj_in.is_active,
)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
return db_obj
except IntegrityError as e:
await db.rollback()
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
if "slug" in error_msg.lower():
logger.warning(f"Duplicate slug attempted: {obj_in.slug}")
raise ValueError(f"Agent type with slug '{obj_in.slug}' already exists")
logger.error(f"Integrity error creating agent type: {error_msg}")
raise ValueError(f"Database integrity error: {error_msg}")
except Exception as e:
await db.rollback()
logger.error(f"Unexpected error creating agent type: {e!s}", exc_info=True)
raise
async def get_multi_with_filters(
self,
db: AsyncSession,
*,
skip: int = 0,
limit: int = 100,
is_active: bool | None = None,
search: str | None = None,
sort_by: str = "created_at",
sort_order: str = "desc",
) -> tuple[list[AgentType], int]:
"""
Get multiple agent types with filtering, searching, and sorting.
Returns:
Tuple of (agent types list, total count)
"""
try:
query = select(AgentType)
# Apply filters
if is_active is not None:
query = query.where(AgentType.is_active == is_active)
if search:
search_filter = or_(
AgentType.name.ilike(f"%{search}%"),
AgentType.slug.ilike(f"%{search}%"),
AgentType.description.ilike(f"%{search}%"),
)
query = query.where(search_filter)
# Get total count before pagination
count_query = select(func.count()).select_from(query.alias())
count_result = await db.execute(count_query)
total = count_result.scalar_one()
# Apply sorting
sort_column = getattr(AgentType, sort_by, AgentType.created_at)
if sort_order == "desc":
query = query.order_by(sort_column.desc())
else:
query = query.order_by(sort_column.asc())
# Apply pagination
query = query.offset(skip).limit(limit)
result = await db.execute(query)
agent_types = list(result.scalars().all())
return agent_types, total
except Exception as e:
logger.error(f"Error getting agent types with filters: {e!s}")
raise
async def get_with_instance_count(
self,
db: AsyncSession,
*,
agent_type_id: UUID,
) -> dict[str, Any] | None:
"""
Get a single agent type with its instance count.
Returns:
Dictionary with agent_type and instance_count
"""
try:
result = await db.execute(
select(AgentType).where(AgentType.id == agent_type_id)
)
agent_type = result.scalar_one_or_none()
if not agent_type:
return None
# Get instance count
count_result = await db.execute(
select(func.count(AgentInstance.id)).where(
AgentInstance.agent_type_id == agent_type_id
)
)
instance_count = count_result.scalar_one()
return {
"agent_type": agent_type,
"instance_count": instance_count,
}
except Exception as e:
logger.error(
f"Error getting agent type with count {agent_type_id}: {e!s}",
exc_info=True,
)
raise
async def get_multi_with_instance_counts(
self,
db: AsyncSession,
*,
skip: int = 0,
limit: int = 100,
is_active: bool | None = None,
search: str | None = None,
) -> tuple[list[dict[str, Any]], int]:
"""
Get agent types with instance counts in optimized queries.
Returns:
Tuple of (list of dicts with agent_type and instance_count, total count)
"""
try:
# Get filtered agent types
agent_types, total = await self.get_multi_with_filters(
db,
skip=skip,
limit=limit,
is_active=is_active,
search=search,
)
if not agent_types:
return [], 0
agent_type_ids = [at.id for at in agent_types]
# Get instance counts in bulk
counts_result = await db.execute(
select(
AgentInstance.agent_type_id,
func.count(AgentInstance.id).label("count"),
)
.where(AgentInstance.agent_type_id.in_(agent_type_ids))
.group_by(AgentInstance.agent_type_id)
)
counts = {row.agent_type_id: row.count for row in counts_result}
# Combine results
results = [
{
"agent_type": agent_type,
"instance_count": counts.get(agent_type.id, 0),
}
for agent_type in agent_types
]
return results, total
except Exception as e:
logger.error(f"Error getting agent types with counts: {e!s}", exc_info=True)
raise
async def get_by_expertise(
self,
db: AsyncSession,
*,
expertise: str,
is_active: bool = True,
) -> list[AgentType]:
"""Get agent types that have a specific expertise."""
try:
# Use PostgreSQL JSONB contains operator
query = select(AgentType).where(
AgentType.expertise.contains([expertise.lower()]),
AgentType.is_active == is_active,
)
result = await db.execute(query)
return list(result.scalars().all())
except Exception as e:
logger.error(
f"Error getting agent types by expertise {expertise}: {e!s}",
exc_info=True,
)
raise
async def deactivate(
self,
db: AsyncSession,
*,
agent_type_id: UUID,
) -> AgentType | None:
"""Deactivate an agent type (soft delete)."""
try:
result = await db.execute(
select(AgentType).where(AgentType.id == agent_type_id)
)
agent_type = result.scalar_one_or_none()
if not agent_type:
return None
agent_type.is_active = False
await db.commit()
await db.refresh(agent_type)
return agent_type
except Exception as e:
await db.rollback()
logger.error(
f"Error deactivating agent type {agent_type_id}: {e!s}", exc_info=True
)
raise
# Create a singleton instance for use across the application
agent_type = CRUDAgentType(AgentType)

View File

@@ -1,525 +0,0 @@
# app/crud/syndarix/issue.py
"""Async CRUD operations for Issue model using SQLAlchemy 2.0 patterns."""
import logging
from datetime import UTC, datetime
from typing import Any
from uuid import UUID
from sqlalchemy import func, or_, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from app.crud.base import CRUDBase
from app.models.syndarix import AgentInstance, Issue
from app.models.syndarix.enums import IssuePriority, IssueStatus, SyncStatus
from app.schemas.syndarix import IssueCreate, IssueUpdate
logger = logging.getLogger(__name__)
class CRUDIssue(CRUDBase[Issue, IssueCreate, IssueUpdate]):
"""Async CRUD operations for Issue model."""
async def create(self, db: AsyncSession, *, obj_in: IssueCreate) -> Issue:
"""Create a new issue with error handling."""
try:
db_obj = Issue(
project_id=obj_in.project_id,
title=obj_in.title,
body=obj_in.body,
status=obj_in.status,
priority=obj_in.priority,
labels=obj_in.labels,
assigned_agent_id=obj_in.assigned_agent_id,
human_assignee=obj_in.human_assignee,
sprint_id=obj_in.sprint_id,
story_points=obj_in.story_points,
external_tracker_type=obj_in.external_tracker_type,
external_issue_id=obj_in.external_issue_id,
remote_url=obj_in.remote_url,
external_issue_number=obj_in.external_issue_number,
sync_status=SyncStatus.SYNCED,
)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
return db_obj
except IntegrityError as e:
await db.rollback()
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
logger.error(f"Integrity error creating issue: {error_msg}")
raise ValueError(f"Database integrity error: {error_msg}")
except Exception as e:
await db.rollback()
logger.error(f"Unexpected error creating issue: {e!s}", exc_info=True)
raise
async def get_with_details(
self,
db: AsyncSession,
*,
issue_id: UUID,
) -> dict[str, Any] | None:
"""
Get an issue with full details including related entity names.
Returns:
Dictionary with issue and related entity details
"""
try:
# Get issue with joined relationships
result = await db.execute(
select(Issue)
.options(
joinedload(Issue.project),
joinedload(Issue.sprint),
joinedload(Issue.assigned_agent).joinedload(
AgentInstance.agent_type
),
)
.where(Issue.id == issue_id)
)
issue = result.scalar_one_or_none()
if not issue:
return None
return {
"issue": issue,
"project_name": issue.project.name if issue.project else None,
"project_slug": issue.project.slug if issue.project else None,
"sprint_name": issue.sprint.name if issue.sprint else None,
"assigned_agent_type_name": (
issue.assigned_agent.agent_type.name
if issue.assigned_agent and issue.assigned_agent.agent_type
else None
),
}
except Exception as e:
logger.error(
f"Error getting issue with details {issue_id}: {e!s}", exc_info=True
)
raise
async def get_by_project(
self,
db: AsyncSession,
*,
project_id: UUID,
status: IssueStatus | None = None,
priority: IssuePriority | None = None,
sprint_id: UUID | None = None,
assigned_agent_id: UUID | None = None,
labels: list[str] | None = None,
search: str | None = None,
skip: int = 0,
limit: int = 100,
sort_by: str = "created_at",
sort_order: str = "desc",
) -> tuple[list[Issue], int]:
"""Get issues for a specific project with filters."""
try:
query = select(Issue).where(Issue.project_id == project_id)
# Apply filters
if status is not None:
query = query.where(Issue.status == status)
if priority is not None:
query = query.where(Issue.priority == priority)
if sprint_id is not None:
query = query.where(Issue.sprint_id == sprint_id)
if assigned_agent_id is not None:
query = query.where(Issue.assigned_agent_id == assigned_agent_id)
if labels:
# Match any of the provided labels
for label in labels:
query = query.where(Issue.labels.contains([label.lower()]))
if search:
search_filter = or_(
Issue.title.ilike(f"%{search}%"),
Issue.body.ilike(f"%{search}%"),
)
query = query.where(search_filter)
# Get total count
count_query = select(func.count()).select_from(query.alias())
count_result = await db.execute(count_query)
total = count_result.scalar_one()
# Apply sorting
sort_column = getattr(Issue, sort_by, Issue.created_at)
if sort_order == "desc":
query = query.order_by(sort_column.desc())
else:
query = query.order_by(sort_column.asc())
# Apply pagination
query = query.offset(skip).limit(limit)
result = await db.execute(query)
issues = list(result.scalars().all())
return issues, total
except Exception as e:
logger.error(
f"Error getting issues by project {project_id}: {e!s}", exc_info=True
)
raise
async def get_by_sprint(
self,
db: AsyncSession,
*,
sprint_id: UUID,
status: IssueStatus | None = None,
) -> list[Issue]:
"""Get all issues in a sprint."""
try:
query = select(Issue).where(Issue.sprint_id == sprint_id)
if status is not None:
query = query.where(Issue.status == status)
query = query.order_by(Issue.priority.desc(), Issue.created_at.asc())
result = await db.execute(query)
return list(result.scalars().all())
except Exception as e:
logger.error(
f"Error getting issues by sprint {sprint_id}: {e!s}", exc_info=True
)
raise
async def assign_to_agent(
self,
db: AsyncSession,
*,
issue_id: UUID,
agent_id: UUID | None,
) -> Issue | None:
"""Assign an issue to an agent (or unassign if agent_id is None)."""
try:
result = await db.execute(select(Issue).where(Issue.id == issue_id))
issue = result.scalar_one_or_none()
if not issue:
return None
issue.assigned_agent_id = agent_id
issue.human_assignee = None # Clear human assignee when assigning to agent
await db.commit()
await db.refresh(issue)
return issue
except Exception as e:
await db.rollback()
logger.error(
f"Error assigning issue {issue_id} to agent {agent_id}: {e!s}",
exc_info=True,
)
raise
async def assign_to_human(
self,
db: AsyncSession,
*,
issue_id: UUID,
human_assignee: str | None,
) -> Issue | None:
"""Assign an issue to a human (or unassign if human_assignee is None)."""
try:
result = await db.execute(select(Issue).where(Issue.id == issue_id))
issue = result.scalar_one_or_none()
if not issue:
return None
issue.human_assignee = human_assignee
issue.assigned_agent_id = None # Clear agent when assigning to human
await db.commit()
await db.refresh(issue)
return issue
except Exception as e:
await db.rollback()
logger.error(
f"Error assigning issue {issue_id} to human {human_assignee}: {e!s}",
exc_info=True,
)
raise
async def close_issue(
self,
db: AsyncSession,
*,
issue_id: UUID,
) -> Issue | None:
"""Close an issue by setting status and closed_at timestamp."""
try:
result = await db.execute(select(Issue).where(Issue.id == issue_id))
issue = result.scalar_one_or_none()
if not issue:
return None
issue.status = IssueStatus.CLOSED
issue.closed_at = datetime.now(UTC)
await db.commit()
await db.refresh(issue)
return issue
except Exception as e:
await db.rollback()
logger.error(f"Error closing issue {issue_id}: {e!s}", exc_info=True)
raise
async def reopen_issue(
self,
db: AsyncSession,
*,
issue_id: UUID,
) -> Issue | None:
"""Reopen a closed issue."""
try:
result = await db.execute(select(Issue).where(Issue.id == issue_id))
issue = result.scalar_one_or_none()
if not issue:
return None
issue.status = IssueStatus.OPEN
issue.closed_at = None
await db.commit()
await db.refresh(issue)
return issue
except Exception as e:
await db.rollback()
logger.error(f"Error reopening issue {issue_id}: {e!s}", exc_info=True)
raise
async def update_sync_status(
self,
db: AsyncSession,
*,
issue_id: UUID,
sync_status: SyncStatus,
last_synced_at: datetime | None = None,
external_updated_at: datetime | None = None,
) -> Issue | None:
"""Update the sync status of an issue."""
try:
result = await db.execute(select(Issue).where(Issue.id == issue_id))
issue = result.scalar_one_or_none()
if not issue:
return None
issue.sync_status = sync_status
if last_synced_at:
issue.last_synced_at = last_synced_at
if external_updated_at:
issue.external_updated_at = external_updated_at
await db.commit()
await db.refresh(issue)
return issue
except Exception as e:
await db.rollback()
logger.error(
f"Error updating sync status for issue {issue_id}: {e!s}", exc_info=True
)
raise
async def get_project_stats(
self,
db: AsyncSession,
*,
project_id: UUID,
) -> dict[str, Any]:
"""Get issue statistics for a project."""
try:
# Get counts by status
status_counts = await db.execute(
select(Issue.status, func.count(Issue.id).label("count"))
.where(Issue.project_id == project_id)
.group_by(Issue.status)
)
by_status = {row.status.value: row.count for row in status_counts}
# Get counts by priority
priority_counts = await db.execute(
select(Issue.priority, func.count(Issue.id).label("count"))
.where(Issue.project_id == project_id)
.group_by(Issue.priority)
)
by_priority = {row.priority.value: row.count for row in priority_counts}
# Get story points
points_result = await db.execute(
select(
func.sum(Issue.story_points).label("total"),
func.sum(Issue.story_points)
.filter(Issue.status == IssueStatus.CLOSED)
.label("completed"),
).where(Issue.project_id == project_id)
)
points_row = points_result.one()
total_issues = sum(by_status.values())
return {
"total": total_issues,
"open": by_status.get("open", 0),
"in_progress": by_status.get("in_progress", 0),
"in_review": by_status.get("in_review", 0),
"blocked": by_status.get("blocked", 0),
"closed": by_status.get("closed", 0),
"by_priority": by_priority,
"total_story_points": points_row.total,
"completed_story_points": points_row.completed,
}
except Exception as e:
logger.error(
f"Error getting issue stats for project {project_id}: {e!s}",
exc_info=True,
)
raise
async def get_by_external_id(
self,
db: AsyncSession,
*,
external_tracker_type: str,
external_issue_id: str,
) -> Issue | None:
"""Get an issue by its external tracker ID."""
try:
result = await db.execute(
select(Issue).where(
Issue.external_tracker_type == external_tracker_type,
Issue.external_issue_id == external_issue_id,
)
)
return result.scalar_one_or_none()
except Exception as e:
logger.error(
f"Error getting issue by external ID {external_tracker_type}:{external_issue_id}: {e!s}",
exc_info=True,
)
raise
async def get_pending_sync(
self,
db: AsyncSession,
*,
project_id: UUID | None = None,
limit: int = 100,
) -> list[Issue]:
"""Get issues that need to be synced with external tracker."""
try:
query = select(Issue).where(
Issue.external_tracker_type.isnot(None),
Issue.sync_status.in_([SyncStatus.PENDING, SyncStatus.ERROR]),
)
if project_id:
query = query.where(Issue.project_id == project_id)
query = query.order_by(Issue.updated_at.asc()).limit(limit)
result = await db.execute(query)
return list(result.scalars().all())
except Exception as e:
logger.error(f"Error getting pending sync issues: {e!s}", exc_info=True)
raise
async def remove_sprint_from_issues(
self,
db: AsyncSession,
*,
sprint_id: UUID,
) -> int:
"""Remove sprint assignment from all issues in a sprint.
Used when deleting a sprint to clean up references.
Returns:
Number of issues updated
"""
try:
from sqlalchemy import update
result = await db.execute(
update(Issue).where(Issue.sprint_id == sprint_id).values(sprint_id=None)
)
await db.commit()
return result.rowcount
except Exception as e:
await db.rollback()
logger.error(
f"Error removing sprint {sprint_id} from issues: {e!s}",
exc_info=True,
)
raise
async def unassign(
self,
db: AsyncSession,
*,
issue_id: UUID,
) -> Issue | None:
"""Remove agent assignment from an issue.
Returns:
Updated issue or None if not found
"""
try:
result = await db.execute(select(Issue).where(Issue.id == issue_id))
issue = result.scalar_one_or_none()
if not issue:
return None
issue.assigned_agent_id = None
await db.commit()
await db.refresh(issue)
return issue
except Exception as e:
await db.rollback()
logger.error(f"Error unassigning issue {issue_id}: {e!s}", exc_info=True)
raise
async def remove_from_sprint(
self,
db: AsyncSession,
*,
issue_id: UUID,
) -> Issue | None:
"""Remove an issue from its current sprint.
Returns:
Updated issue or None if not found
"""
try:
result = await db.execute(select(Issue).where(Issue.id == issue_id))
issue = result.scalar_one_or_none()
if not issue:
return None
issue.sprint_id = None
await db.commit()
await db.refresh(issue)
return issue
except Exception as e:
await db.rollback()
logger.error(
f"Error removing issue {issue_id} from sprint: {e!s}",
exc_info=True,
)
raise
# Create a singleton instance for use across the application
issue = CRUDIssue(Issue)

View File

@@ -1,362 +0,0 @@
# app/crud/syndarix/project.py
"""Async CRUD operations for Project model using SQLAlchemy 2.0 patterns."""
import logging
from datetime import UTC, datetime
from typing import Any
from uuid import UUID
from sqlalchemy import func, or_, select, update
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from app.crud.base import CRUDBase
from app.models.syndarix import AgentInstance, Issue, Project, Sprint
from app.models.syndarix.enums import AgentStatus, ProjectStatus, SprintStatus
from app.schemas.syndarix import ProjectCreate, ProjectUpdate
logger = logging.getLogger(__name__)
class CRUDProject(CRUDBase[Project, ProjectCreate, ProjectUpdate]):
"""Async CRUD operations for Project model."""
async def get_by_slug(self, db: AsyncSession, *, slug: str) -> Project | None:
"""Get project by slug."""
try:
result = await db.execute(select(Project).where(Project.slug == slug))
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error getting project by slug {slug}: {e!s}")
raise
async def create(self, db: AsyncSession, *, obj_in: ProjectCreate) -> Project:
"""Create a new project with error handling."""
try:
db_obj = Project(
name=obj_in.name,
slug=obj_in.slug,
description=obj_in.description,
autonomy_level=obj_in.autonomy_level,
status=obj_in.status,
settings=obj_in.settings or {},
owner_id=obj_in.owner_id,
)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
return db_obj
except IntegrityError as e:
await db.rollback()
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
if "slug" in error_msg.lower():
logger.warning(f"Duplicate slug attempted: {obj_in.slug}")
raise ValueError(f"Project with slug '{obj_in.slug}' already exists")
logger.error(f"Integrity error creating project: {error_msg}")
raise ValueError(f"Database integrity error: {error_msg}")
except Exception as e:
await db.rollback()
logger.error(f"Unexpected error creating project: {e!s}", exc_info=True)
raise
async def get_multi_with_filters(
self,
db: AsyncSession,
*,
skip: int = 0,
limit: int = 100,
status: ProjectStatus | None = None,
owner_id: UUID | None = None,
search: str | None = None,
sort_by: str = "created_at",
sort_order: str = "desc",
) -> tuple[list[Project], int]:
"""
Get multiple projects with filtering, searching, and sorting.
Returns:
Tuple of (projects list, total count)
"""
try:
query = select(Project)
# Apply filters
if status is not None:
query = query.where(Project.status == status)
if owner_id is not None:
query = query.where(Project.owner_id == owner_id)
if search:
search_filter = or_(
Project.name.ilike(f"%{search}%"),
Project.slug.ilike(f"%{search}%"),
Project.description.ilike(f"%{search}%"),
)
query = query.where(search_filter)
# Get total count before pagination
count_query = select(func.count()).select_from(query.alias())
count_result = await db.execute(count_query)
total = count_result.scalar_one()
# Apply sorting
sort_column = getattr(Project, sort_by, Project.created_at)
if sort_order == "desc":
query = query.order_by(sort_column.desc())
else:
query = query.order_by(sort_column.asc())
# Apply pagination
query = query.offset(skip).limit(limit)
result = await db.execute(query)
projects = list(result.scalars().all())
return projects, total
except Exception as e:
logger.error(f"Error getting projects with filters: {e!s}")
raise
async def get_with_counts(
self,
db: AsyncSession,
*,
project_id: UUID,
) -> dict[str, Any] | None:
"""
Get a single project with agent and issue counts.
Returns:
Dictionary with project, agent_count, issue_count, active_sprint_name
"""
try:
# Get project
result = await db.execute(select(Project).where(Project.id == project_id))
project = result.scalar_one_or_none()
if not project:
return None
# Get agent count
agent_count_result = await db.execute(
select(func.count(AgentInstance.id)).where(
AgentInstance.project_id == project_id
)
)
agent_count = agent_count_result.scalar_one()
# Get issue count
issue_count_result = await db.execute(
select(func.count(Issue.id)).where(Issue.project_id == project_id)
)
issue_count = issue_count_result.scalar_one()
# Get active sprint name
active_sprint_result = await db.execute(
select(Sprint.name).where(
Sprint.project_id == project_id,
Sprint.status == SprintStatus.ACTIVE,
)
)
active_sprint_name = active_sprint_result.scalar_one_or_none()
return {
"project": project,
"agent_count": agent_count,
"issue_count": issue_count,
"active_sprint_name": active_sprint_name,
}
except Exception as e:
logger.error(
f"Error getting project with counts {project_id}: {e!s}", exc_info=True
)
raise
async def get_multi_with_counts(
self,
db: AsyncSession,
*,
skip: int = 0,
limit: int = 100,
status: ProjectStatus | None = None,
owner_id: UUID | None = None,
search: str | None = None,
) -> tuple[list[dict[str, Any]], int]:
"""
Get projects with agent/issue counts in optimized queries.
Returns:
Tuple of (list of dicts with project and counts, total count)
"""
try:
# Get filtered projects
projects, total = await self.get_multi_with_filters(
db,
skip=skip,
limit=limit,
status=status,
owner_id=owner_id,
search=search,
)
if not projects:
return [], 0
project_ids = [p.id for p in projects]
# Get agent counts in bulk
agent_counts_result = await db.execute(
select(
AgentInstance.project_id,
func.count(AgentInstance.id).label("count"),
)
.where(AgentInstance.project_id.in_(project_ids))
.group_by(AgentInstance.project_id)
)
agent_counts = {row.project_id: row.count for row in agent_counts_result}
# Get issue counts in bulk
issue_counts_result = await db.execute(
select(
Issue.project_id,
func.count(Issue.id).label("count"),
)
.where(Issue.project_id.in_(project_ids))
.group_by(Issue.project_id)
)
issue_counts = {row.project_id: row.count for row in issue_counts_result}
# Get active sprint names
active_sprints_result = await db.execute(
select(Sprint.project_id, Sprint.name).where(
Sprint.project_id.in_(project_ids),
Sprint.status == SprintStatus.ACTIVE,
)
)
active_sprints = {row.project_id: row.name for row in active_sprints_result}
# Combine results
results = [
{
"project": project,
"agent_count": agent_counts.get(project.id, 0),
"issue_count": issue_counts.get(project.id, 0),
"active_sprint_name": active_sprints.get(project.id),
}
for project in projects
]
return results, total
except Exception as e:
logger.error(f"Error getting projects with counts: {e!s}", exc_info=True)
raise
async def get_projects_by_owner(
self,
db: AsyncSession,
*,
owner_id: UUID,
status: ProjectStatus | None = None,
) -> list[Project]:
"""Get all projects owned by a specific user."""
try:
query = select(Project).where(Project.owner_id == owner_id)
if status is not None:
query = query.where(Project.status == status)
query = query.order_by(Project.created_at.desc())
result = await db.execute(query)
return list(result.scalars().all())
except Exception as e:
logger.error(
f"Error getting projects by owner {owner_id}: {e!s}", exc_info=True
)
raise
async def archive_project(
self,
db: AsyncSession,
*,
project_id: UUID,
) -> Project | None:
"""Archive a project by setting status to ARCHIVED.
This also performs cascading cleanup:
- Terminates all active agent instances
- Cancels all planned/active sprints
- Unassigns issues from terminated agents
"""
try:
result = await db.execute(select(Project).where(Project.id == project_id))
project = result.scalar_one_or_none()
if not project:
return None
now = datetime.now(UTC)
# 1. Get all agent IDs that will be terminated
agents_to_terminate = await db.execute(
select(AgentInstance.id).where(
AgentInstance.project_id == project_id,
AgentInstance.status != AgentStatus.TERMINATED,
)
)
agent_ids = [row[0] for row in agents_to_terminate.fetchall()]
# 2. Unassign issues from these agents to prevent orphaned assignments
if agent_ids:
await db.execute(
update(Issue)
.where(Issue.assigned_agent_id.in_(agent_ids))
.values(assigned_agent_id=None)
)
# 3. Terminate all active agents
await db.execute(
update(AgentInstance)
.where(
AgentInstance.project_id == project_id,
AgentInstance.status != AgentStatus.TERMINATED,
)
.values(
status=AgentStatus.TERMINATED,
terminated_at=now,
current_task=None,
session_id=None,
updated_at=now,
)
)
# 4. Cancel all planned/active sprints
await db.execute(
update(Sprint)
.where(
Sprint.project_id == project_id,
Sprint.status.in_([SprintStatus.PLANNED, SprintStatus.ACTIVE]),
)
.values(
status=SprintStatus.CANCELLED,
updated_at=now,
)
)
# 5. Archive the project
project.status = ProjectStatus.ARCHIVED
await db.commit()
await db.refresh(project)
logger.info(
f"Archived project {project_id}: terminated agents={len(agent_ids)}"
)
return project
except Exception as e:
await db.rollback()
logger.error(f"Error archiving project {project_id}: {e!s}", exc_info=True)
raise
# Create a singleton instance for use across the application
project = CRUDProject(Project)

View File

@@ -1,439 +0,0 @@
# app/crud/syndarix/sprint.py
"""Async CRUD operations for Sprint model using SQLAlchemy 2.0 patterns."""
import logging
from datetime import date
from typing import Any
from uuid import UUID
from sqlalchemy import func, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from app.crud.base import CRUDBase
from app.models.syndarix import Issue, Sprint
from app.models.syndarix.enums import IssueStatus, SprintStatus
from app.schemas.syndarix import SprintCreate, SprintUpdate
logger = logging.getLogger(__name__)
class CRUDSprint(CRUDBase[Sprint, SprintCreate, SprintUpdate]):
"""Async CRUD operations for Sprint model."""
async def create(self, db: AsyncSession, *, obj_in: SprintCreate) -> Sprint:
"""Create a new sprint with error handling."""
try:
db_obj = Sprint(
project_id=obj_in.project_id,
name=obj_in.name,
number=obj_in.number,
goal=obj_in.goal,
start_date=obj_in.start_date,
end_date=obj_in.end_date,
status=obj_in.status,
planned_points=obj_in.planned_points,
velocity=obj_in.velocity,
)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
return db_obj
except IntegrityError as e:
await db.rollback()
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
logger.error(f"Integrity error creating sprint: {error_msg}")
raise ValueError(f"Database integrity error: {error_msg}")
except Exception as e:
await db.rollback()
logger.error(f"Unexpected error creating sprint: {e!s}", exc_info=True)
raise
async def get_with_details(
self,
db: AsyncSession,
*,
sprint_id: UUID,
) -> dict[str, Any] | None:
"""
Get a sprint with full details including issue counts.
Returns:
Dictionary with sprint and related details
"""
try:
# Get sprint with joined project
result = await db.execute(
select(Sprint)
.options(joinedload(Sprint.project))
.where(Sprint.id == sprint_id)
)
sprint = result.scalar_one_or_none()
if not sprint:
return None
# Get issue counts
issue_counts = await db.execute(
select(
func.count(Issue.id).label("total"),
func.count(Issue.id)
.filter(Issue.status == IssueStatus.OPEN)
.label("open"),
func.count(Issue.id)
.filter(Issue.status == IssueStatus.CLOSED)
.label("completed"),
).where(Issue.sprint_id == sprint_id)
)
counts = issue_counts.one()
return {
"sprint": sprint,
"project_name": sprint.project.name if sprint.project else None,
"project_slug": sprint.project.slug if sprint.project else None,
"issue_count": counts.total,
"open_issues": counts.open,
"completed_issues": counts.completed,
}
except Exception as e:
logger.error(
f"Error getting sprint with details {sprint_id}: {e!s}", exc_info=True
)
raise
async def get_by_project(
self,
db: AsyncSession,
*,
project_id: UUID,
status: SprintStatus | None = None,
skip: int = 0,
limit: int = 100,
) -> tuple[list[Sprint], int]:
"""Get sprints for a specific project."""
try:
query = select(Sprint).where(Sprint.project_id == project_id)
if status is not None:
query = query.where(Sprint.status == status)
# Get total count
count_query = select(func.count()).select_from(query.alias())
count_result = await db.execute(count_query)
total = count_result.scalar_one()
# Apply sorting (by number descending - newest first)
query = query.order_by(Sprint.number.desc())
query = query.offset(skip).limit(limit)
result = await db.execute(query)
sprints = list(result.scalars().all())
return sprints, total
except Exception as e:
logger.error(
f"Error getting sprints by project {project_id}: {e!s}", exc_info=True
)
raise
async def get_active_sprint(
self,
db: AsyncSession,
*,
project_id: UUID,
) -> Sprint | None:
"""Get the currently active sprint for a project."""
try:
result = await db.execute(
select(Sprint).where(
Sprint.project_id == project_id,
Sprint.status == SprintStatus.ACTIVE,
)
)
return result.scalar_one_or_none()
except Exception as e:
logger.error(
f"Error getting active sprint for project {project_id}: {e!s}",
exc_info=True,
)
raise
async def get_next_sprint_number(
self,
db: AsyncSession,
*,
project_id: UUID,
) -> int:
"""Get the next sprint number for a project."""
try:
result = await db.execute(
select(func.max(Sprint.number)).where(Sprint.project_id == project_id)
)
max_number = result.scalar_one_or_none()
return (max_number or 0) + 1
except Exception as e:
logger.error(
f"Error getting next sprint number for project {project_id}: {e!s}",
exc_info=True,
)
raise
async def start_sprint(
self,
db: AsyncSession,
*,
sprint_id: UUID,
start_date: date | None = None,
) -> Sprint | None:
"""Start a planned sprint.
Uses row-level locking (SELECT FOR UPDATE) to prevent race conditions
when multiple requests try to start sprints concurrently.
"""
try:
# Lock the sprint row to prevent concurrent modifications
result = await db.execute(
select(Sprint).where(Sprint.id == sprint_id).with_for_update()
)
sprint = result.scalar_one_or_none()
if not sprint:
return None
if sprint.status != SprintStatus.PLANNED:
raise ValueError(
f"Cannot start sprint with status {sprint.status.value}"
)
# Check for existing active sprint with lock to prevent race condition
# Lock all sprints for this project to ensure atomic check-and-update
active_check = await db.execute(
select(Sprint)
.where(
Sprint.project_id == sprint.project_id,
Sprint.status == SprintStatus.ACTIVE,
)
.with_for_update()
)
active_sprint = active_check.scalar_one_or_none()
if active_sprint:
raise ValueError(
f"Project already has an active sprint: {active_sprint.name}"
)
sprint.status = SprintStatus.ACTIVE
if start_date:
sprint.start_date = start_date
# Calculate planned points from issues
points_result = await db.execute(
select(func.sum(Issue.story_points)).where(Issue.sprint_id == sprint_id)
)
sprint.planned_points = points_result.scalar_one_or_none() or 0
await db.commit()
await db.refresh(sprint)
return sprint
except ValueError:
raise
except Exception as e:
await db.rollback()
logger.error(f"Error starting sprint {sprint_id}: {e!s}", exc_info=True)
raise
async def complete_sprint(
self,
db: AsyncSession,
*,
sprint_id: UUID,
) -> Sprint | None:
"""Complete an active sprint and calculate completed points.
Uses row-level locking (SELECT FOR UPDATE) to prevent race conditions
when velocity is being calculated and other operations might modify issues.
"""
try:
# Lock the sprint row to prevent concurrent modifications
result = await db.execute(
select(Sprint).where(Sprint.id == sprint_id).with_for_update()
)
sprint = result.scalar_one_or_none()
if not sprint:
return None
if sprint.status != SprintStatus.ACTIVE:
raise ValueError(
f"Cannot complete sprint with status {sprint.status.value}"
)
sprint.status = SprintStatus.COMPLETED
# Calculate velocity (completed points) from closed issues
# Note: Issues are not locked, but sprint lock ensures this sprint's
# completion is atomic and prevents concurrent completion attempts
points_result = await db.execute(
select(func.sum(Issue.story_points)).where(
Issue.sprint_id == sprint_id,
Issue.status == IssueStatus.CLOSED,
)
)
sprint.velocity = points_result.scalar_one_or_none() or 0
await db.commit()
await db.refresh(sprint)
return sprint
except ValueError:
raise
except Exception as e:
await db.rollback()
logger.error(f"Error completing sprint {sprint_id}: {e!s}", exc_info=True)
raise
async def cancel_sprint(
self,
db: AsyncSession,
*,
sprint_id: UUID,
) -> Sprint | None:
"""Cancel a sprint (only PLANNED or ACTIVE sprints can be cancelled).
Uses row-level locking to prevent race conditions with concurrent
sprint status modifications.
"""
try:
# Lock the sprint row to prevent concurrent modifications
result = await db.execute(
select(Sprint).where(Sprint.id == sprint_id).with_for_update()
)
sprint = result.scalar_one_or_none()
if not sprint:
return None
if sprint.status not in [SprintStatus.PLANNED, SprintStatus.ACTIVE]:
raise ValueError(
f"Cannot cancel sprint with status {sprint.status.value}"
)
sprint.status = SprintStatus.CANCELLED
await db.commit()
await db.refresh(sprint)
return sprint
except ValueError:
raise
except Exception as e:
await db.rollback()
logger.error(f"Error cancelling sprint {sprint_id}: {e!s}", exc_info=True)
raise
async def get_velocity(
self,
db: AsyncSession,
*,
project_id: UUID,
limit: int = 5,
) -> list[dict[str, Any]]:
"""Get velocity data for completed sprints."""
try:
result = await db.execute(
select(Sprint)
.where(
Sprint.project_id == project_id,
Sprint.status == SprintStatus.COMPLETED,
)
.order_by(Sprint.number.desc())
.limit(limit)
)
sprints = list(result.scalars().all())
velocity_data = []
for sprint in reversed(sprints): # Return in chronological order
velocity_ratio = None
if sprint.planned_points and sprint.planned_points > 0:
velocity_ratio = (sprint.velocity or 0) / sprint.planned_points
velocity_data.append(
{
"sprint_number": sprint.number,
"sprint_name": sprint.name,
"planned_points": sprint.planned_points,
"velocity": sprint.velocity,
"velocity_ratio": velocity_ratio,
}
)
return velocity_data
except Exception as e:
logger.error(
f"Error getting velocity for project {project_id}: {e!s}",
exc_info=True,
)
raise
async def get_sprints_with_issue_counts(
self,
db: AsyncSession,
*,
project_id: UUID,
skip: int = 0,
limit: int = 100,
) -> tuple[list[dict[str, Any]], int]:
"""Get sprints with issue counts in optimized queries."""
try:
# Get sprints
sprints, total = await self.get_by_project(
db, project_id=project_id, skip=skip, limit=limit
)
if not sprints:
return [], 0
sprint_ids = [s.id for s in sprints]
# Get issue counts in bulk
issue_counts = await db.execute(
select(
Issue.sprint_id,
func.count(Issue.id).label("total"),
func.count(Issue.id)
.filter(Issue.status == IssueStatus.OPEN)
.label("open"),
func.count(Issue.id)
.filter(Issue.status == IssueStatus.CLOSED)
.label("completed"),
)
.where(Issue.sprint_id.in_(sprint_ids))
.group_by(Issue.sprint_id)
)
counts_map = {
row.sprint_id: {
"issue_count": row.total,
"open_issues": row.open,
"completed_issues": row.completed,
}
for row in issue_counts
}
# Combine results
results = [
{
"sprint": sprint,
**counts_map.get(
sprint.id,
{"issue_count": 0, "open_issues": 0, "completed_issues": 0},
),
}
for sprint in sprints
]
return results, total
except Exception as e:
logger.error(
f"Error getting sprints with counts for project {project_id}: {e!s}",
exc_info=True,
)
raise
# Create a singleton instance for use across the application
sprint = CRUDSprint(Sprint)

View File

@@ -16,10 +16,10 @@ from sqlalchemy import select, text
from app.core.config import settings from app.core.config import settings
from app.core.database import SessionLocal, engine from app.core.database import SessionLocal, engine
from app.crud.user import user as user_crud
from app.models.organization import Organization from app.models.organization import Organization
from app.models.user import User from app.models.user import User
from app.models.user_organization import UserOrganization from app.models.user_organization import UserOrganization
from app.repositories.user import user_repo as user_repo
from app.schemas.users import UserCreate from app.schemas.users import UserCreate
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -44,16 +44,17 @@ async def init_db() -> User | None:
if not settings.FIRST_SUPERUSER_EMAIL or not settings.FIRST_SUPERUSER_PASSWORD: if not settings.FIRST_SUPERUSER_EMAIL or not settings.FIRST_SUPERUSER_PASSWORD:
logger.warning( logger.warning(
"First superuser credentials not configured in settings. " "First superuser credentials not configured in settings. "
f"Using defaults: {superuser_email}" "Using defaults: %s",
superuser_email,
) )
async with SessionLocal() as session: async with SessionLocal() as session:
try: try:
# Check if superuser already exists # Check if superuser already exists
existing_user = await user_crud.get_by_email(session, email=superuser_email) existing_user = await user_repo.get_by_email(session, email=superuser_email)
if existing_user: if existing_user:
logger.info(f"Superuser already exists: {existing_user.email}") logger.info("Superuser already exists: %s", existing_user.email)
return existing_user return existing_user
# Create superuser if doesn't exist # Create superuser if doesn't exist
@@ -65,11 +66,11 @@ async def init_db() -> User | None:
is_superuser=True, is_superuser=True,
) )
user = await user_crud.create(session, obj_in=user_in) user = await user_repo.create(session, obj_in=user_in)
await session.commit() await session.commit()
await session.refresh(user) await session.refresh(user)
logger.info(f"Created first superuser: {user.email}") logger.info("Created first superuser: %s", user.email)
# Create demo data if in demo mode # Create demo data if in demo mode
if settings.DEMO_MODE: if settings.DEMO_MODE:
@@ -79,7 +80,7 @@ async def init_db() -> User | None:
except Exception as e: except Exception as e:
await session.rollback() await session.rollback()
logger.error(f"Error initializing database: {e}") logger.error("Error initializing database: %s", e)
raise raise
@@ -92,7 +93,7 @@ async def load_demo_data(session):
"""Load demo data from JSON file.""" """Load demo data from JSON file."""
demo_data_path = Path(__file__).parent / "core" / "demo_data.json" demo_data_path = Path(__file__).parent / "core" / "demo_data.json"
if not demo_data_path.exists(): if not demo_data_path.exists():
logger.warning(f"Demo data file not found: {demo_data_path}") logger.warning("Demo data file not found: %s", demo_data_path)
return return
try: try:
@@ -119,7 +120,7 @@ async def load_demo_data(session):
session.add(org) session.add(org)
await session.flush() # Flush to get ID await session.flush() # Flush to get ID
org_map[org.slug] = org org_map[org.slug] = org
logger.info(f"Created demo organization: {org.name}") logger.info("Created demo organization: %s", org.name)
else: else:
# We can't easily get the ORM object from raw SQL result for map without querying again or mapping # We can't easily get the ORM object from raw SQL result for map without querying again or mapping
# So let's just query it properly if we need it for relationships # So let's just query it properly if we need it for relationships
@@ -135,7 +136,7 @@ async def load_demo_data(session):
# Create Users # Create Users
for user_data in data.get("users", []): for user_data in data.get("users", []):
existing_user = await user_crud.get_by_email( existing_user = await user_repo.get_by_email(
session, email=user_data["email"] session, email=user_data["email"]
) )
if not existing_user: if not existing_user:
@@ -148,7 +149,7 @@ async def load_demo_data(session):
is_superuser=user_data["is_superuser"], is_superuser=user_data["is_superuser"],
is_active=user_data.get("is_active", True), is_active=user_data.get("is_active", True),
) )
user = await user_crud.create(session, obj_in=user_in) user = await user_repo.create(session, obj_in=user_in)
# Randomize created_at for demo data (last 30 days) # Randomize created_at for demo data (last 30 days)
# This makes the charts look more realistic # This makes the charts look more realistic
@@ -174,7 +175,10 @@ async def load_demo_data(session):
) )
logger.info( logger.info(
f"Created demo user: {user.email} (created {days_ago} days ago, active={user_data.get('is_active', True)})" "Created demo user: %s (created %s days ago, active=%s)",
user.email,
days_ago,
user_data.get("is_active", True),
) )
# Add to organization if specified # Add to organization if specified
@@ -187,15 +191,15 @@ async def load_demo_data(session):
user_id=user.id, organization_id=org.id, role=role user_id=user.id, organization_id=org.id, role=role
) )
session.add(member) session.add(member)
logger.info(f"Added {user.email} to {org.name} as {role}") logger.info("Added %s to %s as %s", user.email, org.name, role)
else: else:
logger.info(f"Demo user already exists: {existing_user.email}") logger.info("Demo user already exists: %s", existing_user.email)
await session.commit() await session.commit()
logger.info("Demo data loaded successfully") logger.info("Demo data loaded successfully")
except Exception as e: except Exception as e:
logger.error(f"Error loading demo data: {e}") logger.error("Error loading demo data: %s", e)
raise raise

View File

@@ -1,7 +1,7 @@
import logging import logging
import os import os
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from datetime import datetime from datetime import UTC, datetime
from typing import Any from typing import Any
from apscheduler.schedulers.asyncio import AsyncIOScheduler from apscheduler.schedulers.asyncio import AsyncIOScheduler
@@ -16,7 +16,7 @@ from slowapi.util import get_remote_address
from app.api.main import api_router from app.api.main import api_router
from app.api.routes.oauth_provider import wellknown_router as oauth_wellknown_router from app.api.routes.oauth_provider import wellknown_router as oauth_wellknown_router
from app.core.config import settings from app.core.config import settings
from app.core.database import check_database_health from app.core.database import check_database_health, close_async_db
from app.core.exceptions import ( from app.core.exceptions import (
APIException, APIException,
api_exception_handler, api_exception_handler,
@@ -72,6 +72,7 @@ async def lifespan(app: FastAPI):
if os.getenv("IS_TEST", "False") != "True": if os.getenv("IS_TEST", "False") != "True":
scheduler.shutdown() scheduler.shutdown()
logger.info("Scheduled jobs stopped") logger.info("Scheduled jobs stopped")
await close_async_db()
logger.info("Starting app!!!") logger.info("Starting app!!!")
@@ -294,7 +295,7 @@ async def health_check() -> JSONResponse:
""" """
health_status: dict[str, Any] = { health_status: dict[str, Any] = {
"status": "healthy", "status": "healthy",
"timestamp": datetime.utcnow().isoformat() + "Z", "timestamp": datetime.now(UTC).isoformat().replace("+00:00", "Z"),
"version": settings.VERSION, "version": settings.VERSION,
"environment": settings.ENVIRONMENT, "environment": settings.ENVIRONMENT,
"checks": {}, "checks": {},
@@ -319,7 +320,7 @@ async def health_check() -> JSONResponse:
"message": f"Database connection failed: {e!s}", "message": f"Database connection failed: {e!s}",
} }
response_status = status.HTTP_503_SERVICE_UNAVAILABLE response_status = status.HTTP_503_SERVICE_UNAVAILABLE
logger.error(f"Health check failed - database error: {e}") logger.error("Health check failed - database error: %s", e)
return JSONResponse(status_code=response_status, content=health_status) return JSONResponse(status_code=response_status, content=health_status)

View File

@@ -18,26 +18,13 @@ from .oauth_provider_token import OAuthConsent, OAuthProviderRefreshToken
from .oauth_state import OAuthState from .oauth_state import OAuthState
from .organization import Organization from .organization import Organization
# Syndarix domain models
from .syndarix import (
AgentInstance,
AgentType,
Issue,
Project,
Sprint,
)
# Import models # Import models
from .user import User from .user import User
from .user_organization import OrganizationRole, UserOrganization from .user_organization import OrganizationRole, UserOrganization
from .user_session import UserSession from .user_session import UserSession
__all__ = [ __all__ = [
# Syndarix models
"AgentInstance",
"AgentType",
"Base", "Base",
"Issue",
"OAuthAccount", "OAuthAccount",
"OAuthAuthorizationCode", "OAuthAuthorizationCode",
"OAuthClient", "OAuthClient",
@@ -46,8 +33,6 @@ __all__ = [
"OAuthState", "OAuthState",
"Organization", "Organization",
"OrganizationRole", "OrganizationRole",
"Project",
"Sprint",
"TimestampMixin", "TimestampMixin",
"UUIDMixin", "UUIDMixin",
"User", "User",

View File

@@ -36,9 +36,9 @@ class OAuthAccount(Base, UUIDMixin, TimestampMixin):
) # Email from provider (for reference) ) # Email from provider (for reference)
# Optional: store provider tokens for API access # Optional: store provider tokens for API access
# These should be encrypted at rest in production # TODO: Encrypt these at rest in production (requires key management infrastructure)
access_token_encrypted = Column(String(2048), nullable=True) access_token = Column(String(2048), nullable=True)
refresh_token_encrypted = Column(String(2048), nullable=True) refresh_token = Column(String(2048), nullable=True)
token_expires_at = Column(DateTime(timezone=True), nullable=True) token_expires_at = Column(DateTime(timezone=True), nullable=True)
# Relationship # Relationship

View File

@@ -92,7 +92,7 @@ class OAuthAuthorizationCode(Base, UUIDMixin, TimestampMixin):
# Handle both timezone-aware and naive datetimes from DB # Handle both timezone-aware and naive datetimes from DB
if expires_at.tzinfo is None: if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=UTC) expires_at = expires_at.replace(tzinfo=UTC)
return now > expires_at return bool(now > expires_at)
@property @property
def is_valid(self) -> bool: def is_valid(self) -> bool:

View File

@@ -99,7 +99,7 @@ class OAuthProviderRefreshToken(Base, UUIDMixin, TimestampMixin):
# Handle both timezone-aware and naive datetimes from DB # Handle both timezone-aware and naive datetimes from DB
if expires_at.tzinfo is None: if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=UTC) expires_at = expires_at.replace(tzinfo=UTC)
return now > expires_at return bool(now > expires_at)
@property @property
def is_valid(self) -> bool: def is_valid(self) -> bool:

View File

@@ -1,47 +0,0 @@
# app/models/syndarix/__init__.py
"""
Syndarix domain models.
This package contains all the core entities for the Syndarix AI consulting platform:
- Project: Client engagements with autonomy settings
- AgentType: Templates for AI agent capabilities
- AgentInstance: Spawned agents working on projects
- Issue: Units of work with external tracker sync
- Sprint: Time-boxed iterations for organizing work
"""
from .agent_instance import AgentInstance
from .agent_type import AgentType
from .enums import (
AgentStatus,
AutonomyLevel,
ClientMode,
IssuePriority,
IssueStatus,
IssueType,
ProjectComplexity,
ProjectStatus,
SprintStatus,
SyncStatus,
)
from .issue import Issue
from .project import Project
from .sprint import Sprint
__all__ = [
"AgentInstance",
"AgentStatus",
"AgentType",
"AutonomyLevel",
"ClientMode",
"Issue",
"IssuePriority",
"IssueStatus",
"IssueType",
"Project",
"ProjectComplexity",
"ProjectStatus",
"Sprint",
"SprintStatus",
"SyncStatus",
]

View File

@@ -1,111 +0,0 @@
# app/models/syndarix/agent_instance.py
"""
AgentInstance model for Syndarix AI consulting platform.
An AgentInstance is a spawned instance of an AgentType, assigned to a
specific project to perform work.
"""
from sqlalchemy import (
BigInteger,
Column,
DateTime,
Enum,
ForeignKey,
Index,
Integer,
Numeric,
String,
Text,
)
from sqlalchemy.dialects.postgresql import (
JSONB,
UUID as PGUUID,
)
from sqlalchemy.orm import relationship
from app.models.base import Base, TimestampMixin, UUIDMixin
from .enums import AgentStatus
class AgentInstance(Base, UUIDMixin, TimestampMixin):
"""
AgentInstance model representing a spawned agent working on a project.
Tracks:
- Current status and task
- Memory (short-term in DB, long-term reference to vector store)
- Session information for MCP connections
- Usage metrics (tasks completed, tokens, cost)
"""
__tablename__ = "agent_instances"
# Foreign keys
agent_type_id = Column(
PGUUID(as_uuid=True),
ForeignKey("agent_types.id", ondelete="RESTRICT"),
nullable=False,
index=True,
)
project_id = Column(
PGUUID(as_uuid=True),
ForeignKey("projects.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
# Agent instance name (e.g., "Dave", "Eve") for personality
name = Column(String(100), nullable=False, index=True)
# Status tracking
status: Column[AgentStatus] = Column(
Enum(AgentStatus),
default=AgentStatus.IDLE,
nullable=False,
index=True,
)
# Current task description (brief summary of what agent is doing)
current_task = Column(Text, nullable=True)
# Short-term memory stored in database (conversation context, recent decisions)
short_term_memory = Column(JSONB, default=dict, nullable=False)
# Reference to long-term memory in vector store (e.g., "project-123/agent-456")
long_term_memory_ref = Column(String(500), nullable=True)
# Session ID for active MCP connections
session_id = Column(String(255), nullable=True, index=True)
# Activity tracking
last_activity_at = Column(DateTime(timezone=True), nullable=True, index=True)
terminated_at = Column(DateTime(timezone=True), nullable=True, index=True)
# Usage metrics
tasks_completed = Column(Integer, default=0, nullable=False)
tokens_used = Column(BigInteger, default=0, nullable=False)
cost_incurred = Column(Numeric(precision=10, scale=4), default=0, nullable=False)
# Relationships
agent_type = relationship("AgentType", back_populates="instances")
project = relationship("Project", back_populates="agent_instances")
assigned_issues = relationship(
"Issue",
back_populates="assigned_agent",
foreign_keys="Issue.assigned_agent_id",
)
__table_args__ = (
Index("ix_agent_instances_project_status", "project_id", "status"),
Index("ix_agent_instances_type_status", "agent_type_id", "status"),
Index("ix_agent_instances_project_type", "project_id", "agent_type_id"),
)
def __repr__(self) -> str:
return (
f"<AgentInstance {self.name} ({self.id}) type={self.agent_type_id} "
f"project={self.project_id} status={self.status.value}>"
)

View File

@@ -1,72 +0,0 @@
# app/models/syndarix/agent_type.py
"""
AgentType model for Syndarix AI consulting platform.
An AgentType is a template that defines the capabilities, personality,
and model configuration for agent instances.
"""
from sqlalchemy import Boolean, Column, Index, String, Text
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import relationship
from app.models.base import Base, TimestampMixin, UUIDMixin
class AgentType(Base, UUIDMixin, TimestampMixin):
"""
AgentType model representing a template for agent instances.
Each agent type defines:
- Expertise areas and personality prompt
- Model configuration (primary, fallback, parameters)
- MCP server access and tool permissions
Examples: ProductOwner, Architect, BackendEngineer, QAEngineer
"""
__tablename__ = "agent_types"
name = Column(String(255), nullable=False, index=True)
slug = Column(String(255), unique=True, nullable=False, index=True)
description = Column(Text, nullable=True)
# Areas of expertise for this agent type (e.g., ["python", "fastapi", "databases"])
expertise = Column(JSONB, default=list, nullable=False)
# System prompt defining the agent's personality and behavior
personality_prompt = Column(Text, nullable=False)
# Primary LLM model to use (e.g., "claude-opus-4-5-20251101")
primary_model = Column(String(100), nullable=False)
# Fallback models in order of preference
fallback_models = Column(JSONB, default=list, nullable=False)
# Model parameters (temperature, max_tokens, etc.)
model_params = Column(JSONB, default=dict, nullable=False)
# List of MCP servers this agent can connect to
mcp_servers = Column(JSONB, default=list, nullable=False)
# Tool permissions configuration
# Structure: {"allowed": ["*"], "denied": [], "require_approval": ["gitea:create_pr"]}
tool_permissions = Column(JSONB, default=dict, nullable=False)
# Whether this agent type is available for new instances
is_active = Column(Boolean, default=True, nullable=False, index=True)
# Relationships
instances = relationship(
"AgentInstance",
back_populates="agent_type",
cascade="all, delete-orphan",
)
__table_args__ = (
Index("ix_agent_types_slug_active", "slug", "is_active"),
Index("ix_agent_types_name_active", "name", "is_active"),
)
def __repr__(self) -> str:
return f"<AgentType {self.name} ({self.slug}) active={self.is_active}>"

View File

@@ -1,169 +0,0 @@
# app/models/syndarix/enums.py
"""
Enums for Syndarix domain models.
These enums represent the core state machines and categorizations
used throughout the Syndarix AI consulting platform.
"""
from enum import Enum as PyEnum
class AutonomyLevel(str, PyEnum):
"""
Defines how much control the human has over agent actions.
FULL_CONTROL: Human must approve every agent action
MILESTONE: Human approves at sprint boundaries and major decisions
AUTONOMOUS: Agents work independently, only escalating critical issues
"""
FULL_CONTROL = "full_control"
MILESTONE = "milestone"
AUTONOMOUS = "autonomous"
class ProjectComplexity(str, PyEnum):
"""
Project complexity level for estimation and planning.
SCRIPT: Simple automation or script-level work
SIMPLE: Straightforward feature or fix
MEDIUM: Standard complexity with some architectural considerations
COMPLEX: Large-scale feature requiring significant design work
"""
SCRIPT = "script"
SIMPLE = "simple"
MEDIUM = "medium"
COMPLEX = "complex"
class ClientMode(str, PyEnum):
"""
How the client prefers to interact with agents.
TECHNICAL: Client is technical and prefers detailed updates
AUTO: Agents automatically determine communication level
"""
TECHNICAL = "technical"
AUTO = "auto"
class ProjectStatus(str, PyEnum):
"""
Project lifecycle status.
ACTIVE: Project is actively being worked on
PAUSED: Project is temporarily on hold
COMPLETED: Project has been delivered successfully
ARCHIVED: Project is no longer accessible for work
"""
ACTIVE = "active"
PAUSED = "paused"
COMPLETED = "completed"
ARCHIVED = "archived"
class AgentStatus(str, PyEnum):
"""
Current operational status of an agent instance.
IDLE: Agent is available but not currently working
WORKING: Agent is actively processing a task
WAITING: Agent is waiting for external input or approval
PAUSED: Agent has been manually paused
TERMINATED: Agent instance has been shut down
"""
IDLE = "idle"
WORKING = "working"
WAITING = "waiting"
PAUSED = "paused"
TERMINATED = "terminated"
class IssueType(str, PyEnum):
"""
Issue type for categorization and hierarchy.
EPIC: Large feature or body of work containing stories
STORY: User-facing feature or requirement
TASK: Technical work item
BUG: Defect or issue to be fixed
"""
EPIC = "epic"
STORY = "story"
TASK = "task"
BUG = "bug"
class IssueStatus(str, PyEnum):
"""
Issue workflow status.
OPEN: Issue is ready to be worked on
IN_PROGRESS: Agent or human is actively working on the issue
IN_REVIEW: Work is complete, awaiting review
BLOCKED: Issue cannot proceed due to dependencies or blockers
CLOSED: Issue has been completed or cancelled
"""
OPEN = "open"
IN_PROGRESS = "in_progress"
IN_REVIEW = "in_review"
BLOCKED = "blocked"
CLOSED = "closed"
class IssuePriority(str, PyEnum):
"""
Issue priority levels.
LOW: Nice to have, can be deferred
MEDIUM: Standard priority, should be done
HIGH: Important, should be prioritized
CRITICAL: Must be done immediately, blocking other work
"""
LOW = "low"
MEDIUM = "medium"
HIGH = "high"
CRITICAL = "critical"
class SyncStatus(str, PyEnum):
"""
External issue tracker synchronization status.
SYNCED: Local and remote are in sync
PENDING: Local changes waiting to be pushed
CONFLICT: Merge conflict between local and remote
ERROR: Synchronization failed due to an error
"""
SYNCED = "synced"
PENDING = "pending"
CONFLICT = "conflict"
ERROR = "error"
class SprintStatus(str, PyEnum):
"""
Sprint lifecycle status.
PLANNED: Sprint has been created but not started
ACTIVE: Sprint is currently in progress
IN_REVIEW: Sprint work is done, demo/review pending
COMPLETED: Sprint has been finished successfully
CANCELLED: Sprint was cancelled before completion
"""
PLANNED = "planned"
ACTIVE = "active"
IN_REVIEW = "in_review"
COMPLETED = "completed"
CANCELLED = "cancelled"

View File

@@ -1,176 +0,0 @@
# app/models/syndarix/issue.py
"""
Issue model for Syndarix AI consulting platform.
An Issue represents a unit of work that can be assigned to agents or humans,
with optional synchronization to external issue trackers (Gitea, GitHub, GitLab).
"""
from sqlalchemy import (
Column,
Date,
DateTime,
Enum,
ForeignKey,
Index,
Integer,
String,
Text,
)
from sqlalchemy.dialects.postgresql import (
JSONB,
UUID as PGUUID,
)
from sqlalchemy.orm import relationship
from app.models.base import Base, TimestampMixin, UUIDMixin
from .enums import IssuePriority, IssueStatus, IssueType, SyncStatus
class Issue(Base, UUIDMixin, TimestampMixin):
"""
Issue model representing a unit of work in a project.
Features:
- Standard issue fields (title, body, status, priority)
- Assignment to agent instances or human assignees
- Sprint association for backlog management
- External tracker synchronization (Gitea, GitHub, GitLab)
"""
__tablename__ = "issues"
# Foreign key to project
project_id = Column(
PGUUID(as_uuid=True),
ForeignKey("projects.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
# Parent issue for hierarchy (Epic -> Story -> Task)
parent_id = Column(
PGUUID(as_uuid=True),
ForeignKey("issues.id", ondelete="CASCADE"),
nullable=True,
index=True,
)
# Issue type (Epic, Story, Task, Bug)
type: Column[IssueType] = Column(
Enum(IssueType),
default=IssueType.TASK,
nullable=False,
index=True,
)
# Reporter (who created this issue - can be user or agent)
reporter_id = Column(
PGUUID(as_uuid=True),
nullable=True, # System-generated issues may have no reporter
index=True,
)
# Issue content
title = Column(String(500), nullable=False)
body = Column(Text, nullable=False, default="")
# Status and priority
status: Column[IssueStatus] = Column(
Enum(IssueStatus),
default=IssueStatus.OPEN,
nullable=False,
index=True,
)
priority: Column[IssuePriority] = Column(
Enum(IssuePriority),
default=IssuePriority.MEDIUM,
nullable=False,
index=True,
)
# Labels for categorization (e.g., ["bug", "frontend", "urgent"])
labels = Column(JSONB, default=list, nullable=False)
# Assignment - either to an agent or a human (mutually exclusive)
assigned_agent_id = Column(
PGUUID(as_uuid=True),
ForeignKey("agent_instances.id", ondelete="SET NULL"),
nullable=True,
index=True,
)
# Human assignee (username or email, not a FK to allow external users)
human_assignee = Column(String(255), nullable=True, index=True)
# Sprint association
sprint_id = Column(
PGUUID(as_uuid=True),
ForeignKey("sprints.id", ondelete="SET NULL"),
nullable=True,
index=True,
)
# Story points for estimation
story_points = Column(Integer, nullable=True)
# Due date for the issue
due_date = Column(Date, nullable=True, index=True)
# External tracker integration
external_tracker_type = Column(
String(50),
nullable=True,
index=True,
) # 'gitea', 'github', 'gitlab'
external_issue_id = Column(String(255), nullable=True) # External system's ID
remote_url = Column(String(1000), nullable=True) # Link to external issue
external_issue_number = Column(Integer, nullable=True) # Issue number (e.g., #123)
# Sync status with external tracker
sync_status: Column[SyncStatus] = Column(
Enum(SyncStatus),
default=SyncStatus.SYNCED,
nullable=False,
# Note: Index defined in __table_args__ as ix_issues_sync_status
)
last_synced_at = Column(DateTime(timezone=True), nullable=True)
external_updated_at = Column(DateTime(timezone=True), nullable=True)
# Lifecycle timestamp
closed_at = Column(DateTime(timezone=True), nullable=True, index=True)
# Relationships
project = relationship("Project", back_populates="issues")
assigned_agent = relationship(
"AgentInstance",
back_populates="assigned_issues",
foreign_keys=[assigned_agent_id],
)
sprint = relationship("Sprint", back_populates="issues")
parent = relationship("Issue", remote_side="Issue.id", backref="children")
__table_args__ = (
Index("ix_issues_project_status", "project_id", "status"),
Index("ix_issues_project_priority", "project_id", "priority"),
Index("ix_issues_project_sprint", "project_id", "sprint_id"),
Index(
"ix_issues_external_tracker_id",
"external_tracker_type",
"external_issue_id",
),
Index("ix_issues_sync_status", "sync_status"),
Index("ix_issues_project_agent", "project_id", "assigned_agent_id"),
Index("ix_issues_project_type", "project_id", "type"),
Index("ix_issues_project_status_priority", "project_id", "status", "priority"),
)
def __repr__(self) -> str:
return (
f"<Issue {self.id} title='{self.title[:30]}...' "
f"status={self.status.value} priority={self.priority.value}>"
)

View File

@@ -1,103 +0,0 @@
# app/models/syndarix/project.py
"""
Project model for Syndarix AI consulting platform.
A Project represents a client engagement where AI agents collaborate
to deliver software solutions.
"""
from sqlalchemy import Column, Enum, ForeignKey, Index, String, Text
from sqlalchemy.dialects.postgresql import (
JSONB,
UUID as PGUUID,
)
from sqlalchemy.orm import relationship
from app.models.base import Base, TimestampMixin, UUIDMixin
from .enums import AutonomyLevel, ClientMode, ProjectComplexity, ProjectStatus
class Project(Base, UUIDMixin, TimestampMixin):
"""
Project model representing a client engagement.
A project contains:
- Configuration for how autonomous agents should operate
- Settings for MCP server integrations
- Relationship to assigned agents, issues, and sprints
"""
__tablename__ = "projects"
name = Column(String(255), nullable=False, index=True)
slug = Column(String(255), unique=True, nullable=False, index=True)
description = Column(Text, nullable=True)
autonomy_level: Column[AutonomyLevel] = Column(
Enum(AutonomyLevel),
default=AutonomyLevel.MILESTONE,
nullable=False,
index=True,
)
status: Column[ProjectStatus] = Column(
Enum(ProjectStatus),
default=ProjectStatus.ACTIVE,
nullable=False,
index=True,
)
complexity: Column[ProjectComplexity] = Column(
Enum(ProjectComplexity),
default=ProjectComplexity.MEDIUM,
nullable=False,
index=True,
)
client_mode: Column[ClientMode] = Column(
Enum(ClientMode),
default=ClientMode.AUTO,
nullable=False,
index=True,
)
# JSON field for flexible project configuration
# Can include: mcp_servers, webhook_urls, notification_settings, etc.
settings = Column(JSONB, default=dict, nullable=False)
# Foreign key to the User who owns this project
owner_id = Column(
PGUUID(as_uuid=True),
ForeignKey("users.id", ondelete="SET NULL"),
nullable=True,
index=True,
)
# Relationships
owner = relationship("User", foreign_keys=[owner_id])
agent_instances = relationship(
"AgentInstance",
back_populates="project",
cascade="all, delete-orphan",
)
issues = relationship(
"Issue",
back_populates="project",
cascade="all, delete-orphan",
)
sprints = relationship(
"Sprint",
back_populates="project",
cascade="all, delete-orphan",
)
__table_args__ = (
Index("ix_projects_slug_status", "slug", "status"),
Index("ix_projects_owner_status", "owner_id", "status"),
Index("ix_projects_autonomy_status", "autonomy_level", "status"),
Index("ix_projects_complexity_status", "complexity", "status"),
)
def __repr__(self) -> str:
return f"<Project {self.name} ({self.slug}) status={self.status.value}>"

View File

@@ -1,86 +0,0 @@
# app/models/syndarix/sprint.py
"""
Sprint model for Syndarix AI consulting platform.
A Sprint represents a time-boxed iteration for organizing and delivering work.
"""
from sqlalchemy import (
Column,
Date,
Enum,
ForeignKey,
Index,
Integer,
String,
Text,
UniqueConstraint,
)
from sqlalchemy.dialects.postgresql import UUID as PGUUID
from sqlalchemy.orm import relationship
from app.models.base import Base, TimestampMixin, UUIDMixin
from .enums import SprintStatus
class Sprint(Base, UUIDMixin, TimestampMixin):
"""
Sprint model representing a time-boxed iteration.
Tracks:
- Sprint metadata (name, number, goal)
- Date range (start/end)
- Progress metrics (planned vs completed points)
"""
__tablename__ = "sprints"
# Foreign key to project
project_id = Column(
PGUUID(as_uuid=True),
ForeignKey("projects.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
# Sprint identification
name = Column(String(255), nullable=False)
number = Column(Integer, nullable=False) # Sprint number within project
# Sprint goal (what we aim to achieve)
goal = Column(Text, nullable=True)
# Date range
start_date = Column(Date, nullable=False, index=True)
end_date = Column(Date, nullable=False, index=True)
# Status
status: Column[SprintStatus] = Column(
Enum(SprintStatus),
default=SprintStatus.PLANNED,
nullable=False,
index=True,
)
# Progress metrics
planned_points = Column(Integer, nullable=True) # Sum of story points at start
velocity = Column(Integer, nullable=True) # Sum of completed story points
# Relationships
project = relationship("Project", back_populates="sprints")
issues = relationship("Issue", back_populates="sprint")
__table_args__ = (
Index("ix_sprints_project_status", "project_id", "status"),
Index("ix_sprints_project_number", "project_id", "number"),
Index("ix_sprints_date_range", "start_date", "end_date"),
# Ensure sprint numbers are unique within a project
UniqueConstraint("project_id", "number", name="uq_sprint_project_number"),
)
def __repr__(self) -> str:
return (
f"<Sprint {self.name} (#{self.number}) "
f"project={self.project_id} status={self.status.value}>"
)

View File

@@ -76,7 +76,11 @@ class UserSession(Base, UUIDMixin, TimestampMixin):
"""Check if session has expired.""" """Check if session has expired."""
from datetime import datetime from datetime import datetime
return self.expires_at < datetime.now(UTC) now = datetime.now(UTC)
expires_at = self.expires_at
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=UTC)
return bool(expires_at < now)
def to_dict(self): def to_dict(self):
"""Convert session to dictionary for serialization.""" """Convert session to dictionary for serialization."""

View File

@@ -0,0 +1,39 @@
# app/repositories/__init__.py
"""Repository layer — all database access goes through these classes."""
from app.repositories.oauth_account import OAuthAccountRepository, oauth_account_repo
from app.repositories.oauth_authorization_code import (
OAuthAuthorizationCodeRepository,
oauth_authorization_code_repo,
)
from app.repositories.oauth_client import OAuthClientRepository, oauth_client_repo
from app.repositories.oauth_consent import OAuthConsentRepository, oauth_consent_repo
from app.repositories.oauth_provider_token import (
OAuthProviderTokenRepository,
oauth_provider_token_repo,
)
from app.repositories.oauth_state import OAuthStateRepository, oauth_state_repo
from app.repositories.organization import OrganizationRepository, organization_repo
from app.repositories.session import SessionRepository, session_repo
from app.repositories.user import UserRepository, user_repo
__all__ = [
"OAuthAccountRepository",
"OAuthAuthorizationCodeRepository",
"OAuthClientRepository",
"OAuthConsentRepository",
"OAuthProviderTokenRepository",
"OAuthStateRepository",
"OrganizationRepository",
"SessionRepository",
"UserRepository",
"oauth_account_repo",
"oauth_authorization_code_repo",
"oauth_client_repo",
"oauth_consent_repo",
"oauth_provider_token_repo",
"oauth_state_repo",
"organization_repo",
"session_repo",
"user_repo",
]

View File

@@ -1,6 +1,6 @@
# app/crud/base_async.py # app/repositories/base.py
""" """
Async CRUD operations base class using SQLAlchemy 2.0 async patterns. Base repository class for async database operations using SQLAlchemy 2.0 async patterns.
Provides reusable create, read, update, and delete operations for all models. Provides reusable create, read, update, and delete operations for all models.
""" """
@@ -18,6 +18,11 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Load from sqlalchemy.orm import Load
from app.core.database import Base from app.core.database import Base
from app.core.repository_exceptions import (
DuplicateEntryError,
IntegrityConstraintError,
InvalidInputError,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -26,16 +31,16 @@ CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel) UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
class CRUDBase[ class BaseRepository[
ModelType: Base, ModelType: Base,
CreateSchemaType: BaseModel, CreateSchemaType: BaseModel,
UpdateSchemaType: BaseModel, UpdateSchemaType: BaseModel,
]: ]:
"""Async CRUD operations for a model.""" """Async repository operations for a model."""
def __init__(self, model: type[ModelType]): def __init__(self, model: type[ModelType]):
""" """
CRUD object with default async methods to Create, Read, Update, Delete. Repository object with default async methods to Create, Read, Update, Delete.
Parameters: Parameters:
model: A SQLAlchemy model class model: A SQLAlchemy model class
@@ -56,26 +61,19 @@ class CRUDBase[
Returns: Returns:
Model instance or None if not found Model instance or None if not found
Example:
# Eager load user relationship
from sqlalchemy.orm import joinedload
session = await session_crud.get(db, id=session_id, options=[joinedload(UserSession.user)])
""" """
# Validate UUID format and convert to UUID object if string
try: try:
if isinstance(id, uuid.UUID): if isinstance(id, uuid.UUID):
uuid_obj = id uuid_obj = id
else: else:
uuid_obj = uuid.UUID(str(id)) uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e: except (ValueError, AttributeError, TypeError) as e:
logger.warning(f"Invalid UUID format: {id} - {e!s}") logger.warning("Invalid UUID format: %s - %s", id, e)
return None return None
try: try:
query = select(self.model).where(self.model.id == uuid_obj) query = select(self.model).where(self.model.id == uuid_obj)
# Apply eager loading options if provided
if options: if options:
for option in options: for option in options:
query = query.options(option) query = query.options(option)
@@ -83,7 +81,9 @@ class CRUDBase[
result = await db.execute(query) result = await db.execute(query)
return result.scalar_one_or_none() return result.scalar_one_or_none()
except Exception as e: except Exception as e:
logger.error(f"Error retrieving {self.model.__name__} with id {id}: {e!s}") logger.error(
"Error retrieving %s with id %s: %s", self.model.__name__, id, e
)
raise raise
async def get_multi( async def get_multi(
@@ -96,28 +96,17 @@ class CRUDBase[
) -> list[ModelType]: ) -> list[ModelType]:
""" """
Get multiple records with pagination validation and optional eager loading. Get multiple records with pagination validation and optional eager loading.
Args:
db: Database session
skip: Number of records to skip
limit: Maximum number of records to return
options: Optional list of SQLAlchemy load options for eager loading
Returns:
List of model instances
""" """
# Validate pagination parameters
if skip < 0: if skip < 0:
raise ValueError("skip must be non-negative") raise InvalidInputError("skip must be non-negative")
if limit < 0: if limit < 0:
raise ValueError("limit must be non-negative") raise InvalidInputError("limit must be non-negative")
if limit > 1000: if limit > 1000:
raise ValueError("Maximum limit is 1000") raise InvalidInputError("Maximum limit is 1000")
try: try:
query = select(self.model).offset(skip).limit(limit) query = select(self.model).order_by(self.model.id).offset(skip).limit(limit)
# Apply eager loading options if provided
if options: if options:
for option in options: for option in options:
query = query.options(option) query = query.options(option)
@@ -126,7 +115,7 @@ class CRUDBase[
return list(result.scalars().all()) return list(result.scalars().all())
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Error retrieving multiple {self.model.__name__} records: {e!s}" "Error retrieving multiple %s records: %s", self.model.__name__, e
) )
raise raise
@@ -136,9 +125,8 @@ class CRUDBase[
"""Create a new record with error handling. """Create a new record with error handling.
NOTE: This method is defensive code that's never called in practice. NOTE: This method is defensive code that's never called in practice.
All CRUD subclasses (CRUDUser, CRUDOrganization, CRUDSession) override this method All repository subclasses override this method with their own implementations.
with their own implementations, so the base implementation and its exception handlers Marked as pragma: no cover to avoid false coverage gaps.
are never executed. Marked as pragma: no cover to avoid false coverage gaps.
""" """
try: # pragma: no cover try: # pragma: no cover
obj_in_data = jsonable_encoder(obj_in) obj_in_data = jsonable_encoder(obj_in)
@@ -152,22 +140,24 @@ class CRUDBase[
error_msg = str(e.orig) if hasattr(e, "orig") else str(e) error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower(): if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
logger.warning( logger.warning(
f"Duplicate entry attempted for {self.model.__name__}: {error_msg}" "Duplicate entry attempted for %s: %s",
self.model.__name__,
error_msg,
) )
raise ValueError( raise DuplicateEntryError(
f"A {self.model.__name__} with this data already exists" f"A {self.model.__name__} with this data already exists"
) )
logger.error(f"Integrity error creating {self.model.__name__}: {error_msg}") logger.error(
raise ValueError(f"Database integrity error: {error_msg}") "Integrity error creating %s: %s", self.model.__name__, error_msg
)
raise IntegrityConstraintError(f"Database integrity error: {error_msg}")
except (OperationalError, DataError) as e: # pragma: no cover except (OperationalError, DataError) as e: # pragma: no cover
await db.rollback() await db.rollback()
logger.error(f"Database error creating {self.model.__name__}: {e!s}") logger.error("Database error creating %s: %s", self.model.__name__, e)
raise ValueError(f"Database operation failed: {e!s}") raise IntegrityConstraintError(f"Database operation failed: {e!s}")
except Exception as e: # pragma: no cover except Exception as e: # pragma: no cover
await db.rollback() await db.rollback()
logger.error( logger.exception("Unexpected error creating %s: %s", self.model.__name__, e)
f"Unexpected error creating {self.model.__name__}: {e!s}", exc_info=True
)
raise raise
async def update( async def update(
@@ -198,34 +188,35 @@ class CRUDBase[
error_msg = str(e.orig) if hasattr(e, "orig") else str(e) error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower(): if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
logger.warning( logger.warning(
f"Duplicate entry attempted for {self.model.__name__}: {error_msg}" "Duplicate entry attempted for %s: %s",
self.model.__name__,
error_msg,
) )
raise ValueError( raise DuplicateEntryError(
f"A {self.model.__name__} with this data already exists" f"A {self.model.__name__} with this data already exists"
) )
logger.error(f"Integrity error updating {self.model.__name__}: {error_msg}") logger.error(
raise ValueError(f"Database integrity error: {error_msg}") "Integrity error updating %s: %s", self.model.__name__, error_msg
)
raise IntegrityConstraintError(f"Database integrity error: {error_msg}")
except (OperationalError, DataError) as e: except (OperationalError, DataError) as e:
await db.rollback() await db.rollback()
logger.error(f"Database error updating {self.model.__name__}: {e!s}") logger.error("Database error updating %s: %s", self.model.__name__, e)
raise ValueError(f"Database operation failed: {e!s}") raise IntegrityConstraintError(f"Database operation failed: {e!s}")
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error( logger.exception("Unexpected error updating %s: %s", self.model.__name__, e)
f"Unexpected error updating {self.model.__name__}: {e!s}", exc_info=True
)
raise raise
async def remove(self, db: AsyncSession, *, id: str) -> ModelType | None: async def remove(self, db: AsyncSession, *, id: str) -> ModelType | None:
"""Delete a record with error handling and null check.""" """Delete a record with error handling and null check."""
# Validate UUID format and convert to UUID object if string
try: try:
if isinstance(id, uuid.UUID): if isinstance(id, uuid.UUID):
uuid_obj = id uuid_obj = id
else: else:
uuid_obj = uuid.UUID(str(id)) uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e: except (ValueError, AttributeError, TypeError) as e:
logger.warning(f"Invalid UUID format for deletion: {id} - {e!s}") logger.warning("Invalid UUID format for deletion: %s - %s", id, e)
return None return None
try: try:
@@ -236,7 +227,7 @@ class CRUDBase[
if obj is None: if obj is None:
logger.warning( logger.warning(
f"{self.model.__name__} with id {id} not found for deletion" "%s with id %s not found for deletion", self.model.__name__, id
) )
return None return None
@@ -246,15 +237,16 @@ class CRUDBase[
except IntegrityError as e: except IntegrityError as e:
await db.rollback() await db.rollback()
error_msg = str(e.orig) if hasattr(e, "orig") else str(e) error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
logger.error(f"Integrity error deleting {self.model.__name__}: {error_msg}") logger.error(
raise ValueError( "Integrity error deleting %s: %s", self.model.__name__, error_msg
)
raise IntegrityConstraintError(
f"Cannot delete {self.model.__name__}: referenced by other records" f"Cannot delete {self.model.__name__}: referenced by other records"
) )
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error( logger.exception(
f"Error deleting {self.model.__name__} with id {id}: {e!s}", "Error deleting %s with id %s: %s", self.model.__name__, id, e
exc_info=True,
) )
raise raise
@@ -272,57 +264,40 @@ class CRUDBase[
Get multiple records with total count, filtering, and sorting. Get multiple records with total count, filtering, and sorting.
NOTE: This method is defensive code that's never called in practice. NOTE: This method is defensive code that's never called in practice.
All CRUD subclasses (CRUDUser, CRUDOrganization, CRUDSession) override this method All repository subclasses override this method with their own implementations.
with their own implementations that include additional parameters like search.
Marked as pragma: no cover to avoid false coverage gaps. Marked as pragma: no cover to avoid false coverage gaps.
Args:
db: Database session
skip: Number of records to skip
limit: Maximum number of records to return
sort_by: Field name to sort by (must be a valid model attribute)
sort_order: Sort order ("asc" or "desc")
filters: Dictionary of filters (field_name: value)
Returns:
Tuple of (items, total_count)
""" """
# Validate pagination parameters
if skip < 0: if skip < 0:
raise ValueError("skip must be non-negative") raise InvalidInputError("skip must be non-negative")
if limit < 0: if limit < 0:
raise ValueError("limit must be non-negative") raise InvalidInputError("limit must be non-negative")
if limit > 1000: if limit > 1000:
raise ValueError("Maximum limit is 1000") raise InvalidInputError("Maximum limit is 1000")
try: try:
# Build base query
query = select(self.model) query = select(self.model)
# Exclude soft-deleted records by default
if hasattr(self.model, "deleted_at"): if hasattr(self.model, "deleted_at"):
query = query.where(self.model.deleted_at.is_(None)) query = query.where(self.model.deleted_at.is_(None))
# Apply filters
if filters: if filters:
for field, value in filters.items(): for field, value in filters.items():
if hasattr(self.model, field) and value is not None: if hasattr(self.model, field) and value is not None:
query = query.where(getattr(self.model, field) == value) query = query.where(getattr(self.model, field) == value)
# Get total count (before pagination)
count_query = select(func.count()).select_from(query.alias()) count_query = select(func.count()).select_from(query.alias())
count_result = await db.execute(count_query) count_result = await db.execute(count_query)
total = count_result.scalar_one() total = count_result.scalar_one()
# Apply sorting
if sort_by and hasattr(self.model, sort_by): if sort_by and hasattr(self.model, sort_by):
sort_column = getattr(self.model, sort_by) sort_column = getattr(self.model, sort_by)
if sort_order.lower() == "desc": if sort_order.lower() == "desc":
query = query.order_by(sort_column.desc()) query = query.order_by(sort_column.desc())
else: else:
query = query.order_by(sort_column.asc()) query = query.order_by(sort_column.asc())
else:
query = query.order_by(self.model.id)
# Apply pagination
query = query.offset(skip).limit(limit) query = query.offset(skip).limit(limit)
items_result = await db.execute(query) items_result = await db.execute(query)
items = list(items_result.scalars().all()) items = list(items_result.scalars().all())
@@ -330,7 +305,7 @@ class CRUDBase[
return items, total return items, total
except Exception as e: # pragma: no cover except Exception as e: # pragma: no cover
logger.error( logger.error(
f"Error retrieving paginated {self.model.__name__} records: {e!s}" "Error retrieving paginated %s records: %s", self.model.__name__, e
) )
raise raise
@@ -340,7 +315,7 @@ class CRUDBase[
result = await db.execute(select(func.count(self.model.id))) result = await db.execute(select(func.count(self.model.id)))
return result.scalar_one() return result.scalar_one()
except Exception as e: except Exception as e:
logger.error(f"Error counting {self.model.__name__} records: {e!s}") logger.error("Error counting %s records: %s", self.model.__name__, e)
raise raise
async def exists(self, db: AsyncSession, id: str) -> bool: async def exists(self, db: AsyncSession, id: str) -> bool:
@@ -356,14 +331,13 @@ class CRUDBase[
""" """
from datetime import datetime from datetime import datetime
# Validate UUID format and convert to UUID object if string
try: try:
if isinstance(id, uuid.UUID): if isinstance(id, uuid.UUID):
uuid_obj = id uuid_obj = id
else: else:
uuid_obj = uuid.UUID(str(id)) uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e: except (ValueError, AttributeError, TypeError) as e:
logger.warning(f"Invalid UUID format for soft deletion: {id} - {e!s}") logger.warning("Invalid UUID format for soft deletion: %s - %s", id, e)
return None return None
try: try:
@@ -374,18 +348,16 @@ class CRUDBase[
if obj is None: if obj is None:
logger.warning( logger.warning(
f"{self.model.__name__} with id {id} not found for soft deletion" "%s with id %s not found for soft deletion", self.model.__name__, id
) )
return None return None
# Check if model supports soft deletes
if not hasattr(self.model, "deleted_at"): if not hasattr(self.model, "deleted_at"):
logger.error(f"{self.model.__name__} does not support soft deletes") logger.error("%s does not support soft deletes", self.model.__name__)
raise ValueError( raise InvalidInputError(
f"{self.model.__name__} does not have a deleted_at column" f"{self.model.__name__} does not have a deleted_at column"
) )
# Set deleted_at timestamp
obj.deleted_at = datetime.now(UTC) obj.deleted_at = datetime.now(UTC)
db.add(obj) db.add(obj)
await db.commit() await db.commit()
@@ -393,9 +365,8 @@ class CRUDBase[
return obj return obj
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error( logger.exception(
f"Error soft deleting {self.model.__name__} with id {id}: {e!s}", "Error soft deleting %s with id %s: %s", self.model.__name__, id, e
exc_info=True,
) )
raise raise
@@ -405,18 +376,16 @@ class CRUDBase[
Only works if the model has a 'deleted_at' column. Only works if the model has a 'deleted_at' column.
""" """
# Validate UUID format
try: try:
if isinstance(id, uuid.UUID): if isinstance(id, uuid.UUID):
uuid_obj = id uuid_obj = id
else: else:
uuid_obj = uuid.UUID(str(id)) uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e: except (ValueError, AttributeError, TypeError) as e:
logger.warning(f"Invalid UUID format for restoration: {id} - {e!s}") logger.warning("Invalid UUID format for restoration: %s - %s", id, e)
return None return None
try: try:
# Find the soft-deleted record
if hasattr(self.model, "deleted_at"): if hasattr(self.model, "deleted_at"):
result = await db.execute( result = await db.execute(
select(self.model).where( select(self.model).where(
@@ -425,18 +394,19 @@ class CRUDBase[
) )
obj = result.scalar_one_or_none() obj = result.scalar_one_or_none()
else: else:
logger.error(f"{self.model.__name__} does not support soft deletes") logger.error("%s does not support soft deletes", self.model.__name__)
raise ValueError( raise InvalidInputError(
f"{self.model.__name__} does not have a deleted_at column" f"{self.model.__name__} does not have a deleted_at column"
) )
if obj is None: if obj is None:
logger.warning( logger.warning(
f"Soft-deleted {self.model.__name__} with id {id} not found for restoration" "Soft-deleted %s with id %s not found for restoration",
self.model.__name__,
id,
) )
return None return None
# Clear deleted_at timestamp
obj.deleted_at = None obj.deleted_at = None
db.add(obj) db.add(obj)
await db.commit() await db.commit()
@@ -444,8 +414,7 @@ class CRUDBase[
return obj return obj
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error( logger.exception(
f"Error restoring {self.model.__name__} with id {id}: {e!s}", "Error restoring %s with id %s: %s", self.model.__name__, id, e
exc_info=True,
) )
raise raise

View File

@@ -0,0 +1,249 @@
# app/repositories/oauth_account.py
"""Repository for OAuthAccount model async database operations."""
import logging
from datetime import datetime
from uuid import UUID
from pydantic import BaseModel
from sqlalchemy import and_, delete, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from app.core.repository_exceptions import DuplicateEntryError
from app.models.oauth_account import OAuthAccount
from app.repositories.base import BaseRepository
from app.schemas.oauth import OAuthAccountCreate
logger = logging.getLogger(__name__)
class EmptySchema(BaseModel):
"""Placeholder schema for repository operations that don't need update schemas."""
class OAuthAccountRepository(
BaseRepository[OAuthAccount, OAuthAccountCreate, EmptySchema]
):
"""Repository for OAuth account links."""
async def get_by_provider_id(
self,
db: AsyncSession,
*,
provider: str,
provider_user_id: str,
) -> OAuthAccount | None:
"""Get OAuth account by provider and provider user ID."""
try:
result = await db.execute(
select(OAuthAccount)
.where(
and_(
OAuthAccount.provider == provider,
OAuthAccount.provider_user_id == provider_user_id,
)
)
.options(joinedload(OAuthAccount.user))
)
return result.scalar_one_or_none()
except Exception as e: # pragma: no cover
logger.error(
"Error getting OAuth account for %s:%s: %s",
provider,
provider_user_id,
e,
)
raise
async def get_by_provider_email(
self,
db: AsyncSession,
*,
provider: str,
email: str,
) -> OAuthAccount | None:
"""Get OAuth account by provider and email."""
try:
result = await db.execute(
select(OAuthAccount)
.where(
and_(
OAuthAccount.provider == provider,
OAuthAccount.provider_email == email,
)
)
.options(joinedload(OAuthAccount.user))
)
return result.scalar_one_or_none()
except Exception as e: # pragma: no cover
logger.error(
"Error getting OAuth account for %s email %s: %s", provider, email, e
)
raise
async def get_user_accounts(
self,
db: AsyncSession,
*,
user_id: str | UUID,
) -> list[OAuthAccount]:
"""Get all OAuth accounts linked to a user."""
try:
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
result = await db.execute(
select(OAuthAccount)
.where(OAuthAccount.user_id == user_uuid)
.order_by(OAuthAccount.created_at.desc())
)
return list(result.scalars().all())
except Exception as e: # pragma: no cover
logger.error("Error getting OAuth accounts for user %s: %s", user_id, e)
raise
async def get_user_account_by_provider(
self,
db: AsyncSession,
*,
user_id: str | UUID,
provider: str,
) -> OAuthAccount | None:
"""Get a specific OAuth account for a user and provider."""
try:
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
result = await db.execute(
select(OAuthAccount).where(
and_(
OAuthAccount.user_id == user_uuid,
OAuthAccount.provider == provider,
)
)
)
return result.scalar_one_or_none()
except Exception as e: # pragma: no cover
logger.error(
"Error getting OAuth account for user %s, provider %s: %s",
user_id,
provider,
e,
)
raise
async def create_account(
self, db: AsyncSession, *, obj_in: OAuthAccountCreate
) -> OAuthAccount:
"""Create a new OAuth account link."""
try:
db_obj = OAuthAccount(
user_id=obj_in.user_id,
provider=obj_in.provider,
provider_user_id=obj_in.provider_user_id,
provider_email=obj_in.provider_email,
access_token=obj_in.access_token,
refresh_token=obj_in.refresh_token,
token_expires_at=obj_in.token_expires_at,
)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
logger.info(
"OAuth account created: %s linked to user %s",
obj_in.provider,
obj_in.user_id,
)
return db_obj
except IntegrityError as e: # pragma: no cover
await db.rollback()
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
if "uq_oauth_provider_user" in error_msg.lower():
logger.warning(
"OAuth account already exists: %s:%s",
obj_in.provider,
obj_in.provider_user_id,
)
raise DuplicateEntryError(
f"This {obj_in.provider} account is already linked to another user"
)
logger.error("Integrity error creating OAuth account: %s", error_msg)
raise DuplicateEntryError(f"Failed to create OAuth account: {error_msg}")
except Exception as e: # pragma: no cover
await db.rollback()
logger.exception("Error creating OAuth account: %s", e)
raise
async def delete_account(
self,
db: AsyncSession,
*,
user_id: str | UUID,
provider: str,
) -> bool:
"""Delete an OAuth account link."""
try:
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
result = await db.execute(
delete(OAuthAccount).where(
and_(
OAuthAccount.user_id == user_uuid,
OAuthAccount.provider == provider,
)
)
)
await db.commit()
deleted = result.rowcount > 0
if deleted:
logger.info(
"OAuth account deleted: %s unlinked from user %s", provider, user_id
)
else:
logger.warning(
"OAuth account not found for deletion: %s for user %s",
provider,
user_id,
)
return deleted
except Exception as e: # pragma: no cover
await db.rollback()
logger.error(
"Error deleting OAuth account %s for user %s: %s", provider, user_id, e
)
raise
async def update_tokens(
self,
db: AsyncSession,
*,
account: OAuthAccount,
access_token: str | None = None,
refresh_token: str | None = None,
token_expires_at: datetime | None = None,
) -> OAuthAccount:
"""Update OAuth tokens for an account."""
try:
if access_token is not None:
account.access_token = access_token
if refresh_token is not None:
account.refresh_token = refresh_token
if token_expires_at is not None:
account.token_expires_at = token_expires_at
db.add(account)
await db.commit()
await db.refresh(account)
return account
except Exception as e: # pragma: no cover
await db.rollback()
logger.error("Error updating OAuth tokens: %s", e)
raise
# Singleton instance
oauth_account_repo = OAuthAccountRepository(OAuthAccount)

View File

@@ -0,0 +1,108 @@
# app/repositories/oauth_authorization_code.py
"""Repository for OAuthAuthorizationCode model."""
import logging
from datetime import UTC, datetime
from uuid import UUID
from sqlalchemy import and_, delete, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.oauth_authorization_code import OAuthAuthorizationCode
logger = logging.getLogger(__name__)
class OAuthAuthorizationCodeRepository:
"""Repository for OAuth 2.0 authorization codes."""
async def create_code(
self,
db: AsyncSession,
*,
code: str,
client_id: str,
user_id: UUID,
redirect_uri: str,
scope: str,
expires_at: datetime,
code_challenge: str | None = None,
code_challenge_method: str | None = None,
state: str | None = None,
nonce: str | None = None,
) -> OAuthAuthorizationCode:
"""Create and persist a new authorization code."""
auth_code = OAuthAuthorizationCode(
code=code,
client_id=client_id,
user_id=user_id,
redirect_uri=redirect_uri,
scope=scope,
code_challenge=code_challenge,
code_challenge_method=code_challenge_method,
state=state,
nonce=nonce,
expires_at=expires_at,
used=False,
)
db.add(auth_code)
await db.commit()
return auth_code
async def consume_code_atomically(
self, db: AsyncSession, *, code: str
) -> UUID | None:
"""
Atomically mark a code as used and return its UUID.
Returns the UUID if the code was found and not yet used, None otherwise.
This prevents race conditions per RFC 6749 Section 4.1.2.
"""
stmt = (
update(OAuthAuthorizationCode)
.where(
and_(
OAuthAuthorizationCode.code == code,
OAuthAuthorizationCode.used == False, # noqa: E712
)
)
.values(used=True)
.returning(OAuthAuthorizationCode.id)
)
result = await db.execute(stmt)
row_id = result.scalar_one_or_none()
if row_id is not None:
await db.commit()
return row_id
async def get_by_id(
self, db: AsyncSession, *, code_id: UUID
) -> OAuthAuthorizationCode | None:
"""Get authorization code by its UUID primary key."""
result = await db.execute(
select(OAuthAuthorizationCode).where(OAuthAuthorizationCode.id == code_id)
)
return result.scalar_one_or_none()
async def get_by_code(
self, db: AsyncSession, *, code: str
) -> OAuthAuthorizationCode | None:
"""Get authorization code by the code string value."""
result = await db.execute(
select(OAuthAuthorizationCode).where(OAuthAuthorizationCode.code == code)
)
return result.scalar_one_or_none()
async def cleanup_expired(self, db: AsyncSession) -> int:
"""Delete all expired authorization codes. Returns count deleted."""
result = await db.execute(
delete(OAuthAuthorizationCode).where(
OAuthAuthorizationCode.expires_at < datetime.now(UTC)
)
)
await db.commit()
return result.rowcount # type: ignore[attr-defined]
# Singleton instance
oauth_authorization_code_repo = OAuthAuthorizationCodeRepository()

View File

@@ -0,0 +1,201 @@
# app/repositories/oauth_client.py
"""Repository for OAuthClient model async database operations."""
import logging
import secrets
from uuid import UUID
from pydantic import BaseModel
from sqlalchemy import and_, delete, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.repository_exceptions import DuplicateEntryError
from app.models.oauth_client import OAuthClient
from app.repositories.base import BaseRepository
from app.schemas.oauth import OAuthClientCreate
logger = logging.getLogger(__name__)
class EmptySchema(BaseModel):
"""Placeholder schema for repository operations that don't need update schemas."""
class OAuthClientRepository(
BaseRepository[OAuthClient, OAuthClientCreate, EmptySchema]
):
"""Repository for OAuth clients (provider mode)."""
async def get_by_client_id(
self, db: AsyncSession, *, client_id: str
) -> OAuthClient | None:
"""Get OAuth client by client_id."""
try:
result = await db.execute(
select(OAuthClient).where(
and_(
OAuthClient.client_id == client_id,
OAuthClient.is_active == True, # noqa: E712
)
)
)
return result.scalar_one_or_none()
except Exception as e: # pragma: no cover
logger.error("Error getting OAuth client %s: %s", client_id, e)
raise
async def create_client(
self,
db: AsyncSession,
*,
obj_in: OAuthClientCreate,
owner_user_id: UUID | None = None,
) -> tuple[OAuthClient, str | None]:
"""Create a new OAuth client."""
try:
client_id = secrets.token_urlsafe(32)
client_secret = None
client_secret_hash = None
if obj_in.client_type == "confidential":
client_secret = secrets.token_urlsafe(48)
from app.core.auth import get_password_hash
client_secret_hash = get_password_hash(client_secret)
db_obj = OAuthClient(
client_id=client_id,
client_secret_hash=client_secret_hash,
client_name=obj_in.client_name,
client_description=obj_in.client_description,
client_type=obj_in.client_type,
redirect_uris=obj_in.redirect_uris,
allowed_scopes=obj_in.allowed_scopes,
owner_user_id=owner_user_id,
is_active=True,
)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
logger.info(
"OAuth client created: %s (%s...)", obj_in.client_name, client_id[:8]
)
return db_obj, client_secret
except IntegrityError as e: # pragma: no cover
await db.rollback()
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
logger.error("Error creating OAuth client: %s", error_msg)
raise DuplicateEntryError(f"Failed to create OAuth client: {error_msg}")
except Exception as e: # pragma: no cover
await db.rollback()
logger.exception("Error creating OAuth client: %s", e)
raise
async def deactivate_client(
self, db: AsyncSession, *, client_id: str
) -> OAuthClient | None:
"""Deactivate an OAuth client."""
try:
client = await self.get_by_client_id(db, client_id=client_id)
if client is None:
return None
client.is_active = False
db.add(client)
await db.commit()
await db.refresh(client)
logger.info("OAuth client deactivated: %s", client.client_name)
return client
except Exception as e: # pragma: no cover
await db.rollback()
logger.error("Error deactivating OAuth client %s: %s", client_id, e)
raise
async def validate_redirect_uri(
self, db: AsyncSession, *, client_id: str, redirect_uri: str
) -> bool:
"""Validate that a redirect URI is allowed for a client."""
try:
client = await self.get_by_client_id(db, client_id=client_id)
if client is None:
return False
return redirect_uri in (client.redirect_uris or [])
except Exception as e: # pragma: no cover
logger.error("Error validating redirect URI: %s", e)
return False
async def verify_client_secret(
self, db: AsyncSession, *, client_id: str, client_secret: str
) -> bool:
"""Verify client credentials."""
try:
result = await db.execute(
select(OAuthClient).where(
and_(
OAuthClient.client_id == client_id,
OAuthClient.is_active == True, # noqa: E712
)
)
)
client = result.scalar_one_or_none()
if client is None or client.client_secret_hash is None:
return False
from app.core.auth import verify_password
stored_hash: str = str(client.client_secret_hash)
if stored_hash.startswith("$2"):
return verify_password(client_secret, stored_hash)
else:
import hashlib
secret_hash = hashlib.sha256(client_secret.encode()).hexdigest()
return secrets.compare_digest(stored_hash, secret_hash)
except Exception as e: # pragma: no cover
logger.error("Error verifying client secret: %s", e)
return False
async def get_all_clients(
self, db: AsyncSession, *, include_inactive: bool = False
) -> list[OAuthClient]:
"""Get all OAuth clients."""
try:
query = select(OAuthClient).order_by(OAuthClient.created_at.desc())
if not include_inactive:
query = query.where(OAuthClient.is_active == True) # noqa: E712
result = await db.execute(query)
return list(result.scalars().all())
except Exception as e: # pragma: no cover
logger.error("Error getting all OAuth clients: %s", e)
raise
async def delete_client(self, db: AsyncSession, *, client_id: str) -> bool:
"""Delete an OAuth client permanently."""
try:
result = await db.execute(
delete(OAuthClient).where(OAuthClient.client_id == client_id)
)
await db.commit()
deleted = result.rowcount > 0
if deleted:
logger.info("OAuth client deleted: %s", client_id)
else:
logger.warning("OAuth client not found for deletion: %s", client_id)
return deleted
except Exception as e: # pragma: no cover
await db.rollback()
logger.error("Error deleting OAuth client %s: %s", client_id, e)
raise
# Singleton instance
oauth_client_repo = OAuthClientRepository(OAuthClient)

View File

@@ -0,0 +1,113 @@
# app/repositories/oauth_consent.py
"""Repository for OAuthConsent model."""
import logging
from typing import Any
from uuid import UUID
from sqlalchemy import and_, delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.oauth_client import OAuthClient
from app.models.oauth_provider_token import OAuthConsent
logger = logging.getLogger(__name__)
class OAuthConsentRepository:
"""Repository for OAuth consent records (user grants to clients)."""
async def get_consent(
self, db: AsyncSession, *, user_id: UUID, client_id: str
) -> OAuthConsent | None:
"""Get the consent record for a user-client pair, or None if not found."""
result = await db.execute(
select(OAuthConsent).where(
and_(
OAuthConsent.user_id == user_id,
OAuthConsent.client_id == client_id,
)
)
)
return result.scalar_one_or_none()
async def grant_consent(
self,
db: AsyncSession,
*,
user_id: UUID,
client_id: str,
scopes: list[str],
) -> OAuthConsent:
"""
Create or update consent for a user-client pair.
If consent already exists, the new scopes are merged with existing ones.
Returns the created or updated consent record.
"""
consent = await self.get_consent(db, user_id=user_id, client_id=client_id)
if consent:
existing = (
set(consent.granted_scopes.split()) if consent.granted_scopes else set()
)
merged = existing | set(scopes)
consent.granted_scopes = " ".join(sorted(merged)) # type: ignore[assignment]
else:
consent = OAuthConsent(
user_id=user_id,
client_id=client_id,
granted_scopes=" ".join(sorted(set(scopes))),
)
db.add(consent)
await db.commit()
await db.refresh(consent)
return consent
async def get_user_consents_with_clients(
self, db: AsyncSession, *, user_id: UUID
) -> list[dict[str, Any]]:
"""Get all consent records for a user joined with client details."""
result = await db.execute(
select(OAuthConsent, OAuthClient)
.join(OAuthClient, OAuthConsent.client_id == OAuthClient.client_id)
.where(OAuthConsent.user_id == user_id)
)
rows = result.all()
return [
{
"client_id": consent.client_id,
"client_name": client.client_name,
"client_description": client.client_description,
"granted_scopes": consent.granted_scopes.split()
if consent.granted_scopes
else [],
"granted_at": consent.created_at.isoformat(),
}
for consent, client in rows
]
async def revoke_consent(
self, db: AsyncSession, *, user_id: UUID, client_id: str
) -> bool:
"""
Delete the consent record for a user-client pair.
Returns True if a record was found and deleted.
Note: Callers are responsible for also revoking associated tokens.
"""
result = await db.execute(
delete(OAuthConsent).where(
and_(
OAuthConsent.user_id == user_id,
OAuthConsent.client_id == client_id,
)
)
)
await db.commit()
return result.rowcount > 0 # type: ignore[attr-defined]
# Singleton instance
oauth_consent_repo = OAuthConsentRepository()

View File

@@ -0,0 +1,142 @@
# app/repositories/oauth_provider_token.py
"""Repository for OAuthProviderRefreshToken model."""
import logging
from datetime import UTC, datetime, timedelta
from uuid import UUID
from sqlalchemy import and_, delete, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.oauth_provider_token import OAuthProviderRefreshToken
logger = logging.getLogger(__name__)
class OAuthProviderTokenRepository:
"""Repository for OAuth provider refresh tokens."""
async def create_token(
self,
db: AsyncSession,
*,
token_hash: str,
jti: str,
client_id: str,
user_id: UUID,
scope: str,
expires_at: datetime,
device_info: str | None = None,
ip_address: str | None = None,
) -> OAuthProviderRefreshToken:
"""Create and persist a new refresh token record."""
token = OAuthProviderRefreshToken(
token_hash=token_hash,
jti=jti,
client_id=client_id,
user_id=user_id,
scope=scope,
expires_at=expires_at,
device_info=device_info,
ip_address=ip_address,
)
db.add(token)
await db.commit()
return token
async def get_by_token_hash(
self, db: AsyncSession, *, token_hash: str
) -> OAuthProviderRefreshToken | None:
"""Get refresh token record by SHA-256 token hash."""
result = await db.execute(
select(OAuthProviderRefreshToken).where(
OAuthProviderRefreshToken.token_hash == token_hash
)
)
return result.scalar_one_or_none()
async def get_by_jti(
self, db: AsyncSession, *, jti: str
) -> OAuthProviderRefreshToken | None:
"""Get refresh token record by JWT ID (JTI)."""
result = await db.execute(
select(OAuthProviderRefreshToken).where(
OAuthProviderRefreshToken.jti == jti
)
)
return result.scalar_one_or_none()
async def revoke(
self, db: AsyncSession, *, token: OAuthProviderRefreshToken
) -> None:
"""Mark a specific token record as revoked."""
token.revoked = True # type: ignore[assignment]
token.last_used_at = datetime.now(UTC) # type: ignore[assignment]
await db.commit()
async def revoke_all_for_user_client(
self, db: AsyncSession, *, user_id: UUID, client_id: str
) -> int:
"""
Revoke all active tokens for a specific user-client pair.
Used when security incidents are detected (e.g., authorization code reuse).
Returns the number of tokens revoked.
"""
result = await db.execute(
update(OAuthProviderRefreshToken)
.where(
and_(
OAuthProviderRefreshToken.user_id == user_id,
OAuthProviderRefreshToken.client_id == client_id,
OAuthProviderRefreshToken.revoked == False, # noqa: E712
)
)
.values(revoked=True)
)
count = result.rowcount # type: ignore[attr-defined]
if count > 0:
await db.commit()
return count
async def revoke_all_for_user(self, db: AsyncSession, *, user_id: UUID) -> int:
"""
Revoke all active tokens for a user across all clients.
Used when user changes password or logs out everywhere.
Returns the number of tokens revoked.
"""
result = await db.execute(
update(OAuthProviderRefreshToken)
.where(
and_(
OAuthProviderRefreshToken.user_id == user_id,
OAuthProviderRefreshToken.revoked == False, # noqa: E712
)
)
.values(revoked=True)
)
count = result.rowcount # type: ignore[attr-defined]
if count > 0:
await db.commit()
return count
async def cleanup_expired(self, db: AsyncSession, *, cutoff_days: int = 7) -> int:
"""
Delete expired refresh tokens older than cutoff_days.
Should be called periodically (e.g., daily).
Returns the number of tokens deleted.
"""
cutoff = datetime.now(UTC) - timedelta(days=cutoff_days)
result = await db.execute(
delete(OAuthProviderRefreshToken).where(
OAuthProviderRefreshToken.expires_at < cutoff
)
)
await db.commit()
return result.rowcount # type: ignore[attr-defined]
# Singleton instance
oauth_provider_token_repo = OAuthProviderTokenRepository()

View File

@@ -0,0 +1,113 @@
# app/repositories/oauth_state.py
"""Repository for OAuthState model async database operations."""
import logging
from datetime import UTC, datetime
from pydantic import BaseModel
from sqlalchemy import delete, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.repository_exceptions import DuplicateEntryError
from app.models.oauth_state import OAuthState
from app.repositories.base import BaseRepository
from app.schemas.oauth import OAuthStateCreate
logger = logging.getLogger(__name__)
class EmptySchema(BaseModel):
"""Placeholder schema for repository operations that don't need update schemas."""
class OAuthStateRepository(BaseRepository[OAuthState, OAuthStateCreate, EmptySchema]):
"""Repository for OAuth state (CSRF protection)."""
async def create_state(
self, db: AsyncSession, *, obj_in: OAuthStateCreate
) -> OAuthState:
"""Create a new OAuth state for CSRF protection."""
try:
db_obj = OAuthState(
state=obj_in.state,
code_verifier=obj_in.code_verifier,
nonce=obj_in.nonce,
provider=obj_in.provider,
redirect_uri=obj_in.redirect_uri,
user_id=obj_in.user_id,
expires_at=obj_in.expires_at,
)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
logger.debug("OAuth state created for %s", obj_in.provider)
return db_obj
except IntegrityError as e: # pragma: no cover
await db.rollback()
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
logger.error("OAuth state collision: %s", error_msg)
raise DuplicateEntryError("Failed to create OAuth state, please retry")
except Exception as e: # pragma: no cover
await db.rollback()
logger.exception("Error creating OAuth state: %s", e)
raise
async def get_and_consume_state(
self, db: AsyncSession, *, state: str
) -> OAuthState | None:
"""Get and delete OAuth state (consume it)."""
try:
result = await db.execute(
select(OAuthState).where(OAuthState.state == state)
)
db_obj = result.scalar_one_or_none()
if db_obj is None:
logger.warning("OAuth state not found: %s...", state[:8])
return None
now = datetime.now(UTC)
expires_at = db_obj.expires_at
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=UTC)
if expires_at < now:
logger.warning("OAuth state expired: %s...", state[:8])
await db.delete(db_obj)
await db.commit()
return None
await db.delete(db_obj)
await db.commit()
logger.debug("OAuth state consumed: %s...", state[:8])
return db_obj
except Exception as e: # pragma: no cover
await db.rollback()
logger.error("Error consuming OAuth state: %s", e)
raise
async def cleanup_expired(self, db: AsyncSession) -> int:
"""Clean up expired OAuth states."""
try:
now = datetime.now(UTC)
stmt = delete(OAuthState).where(OAuthState.expires_at < now)
result = await db.execute(stmt)
await db.commit()
count = result.rowcount
if count > 0:
logger.info("Cleaned up %s expired OAuth states", count)
return count
except Exception as e: # pragma: no cover
await db.rollback()
logger.error("Error cleaning up expired OAuth states: %s", e)
raise
# Singleton instance
oauth_state_repo = OAuthStateRepository(OAuthState)

View File

@@ -1,5 +1,5 @@
# app/crud/organization_async.py # app/repositories/organization.py
"""Async CRUD operations for Organization model using SQLAlchemy 2.0 patterns.""" """Repository for Organization model async database operations using SQLAlchemy 2.0 patterns."""
import logging import logging
from typing import Any from typing import Any
@@ -9,10 +9,11 @@ from sqlalchemy import and_, case, func, or_, select
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.crud.base import CRUDBase from app.core.repository_exceptions import DuplicateEntryError, IntegrityConstraintError
from app.models.organization import Organization from app.models.organization import Organization
from app.models.user import User from app.models.user import User
from app.models.user_organization import OrganizationRole, UserOrganization from app.models.user_organization import OrganizationRole, UserOrganization
from app.repositories.base import BaseRepository
from app.schemas.organizations import ( from app.schemas.organizations import (
OrganizationCreate, OrganizationCreate,
OrganizationUpdate, OrganizationUpdate,
@@ -21,8 +22,10 @@ from app.schemas.organizations import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUpdate]): class OrganizationRepository(
"""Async CRUD operations for Organization model.""" BaseRepository[Organization, OrganizationCreate, OrganizationUpdate]
):
"""Repository for Organization model."""
async def get_by_slug(self, db: AsyncSession, *, slug: str) -> Organization | None: async def get_by_slug(self, db: AsyncSession, *, slug: str) -> Organization | None:
"""Get organization by slug.""" """Get organization by slug."""
@@ -32,7 +35,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
) )
return result.scalar_one_or_none() return result.scalar_one_or_none()
except Exception as e: except Exception as e:
logger.error(f"Error getting organization by slug {slug}: {e!s}") logger.error("Error getting organization by slug %s: %s", slug, e)
raise raise
async def create( async def create(
@@ -54,18 +57,20 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
except IntegrityError as e: except IntegrityError as e:
await db.rollback() await db.rollback()
error_msg = str(e.orig) if hasattr(e, "orig") else str(e) error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
if "slug" in error_msg.lower(): if (
logger.warning(f"Duplicate slug attempted: {obj_in.slug}") "slug" in error_msg.lower()
raise ValueError( or "unique" in error_msg.lower()
or "duplicate" in error_msg.lower()
):
logger.warning("Duplicate slug attempted: %s", obj_in.slug)
raise DuplicateEntryError(
f"Organization with slug '{obj_in.slug}' already exists" f"Organization with slug '{obj_in.slug}' already exists"
) )
logger.error(f"Integrity error creating organization: {error_msg}") logger.error("Integrity error creating organization: %s", error_msg)
raise ValueError(f"Database integrity error: {error_msg}") raise IntegrityConstraintError(f"Database integrity error: {error_msg}")
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error( logger.exception("Unexpected error creating organization: %s", e)
f"Unexpected error creating organization: {e!s}", exc_info=True
)
raise raise
async def get_multi_with_filters( async def get_multi_with_filters(
@@ -79,16 +84,10 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
sort_by: str = "created_at", sort_by: str = "created_at",
sort_order: str = "desc", sort_order: str = "desc",
) -> tuple[list[Organization], int]: ) -> tuple[list[Organization], int]:
""" """Get multiple organizations with filtering, searching, and sorting."""
Get multiple organizations with filtering, searching, and sorting.
Returns:
Tuple of (organizations list, total count)
"""
try: try:
query = select(Organization) query = select(Organization)
# Apply filters
if is_active is not None: if is_active is not None:
query = query.where(Organization.is_active == is_active) query = query.where(Organization.is_active == is_active)
@@ -100,26 +99,23 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
) )
query = query.where(search_filter) query = query.where(search_filter)
# Get total count before pagination
count_query = select(func.count()).select_from(query.alias()) count_query = select(func.count()).select_from(query.alias())
count_result = await db.execute(count_query) count_result = await db.execute(count_query)
total = count_result.scalar_one() total = count_result.scalar_one()
# Apply sorting
sort_column = getattr(Organization, sort_by, Organization.created_at) sort_column = getattr(Organization, sort_by, Organization.created_at)
if sort_order == "desc": if sort_order == "desc":
query = query.order_by(sort_column.desc()) query = query.order_by(sort_column.desc())
else: else:
query = query.order_by(sort_column.asc()) query = query.order_by(sort_column.asc())
# Apply pagination
query = query.offset(skip).limit(limit) query = query.offset(skip).limit(limit)
result = await db.execute(query) result = await db.execute(query)
organizations = list(result.scalars().all()) organizations = list(result.scalars().all())
return organizations, total return organizations, total
except Exception as e: except Exception as e:
logger.error(f"Error getting organizations with filters: {e!s}") logger.error("Error getting organizations with filters: %s", e)
raise raise
async def get_member_count(self, db: AsyncSession, *, organization_id: UUID) -> int: async def get_member_count(self, db: AsyncSession, *, organization_id: UUID) -> int:
@@ -136,7 +132,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
return result.scalar_one() or 0 return result.scalar_one() or 0
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Error getting member count for organization {organization_id}: {e!s}" "Error getting member count for organization %s: %s", organization_id, e
) )
raise raise
@@ -149,16 +145,8 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
is_active: bool | None = None, is_active: bool | None = None,
search: str | None = None, search: str | None = None,
) -> tuple[list[dict[str, Any]], int]: ) -> tuple[list[dict[str, Any]], int]:
""" """Get organizations with member counts in a SINGLE QUERY using JOIN and GROUP BY."""
Get organizations with member counts in a SINGLE QUERY using JOIN and GROUP BY.
This eliminates the N+1 query problem.
Returns:
Tuple of (list of dicts with org and member_count, total count)
"""
try: try:
# Build base query with LEFT JOIN and GROUP BY
# Use CASE statement to count only active members
query = ( query = (
select( select(
Organization, Organization,
@@ -181,10 +169,10 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
.group_by(Organization.id) .group_by(Organization.id)
) )
# Apply filters
if is_active is not None: if is_active is not None:
query = query.where(Organization.is_active == is_active) query = query.where(Organization.is_active == is_active)
search_filter = None
if search: if search:
search_filter = or_( search_filter = or_(
Organization.name.ilike(f"%{search}%"), Organization.name.ilike(f"%{search}%"),
@@ -193,17 +181,15 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
) )
query = query.where(search_filter) query = query.where(search_filter)
# Get total count
count_query = select(func.count(Organization.id)) count_query = select(func.count(Organization.id))
if is_active is not None: if is_active is not None:
count_query = count_query.where(Organization.is_active == is_active) count_query = count_query.where(Organization.is_active == is_active)
if search: if search_filter is not None:
count_query = count_query.where(search_filter) count_query = count_query.where(search_filter)
count_result = await db.execute(count_query) count_result = await db.execute(count_query)
total = count_result.scalar_one() total = count_result.scalar_one()
# Apply pagination and ordering
query = ( query = (
query.order_by(Organization.created_at.desc()).offset(skip).limit(limit) query.order_by(Organization.created_at.desc()).offset(skip).limit(limit)
) )
@@ -211,7 +197,6 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
result = await db.execute(query) result = await db.execute(query)
rows = result.all() rows = result.all()
# Convert to list of dicts
orgs_with_counts = [ orgs_with_counts = [
{"organization": org, "member_count": member_count} {"organization": org, "member_count": member_count}
for org, member_count in rows for org, member_count in rows
@@ -220,9 +205,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
return orgs_with_counts, total return orgs_with_counts, total
except Exception as e: except Exception as e:
logger.error( logger.exception("Error getting organizations with member counts: %s", e)
f"Error getting organizations with member counts: {e!s}", exc_info=True
)
raise raise
async def add_user( async def add_user(
@@ -236,7 +219,6 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
) -> UserOrganization: ) -> UserOrganization:
"""Add a user to an organization with a specific role.""" """Add a user to an organization with a specific role."""
try: try:
# Check if relationship already exists
result = await db.execute( result = await db.execute(
select(UserOrganization).where( select(UserOrganization).where(
and_( and_(
@@ -248,7 +230,6 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
existing = result.scalar_one_or_none() existing = result.scalar_one_or_none()
if existing: if existing:
# Reactivate if inactive, or raise error if already active
if not existing.is_active: if not existing.is_active:
existing.is_active = True existing.is_active = True
existing.role = role existing.role = role
@@ -257,9 +238,10 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
await db.refresh(existing) await db.refresh(existing)
return existing return existing
else: else:
raise ValueError("User is already a member of this organization") raise DuplicateEntryError(
"User is already a member of this organization"
)
# Create new relationship
user_org = UserOrganization( user_org = UserOrganization(
user_id=user_id, user_id=user_id,
organization_id=organization_id, organization_id=organization_id,
@@ -273,11 +255,11 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
return user_org return user_org
except IntegrityError as e: except IntegrityError as e:
await db.rollback() await db.rollback()
logger.error(f"Integrity error adding user to organization: {e!s}") logger.error("Integrity error adding user to organization: %s", e)
raise ValueError("Failed to add user to organization") raise IntegrityConstraintError("Failed to add user to organization")
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error(f"Error adding user to organization: {e!s}", exc_info=True) logger.exception("Error adding user to organization: %s", e)
raise raise
async def remove_user( async def remove_user(
@@ -303,7 +285,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
return True return True
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error(f"Error removing user from organization: {e!s}", exc_info=True) logger.exception("Error removing user from organization: %s", e)
raise raise
async def update_user_role( async def update_user_role(
@@ -338,7 +320,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
return user_org return user_org
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error(f"Error updating user role: {e!s}", exc_info=True) logger.exception("Error updating user role: %s", e)
raise raise
async def get_organization_members( async def get_organization_members(
@@ -348,16 +330,10 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
organization_id: UUID, organization_id: UUID,
skip: int = 0, skip: int = 0,
limit: int = 100, limit: int = 100,
is_active: bool = True, is_active: bool | None = True,
) -> tuple[list[dict[str, Any]], int]: ) -> tuple[list[dict[str, Any]], int]:
""" """Get members of an organization with user details."""
Get members of an organization with user details.
Returns:
Tuple of (members list with user details, total count)
"""
try: try:
# Build query with join
query = ( query = (
select(UserOrganization, User) select(UserOrganization, User)
.join(User, UserOrganization.user_id == User.id) .join(User, UserOrganization.user_id == User.id)
@@ -367,7 +343,6 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
if is_active is not None: if is_active is not None:
query = query.where(UserOrganization.is_active == is_active) query = query.where(UserOrganization.is_active == is_active)
# Get total count
count_query = select(func.count()).select_from( count_query = select(func.count()).select_from(
select(UserOrganization) select(UserOrganization)
.where(UserOrganization.organization_id == organization_id) .where(UserOrganization.organization_id == organization_id)
@@ -381,7 +356,6 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
count_result = await db.execute(count_query) count_result = await db.execute(count_query)
total = count_result.scalar_one() total = count_result.scalar_one()
# Apply ordering and pagination
query = ( query = (
query.order_by(UserOrganization.created_at.desc()) query.order_by(UserOrganization.created_at.desc())
.offset(skip) .offset(skip)
@@ -406,11 +380,11 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
return members, total return members, total
except Exception as e: except Exception as e:
logger.error(f"Error getting organization members: {e!s}") logger.error("Error getting organization members: %s", e)
raise raise
async def get_user_organizations( async def get_user_organizations(
self, db: AsyncSession, *, user_id: UUID, is_active: bool = True self, db: AsyncSession, *, user_id: UUID, is_active: bool | None = True
) -> list[Organization]: ) -> list[Organization]:
"""Get all organizations a user belongs to.""" """Get all organizations a user belongs to."""
try: try:
@@ -429,21 +403,14 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
result = await db.execute(query) result = await db.execute(query)
return list(result.scalars().all()) return list(result.scalars().all())
except Exception as e: except Exception as e:
logger.error(f"Error getting user organizations: {e!s}") logger.error("Error getting user organizations: %s", e)
raise raise
async def get_user_organizations_with_details( async def get_user_organizations_with_details(
self, db: AsyncSession, *, user_id: UUID, is_active: bool = True self, db: AsyncSession, *, user_id: UUID, is_active: bool | None = True
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
""" """Get user's organizations with role and member count in SINGLE QUERY."""
Get user's organizations with role and member count in SINGLE QUERY.
Eliminates N+1 problem by using subquery for member counts.
Returns:
List of dicts with organization, role, and member_count
"""
try: try:
# Subquery to get member counts for each organization
member_count_subq = ( member_count_subq = (
select( select(
UserOrganization.organization_id, UserOrganization.organization_id,
@@ -454,7 +421,6 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
.subquery() .subquery()
) )
# Main query with JOIN to get org, role, and member count
query = ( query = (
select( select(
Organization, Organization,
@@ -486,9 +452,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
] ]
except Exception as e: except Exception as e:
logger.error( logger.exception("Error getting user organizations with details: %s", e)
f"Error getting user organizations with details: {e!s}", exc_info=True
)
raise raise
async def get_user_role_in_org( async def get_user_role_in_org(
@@ -507,9 +471,9 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
) )
user_org = result.scalar_one_or_none() user_org = result.scalar_one_or_none()
return user_org.role if user_org else None return user_org.role if user_org else None # pyright: ignore[reportReturnType]
except Exception as e: except Exception as e:
logger.error(f"Error getting user role in org: {e!s}") logger.error("Error getting user role in org: %s", e)
raise raise
async def is_user_org_owner( async def is_user_org_owner(
@@ -531,5 +495,5 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
return role in [OrganizationRole.OWNER, OrganizationRole.ADMIN] return role in [OrganizationRole.OWNER, OrganizationRole.ADMIN]
# Create a singleton instance for use across the application # Singleton instance
organization = CRUDOrganization(Organization) organization_repo = OrganizationRepository(Organization)

View File

@@ -1,6 +1,5 @@
""" # app/repositories/session.py
Async CRUD operations for user sessions using SQLAlchemy 2.0 patterns. """Repository for UserSession model async database operations using SQLAlchemy 2.0 patterns."""
"""
import logging import logging
import uuid import uuid
@@ -11,49 +10,32 @@ from sqlalchemy import and_, delete, func, select, update
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload from sqlalchemy.orm import joinedload
from app.crud.base import CRUDBase from app.core.repository_exceptions import IntegrityConstraintError, InvalidInputError
from app.models.user_session import UserSession from app.models.user_session import UserSession
from app.repositories.base import BaseRepository
from app.schemas.sessions import SessionCreate, SessionUpdate from app.schemas.sessions import SessionCreate, SessionUpdate
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]): class SessionRepository(BaseRepository[UserSession, SessionCreate, SessionUpdate]):
"""Async CRUD operations for user sessions.""" """Repository for UserSession model."""
async def get_by_jti(self, db: AsyncSession, *, jti: str) -> UserSession | None: async def get_by_jti(self, db: AsyncSession, *, jti: str) -> UserSession | None:
""" """Get session by refresh token JTI."""
Get session by refresh token JTI.
Args:
db: Database session
jti: Refresh token JWT ID
Returns:
UserSession if found, None otherwise
"""
try: try:
result = await db.execute( result = await db.execute(
select(UserSession).where(UserSession.refresh_token_jti == jti) select(UserSession).where(UserSession.refresh_token_jti == jti)
) )
return result.scalar_one_or_none() return result.scalar_one_or_none()
except Exception as e: except Exception as e:
logger.error(f"Error getting session by JTI {jti}: {e!s}") logger.error("Error getting session by JTI %s: %s", jti, e)
raise raise
async def get_active_by_jti( async def get_active_by_jti(
self, db: AsyncSession, *, jti: str self, db: AsyncSession, *, jti: str
) -> UserSession | None: ) -> UserSession | None:
""" """Get active session by refresh token JTI."""
Get active session by refresh token JTI.
Args:
db: Database session
jti: Refresh token JWT ID
Returns:
Active UserSession if found, None otherwise
"""
try: try:
result = await db.execute( result = await db.execute(
select(UserSession).where( select(UserSession).where(
@@ -65,7 +47,7 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
) )
return result.scalar_one_or_none() return result.scalar_one_or_none()
except Exception as e: except Exception as e:
logger.error(f"Error getting active session by JTI {jti}: {e!s}") logger.error("Error getting active session by JTI %s: %s", jti, e)
raise raise
async def get_user_sessions( async def get_user_sessions(
@@ -76,25 +58,12 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
active_only: bool = True, active_only: bool = True,
with_user: bool = False, with_user: bool = False,
) -> list[UserSession]: ) -> list[UserSession]:
""" """Get all sessions for a user with optional eager loading."""
Get all sessions for a user with optional eager loading.
Args:
db: Database session
user_id: User ID
active_only: If True, return only active sessions
with_user: If True, eager load user relationship to prevent N+1
Returns:
List of UserSession objects
"""
try: try:
# Convert user_id string to UUID if needed
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
query = select(UserSession).where(UserSession.user_id == user_uuid) query = select(UserSession).where(UserSession.user_id == user_uuid)
# Add eager loading if requested to prevent N+1 queries
if with_user: if with_user:
query = query.options(joinedload(UserSession.user)) query = query.options(joinedload(UserSession.user))
@@ -105,25 +74,13 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
result = await db.execute(query) result = await db.execute(query)
return list(result.scalars().all()) return list(result.scalars().all())
except Exception as e: except Exception as e:
logger.error(f"Error getting sessions for user {user_id}: {e!s}") logger.error("Error getting sessions for user %s: %s", user_id, e)
raise raise
async def create_session( async def create_session(
self, db: AsyncSession, *, obj_in: SessionCreate self, db: AsyncSession, *, obj_in: SessionCreate
) -> UserSession: ) -> UserSession:
""" """Create a new user session."""
Create a new user session.
Args:
db: Database session
obj_in: SessionCreate schema with session data
Returns:
Created UserSession
Raises:
ValueError: If session creation fails
"""
try: try:
db_obj = UserSession( db_obj = UserSession(
user_id=obj_in.user_id, user_id=obj_in.user_id,
@@ -143,33 +100,26 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
await db.refresh(db_obj) await db.refresh(db_obj)
logger.info( logger.info(
f"Session created for user {obj_in.user_id} from {obj_in.device_name} " "Session created for user %s from %s (IP: %s)",
f"(IP: {obj_in.ip_address})" obj_in.user_id,
obj_in.device_name,
obj_in.ip_address,
) )
return db_obj return db_obj
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error(f"Error creating session: {e!s}", exc_info=True) logger.exception("Error creating session: %s", e)
raise ValueError(f"Failed to create session: {e!s}") raise IntegrityConstraintError(f"Failed to create session: {e!s}")
async def deactivate( async def deactivate(
self, db: AsyncSession, *, session_id: str self, db: AsyncSession, *, session_id: str
) -> UserSession | None: ) -> UserSession | None:
""" """Deactivate a session (logout from device)."""
Deactivate a session (logout from device).
Args:
db: Database session
session_id: Session UUID
Returns:
Deactivated UserSession if found, None otherwise
"""
try: try:
session = await self.get(db, id=session_id) session = await self.get(db, id=session_id)
if not session: if not session:
logger.warning(f"Session {session_id} not found for deactivation") logger.warning("Session %s not found for deactivation", session_id)
return None return None
session.is_active = False session.is_active = False
@@ -178,31 +128,23 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
await db.refresh(session) await db.refresh(session)
logger.info( logger.info(
f"Session {session_id} deactivated for user {session.user_id} " "Session %s deactivated for user %s (%s)",
f"({session.device_name})" session_id,
session.user_id,
session.device_name,
) )
return session return session
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error(f"Error deactivating session {session_id}: {e!s}") logger.error("Error deactivating session %s: %s", session_id, e)
raise raise
async def deactivate_all_user_sessions( async def deactivate_all_user_sessions(
self, db: AsyncSession, *, user_id: str self, db: AsyncSession, *, user_id: str
) -> int: ) -> int:
""" """Deactivate all active sessions for a user (logout from all devices)."""
Deactivate all active sessions for a user (logout from all devices).
Args:
db: Database session
user_id: User ID
Returns:
Number of sessions deactivated
"""
try: try:
# Convert user_id string to UUID if needed
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
stmt = ( stmt = (
@@ -216,27 +158,18 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
count = result.rowcount count = result.rowcount
logger.info(f"Deactivated {count} sessions for user {user_id}") logger.info("Deactivated %s sessions for user %s", count, user_id)
return count return count
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error(f"Error deactivating all sessions for user {user_id}: {e!s}") logger.error("Error deactivating all sessions for user %s: %s", user_id, e)
raise raise
async def update_last_used( async def update_last_used(
self, db: AsyncSession, *, session: UserSession self, db: AsyncSession, *, session: UserSession
) -> UserSession: ) -> UserSession:
""" """Update the last_used_at timestamp for a session."""
Update the last_used_at timestamp for a session.
Args:
db: Database session
session: UserSession object
Returns:
Updated UserSession
"""
try: try:
session.last_used_at = datetime.now(UTC) session.last_used_at = datetime.now(UTC)
db.add(session) db.add(session)
@@ -245,7 +178,7 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
return session return session
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error(f"Error updating last_used for session {session.id}: {e!s}") logger.error("Error updating last_used for session %s: %s", session.id, e)
raise raise
async def update_refresh_token( async def update_refresh_token(
@@ -256,20 +189,7 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
new_jti: str, new_jti: str,
new_expires_at: datetime, new_expires_at: datetime,
) -> UserSession: ) -> UserSession:
""" """Update session with new refresh token JTI and expiration."""
Update session with new refresh token JTI and expiration.
Called during token refresh.
Args:
db: Database session
session: UserSession object
new_jti: New refresh token JTI
new_expires_at: New expiration datetime
Returns:
Updated UserSession
"""
try: try:
session.refresh_token_jti = new_jti session.refresh_token_jti = new_jti
session.expires_at = new_expires_at session.expires_at = new_expires_at
@@ -281,32 +201,16 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error( logger.error(
f"Error updating refresh token for session {session.id}: {e!s}" "Error updating refresh token for session %s: %s", session.id, e
) )
raise raise
async def cleanup_expired(self, db: AsyncSession, *, keep_days: int = 30) -> int: async def cleanup_expired(self, db: AsyncSession, *, keep_days: int = 30) -> int:
""" """Clean up expired sessions using optimized bulk DELETE."""
Clean up expired sessions using optimized bulk DELETE.
Deletes sessions that are:
- Expired AND inactive
- Older than keep_days
Uses single DELETE query instead of N individual deletes for efficiency.
Args:
db: Database session
keep_days: Keep inactive sessions for this many days (for audit)
Returns:
Number of sessions deleted
"""
try: try:
cutoff_date = datetime.now(UTC) - timedelta(days=keep_days) cutoff_date = datetime.now(UTC) - timedelta(days=keep_days)
now = datetime.now(UTC) now = datetime.now(UTC)
# Use bulk DELETE with WHERE clause - single query
stmt = delete(UserSession).where( stmt = delete(UserSession).where(
and_( and_(
UserSession.is_active == False, # noqa: E712 UserSession.is_active == False, # noqa: E712
@@ -321,38 +225,25 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
count = result.rowcount count = result.rowcount
if count > 0: if count > 0:
logger.info(f"Cleaned up {count} expired sessions using bulk DELETE") logger.info("Cleaned up %s expired sessions using bulk DELETE", count)
return count return count
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error(f"Error cleaning up expired sessions: {e!s}") logger.error("Error cleaning up expired sessions: %s", e)
raise raise
async def cleanup_expired_for_user(self, db: AsyncSession, *, user_id: str) -> int: async def cleanup_expired_for_user(self, db: AsyncSession, *, user_id: str) -> int:
""" """Clean up expired and inactive sessions for a specific user."""
Clean up expired and inactive sessions for a specific user.
Uses single bulk DELETE query for efficiency instead of N individual deletes.
Args:
db: Database session
user_id: User ID to cleanup sessions for
Returns:
Number of sessions deleted
"""
try: try:
# Validate UUID
try: try:
uuid_obj = uuid.UUID(user_id) uuid_obj = uuid.UUID(user_id)
except (ValueError, AttributeError): except (ValueError, AttributeError):
logger.error(f"Invalid UUID format: {user_id}") logger.error("Invalid UUID format: %s", user_id)
raise ValueError(f"Invalid user ID format: {user_id}") raise InvalidInputError(f"Invalid user ID format: {user_id}")
now = datetime.now(UTC) now = datetime.now(UTC)
# Use bulk DELETE with WHERE clause - single query
stmt = delete(UserSession).where( stmt = delete(UserSession).where(
and_( and_(
UserSession.user_id == uuid_obj, UserSession.user_id == uuid_obj,
@@ -368,30 +259,22 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
if count > 0: if count > 0:
logger.info( logger.info(
f"Cleaned up {count} expired sessions for user {user_id} using bulk DELETE" "Cleaned up %s expired sessions for user %s using bulk DELETE",
count,
user_id,
) )
return count return count
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error( logger.error(
f"Error cleaning up expired sessions for user {user_id}: {e!s}" "Error cleaning up expired sessions for user %s: %s", user_id, e
) )
raise raise
async def get_user_session_count(self, db: AsyncSession, *, user_id: str) -> int: async def get_user_session_count(self, db: AsyncSession, *, user_id: str) -> int:
""" """Get count of active sessions for a user."""
Get count of active sessions for a user.
Args:
db: Database session
user_id: User ID
Returns:
Number of active sessions
"""
try: try:
# Convert user_id string to UUID if needed
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
result = await db.execute( result = await db.execute(
@@ -401,7 +284,7 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
) )
return result.scalar_one() return result.scalar_one()
except Exception as e: except Exception as e:
logger.error(f"Error counting sessions for user {user_id}: {e!s}") logger.error("Error counting sessions for user %s: %s", user_id, e)
raise raise
async def get_all_sessions( async def get_all_sessions(
@@ -413,31 +296,16 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
active_only: bool = True, active_only: bool = True,
with_user: bool = True, with_user: bool = True,
) -> tuple[list[UserSession], int]: ) -> tuple[list[UserSession], int]:
""" """Get all sessions across all users with pagination (admin only)."""
Get all sessions across all users with pagination (admin only).
Args:
db: Database session
skip: Number of records to skip
limit: Maximum number of records to return
active_only: If True, return only active sessions
with_user: If True, eager load user relationship to prevent N+1
Returns:
Tuple of (list of UserSession objects, total count)
"""
try: try:
# Build query
query = select(UserSession) query = select(UserSession)
# Add eager loading if requested to prevent N+1 queries
if with_user: if with_user:
query = query.options(joinedload(UserSession.user)) query = query.options(joinedload(UserSession.user))
if active_only: if active_only:
query = query.where(UserSession.is_active) query = query.where(UserSession.is_active)
# Get total count
count_query = select(func.count(UserSession.id)) count_query = select(func.count(UserSession.id))
if active_only: if active_only:
count_query = count_query.where(UserSession.is_active) count_query = count_query.where(UserSession.is_active)
@@ -445,7 +313,6 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
count_result = await db.execute(count_query) count_result = await db.execute(count_query)
total = count_result.scalar_one() total = count_result.scalar_one()
# Apply pagination and ordering
query = ( query = (
query.order_by(UserSession.last_used_at.desc()) query.order_by(UserSession.last_used_at.desc())
.offset(skip) .offset(skip)
@@ -458,9 +325,9 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
return sessions, total return sessions, total
except Exception as e: except Exception as e:
logger.error(f"Error getting all sessions: {e!s}", exc_info=True) logger.exception("Error getting all sessions: %s", e)
raise raise
# Create singleton instance # Singleton instance
session = CRUDSession(UserSession) session_repo = SessionRepository(UserSession)

View File

@@ -1,5 +1,5 @@
# app/crud/user_async.py # app/repositories/user.py
"""Async CRUD operations for User model using SQLAlchemy 2.0 patterns.""" """Repository for User model async database operations using SQLAlchemy 2.0 patterns."""
import logging import logging
from datetime import UTC, datetime from datetime import UTC, datetime
@@ -11,15 +11,16 @@ from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.core.auth import get_password_hash_async from app.core.auth import get_password_hash_async
from app.crud.base import CRUDBase from app.core.repository_exceptions import DuplicateEntryError, InvalidInputError
from app.models.user import User from app.models.user import User
from app.repositories.base import BaseRepository
from app.schemas.users import UserCreate, UserUpdate from app.schemas.users import UserCreate, UserUpdate
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]): class UserRepository(BaseRepository[User, UserCreate, UserUpdate]):
"""Async CRUD operations for User model.""" """Repository for User model."""
async def get_by_email(self, db: AsyncSession, *, email: str) -> User | None: async def get_by_email(self, db: AsyncSession, *, email: str) -> User | None:
"""Get user by email address.""" """Get user by email address."""
@@ -27,13 +28,12 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
result = await db.execute(select(User).where(User.email == email)) result = await db.execute(select(User).where(User.email == email))
return result.scalar_one_or_none() return result.scalar_one_or_none()
except Exception as e: except Exception as e:
logger.error(f"Error getting user by email {email}: {e!s}") logger.error("Error getting user by email %s: %s", email, e)
raise raise
async def create(self, db: AsyncSession, *, obj_in: UserCreate) -> User: async def create(self, db: AsyncSession, *, obj_in: UserCreate) -> User:
"""Create a new user with async password hashing and error handling.""" """Create a new user with async password hashing and error handling."""
try: try:
# Hash password asynchronously to avoid blocking event loop
password_hash = await get_password_hash_async(obj_in.password) password_hash = await get_password_hash_async(obj_in.password)
db_obj = User( db_obj = User(
@@ -57,13 +57,49 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
await db.rollback() await db.rollback()
error_msg = str(e.orig) if hasattr(e, "orig") else str(e) error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
if "email" in error_msg.lower(): if "email" in error_msg.lower():
logger.warning(f"Duplicate email attempted: {obj_in.email}") logger.warning("Duplicate email attempted: %s", obj_in.email)
raise ValueError(f"User with email {obj_in.email} already exists") raise DuplicateEntryError(
logger.error(f"Integrity error creating user: {error_msg}") f"User with email {obj_in.email} already exists"
raise ValueError(f"Database integrity error: {error_msg}") )
logger.error("Integrity error creating user: %s", error_msg)
raise DuplicateEntryError(f"Database integrity error: {error_msg}")
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error(f"Unexpected error creating user: {e!s}", exc_info=True) logger.exception("Unexpected error creating user: %s", e)
raise
async def create_oauth_user(
self,
db: AsyncSession,
*,
email: str,
first_name: str = "User",
last_name: str | None = None,
) -> User:
"""Create a new passwordless user for OAuth sign-in."""
try:
db_obj = User(
email=email,
password_hash=None, # OAuth-only user
first_name=first_name,
last_name=last_name,
is_active=True,
is_superuser=False,
)
db.add(db_obj)
await db.flush() # Get user.id without committing
return db_obj
except IntegrityError as e:
await db.rollback()
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
if "email" in error_msg.lower():
logger.warning("Duplicate email attempted: %s", email)
raise DuplicateEntryError(f"User with email {email} already exists")
logger.error("Integrity error creating OAuth user: %s", error_msg)
raise DuplicateEntryError(f"Database integrity error: {error_msg}")
except Exception as e:
await db.rollback()
logger.exception("Unexpected error creating OAuth user: %s", e)
raise raise
async def update( async def update(
@@ -75,8 +111,6 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
else: else:
update_data = obj_in.model_dump(exclude_unset=True) update_data = obj_in.model_dump(exclude_unset=True)
# Handle password separately if it exists in update data
# Hash password asynchronously to avoid blocking event loop
if "password" in update_data: if "password" in update_data:
update_data["password_hash"] = await get_password_hash_async( update_data["password_hash"] = await get_password_hash_async(
update_data["password"] update_data["password"]
@@ -85,6 +119,15 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
return await super().update(db, db_obj=db_obj, obj_in=update_data) return await super().update(db, db_obj=db_obj, obj_in=update_data)
async def update_password(
self, db: AsyncSession, *, user: User, password_hash: str
) -> User:
"""Set a new password hash on a user and commit."""
user.password_hash = password_hash
await db.commit()
await db.refresh(user)
return user
async def get_multi_with_total( async def get_multi_with_total(
self, self,
db: AsyncSession, db: AsyncSession,
@@ -96,43 +139,23 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
filters: dict[str, Any] | None = None, filters: dict[str, Any] | None = None,
search: str | None = None, search: str | None = None,
) -> tuple[list[User], int]: ) -> tuple[list[User], int]:
""" """Get multiple users with total count, filtering, sorting, and search."""
Get multiple users with total count, filtering, sorting, and search.
Args:
db: Database session
skip: Number of records to skip
limit: Maximum number of records to return
sort_by: Field name to sort by
sort_order: Sort order ("asc" or "desc")
filters: Dictionary of filters (field_name: value)
search: Search term to match against email, first_name, last_name
Returns:
Tuple of (users list, total count)
"""
# Validate pagination
if skip < 0: if skip < 0:
raise ValueError("skip must be non-negative") raise InvalidInputError("skip must be non-negative")
if limit < 0: if limit < 0:
raise ValueError("limit must be non-negative") raise InvalidInputError("limit must be non-negative")
if limit > 1000: if limit > 1000:
raise ValueError("Maximum limit is 1000") raise InvalidInputError("Maximum limit is 1000")
try: try:
# Build base query
query = select(User) query = select(User)
# Exclude soft-deleted users
query = query.where(User.deleted_at.is_(None)) query = query.where(User.deleted_at.is_(None))
# Apply filters
if filters: if filters:
for field, value in filters.items(): for field, value in filters.items():
if hasattr(User, field) and value is not None: if hasattr(User, field) and value is not None:
query = query.where(getattr(User, field) == value) query = query.where(getattr(User, field) == value)
# Apply search
if search: if search:
search_filter = or_( search_filter = or_(
User.email.ilike(f"%{search}%"), User.email.ilike(f"%{search}%"),
@@ -141,14 +164,12 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
) )
query = query.where(search_filter) query = query.where(search_filter)
# Get total count
from sqlalchemy import func from sqlalchemy import func
count_query = select(func.count()).select_from(query.alias()) count_query = select(func.count()).select_from(query.alias())
count_result = await db.execute(count_query) count_result = await db.execute(count_query)
total = count_result.scalar_one() total = count_result.scalar_one()
# Apply sorting
if sort_by and hasattr(User, sort_by): if sort_by and hasattr(User, sort_by):
sort_column = getattr(User, sort_by) sort_column = getattr(User, sort_by)
if sort_order.lower() == "desc": if sort_order.lower() == "desc":
@@ -156,7 +177,6 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
else: else:
query = query.order_by(sort_column.asc()) query = query.order_by(sort_column.asc())
# Apply pagination
query = query.offset(skip).limit(limit) query = query.offset(skip).limit(limit)
result = await db.execute(query) result = await db.execute(query)
users = list(result.scalars().all()) users = list(result.scalars().all())
@@ -164,32 +184,21 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
return users, total return users, total
except Exception as e: except Exception as e:
logger.error(f"Error retrieving paginated users: {e!s}") logger.error("Error retrieving paginated users: %s", e)
raise raise
async def bulk_update_status( async def bulk_update_status(
self, db: AsyncSession, *, user_ids: list[UUID], is_active: bool self, db: AsyncSession, *, user_ids: list[UUID], is_active: bool
) -> int: ) -> int:
""" """Bulk update is_active status for multiple users."""
Bulk update is_active status for multiple users.
Args:
db: Database session
user_ids: List of user IDs to update
is_active: New active status
Returns:
Number of users updated
"""
try: try:
if not user_ids: if not user_ids:
return 0 return 0
# Use UPDATE with WHERE IN for efficiency
stmt = ( stmt = (
update(User) update(User)
.where(User.id.in_(user_ids)) .where(User.id.in_(user_ids))
.where(User.deleted_at.is_(None)) # Don't update deleted users .where(User.deleted_at.is_(None))
.values(is_active=is_active, updated_at=datetime.now(UTC)) .values(is_active=is_active, updated_at=datetime.now(UTC))
) )
@@ -197,12 +206,14 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
await db.commit() await db.commit()
updated_count = result.rowcount updated_count = result.rowcount
logger.info(f"Bulk updated {updated_count} users to is_active={is_active}") logger.info(
"Bulk updated %s users to is_active=%s", updated_count, is_active
)
return updated_count return updated_count
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error(f"Error bulk updating user status: {e!s}", exc_info=True) logger.exception("Error bulk updating user status: %s", e)
raise raise
async def bulk_soft_delete( async def bulk_soft_delete(
@@ -212,34 +223,20 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
user_ids: list[UUID], user_ids: list[UUID],
exclude_user_id: UUID | None = None, exclude_user_id: UUID | None = None,
) -> int: ) -> int:
""" """Bulk soft delete multiple users."""
Bulk soft delete multiple users.
Args:
db: Database session
user_ids: List of user IDs to delete
exclude_user_id: Optional user ID to exclude (e.g., the admin performing the action)
Returns:
Number of users deleted
"""
try: try:
if not user_ids: if not user_ids:
return 0 return 0
# Remove excluded user from list
filtered_ids = [uid for uid in user_ids if uid != exclude_user_id] filtered_ids = [uid for uid in user_ids if uid != exclude_user_id]
if not filtered_ids: if not filtered_ids:
return 0 return 0
# Use UPDATE with WHERE IN for efficiency
stmt = ( stmt = (
update(User) update(User)
.where(User.id.in_(filtered_ids)) .where(User.id.in_(filtered_ids))
.where( .where(User.deleted_at.is_(None))
User.deleted_at.is_(None)
) # Don't re-delete already deleted users
.values( .values(
deleted_at=datetime.now(UTC), deleted_at=datetime.now(UTC),
is_active=False, is_active=False,
@@ -251,22 +248,22 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
await db.commit() await db.commit()
deleted_count = result.rowcount deleted_count = result.rowcount
logger.info(f"Bulk soft deleted {deleted_count} users") logger.info("Bulk soft deleted %s users", deleted_count)
return deleted_count return deleted_count
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error(f"Error bulk deleting users: {e!s}", exc_info=True) logger.exception("Error bulk deleting users: %s", e)
raise raise
def is_active(self, user: User) -> bool: def is_active(self, user: User) -> bool:
"""Check if user is active.""" """Check if user is active."""
return user.is_active return bool(user.is_active)
def is_superuser(self, user: User) -> bool: def is_superuser(self, user: User) -> bool:
"""Check if user is a superuser.""" """Check if user is a superuser."""
return user.is_superuser return bool(user.is_superuser)
# Create a singleton instance for use across the application # Singleton instance
user = CRUDUser(User) user_repo = UserRepository(User)

View File

@@ -1,273 +0,0 @@
"""
Event schemas for the Syndarix EventBus (Redis Pub/Sub).
This module defines event types and payload schemas for real-time communication
between services, agents, and the frontend.
"""
from datetime import datetime
from enum import Enum
from typing import Literal
from uuid import UUID
from pydantic import BaseModel, Field
class EventType(str, Enum):
"""
Event types for the EventBus.
Naming convention: {domain}.{action}
"""
# Agent Events
AGENT_SPAWNED = "agent.spawned"
AGENT_STATUS_CHANGED = "agent.status_changed"
AGENT_MESSAGE = "agent.message"
AGENT_TERMINATED = "agent.terminated"
# Issue Events
ISSUE_CREATED = "issue.created"
ISSUE_UPDATED = "issue.updated"
ISSUE_ASSIGNED = "issue.assigned"
ISSUE_CLOSED = "issue.closed"
# Sprint Events
SPRINT_STARTED = "sprint.started"
SPRINT_COMPLETED = "sprint.completed"
# Approval Events
APPROVAL_REQUESTED = "approval.requested"
APPROVAL_GRANTED = "approval.granted"
APPROVAL_DENIED = "approval.denied"
# Project Events
PROJECT_CREATED = "project.created"
PROJECT_UPDATED = "project.updated"
PROJECT_ARCHIVED = "project.archived"
# Workflow Events
WORKFLOW_STARTED = "workflow.started"
WORKFLOW_STEP_COMPLETED = "workflow.step_completed"
WORKFLOW_COMPLETED = "workflow.completed"
WORKFLOW_FAILED = "workflow.failed"
ActorType = Literal["agent", "user", "system"]
class Event(BaseModel):
"""
Base event schema for the EventBus.
All events published to the EventBus must conform to this schema.
"""
id: str = Field(
...,
description="Unique event identifier (UUID string)",
examples=["550e8400-e29b-41d4-a716-446655440000"],
)
type: EventType = Field(
...,
description="Event type enum value",
examples=[EventType.AGENT_MESSAGE],
)
timestamp: datetime = Field(
...,
description="When the event occurred (UTC)",
examples=["2024-01-15T10:30:00Z"],
)
project_id: UUID = Field(
...,
description="Project this event belongs to",
examples=["550e8400-e29b-41d4-a716-446655440001"],
)
actor_id: UUID | None = Field(
default=None,
description="ID of the agent or user who triggered the event",
examples=["550e8400-e29b-41d4-a716-446655440002"],
)
actor_type: ActorType = Field(
...,
description="Type of actor: 'agent', 'user', or 'system'",
examples=["agent"],
)
payload: dict = Field(
default_factory=dict,
description="Event-specific payload data",
)
model_config = {
"json_schema_extra": {
"example": {
"id": "550e8400-e29b-41d4-a716-446655440000",
"type": "agent.message",
"timestamp": "2024-01-15T10:30:00Z",
"project_id": "550e8400-e29b-41d4-a716-446655440001",
"actor_id": "550e8400-e29b-41d4-a716-446655440002",
"actor_type": "agent",
"payload": {"message": "Processing task...", "progress": 50},
}
}
}
# Specific payload schemas for type safety
class AgentSpawnedPayload(BaseModel):
"""Payload for AGENT_SPAWNED events."""
agent_instance_id: UUID = Field(..., description="ID of the spawned agent instance")
agent_type_id: UUID = Field(..., description="ID of the agent type")
agent_name: str = Field(..., description="Human-readable name of the agent")
role: str = Field(..., description="Agent role (e.g., 'product_owner', 'engineer')")
class AgentStatusChangedPayload(BaseModel):
"""Payload for AGENT_STATUS_CHANGED events."""
agent_instance_id: UUID = Field(..., description="ID of the agent instance")
previous_status: str = Field(..., description="Previous status")
new_status: str = Field(..., description="New status")
reason: str | None = Field(default=None, description="Reason for status change")
class AgentMessagePayload(BaseModel):
"""Payload for AGENT_MESSAGE events."""
agent_instance_id: UUID = Field(..., description="ID of the agent instance")
message: str = Field(..., description="Message content")
message_type: str = Field(
default="info",
description="Message type: 'info', 'warning', 'error', 'debug'",
)
metadata: dict = Field(
default_factory=dict,
description="Additional metadata (e.g., token usage, model info)",
)
class AgentTerminatedPayload(BaseModel):
"""Payload for AGENT_TERMINATED events."""
agent_instance_id: UUID = Field(..., description="ID of the agent instance")
termination_reason: str = Field(..., description="Reason for termination")
final_status: str = Field(..., description="Final status at termination")
class IssueCreatedPayload(BaseModel):
"""Payload for ISSUE_CREATED events."""
issue_id: str = Field(..., description="Issue ID (from external tracker)")
title: str = Field(..., description="Issue title")
priority: str | None = Field(default=None, description="Issue priority")
labels: list[str] = Field(default_factory=list, description="Issue labels")
class IssueUpdatedPayload(BaseModel):
"""Payload for ISSUE_UPDATED events."""
issue_id: str = Field(..., description="Issue ID (from external tracker)")
changes: dict = Field(..., description="Dictionary of field changes")
class IssueAssignedPayload(BaseModel):
"""Payload for ISSUE_ASSIGNED events."""
issue_id: str = Field(..., description="Issue ID (from external tracker)")
assignee_id: UUID | None = Field(
default=None, description="Agent or user assigned to"
)
assignee_name: str | None = Field(default=None, description="Assignee name")
class IssueClosedPayload(BaseModel):
"""Payload for ISSUE_CLOSED events."""
issue_id: str = Field(..., description="Issue ID (from external tracker)")
resolution: str = Field(..., description="Resolution status")
class SprintStartedPayload(BaseModel):
"""Payload for SPRINT_STARTED events."""
sprint_id: UUID = Field(..., description="Sprint ID")
sprint_name: str = Field(..., description="Sprint name")
goal: str | None = Field(default=None, description="Sprint goal")
issue_count: int = Field(default=0, description="Number of issues in sprint")
class SprintCompletedPayload(BaseModel):
"""Payload for SPRINT_COMPLETED events."""
sprint_id: UUID = Field(..., description="Sprint ID")
sprint_name: str = Field(..., description="Sprint name")
completed_issues: int = Field(default=0, description="Number of completed issues")
incomplete_issues: int = Field(default=0, description="Number of incomplete issues")
class ApprovalRequestedPayload(BaseModel):
"""Payload for APPROVAL_REQUESTED events."""
approval_id: UUID = Field(..., description="Approval request ID")
approval_type: str = Field(..., description="Type of approval needed")
description: str = Field(..., description="Description of what needs approval")
requested_by: UUID | None = Field(
default=None, description="Agent/user requesting approval"
)
timeout_minutes: int | None = Field(
default=None, description="Minutes before auto-escalation"
)
class ApprovalGrantedPayload(BaseModel):
"""Payload for APPROVAL_GRANTED events."""
approval_id: UUID = Field(..., description="Approval request ID")
approved_by: UUID = Field(..., description="User who granted approval")
comments: str | None = Field(default=None, description="Approval comments")
class ApprovalDeniedPayload(BaseModel):
"""Payload for APPROVAL_DENIED events."""
approval_id: UUID = Field(..., description="Approval request ID")
denied_by: UUID = Field(..., description="User who denied approval")
reason: str = Field(..., description="Reason for denial")
class WorkflowStartedPayload(BaseModel):
"""Payload for WORKFLOW_STARTED events."""
workflow_id: UUID = Field(..., description="Workflow execution ID")
workflow_type: str = Field(..., description="Type of workflow")
total_steps: int = Field(default=0, description="Total number of steps")
class WorkflowStepCompletedPayload(BaseModel):
"""Payload for WORKFLOW_STEP_COMPLETED events."""
workflow_id: UUID = Field(..., description="Workflow execution ID")
step_name: str = Field(..., description="Name of completed step")
step_number: int = Field(..., description="Step number (1-indexed)")
total_steps: int = Field(..., description="Total number of steps")
result: dict = Field(default_factory=dict, description="Step result data")
class WorkflowCompletedPayload(BaseModel):
"""Payload for WORKFLOW_COMPLETED events."""
workflow_id: UUID = Field(..., description="Workflow execution ID")
duration_seconds: float = Field(..., description="Total execution duration")
result: dict = Field(default_factory=dict, description="Workflow result data")
class WorkflowFailedPayload(BaseModel):
"""Payload for WORKFLOW_FAILED events."""
workflow_id: UUID = Field(..., description="Workflow execution ID")
error_message: str = Field(..., description="Error message")
failed_step: str | None = Field(default=None, description="Step that failed")
recoverable: bool = Field(default=False, description="Whether error is recoverable")

View File

@@ -60,8 +60,8 @@ class OAuthAccountCreate(OAuthAccountBase):
user_id: UUID user_id: UUID
provider_user_id: str = Field(..., max_length=255) provider_user_id: str = Field(..., max_length=255)
access_token_encrypted: str | None = None access_token: str | None = None
refresh_token_encrypted: str | None = None refresh_token: str | None = None
token_expires_at: datetime | None = None token_expires_at: datetime | None = None

View File

@@ -48,7 +48,7 @@ class OrganizationCreate(OrganizationBase):
"""Schema for creating a new organization.""" """Schema for creating a new organization."""
name: str = Field(..., min_length=1, max_length=255) name: str = Field(..., min_length=1, max_length=255)
slug: str = Field(..., min_length=1, max_length=255) slug: str = Field(..., min_length=1, max_length=255) # pyright: ignore[reportIncompatibleVariableOverride]
class OrganizationUpdate(BaseModel): class OrganizationUpdate(BaseModel):

View File

@@ -1,113 +0,0 @@
# app/schemas/syndarix/__init__.py
"""
Syndarix domain schemas.
This package contains Pydantic schemas for validating and serializing
Syndarix domain entities.
"""
from .agent_instance import (
AgentInstanceCreate,
AgentInstanceInDB,
AgentInstanceListResponse,
AgentInstanceMetrics,
AgentInstanceResponse,
AgentInstanceTerminate,
AgentInstanceUpdate,
)
from .agent_type import (
AgentTypeCreate,
AgentTypeInDB,
AgentTypeListResponse,
AgentTypeResponse,
AgentTypeUpdate,
)
from .enums import (
AgentStatus,
AutonomyLevel,
IssuePriority,
IssueStatus,
ProjectStatus,
SprintStatus,
SyncStatus,
)
from .issue import (
IssueAssign,
IssueClose,
IssueCreate,
IssueInDB,
IssueListResponse,
IssueResponse,
IssueStats,
IssueSyncUpdate,
IssueUpdate,
)
from .project import (
ProjectCreate,
ProjectInDB,
ProjectListResponse,
ProjectResponse,
ProjectUpdate,
)
from .sprint import (
SprintBurndown,
SprintComplete,
SprintCreate,
SprintInDB,
SprintListResponse,
SprintResponse,
SprintStart,
SprintUpdate,
SprintVelocity,
)
__all__ = [
# AgentInstance schemas
"AgentInstanceCreate",
"AgentInstanceInDB",
"AgentInstanceListResponse",
"AgentInstanceMetrics",
"AgentInstanceResponse",
"AgentInstanceTerminate",
"AgentInstanceUpdate",
# Enums
"AgentStatus",
# AgentType schemas
"AgentTypeCreate",
"AgentTypeInDB",
"AgentTypeListResponse",
"AgentTypeResponse",
"AgentTypeUpdate",
"AutonomyLevel",
# Issue schemas
"IssueAssign",
"IssueClose",
"IssueCreate",
"IssueInDB",
"IssueListResponse",
"IssuePriority",
"IssueResponse",
"IssueStats",
"IssueStatus",
"IssueSyncUpdate",
"IssueUpdate",
# Project schemas
"ProjectCreate",
"ProjectInDB",
"ProjectListResponse",
"ProjectResponse",
"ProjectStatus",
"ProjectUpdate",
# Sprint schemas
"SprintBurndown",
"SprintComplete",
"SprintCreate",
"SprintInDB",
"SprintListResponse",
"SprintResponse",
"SprintStart",
"SprintStatus",
"SprintUpdate",
"SprintVelocity",
"SyncStatus",
]

View File

@@ -1,124 +0,0 @@
# app/schemas/syndarix/agent_instance.py
"""
Pydantic schemas for AgentInstance entity.
"""
from datetime import datetime
from decimal import Decimal
from typing import Any
from uuid import UUID
from pydantic import BaseModel, ConfigDict, Field
from .enums import AgentStatus
class AgentInstanceBase(BaseModel):
"""Base agent instance schema with common fields."""
agent_type_id: UUID
project_id: UUID
status: AgentStatus = AgentStatus.IDLE
current_task: str | None = None
short_term_memory: dict[str, Any] = Field(default_factory=dict)
long_term_memory_ref: str | None = Field(None, max_length=500)
session_id: str | None = Field(None, max_length=255)
class AgentInstanceCreate(BaseModel):
"""Schema for creating a new agent instance."""
agent_type_id: UUID
project_id: UUID
name: str = Field(..., min_length=1, max_length=100)
status: AgentStatus = AgentStatus.IDLE
current_task: str | None = None
short_term_memory: dict[str, Any] = Field(default_factory=dict)
long_term_memory_ref: str | None = Field(None, max_length=500)
session_id: str | None = Field(None, max_length=255)
class AgentInstanceUpdate(BaseModel):
"""Schema for updating an agent instance."""
status: AgentStatus | None = None
current_task: str | None = None
short_term_memory: dict[str, Any] | None = None
long_term_memory_ref: str | None = None
session_id: str | None = None
last_activity_at: datetime | None = None
tasks_completed: int | None = Field(None, ge=0)
tokens_used: int | None = Field(None, ge=0)
cost_incurred: Decimal | None = Field(None, ge=0)
class AgentInstanceTerminate(BaseModel):
"""Schema for terminating an agent instance."""
reason: str | None = None
class AgentInstanceInDB(AgentInstanceBase):
"""Schema for agent instance in database."""
id: UUID
last_activity_at: datetime | None = None
terminated_at: datetime | None = None
tasks_completed: int = 0
tokens_used: int = 0
cost_incurred: Decimal = Decimal("0.0000")
created_at: datetime
updated_at: datetime
model_config = ConfigDict(from_attributes=True)
class AgentInstanceResponse(BaseModel):
"""Schema for agent instance API responses."""
id: UUID
agent_type_id: UUID
project_id: UUID
name: str
status: AgentStatus
current_task: str | None = None
short_term_memory: dict[str, Any] = Field(default_factory=dict)
long_term_memory_ref: str | None = None
session_id: str | None = None
last_activity_at: datetime | None = None
terminated_at: datetime | None = None
tasks_completed: int = 0
tokens_used: int = 0
cost_incurred: Decimal = Decimal("0.0000")
created_at: datetime
updated_at: datetime
# Expanded fields from relationships
agent_type_name: str | None = None
agent_type_slug: str | None = None
project_name: str | None = None
project_slug: str | None = None
assigned_issues_count: int | None = 0
model_config = ConfigDict(from_attributes=True)
class AgentInstanceListResponse(BaseModel):
"""Schema for paginated agent instance list responses."""
agent_instances: list[AgentInstanceResponse]
total: int
page: int
page_size: int
pages: int
class AgentInstanceMetrics(BaseModel):
"""Schema for agent instance metrics summary."""
total_instances: int
active_instances: int
idle_instances: int
total_tasks_completed: int
total_tokens_used: int
total_cost_incurred: Decimal

View File

@@ -1,151 +0,0 @@
# app/schemas/syndarix/agent_type.py
"""
Pydantic schemas for AgentType entity.
"""
import re
from datetime import datetime
from typing import Any
from uuid import UUID
from pydantic import BaseModel, ConfigDict, Field, field_validator
class AgentTypeBase(BaseModel):
"""Base agent type schema with common fields."""
name: str = Field(..., min_length=1, max_length=255)
slug: str | None = Field(None, min_length=1, max_length=255)
description: str | None = None
expertise: list[str] = Field(default_factory=list)
personality_prompt: str = Field(..., min_length=1)
primary_model: str = Field(..., min_length=1, max_length=100)
fallback_models: list[str] = Field(default_factory=list)
model_params: dict[str, Any] = Field(default_factory=dict)
mcp_servers: list[str] = Field(default_factory=list)
tool_permissions: dict[str, Any] = Field(default_factory=dict)
is_active: bool = True
@field_validator("slug")
@classmethod
def validate_slug(cls, v: str | None) -> str | None:
"""Validate slug format: lowercase, alphanumeric, hyphens only."""
if v is None:
return v
if not re.match(r"^[a-z0-9-]+$", v):
raise ValueError(
"Slug must contain only lowercase letters, numbers, and hyphens"
)
if v.startswith("-") or v.endswith("-"):
raise ValueError("Slug cannot start or end with a hyphen")
if "--" in v:
raise ValueError("Slug cannot contain consecutive hyphens")
return v
@field_validator("name")
@classmethod
def validate_name(cls, v: str) -> str:
"""Validate agent type name."""
if not v or v.strip() == "":
raise ValueError("Agent type name cannot be empty")
return v.strip()
@field_validator("expertise")
@classmethod
def validate_expertise(cls, v: list[str]) -> list[str]:
"""Validate and normalize expertise list."""
return [e.strip().lower() for e in v if e.strip()]
@field_validator("mcp_servers")
@classmethod
def validate_mcp_servers(cls, v: list[str]) -> list[str]:
"""Validate MCP server list."""
return [s.strip() for s in v if s.strip()]
class AgentTypeCreate(AgentTypeBase):
"""Schema for creating a new agent type."""
name: str = Field(..., min_length=1, max_length=255)
slug: str = Field(..., min_length=1, max_length=255)
personality_prompt: str = Field(..., min_length=1)
primary_model: str = Field(..., min_length=1, max_length=100)
class AgentTypeUpdate(BaseModel):
"""Schema for updating an agent type."""
name: str | None = Field(None, min_length=1, max_length=255)
slug: str | None = Field(None, min_length=1, max_length=255)
description: str | None = None
expertise: list[str] | None = None
personality_prompt: str | None = None
primary_model: str | None = Field(None, min_length=1, max_length=100)
fallback_models: list[str] | None = None
model_params: dict[str, Any] | None = None
mcp_servers: list[str] | None = None
tool_permissions: dict[str, Any] | None = None
is_active: bool | None = None
@field_validator("slug")
@classmethod
def validate_slug(cls, v: str | None) -> str | None:
"""Validate slug format."""
if v is None:
return v
if not re.match(r"^[a-z0-9-]+$", v):
raise ValueError(
"Slug must contain only lowercase letters, numbers, and hyphens"
)
if v.startswith("-") or v.endswith("-"):
raise ValueError("Slug cannot start or end with a hyphen")
if "--" in v:
raise ValueError("Slug cannot contain consecutive hyphens")
return v
@field_validator("name")
@classmethod
def validate_name(cls, v: str | None) -> str | None:
"""Validate agent type name."""
if v is not None and (not v or v.strip() == ""):
raise ValueError("Agent type name cannot be empty")
return v.strip() if v else v
@field_validator("expertise")
@classmethod
def validate_expertise(cls, v: list[str] | None) -> list[str] | None:
"""Validate and normalize expertise list."""
if v is None:
return v
return [e.strip().lower() for e in v if e.strip()]
class AgentTypeInDB(AgentTypeBase):
"""Schema for agent type in database."""
id: UUID
created_at: datetime
updated_at: datetime
model_config = ConfigDict(from_attributes=True)
class AgentTypeResponse(AgentTypeBase):
"""Schema for agent type API responses."""
id: UUID
created_at: datetime
updated_at: datetime
instance_count: int | None = 0
model_config = ConfigDict(from_attributes=True)
class AgentTypeListResponse(BaseModel):
"""Schema for paginated agent type list responses."""
agent_types: list[AgentTypeResponse]
total: int
page: int
page_size: int
pages: int

View File

@@ -1,26 +0,0 @@
# app/schemas/syndarix/enums.py
"""
Re-export enums from models for use in schemas.
This allows schemas to import enums without depending on SQLAlchemy models directly.
"""
from app.models.syndarix.enums import (
AgentStatus,
AutonomyLevel,
IssuePriority,
IssueStatus,
ProjectStatus,
SprintStatus,
SyncStatus,
)
__all__ = [
"AgentStatus",
"AutonomyLevel",
"IssuePriority",
"IssueStatus",
"ProjectStatus",
"SprintStatus",
"SyncStatus",
]

View File

@@ -1,191 +0,0 @@
# app/schemas/syndarix/issue.py
"""
Pydantic schemas for Issue entity.
"""
from datetime import datetime
from typing import Literal
from uuid import UUID
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from .enums import IssuePriority, IssueStatus, SyncStatus
class IssueBase(BaseModel):
"""Base issue schema with common fields."""
title: str = Field(..., min_length=1, max_length=500)
body: str = ""
status: IssueStatus = IssueStatus.OPEN
priority: IssuePriority = IssuePriority.MEDIUM
labels: list[str] = Field(default_factory=list)
story_points: int | None = Field(None, ge=0, le=100)
@field_validator("title")
@classmethod
def validate_title(cls, v: str) -> str:
"""Validate issue title."""
if not v or v.strip() == "":
raise ValueError("Issue title cannot be empty")
return v.strip()
@field_validator("labels")
@classmethod
def validate_labels(cls, v: list[str]) -> list[str]:
"""Validate and normalize labels."""
return [label.strip().lower() for label in v if label.strip()]
class IssueCreate(IssueBase):
"""Schema for creating a new issue."""
project_id: UUID
assigned_agent_id: UUID | None = None
human_assignee: str | None = Field(None, max_length=255)
sprint_id: UUID | None = None
# External tracker fields (optional, for importing from external systems)
external_tracker_type: Literal["gitea", "github", "gitlab"] | None = None
external_issue_id: str | None = Field(None, max_length=255)
remote_url: str | None = Field(None, max_length=1000)
external_issue_number: int | None = None
class IssueUpdate(BaseModel):
"""Schema for updating an issue."""
title: str | None = Field(None, min_length=1, max_length=500)
body: str | None = None
status: IssueStatus | None = None
priority: IssuePriority | None = None
labels: list[str] | None = None
assigned_agent_id: UUID | None = None
human_assignee: str | None = Field(None, max_length=255)
sprint_id: UUID | None = None
story_points: int | None = Field(None, ge=0, le=100)
sync_status: SyncStatus | None = None
@field_validator("title")
@classmethod
def validate_title(cls, v: str | None) -> str | None:
"""Validate issue title."""
if v is not None and (not v or v.strip() == ""):
raise ValueError("Issue title cannot be empty")
return v.strip() if v else v
@field_validator("labels")
@classmethod
def validate_labels(cls, v: list[str] | None) -> list[str] | None:
"""Validate and normalize labels."""
if v is None:
return v
return [label.strip().lower() for label in v if label.strip()]
class IssueClose(BaseModel):
"""Schema for closing an issue."""
resolution: str | None = None # Optional resolution note
class IssueAssign(BaseModel):
"""Schema for assigning an issue."""
assigned_agent_id: UUID | None = None
human_assignee: str | None = Field(None, max_length=255)
@model_validator(mode="after")
def validate_assignment(self) -> "IssueAssign":
"""Ensure only one type of assignee is set."""
if self.assigned_agent_id and self.human_assignee:
raise ValueError("Cannot assign to both an agent and a human. Choose one.")
return self
class IssueSyncUpdate(BaseModel):
"""Schema for updating sync-related fields."""
sync_status: SyncStatus
last_synced_at: datetime | None = None
external_updated_at: datetime | None = None
class IssueInDB(IssueBase):
"""Schema for issue in database."""
id: UUID
project_id: UUID
assigned_agent_id: UUID | None = None
human_assignee: str | None = None
sprint_id: UUID | None = None
external_tracker_type: str | None = None
external_issue_id: str | None = None
remote_url: str | None = None
external_issue_number: int | None = None
sync_status: SyncStatus = SyncStatus.SYNCED
last_synced_at: datetime | None = None
external_updated_at: datetime | None = None
closed_at: datetime | None = None
created_at: datetime
updated_at: datetime
model_config = ConfigDict(from_attributes=True)
class IssueResponse(BaseModel):
"""Schema for issue API responses."""
id: UUID
project_id: UUID
title: str
body: str
status: IssueStatus
priority: IssuePriority
labels: list[str] = Field(default_factory=list)
assigned_agent_id: UUID | None = None
human_assignee: str | None = None
sprint_id: UUID | None = None
story_points: int | None = None
external_tracker_type: str | None = None
external_issue_id: str | None = None
remote_url: str | None = None
external_issue_number: int | None = None
sync_status: SyncStatus = SyncStatus.SYNCED
last_synced_at: datetime | None = None
external_updated_at: datetime | None = None
closed_at: datetime | None = None
created_at: datetime
updated_at: datetime
# Expanded fields from relationships
project_name: str | None = None
project_slug: str | None = None
sprint_name: str | None = None
assigned_agent_type_name: str | None = None
model_config = ConfigDict(from_attributes=True)
class IssueListResponse(BaseModel):
"""Schema for paginated issue list responses."""
issues: list[IssueResponse]
total: int
page: int
page_size: int
pages: int
class IssueStats(BaseModel):
"""Schema for issue statistics."""
total: int
open: int
in_progress: int
in_review: int
blocked: int
closed: int
by_priority: dict[str, int]
total_story_points: int | None = None
completed_story_points: int | None = None

View File

@@ -1,131 +0,0 @@
# app/schemas/syndarix/project.py
"""
Pydantic schemas for Project entity.
"""
import re
from datetime import datetime
from typing import Any
from uuid import UUID
from pydantic import BaseModel, ConfigDict, Field, field_validator
from .enums import AutonomyLevel, ProjectStatus
class ProjectBase(BaseModel):
"""Base project schema with common fields."""
name: str = Field(..., min_length=1, max_length=255)
slug: str | None = Field(None, min_length=1, max_length=255)
description: str | None = None
autonomy_level: AutonomyLevel = AutonomyLevel.MILESTONE
status: ProjectStatus = ProjectStatus.ACTIVE
settings: dict[str, Any] = Field(default_factory=dict)
@field_validator("slug")
@classmethod
def validate_slug(cls, v: str | None) -> str | None:
"""Validate slug format: lowercase, alphanumeric, hyphens only."""
if v is None:
return v
if not re.match(r"^[a-z0-9-]+$", v):
raise ValueError(
"Slug must contain only lowercase letters, numbers, and hyphens"
)
if v.startswith("-") or v.endswith("-"):
raise ValueError("Slug cannot start or end with a hyphen")
if "--" in v:
raise ValueError("Slug cannot contain consecutive hyphens")
return v
@field_validator("name")
@classmethod
def validate_name(cls, v: str) -> str:
"""Validate project name."""
if not v or v.strip() == "":
raise ValueError("Project name cannot be empty")
return v.strip()
class ProjectCreate(ProjectBase):
"""Schema for creating a new project."""
name: str = Field(..., min_length=1, max_length=255)
slug: str = Field(..., min_length=1, max_length=255)
owner_id: UUID | None = None
class ProjectUpdate(BaseModel):
"""Schema for updating a project.
Note: owner_id is intentionally excluded to prevent IDOR vulnerabilities.
Project ownership transfer should be done via a dedicated endpoint with
proper authorization checks.
"""
name: str | None = Field(None, min_length=1, max_length=255)
slug: str | None = Field(None, min_length=1, max_length=255)
description: str | None = None
autonomy_level: AutonomyLevel | None = None
status: ProjectStatus | None = None
settings: dict[str, Any] | None = None
@field_validator("slug")
@classmethod
def validate_slug(cls, v: str | None) -> str | None:
"""Validate slug format."""
if v is None:
return v
if not re.match(r"^[a-z0-9-]+$", v):
raise ValueError(
"Slug must contain only lowercase letters, numbers, and hyphens"
)
if v.startswith("-") or v.endswith("-"):
raise ValueError("Slug cannot start or end with a hyphen")
if "--" in v:
raise ValueError("Slug cannot contain consecutive hyphens")
return v
@field_validator("name")
@classmethod
def validate_name(cls, v: str | None) -> str | None:
"""Validate project name."""
if v is not None and (not v or v.strip() == ""):
raise ValueError("Project name cannot be empty")
return v.strip() if v else v
class ProjectInDB(ProjectBase):
"""Schema for project in database."""
id: UUID
owner_id: UUID | None = None
created_at: datetime
updated_at: datetime
model_config = ConfigDict(from_attributes=True)
class ProjectResponse(ProjectBase):
"""Schema for project API responses."""
id: UUID
owner_id: UUID | None = None
created_at: datetime
updated_at: datetime
agent_count: int | None = 0
issue_count: int | None = 0
active_sprint_name: str | None = None
model_config = ConfigDict(from_attributes=True)
class ProjectListResponse(BaseModel):
"""Schema for paginated project list responses."""
projects: list[ProjectResponse]
total: int
page: int
page_size: int
pages: int

View File

@@ -1,135 +0,0 @@
# app/schemas/syndarix/sprint.py
"""
Pydantic schemas for Sprint entity.
"""
from datetime import date, datetime
from uuid import UUID
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from .enums import SprintStatus
class SprintBase(BaseModel):
"""Base sprint schema with common fields."""
name: str = Field(..., min_length=1, max_length=255)
number: int = Field(..., ge=1)
goal: str | None = None
start_date: date
end_date: date
status: SprintStatus = SprintStatus.PLANNED
planned_points: int | None = Field(None, ge=0)
velocity: int | None = Field(None, ge=0)
@field_validator("name")
@classmethod
def validate_name(cls, v: str) -> str:
"""Validate sprint name."""
if not v or v.strip() == "":
raise ValueError("Sprint name cannot be empty")
return v.strip()
@model_validator(mode="after")
def validate_dates(self) -> "SprintBase":
"""Validate that end_date is after start_date."""
if self.end_date < self.start_date:
raise ValueError("End date must be after or equal to start date")
return self
class SprintCreate(SprintBase):
"""Schema for creating a new sprint."""
project_id: UUID
class SprintUpdate(BaseModel):
"""Schema for updating a sprint."""
name: str | None = Field(None, min_length=1, max_length=255)
goal: str | None = None
start_date: date | None = None
end_date: date | None = None
status: SprintStatus | None = None
planned_points: int | None = Field(None, ge=0)
velocity: int | None = Field(None, ge=0)
@field_validator("name")
@classmethod
def validate_name(cls, v: str | None) -> str | None:
"""Validate sprint name."""
if v is not None and (not v or v.strip() == ""):
raise ValueError("Sprint name cannot be empty")
return v.strip() if v else v
class SprintStart(BaseModel):
"""Schema for starting a sprint."""
start_date: date | None = None # Optionally override start date
class SprintComplete(BaseModel):
"""Schema for completing a sprint."""
velocity: int | None = Field(None, ge=0)
notes: str | None = None
class SprintInDB(SprintBase):
"""Schema for sprint in database."""
id: UUID
project_id: UUID
created_at: datetime
updated_at: datetime
model_config = ConfigDict(from_attributes=True)
class SprintResponse(SprintBase):
"""Schema for sprint API responses."""
id: UUID
project_id: UUID
created_at: datetime
updated_at: datetime
# Expanded fields from relationships
project_name: str | None = None
project_slug: str | None = None
issue_count: int | None = 0
open_issues: int | None = 0
completed_issues: int | None = 0
model_config = ConfigDict(from_attributes=True)
class SprintListResponse(BaseModel):
"""Schema for paginated sprint list responses."""
sprints: list[SprintResponse]
total: int
page: int
page_size: int
pages: int
class SprintVelocity(BaseModel):
"""Schema for sprint velocity metrics."""
sprint_number: int
sprint_name: str
planned_points: int | None
velocity: int | None # Sum of completed story points
velocity_ratio: float | None # velocity/planned ratio
class SprintBurndown(BaseModel):
"""Schema for sprint burndown data point."""
date: date
remaining_points: int
ideal_remaining: float

View File

@@ -1,5 +1,19 @@
# app/services/__init__.py # app/services/__init__.py
from . import oauth_provider_service
from .auth_service import AuthService from .auth_service import AuthService
from .oauth_service import OAuthService from .oauth_service import OAuthService
from .organization_service import OrganizationService, organization_service
from .session_service import SessionService, session_service
from .user_service import UserService, user_service
__all__ = ["AuthService", "OAuthService"] __all__ = [
"AuthService",
"OAuthService",
"OrganizationService",
"SessionService",
"UserService",
"oauth_provider_service",
"organization_service",
"session_service",
"user_service",
]

View File

@@ -2,7 +2,6 @@
import logging import logging
from uuid import UUID from uuid import UUID
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.core.auth import ( from app.core.auth import (
@@ -14,12 +13,18 @@ from app.core.auth import (
verify_password_async, verify_password_async,
) )
from app.core.config import settings from app.core.config import settings
from app.core.exceptions import AuthenticationError from app.core.exceptions import AuthenticationError, DuplicateError
from app.core.repository_exceptions import DuplicateEntryError
from app.models.user import User from app.models.user import User
from app.repositories.user import user_repo
from app.schemas.users import Token, UserCreate, UserResponse from app.schemas.users import Token, UserCreate, UserResponse
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Pre-computed bcrypt hash used for constant-time comparison when user is not found,
# preventing timing attacks that could enumerate valid email addresses.
_DUMMY_HASH = "$2b$12$EixZaYVK1fsbw1ZfbX3OXePaWxn96p36zLFbnJHfxPSEFBzXKiHia"
class AuthService: class AuthService:
"""Service for handling authentication operations""" """Service for handling authentication operations"""
@@ -39,10 +44,12 @@ class AuthService:
Returns: Returns:
User if authenticated, None otherwise User if authenticated, None otherwise
""" """
result = await db.execute(select(User).where(User.email == email)) user = await user_repo.get_by_email(db, email=email)
user = result.scalar_one_or_none()
if not user: if not user:
# Perform a dummy verification to match timing of a real bcrypt check,
# preventing email enumeration via response-time differences.
await verify_password_async(password, _DUMMY_HASH)
return None return None
# Verify password asynchronously to avoid blocking event loop # Verify password asynchronously to avoid blocking event loop
@@ -71,40 +78,23 @@ class AuthService:
""" """
try: try:
# Check if user already exists # Check if user already exists
result = await db.execute(select(User).where(User.email == user_data.email)) existing_user = await user_repo.get_by_email(db, email=user_data.email)
existing_user = result.scalar_one_or_none()
if existing_user: if existing_user:
raise AuthenticationError("User with this email already exists") raise DuplicateError("User with this email already exists")
# Create new user with async password hashing # Delegate creation (hashing + commit) to the repository
# Hash password asynchronously to avoid blocking event loop user = await user_repo.create(db, obj_in=user_data)
hashed_password = await get_password_hash_async(user_data.password)
# Create user object from model logger.info("User created successfully: %s", user.email)
user = User(
email=user_data.email,
password_hash=hashed_password,
first_name=user_data.first_name,
last_name=user_data.last_name,
phone_number=user_data.phone_number,
is_active=True,
is_superuser=False,
)
db.add(user)
await db.commit()
await db.refresh(user)
logger.info(f"User created successfully: {user.email}")
return user return user
except AuthenticationError: except (AuthenticationError, DuplicateError):
# Re-raise authentication errors without rollback # Re-raise API exceptions without rollback
raise raise
except DuplicateEntryError as e:
raise DuplicateError(str(e))
except Exception as e: except Exception as e:
# Rollback on any database errors logger.exception("Error creating user: %s", e)
await db.rollback()
logger.error(f"Error creating user: {e!s}", exc_info=True)
raise AuthenticationError(f"Failed to create user: {e!s}") raise AuthenticationError(f"Failed to create user: {e!s}")
@staticmethod @staticmethod
@@ -168,8 +158,7 @@ class AuthService:
user_id = token_data.user_id user_id = token_data.user_id
# Get user from database # Get user from database
result = await db.execute(select(User).where(User.id == user_id)) user = await user_repo.get(db, id=str(user_id))
user = result.scalar_one_or_none()
if not user or not user.is_active: if not user or not user.is_active:
raise TokenInvalidError("Invalid user or inactive account") raise TokenInvalidError("Invalid user or inactive account")
@@ -177,7 +166,7 @@ class AuthService:
return AuthService.create_tokens(user) return AuthService.create_tokens(user)
except (TokenExpiredError, TokenInvalidError) as e: except (TokenExpiredError, TokenInvalidError) as e:
logger.warning(f"Token refresh failed: {e!s}") logger.warning("Token refresh failed: %s", e)
raise raise
@staticmethod @staticmethod
@@ -200,8 +189,7 @@ class AuthService:
AuthenticationError: If current password is incorrect or update fails AuthenticationError: If current password is incorrect or update fails
""" """
try: try:
result = await db.execute(select(User).where(User.id == user_id)) user = await user_repo.get(db, id=str(user_id))
user = result.scalar_one_or_none()
if not user: if not user:
raise AuthenticationError("User not found") raise AuthenticationError("User not found")
@@ -210,10 +198,10 @@ class AuthService:
raise AuthenticationError("Current password is incorrect") raise AuthenticationError("Current password is incorrect")
# Hash new password asynchronously to avoid blocking event loop # Hash new password asynchronously to avoid blocking event loop
user.password_hash = await get_password_hash_async(new_password) new_hash = await get_password_hash_async(new_password)
await db.commit() await user_repo.update_password(db, user=user, password_hash=new_hash)
logger.info(f"Password changed successfully for user {user_id}") logger.info("Password changed successfully for user %s", user_id)
return True return True
except AuthenticationError: except AuthenticationError:
@@ -222,7 +210,34 @@ class AuthService:
except Exception as e: except Exception as e:
# Rollback on any database errors # Rollback on any database errors
await db.rollback() await db.rollback()
logger.error( logger.exception("Error changing password for user %s: %s", user_id, e)
f"Error changing password for user {user_id}: {e!s}", exc_info=True
)
raise AuthenticationError(f"Failed to change password: {e!s}") raise AuthenticationError(f"Failed to change password: {e!s}")
@staticmethod
async def reset_password(
db: AsyncSession, *, email: str, new_password: str
) -> User:
"""
Reset a user's password without requiring the current password.
Args:
db: Database session
email: User email address
new_password: New password to set
Returns:
Updated user
Raises:
AuthenticationError: If user not found or inactive
"""
user = await user_repo.get_by_email(db, email=email)
if not user:
raise AuthenticationError("User not found")
if not user.is_active:
raise AuthenticationError("User account is inactive")
new_hash = await get_password_hash_async(new_password)
user = await user_repo.update_password(db, user=user, password_hash=new_hash)
logger.info("Password reset successfully for %s", email)
return user

View File

@@ -1,178 +0,0 @@
"""
Context Management Engine
Sophisticated context assembly and optimization for LLM requests.
Provides intelligent context selection, token budget management,
and model-specific formatting.
Usage:
from app.services.context import (
ContextSettings,
get_context_settings,
SystemContext,
KnowledgeContext,
ConversationContext,
TaskContext,
ToolContext,
TokenBudget,
BudgetAllocator,
TokenCalculator,
)
# Get settings
settings = get_context_settings()
# Create budget for a model
allocator = BudgetAllocator(settings)
budget = allocator.create_budget_for_model("claude-3-sonnet")
# Create context instances
system_ctx = SystemContext.create_persona(
name="Code Assistant",
description="You are a helpful code assistant.",
capabilities=["Write code", "Debug issues"],
)
"""
# Budget Management
# Adapters
from .adapters import (
ClaudeAdapter,
DefaultAdapter,
ModelAdapter,
OpenAIAdapter,
get_adapter,
)
# Assembly
from .assembly import (
ContextPipeline,
PipelineMetrics,
)
from .budget import (
BudgetAllocator,
TokenBudget,
TokenCalculator,
)
# Cache
from .cache import ContextCache
# Compression
from .compression import (
ContextCompressor,
TruncationResult,
TruncationStrategy,
)
# Configuration
from .config import (
ContextSettings,
get_context_settings,
get_default_settings,
reset_context_settings,
)
# Engine
from .engine import ContextEngine, create_context_engine
# Exceptions
from .exceptions import (
AssemblyTimeoutError,
BudgetExceededError,
CacheError,
CompressionError,
ContextError,
ContextNotFoundError,
FormattingError,
InvalidContextError,
ScoringError,
TokenCountError,
)
# Prioritization
from .prioritization import (
ContextRanker,
RankingResult,
)
# Scoring
from .scoring import (
BaseScorer,
CompositeScorer,
PriorityScorer,
RecencyScorer,
RelevanceScorer,
ScoredContext,
)
# Types
from .types import (
AssembledContext,
BaseContext,
ContextPriority,
ContextType,
ConversationContext,
KnowledgeContext,
MessageRole,
SystemContext,
TaskComplexity,
TaskContext,
TaskStatus,
ToolContext,
ToolResultStatus,
)
__all__ = [
"AssembledContext",
"AssemblyTimeoutError",
"BaseContext",
"BaseScorer",
"BudgetAllocator",
"BudgetExceededError",
"CacheError",
"ClaudeAdapter",
"CompositeScorer",
"CompressionError",
"ContextCache",
"ContextCompressor",
"ContextEngine",
"ContextError",
"ContextNotFoundError",
"ContextPipeline",
"ContextPriority",
"ContextRanker",
"ContextSettings",
"ContextType",
"ConversationContext",
"DefaultAdapter",
"FormattingError",
"InvalidContextError",
"KnowledgeContext",
"MessageRole",
"ModelAdapter",
"OpenAIAdapter",
"PipelineMetrics",
"PriorityScorer",
"RankingResult",
"RecencyScorer",
"RelevanceScorer",
"ScoredContext",
"ScoringError",
"SystemContext",
"TaskComplexity",
"TaskContext",
"TaskStatus",
"TokenBudget",
"TokenCalculator",
"TokenCountError",
"ToolContext",
"ToolResultStatus",
"TruncationResult",
"TruncationStrategy",
"create_context_engine",
"get_adapter",
"get_context_settings",
"get_default_settings",
"reset_context_settings",
]

View File

@@ -1,35 +0,0 @@
"""
Model Adapters Module.
Provides model-specific context formatting adapters.
"""
from .base import DefaultAdapter, ModelAdapter
from .claude import ClaudeAdapter
from .openai import OpenAIAdapter
def get_adapter(model: str) -> ModelAdapter:
"""
Get the appropriate adapter for a model.
Args:
model: Model name
Returns:
Adapter instance for the model
"""
if ClaudeAdapter.matches_model(model):
return ClaudeAdapter()
elif OpenAIAdapter.matches_model(model):
return OpenAIAdapter()
return DefaultAdapter()
__all__ = [
"ClaudeAdapter",
"DefaultAdapter",
"ModelAdapter",
"OpenAIAdapter",
"get_adapter",
]

View File

@@ -1,178 +0,0 @@
"""
Base Model Adapter.
Abstract base class for model-specific context formatting.
"""
from abc import ABC, abstractmethod
from typing import Any, ClassVar
from ..types import BaseContext, ContextType
class ModelAdapter(ABC):
"""
Abstract base adapter for model-specific context formatting.
Each adapter knows how to format contexts for optimal
understanding by a specific LLM family (Claude, OpenAI, etc.).
"""
# Model name patterns this adapter handles
MODEL_PATTERNS: ClassVar[list[str]] = []
@classmethod
def matches_model(cls, model: str) -> bool:
"""
Check if this adapter handles the given model.
Args:
model: Model name to check
Returns:
True if this adapter handles the model
"""
model_lower = model.lower()
return any(pattern in model_lower for pattern in cls.MODEL_PATTERNS)
@abstractmethod
def format(
self,
contexts: list[BaseContext],
**kwargs: Any,
) -> str:
"""
Format contexts for the target model.
Args:
contexts: List of contexts to format
**kwargs: Additional formatting options
Returns:
Formatted context string
"""
...
@abstractmethod
def format_type(
self,
contexts: list[BaseContext],
context_type: ContextType,
**kwargs: Any,
) -> str:
"""
Format contexts of a specific type.
Args:
contexts: List of contexts of the same type
context_type: The type of contexts
**kwargs: Additional formatting options
Returns:
Formatted string for this context type
"""
...
def get_type_order(self) -> list[ContextType]:
"""
Get the preferred order of context types.
Returns:
List of context types in preferred order
"""
return [
ContextType.SYSTEM,
ContextType.TASK,
ContextType.KNOWLEDGE,
ContextType.CONVERSATION,
ContextType.TOOL,
]
def group_by_type(
self, contexts: list[BaseContext]
) -> dict[ContextType, list[BaseContext]]:
"""
Group contexts by their type.
Args:
contexts: List of contexts to group
Returns:
Dictionary mapping context type to list of contexts
"""
by_type: dict[ContextType, list[BaseContext]] = {}
for context in contexts:
ct = context.get_type()
if ct not in by_type:
by_type[ct] = []
by_type[ct].append(context)
return by_type
def get_separator(self) -> str:
"""
Get the separator between context sections.
Returns:
Separator string
"""
return "\n\n"
class DefaultAdapter(ModelAdapter):
"""
Default adapter for unknown models.
Uses simple plain-text formatting with minimal structure.
"""
MODEL_PATTERNS: ClassVar[list[str]] = [] # Fallback adapter
@classmethod
def matches_model(cls, model: str) -> bool:
"""Always returns True as fallback."""
return True
def format(
self,
contexts: list[BaseContext],
**kwargs: Any,
) -> str:
"""Format contexts as plain text."""
if not contexts:
return ""
by_type = self.group_by_type(contexts)
parts: list[str] = []
for ct in self.get_type_order():
if ct in by_type:
formatted = self.format_type(by_type[ct], ct, **kwargs)
if formatted:
parts.append(formatted)
return self.get_separator().join(parts)
def format_type(
self,
contexts: list[BaseContext],
context_type: ContextType,
**kwargs: Any,
) -> str:
"""Format contexts of a type as plain text."""
if not contexts:
return ""
content = "\n\n".join(c.content for c in contexts)
if context_type == ContextType.SYSTEM:
return content
elif context_type == ContextType.TASK:
return f"Task:\n{content}"
elif context_type == ContextType.KNOWLEDGE:
return f"Reference Information:\n{content}"
elif context_type == ContextType.CONVERSATION:
return f"Previous Conversation:\n{content}"
elif context_type == ContextType.TOOL:
return f"Tool Results:\n{content}"
return content

View File

@@ -1,212 +0,0 @@
"""
Claude Model Adapter.
Provides Claude-specific context formatting using XML tags
which Claude models understand natively.
"""
from typing import Any, ClassVar
from ..types import BaseContext, ContextType
from .base import ModelAdapter
class ClaudeAdapter(ModelAdapter):
"""
Claude-specific context formatting adapter.
Claude models have native understanding of XML structure,
so we use XML tags for clear delineation of context types.
Features:
- XML tags for each context type
- Document structure for knowledge contexts
- Role-based message formatting for conversations
- Tool result wrapping with tool names
"""
MODEL_PATTERNS: ClassVar[list[str]] = ["claude", "anthropic"]
def format(
self,
contexts: list[BaseContext],
**kwargs: Any,
) -> str:
"""
Format contexts for Claude models.
Uses XML tags for structured content that Claude
understands natively.
Args:
contexts: List of contexts to format
**kwargs: Additional formatting options
Returns:
XML-structured context string
"""
if not contexts:
return ""
by_type = self.group_by_type(contexts)
parts: list[str] = []
for ct in self.get_type_order():
if ct in by_type:
formatted = self.format_type(by_type[ct], ct, **kwargs)
if formatted:
parts.append(formatted)
return self.get_separator().join(parts)
def format_type(
self,
contexts: list[BaseContext],
context_type: ContextType,
**kwargs: Any,
) -> str:
"""
Format contexts of a specific type for Claude.
Args:
contexts: List of contexts of the same type
context_type: The type of contexts
**kwargs: Additional formatting options
Returns:
XML-formatted string for this context type
"""
if not contexts:
return ""
if context_type == ContextType.SYSTEM:
return self._format_system(contexts)
elif context_type == ContextType.TASK:
return self._format_task(contexts)
elif context_type == ContextType.KNOWLEDGE:
return self._format_knowledge(contexts)
elif context_type == ContextType.CONVERSATION:
return self._format_conversation(contexts)
elif context_type == ContextType.TOOL:
return self._format_tool(contexts)
# Fallback for any unhandled context types - still escape content
# to prevent XML injection if new types are added without updating adapter
return "\n".join(self._escape_xml_content(c.content) for c in contexts)
def _format_system(self, contexts: list[BaseContext]) -> str:
"""Format system contexts."""
# System prompts are typically admin-controlled, but escape for safety
content = "\n\n".join(self._escape_xml_content(c.content) for c in contexts)
return f"<system_instructions>\n{content}\n</system_instructions>"
def _format_task(self, contexts: list[BaseContext]) -> str:
"""Format task contexts."""
content = "\n\n".join(self._escape_xml_content(c.content) for c in contexts)
return f"<current_task>\n{content}\n</current_task>"
def _format_knowledge(self, contexts: list[BaseContext]) -> str:
"""
Format knowledge contexts as structured documents.
Each knowledge context becomes a document with source attribution.
All content is XML-escaped to prevent injection attacks.
"""
parts = ["<reference_documents>"]
for ctx in contexts:
source = self._escape_xml(ctx.source)
# Escape content to prevent XML injection
content = self._escape_xml_content(ctx.content)
score = ctx.metadata.get("score", ctx.metadata.get("relevance_score", ""))
if score:
# Escape score to prevent XML injection via metadata
escaped_score = self._escape_xml(str(score))
parts.append(
f'<document source="{source}" relevance="{escaped_score}">'
)
else:
parts.append(f'<document source="{source}">')
parts.append(content)
parts.append("</document>")
parts.append("</reference_documents>")
return "\n".join(parts)
def _format_conversation(self, contexts: list[BaseContext]) -> str:
"""
Format conversation contexts as message history.
Uses role-based message tags for clear turn delineation.
All content is XML-escaped to prevent prompt injection.
"""
parts = ["<conversation_history>"]
for ctx in contexts:
role = self._escape_xml(ctx.metadata.get("role", "user"))
# Escape content to prevent prompt injection via fake XML tags
content = self._escape_xml_content(ctx.content)
parts.append(f'<message role="{role}">')
parts.append(content)
parts.append("</message>")
parts.append("</conversation_history>")
return "\n".join(parts)
def _format_tool(self, contexts: list[BaseContext]) -> str:
"""
Format tool contexts as tool results.
Each tool result is wrapped with the tool name.
All content is XML-escaped to prevent injection.
"""
parts = ["<tool_results>"]
for ctx in contexts:
tool_name = self._escape_xml(ctx.metadata.get("tool_name", "unknown"))
status = ctx.metadata.get("status", "")
if status:
parts.append(
f'<tool_result name="{tool_name}" status="{self._escape_xml(status)}">'
)
else:
parts.append(f'<tool_result name="{tool_name}">')
# Escape content to prevent injection
parts.append(self._escape_xml_content(ctx.content))
parts.append("</tool_result>")
parts.append("</tool_results>")
return "\n".join(parts)
@staticmethod
def _escape_xml(text: str) -> str:
"""Escape XML special characters in attribute values."""
return (
text.replace("&", "&amp;")
.replace("<", "&lt;")
.replace(">", "&gt;")
.replace('"', "&quot;")
.replace("'", "&apos;")
)
@staticmethod
def _escape_xml_content(text: str) -> str:
"""
Escape XML special characters in element content.
This prevents XML injection attacks where malicious content
could break out of XML tags or inject fake tags for prompt injection.
Only escapes &, <, > since quotes don't need escaping in content.
Args:
text: Content text to escape
Returns:
XML-safe content string
"""
return text.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;")

View File

@@ -1,160 +0,0 @@
"""
OpenAI Model Adapter.
Provides OpenAI-specific context formatting using markdown
which GPT models understand well.
"""
from typing import Any, ClassVar
from ..types import BaseContext, ContextType
from .base import ModelAdapter
class OpenAIAdapter(ModelAdapter):
"""
OpenAI-specific context formatting adapter.
GPT models work well with markdown formatting,
so we use headers and structured markdown for clarity.
Features:
- Markdown headers for each context type
- Bulleted lists for document sources
- Bold role labels for conversations
- Code blocks for tool outputs
"""
MODEL_PATTERNS: ClassVar[list[str]] = ["gpt", "openai", "o1", "o3"]
def format(
self,
contexts: list[BaseContext],
**kwargs: Any,
) -> str:
"""
Format contexts for OpenAI models.
Uses markdown formatting for structured content.
Args:
contexts: List of contexts to format
**kwargs: Additional formatting options
Returns:
Markdown-structured context string
"""
if not contexts:
return ""
by_type = self.group_by_type(contexts)
parts: list[str] = []
for ct in self.get_type_order():
if ct in by_type:
formatted = self.format_type(by_type[ct], ct, **kwargs)
if formatted:
parts.append(formatted)
return self.get_separator().join(parts)
def format_type(
self,
contexts: list[BaseContext],
context_type: ContextType,
**kwargs: Any,
) -> str:
"""
Format contexts of a specific type for OpenAI.
Args:
contexts: List of contexts of the same type
context_type: The type of contexts
**kwargs: Additional formatting options
Returns:
Markdown-formatted string for this context type
"""
if not contexts:
return ""
if context_type == ContextType.SYSTEM:
return self._format_system(contexts)
elif context_type == ContextType.TASK:
return self._format_task(contexts)
elif context_type == ContextType.KNOWLEDGE:
return self._format_knowledge(contexts)
elif context_type == ContextType.CONVERSATION:
return self._format_conversation(contexts)
elif context_type == ContextType.TOOL:
return self._format_tool(contexts)
return "\n".join(c.content for c in contexts)
def _format_system(self, contexts: list[BaseContext]) -> str:
"""Format system contexts."""
content = "\n\n".join(c.content for c in contexts)
return content
def _format_task(self, contexts: list[BaseContext]) -> str:
"""Format task contexts."""
content = "\n\n".join(c.content for c in contexts)
return f"## Current Task\n\n{content}"
def _format_knowledge(self, contexts: list[BaseContext]) -> str:
"""
Format knowledge contexts as structured documents.
Each knowledge context becomes a section with source attribution.
"""
parts = ["## Reference Documents\n"]
for ctx in contexts:
source = ctx.source
score = ctx.metadata.get("score", ctx.metadata.get("relevance_score", ""))
if score:
parts.append(f"### Source: {source} (relevance: {score})\n")
else:
parts.append(f"### Source: {source}\n")
parts.append(ctx.content)
parts.append("")
return "\n".join(parts)
def _format_conversation(self, contexts: list[BaseContext]) -> str:
"""
Format conversation contexts as message history.
Uses bold role labels for clear turn delineation.
"""
parts = []
for ctx in contexts:
role = ctx.metadata.get("role", "user").upper()
parts.append(f"**{role}**: {ctx.content}")
return "\n\n".join(parts)
def _format_tool(self, contexts: list[BaseContext]) -> str:
"""
Format tool contexts as tool results.
Each tool result is in a code block with the tool name.
"""
parts = ["## Recent Tool Results\n"]
for ctx in contexts:
tool_name = ctx.metadata.get("tool_name", "unknown")
status = ctx.metadata.get("status", "")
if status:
parts.append(f"### Tool: {tool_name} ({status})\n")
else:
parts.append(f"### Tool: {tool_name}\n")
parts.append(f"```\n{ctx.content}\n```")
parts.append("")
return "\n".join(parts)

View File

@@ -1,12 +0,0 @@
"""
Context Assembly Module.
Provides the assembly pipeline and formatting.
"""
from .pipeline import ContextPipeline, PipelineMetrics
__all__ = [
"ContextPipeline",
"PipelineMetrics",
]

View File

@@ -1,362 +0,0 @@
"""
Context Assembly Pipeline.
Orchestrates the full context assembly workflow:
Gather → Count → Score → Rank → Compress → Format
"""
import asyncio
import logging
import time
from dataclasses import dataclass, field
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any
from ..adapters import get_adapter
from ..budget import BudgetAllocator, TokenBudget, TokenCalculator
from ..compression.truncation import ContextCompressor
from ..config import ContextSettings, get_context_settings
from ..exceptions import AssemblyTimeoutError
from ..prioritization import ContextRanker
from ..scoring import CompositeScorer
from ..types import AssembledContext, BaseContext, ContextType
if TYPE_CHECKING:
from app.services.mcp.client_manager import MCPClientManager
logger = logging.getLogger(__name__)
@dataclass
class PipelineMetrics:
"""Metrics from pipeline execution."""
start_time: datetime = field(default_factory=lambda: datetime.now(UTC))
end_time: datetime | None = None
total_contexts: int = 0
selected_contexts: int = 0
excluded_contexts: int = 0
compressed_contexts: int = 0
total_tokens: int = 0
assembly_time_ms: float = 0.0
scoring_time_ms: float = 0.0
compression_time_ms: float = 0.0
formatting_time_ms: float = 0.0
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary."""
return {
"start_time": self.start_time.isoformat(),
"end_time": self.end_time.isoformat() if self.end_time else None,
"total_contexts": self.total_contexts,
"selected_contexts": self.selected_contexts,
"excluded_contexts": self.excluded_contexts,
"compressed_contexts": self.compressed_contexts,
"total_tokens": self.total_tokens,
"assembly_time_ms": round(self.assembly_time_ms, 2),
"scoring_time_ms": round(self.scoring_time_ms, 2),
"compression_time_ms": round(self.compression_time_ms, 2),
"formatting_time_ms": round(self.formatting_time_ms, 2),
}
class ContextPipeline:
"""
Context assembly pipeline.
Orchestrates the full workflow of context assembly:
1. Validate and count tokens for all contexts
2. Score contexts based on relevance, recency, and priority
3. Rank and select contexts within budget
4. Compress if needed to fit remaining budget
5. Format for the target model
"""
def __init__(
self,
mcp_manager: "MCPClientManager | None" = None,
settings: ContextSettings | None = None,
calculator: TokenCalculator | None = None,
scorer: CompositeScorer | None = None,
ranker: ContextRanker | None = None,
compressor: ContextCompressor | None = None,
) -> None:
"""
Initialize the context pipeline.
Args:
mcp_manager: MCP client manager for LLM Gateway integration
settings: Context settings
calculator: Token calculator
scorer: Context scorer
ranker: Context ranker
compressor: Context compressor
"""
self._settings = settings or get_context_settings()
self._mcp = mcp_manager
# Initialize components
self._calculator = calculator or TokenCalculator(mcp_manager=mcp_manager)
self._scorer = scorer or CompositeScorer(
mcp_manager=mcp_manager, settings=self._settings
)
self._ranker = ranker or ContextRanker(
scorer=self._scorer, calculator=self._calculator
)
self._compressor = compressor or ContextCompressor(calculator=self._calculator)
self._allocator = BudgetAllocator(self._settings)
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
"""Set MCP manager for all components."""
self._mcp = mcp_manager
self._calculator.set_mcp_manager(mcp_manager)
self._scorer.set_mcp_manager(mcp_manager)
async def assemble(
self,
contexts: list[BaseContext],
query: str,
model: str,
max_tokens: int | None = None,
custom_budget: TokenBudget | None = None,
compress: bool = True,
format_output: bool = True,
timeout_ms: int | None = None,
) -> AssembledContext:
"""
Assemble context for an LLM request.
This is the main entry point for context assembly.
Args:
contexts: List of contexts to assemble
query: Query to optimize for
model: Target model name
max_tokens: Maximum total tokens (uses model default if None)
custom_budget: Optional pre-configured budget
compress: Whether to compress oversized contexts
format_output: Whether to format the final output
timeout_ms: Maximum assembly time in milliseconds
Returns:
AssembledContext with optimized content
Raises:
AssemblyTimeoutError: If assembly exceeds timeout
"""
timeout = timeout_ms or self._settings.max_assembly_time_ms
start = time.perf_counter()
metrics = PipelineMetrics(total_contexts=len(contexts))
try:
# Create or use budget
if custom_budget:
budget = custom_budget
elif max_tokens:
budget = self._allocator.create_budget(max_tokens)
else:
budget = self._allocator.create_budget_for_model(model)
# 1. Count tokens for all contexts (with timeout enforcement)
try:
await asyncio.wait_for(
self._ensure_token_counts(contexts, model),
timeout=self._remaining_timeout(start, timeout),
)
except TimeoutError:
elapsed_ms = (time.perf_counter() - start) * 1000
raise AssemblyTimeoutError(
message="Context assembly timed out during token counting",
elapsed_ms=elapsed_ms,
timeout_ms=timeout,
)
# Check timeout (handles edge case where operation finished just at limit)
self._check_timeout(start, timeout, "token counting")
# 2. Score and rank contexts (with timeout enforcement)
scoring_start = time.perf_counter()
try:
ranking_result = await asyncio.wait_for(
self._ranker.rank(
contexts=contexts,
query=query,
budget=budget,
model=model,
),
timeout=self._remaining_timeout(start, timeout),
)
except TimeoutError:
elapsed_ms = (time.perf_counter() - start) * 1000
raise AssemblyTimeoutError(
message="Context assembly timed out during scoring/ranking",
elapsed_ms=elapsed_ms,
timeout_ms=timeout,
)
metrics.scoring_time_ms = (time.perf_counter() - scoring_start) * 1000
selected_contexts = ranking_result.selected_contexts
metrics.selected_contexts = len(selected_contexts)
metrics.excluded_contexts = len(ranking_result.excluded)
# Check timeout
self._check_timeout(start, timeout, "scoring")
# 3. Compress if needed and enabled (with timeout enforcement)
if compress and self._needs_compression(selected_contexts, budget):
compression_start = time.perf_counter()
try:
selected_contexts = await asyncio.wait_for(
self._compressor.compress_contexts(
selected_contexts, budget, model
),
timeout=self._remaining_timeout(start, timeout),
)
except TimeoutError:
elapsed_ms = (time.perf_counter() - start) * 1000
raise AssemblyTimeoutError(
message="Context assembly timed out during compression",
elapsed_ms=elapsed_ms,
timeout_ms=timeout,
)
metrics.compression_time_ms = (
time.perf_counter() - compression_start
) * 1000
metrics.compressed_contexts = sum(
1 for c in selected_contexts if c.metadata.get("truncated", False)
)
# Check timeout
self._check_timeout(start, timeout, "compression")
# 4. Format output
formatting_start = time.perf_counter()
if format_output:
formatted_content = self._format_contexts(selected_contexts, model)
else:
formatted_content = "\n\n".join(c.content for c in selected_contexts)
metrics.formatting_time_ms = (time.perf_counter() - formatting_start) * 1000
# Calculate final metrics
total_tokens = sum(c.token_count or 0 for c in selected_contexts)
metrics.total_tokens = total_tokens
metrics.assembly_time_ms = (time.perf_counter() - start) * 1000
metrics.end_time = datetime.now(UTC)
return AssembledContext(
content=formatted_content,
total_tokens=total_tokens,
context_count=len(selected_contexts),
assembly_time_ms=metrics.assembly_time_ms,
model=model,
contexts=selected_contexts,
excluded_count=metrics.excluded_contexts,
metadata={
"metrics": metrics.to_dict(),
"query": query,
"budget": budget.to_dict(),
},
)
except AssemblyTimeoutError:
raise
except Exception as e:
logger.error(f"Context assembly failed: {e}", exc_info=True)
raise
async def _ensure_token_counts(
self,
contexts: list[BaseContext],
model: str | None = None,
) -> None:
"""Ensure all contexts have token counts."""
tasks = []
for context in contexts:
if context.token_count is None:
tasks.append(self._count_and_set(context, model))
if tasks:
await asyncio.gather(*tasks)
async def _count_and_set(
self,
context: BaseContext,
model: str | None = None,
) -> None:
"""Count tokens and set on context."""
count = await self._calculator.count_tokens(context.content, model)
context.token_count = count
def _needs_compression(
self,
contexts: list[BaseContext],
budget: TokenBudget,
) -> bool:
"""Check if any contexts exceed their type budget."""
# Group by type and check totals
by_type: dict[ContextType, int] = {}
for context in contexts:
ct = context.get_type()
by_type[ct] = by_type.get(ct, 0) + (context.token_count or 0)
for ct, total in by_type.items():
if total > budget.get_allocation(ct):
return True
# Also check if utilization exceeds threshold
return budget.utilization() > self._settings.compression_threshold
def _format_contexts(
self,
contexts: list[BaseContext],
model: str,
) -> str:
"""
Format contexts for the target model.
Uses model-specific adapters (ClaudeAdapter, OpenAIAdapter, etc.)
to format contexts optimally for each model family.
Args:
contexts: Contexts to format
model: Target model name
Returns:
Formatted context string
"""
adapter = get_adapter(model)
return adapter.format(contexts)
def _check_timeout(
self,
start: float,
timeout_ms: int,
phase: str,
) -> None:
"""Check if timeout exceeded and raise if so."""
elapsed_ms = (time.perf_counter() - start) * 1000
if elapsed_ms >= timeout_ms:
raise AssemblyTimeoutError(
message=f"Context assembly timed out during {phase}",
elapsed_ms=elapsed_ms,
timeout_ms=timeout_ms,
)
def _remaining_timeout(self, start: float, timeout_ms: int) -> float:
"""
Calculate remaining timeout in seconds for asyncio.wait_for.
Returns at least a small positive value to avoid immediate timeout
edge cases with wait_for.
Args:
start: Start time from time.perf_counter()
timeout_ms: Total timeout in milliseconds
Returns:
Remaining timeout in seconds (minimum 0.001)
"""
elapsed_ms = (time.perf_counter() - start) * 1000
remaining_ms = timeout_ms - elapsed_ms
# Return at least 1ms to avoid zero/negative timeout edge cases
return max(remaining_ms / 1000.0, 0.001)

View File

@@ -1,14 +0,0 @@
"""
Token Budget Management Module.
Provides token counting and budget allocation.
"""
from .allocator import BudgetAllocator, TokenBudget
from .calculator import TokenCalculator
__all__ = [
"BudgetAllocator",
"TokenBudget",
"TokenCalculator",
]

View File

@@ -1,433 +0,0 @@
"""
Token Budget Allocator for Context Management.
Manages token budget allocation across context types.
"""
from dataclasses import dataclass, field
from typing import Any
from ..config import ContextSettings, get_context_settings
from ..exceptions import BudgetExceededError
from ..types import ContextType
@dataclass
class TokenBudget:
"""
Token budget allocation and tracking.
Tracks allocated tokens per context type and
monitors usage to prevent overflows.
"""
# Total budget
total: int
# Allocated per type
system: int = 0
task: int = 0
knowledge: int = 0
conversation: int = 0
tools: int = 0
response_reserve: int = 0
buffer: int = 0
# Usage tracking
used: dict[str, int] = field(default_factory=dict)
def __post_init__(self) -> None:
"""Initialize usage tracking."""
if not self.used:
self.used = {ct.value: 0 for ct in ContextType}
def get_allocation(self, context_type: ContextType | str) -> int:
"""
Get allocated tokens for a context type.
Args:
context_type: Context type to get allocation for
Returns:
Allocated token count
"""
if isinstance(context_type, ContextType):
context_type = context_type.value
allocation_map = {
"system": self.system,
"task": self.task,
"knowledge": self.knowledge,
"conversation": self.conversation,
"tool": self.tools,
}
return allocation_map.get(context_type, 0)
def get_used(self, context_type: ContextType | str) -> int:
"""
Get used tokens for a context type.
Args:
context_type: Context type to check
Returns:
Used token count
"""
if isinstance(context_type, ContextType):
context_type = context_type.value
return self.used.get(context_type, 0)
def remaining(self, context_type: ContextType | str) -> int:
"""
Get remaining tokens for a context type.
Args:
context_type: Context type to check
Returns:
Remaining token count
"""
allocated = self.get_allocation(context_type)
used = self.get_used(context_type)
return max(0, allocated - used)
def total_remaining(self) -> int:
"""
Get total remaining tokens across all types.
Returns:
Total remaining tokens
"""
total_used = sum(self.used.values())
usable = self.total - self.response_reserve - self.buffer
return max(0, usable - total_used)
def total_used(self) -> int:
"""
Get total used tokens.
Returns:
Total used tokens
"""
return sum(self.used.values())
def can_fit(self, context_type: ContextType | str, tokens: int) -> bool:
"""
Check if tokens fit within budget for a type.
Args:
context_type: Context type to check
tokens: Number of tokens to fit
Returns:
True if tokens fit within remaining budget
"""
return tokens <= self.remaining(context_type)
def allocate(
self,
context_type: ContextType | str,
tokens: int,
force: bool = False,
) -> bool:
"""
Allocate (use) tokens from a context type's budget.
Args:
context_type: Context type to allocate from
tokens: Number of tokens to allocate
force: If True, allow exceeding budget
Returns:
True if allocation succeeded
Raises:
BudgetExceededError: If tokens exceed budget and force=False
"""
if isinstance(context_type, ContextType):
context_type = context_type.value
if not force and not self.can_fit(context_type, tokens):
raise BudgetExceededError(
message=f"Token budget exceeded for {context_type}",
allocated=self.get_allocation(context_type),
requested=self.get_used(context_type) + tokens,
context_type=context_type,
)
self.used[context_type] = self.used.get(context_type, 0) + tokens
return True
def deallocate(
self,
context_type: ContextType | str,
tokens: int,
) -> None:
"""
Deallocate (return) tokens to a context type's budget.
Args:
context_type: Context type to return to
tokens: Number of tokens to return
"""
if isinstance(context_type, ContextType):
context_type = context_type.value
current = self.used.get(context_type, 0)
self.used[context_type] = max(0, current - tokens)
def reset(self) -> None:
"""Reset all usage tracking."""
self.used = {ct.value: 0 for ct in ContextType}
def utilization(self, context_type: ContextType | str | None = None) -> float:
"""
Get budget utilization percentage.
Args:
context_type: Specific type or None for total
Returns:
Utilization as a fraction (0.0 to 1.0+)
"""
if context_type is None:
usable = self.total - self.response_reserve - self.buffer
if usable <= 0:
return 0.0
return self.total_used() / usable
allocated = self.get_allocation(context_type)
if allocated <= 0:
return 0.0
return self.get_used(context_type) / allocated
def to_dict(self) -> dict[str, Any]:
"""Convert budget to dictionary."""
return {
"total": self.total,
"allocations": {
"system": self.system,
"task": self.task,
"knowledge": self.knowledge,
"conversation": self.conversation,
"tools": self.tools,
"response_reserve": self.response_reserve,
"buffer": self.buffer,
},
"used": dict(self.used),
"remaining": {ct.value: self.remaining(ct) for ct in ContextType},
"total_used": self.total_used(),
"total_remaining": self.total_remaining(),
"utilization": round(self.utilization(), 3),
}
class BudgetAllocator:
"""
Budget allocator for context management.
Creates token budgets based on configuration and
model context window sizes.
"""
def __init__(self, settings: ContextSettings | None = None) -> None:
"""
Initialize budget allocator.
Args:
settings: Context settings (uses default if None)
"""
self._settings = settings or get_context_settings()
def create_budget(
self,
total_tokens: int,
custom_allocations: dict[str, float] | None = None,
) -> TokenBudget:
"""
Create a token budget with allocations.
Args:
total_tokens: Total available tokens
custom_allocations: Optional custom allocation percentages
Returns:
TokenBudget with allocations set
"""
# Use custom or default allocations
if custom_allocations:
alloc = custom_allocations
else:
alloc = self._settings.get_budget_allocation()
return TokenBudget(
total=total_tokens,
system=int(total_tokens * alloc.get("system", 0.05)),
task=int(total_tokens * alloc.get("task", 0.10)),
knowledge=int(total_tokens * alloc.get("knowledge", 0.40)),
conversation=int(total_tokens * alloc.get("conversation", 0.20)),
tools=int(total_tokens * alloc.get("tools", 0.05)),
response_reserve=int(total_tokens * alloc.get("response", 0.15)),
buffer=int(total_tokens * alloc.get("buffer", 0.05)),
)
def adjust_budget(
self,
budget: TokenBudget,
context_type: ContextType | str,
adjustment: int,
) -> TokenBudget:
"""
Adjust a specific allocation in a budget.
Takes tokens from buffer and adds to specified type.
Args:
budget: Budget to adjust
context_type: Type to adjust
adjustment: Positive to increase, negative to decrease
Returns:
Adjusted budget
"""
if isinstance(context_type, ContextType):
context_type = context_type.value
# Calculate adjustment (limited by buffer for increases, by current allocation for decreases)
if adjustment > 0:
# Taking from buffer - limited by available buffer
actual_adjustment = min(adjustment, budget.buffer)
budget.buffer -= actual_adjustment
else:
# Returning to buffer - limited by current allocation of target type
current_allocation = budget.get_allocation(context_type)
# Can't return more than current allocation
actual_adjustment = max(adjustment, -current_allocation)
# Add returned tokens back to buffer (adjustment is negative, so subtract)
budget.buffer -= actual_adjustment
# Apply to target type
if context_type == "system":
budget.system = max(0, budget.system + actual_adjustment)
elif context_type == "task":
budget.task = max(0, budget.task + actual_adjustment)
elif context_type == "knowledge":
budget.knowledge = max(0, budget.knowledge + actual_adjustment)
elif context_type == "conversation":
budget.conversation = max(0, budget.conversation + actual_adjustment)
elif context_type == "tool":
budget.tools = max(0, budget.tools + actual_adjustment)
return budget
def rebalance_budget(
self,
budget: TokenBudget,
prioritize: list[ContextType] | None = None,
) -> TokenBudget:
"""
Rebalance budget based on actual usage.
Moves unused allocations to prioritized types.
Args:
budget: Budget to rebalance
prioritize: Types to prioritize (in order)
Returns:
Rebalanced budget
"""
if prioritize is None:
prioritize = [ContextType.KNOWLEDGE, ContextType.TASK, ContextType.SYSTEM]
# Calculate unused tokens per type
unused: dict[str, int] = {}
for ct in ContextType:
remaining = budget.remaining(ct)
if remaining > 0:
unused[ct.value] = remaining
# Calculate total reclaimable (excluding prioritized types)
prioritize_values = {ct.value for ct in prioritize}
reclaimable = sum(
tokens for ct, tokens in unused.items() if ct not in prioritize_values
)
# Redistribute to prioritized types that are near capacity
for ct in prioritize:
utilization = budget.utilization(ct)
if utilization > 0.8: # Near capacity
# Give more tokens from reclaimable pool
bonus = min(reclaimable, budget.get_allocation(ct) // 2)
self.adjust_budget(budget, ct, bonus)
reclaimable -= bonus
if reclaimable <= 0:
break
return budget
def get_model_context_size(self, model: str) -> int:
"""
Get context window size for a model.
Args:
model: Model name
Returns:
Context window size in tokens
"""
# Common model context sizes
context_sizes = {
"claude-3-opus": 200000,
"claude-3-sonnet": 200000,
"claude-3-haiku": 200000,
"claude-3-5-sonnet": 200000,
"claude-3-5-haiku": 200000,
"claude-opus-4": 200000,
"gpt-4-turbo": 128000,
"gpt-4": 8192,
"gpt-4-32k": 32768,
"gpt-4o": 128000,
"gpt-4o-mini": 128000,
"gpt-3.5-turbo": 16385,
"gemini-1.5-pro": 2000000,
"gemini-1.5-flash": 1000000,
"gemini-2.0-flash": 1000000,
"qwen-plus": 32000,
"qwen-turbo": 8000,
"deepseek-chat": 64000,
"deepseek-reasoner": 64000,
}
# Check exact match first
model_lower = model.lower()
if model_lower in context_sizes:
return context_sizes[model_lower]
# Check prefix match
for model_name, size in context_sizes.items():
if model_lower.startswith(model_name):
return size
# Default fallback
return 8192
def create_budget_for_model(
self,
model: str,
custom_allocations: dict[str, float] | None = None,
) -> TokenBudget:
"""
Create a budget based on model's context window.
Args:
model: Model name
custom_allocations: Optional custom allocation percentages
Returns:
TokenBudget sized for the model
"""
context_size = self.get_model_context_size(model)
return self.create_budget(context_size, custom_allocations)

Some files were not shown because too many files have changed in this diff Show More