Compare commits
20 Commits
dev
...
2104ae38ec
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2104ae38ec | ||
|
|
2055320058 | ||
|
|
11da0d57a8 | ||
|
|
acfda1e9a9 | ||
|
|
3c24a8c522 | ||
|
|
ec111f9ce6 | ||
|
|
520a4d60fb | ||
|
|
fcda8f0f96 | ||
|
|
d6db6af964 | ||
|
|
88cf4e0abc | ||
|
|
f138417486 | ||
|
|
de47d9ee43 | ||
|
|
406b25cda0 | ||
|
|
bd702734c2 | ||
|
|
5594655fba | ||
|
|
ebd307cab4 | ||
|
|
6e3cdebbfb | ||
|
|
a6a336b66e | ||
|
|
9901dc7f51 | ||
|
|
ac64d9505e |
2
.github/workflows/README.md
vendored
2
.github/workflows/README.md
vendored
@@ -41,7 +41,7 @@ To enable CI/CD workflows:
|
||||
- Runs on: Push to main/develop, PRs affecting frontend code
|
||||
- Tests: Frontend unit tests (Jest)
|
||||
- Coverage: Uploads to Codecov
|
||||
- Fast: Uses bun cache
|
||||
- Fast: Uses npm cache
|
||||
|
||||
### `e2e-tests.yml`
|
||||
- Runs on: All pushes and PRs
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -187,7 +187,7 @@ coverage.xml
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
backend/.benchmarks
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
40
AGENTS.md
40
AGENTS.md
@@ -13,10 +13,10 @@ uv run uvicorn app.main:app --reload # Start dev server
|
||||
|
||||
# Frontend (Node.js)
|
||||
cd frontend
|
||||
bun install # Install dependencies
|
||||
bun run dev # Start dev server
|
||||
bun run generate:api # Generate API client from OpenAPI
|
||||
bun run test:e2e # Run E2E tests
|
||||
npm install # Install dependencies
|
||||
npm run dev # Start dev server
|
||||
npm run generate:api # Generate API client from OpenAPI
|
||||
npm run test:e2e # Run E2E tests
|
||||
```
|
||||
|
||||
**Access points:**
|
||||
@@ -37,7 +37,7 @@ Default superuser (change in production):
|
||||
│ ├── app/
|
||||
│ │ ├── api/ # API routes (auth, users, organizations, admin)
|
||||
│ │ ├── core/ # Core functionality (auth, config, database)
|
||||
│ │ ├── repositories/ # Repository pattern (database operations)
|
||||
│ │ ├── crud/ # Database CRUD operations
|
||||
│ │ ├── models/ # SQLAlchemy ORM models
|
||||
│ │ ├── schemas/ # Pydantic request/response schemas
|
||||
│ │ ├── services/ # Business logic layer
|
||||
@@ -113,7 +113,7 @@ OAUTH_ISSUER=https://api.yourdomain.com # JWT issuer URL (must be HTTPS in
|
||||
### Database Pattern
|
||||
- **Async SQLAlchemy 2.0** with PostgreSQL
|
||||
- **Connection pooling**: 20 base connections, 50 max overflow
|
||||
- **Repository base class**: `repositories/base.py` with common operations
|
||||
- **CRUD base class**: `crud/base.py` with common operations
|
||||
- **Migrations**: Alembic with helper script `migrate.py`
|
||||
- `python migrate.py auto "message"` - Generate and apply
|
||||
- `python migrate.py list` - View history
|
||||
@@ -121,7 +121,7 @@ OAUTH_ISSUER=https://api.yourdomain.com # JWT issuer URL (must be HTTPS in
|
||||
### Frontend State Management
|
||||
- **Zustand stores**: Lightweight state management
|
||||
- **TanStack Query**: API data fetching/caching
|
||||
- **Auto-generated client**: From OpenAPI spec via `bun run generate:api`
|
||||
- **Auto-generated client**: From OpenAPI spec via `npm run generate:api`
|
||||
- **Dependency Injection**: ALWAYS use `useAuth()` from `AuthContext`, NEVER import `useAuthStore` directly
|
||||
|
||||
### Internationalization (i18n)
|
||||
@@ -165,25 +165,21 @@ Permission dependencies in `api/dependencies/permissions.py`:
|
||||
**Frontend Unit Tests (Jest):**
|
||||
- 97% coverage
|
||||
- Component, hook, and utility testing
|
||||
- Run: `bun run test`
|
||||
- Coverage: `bun run test:coverage`
|
||||
- Run: `npm test`
|
||||
- Coverage: `npm run test:coverage`
|
||||
|
||||
**Frontend E2E Tests (Playwright):**
|
||||
- 56 passing, 1 skipped (zero flaky tests)
|
||||
- Complete user flows (auth, navigation, settings)
|
||||
- Run: `bun run test:e2e`
|
||||
- UI mode: `bun run test:e2e:ui`
|
||||
- Run: `npm run test:e2e`
|
||||
- UI mode: `npm run test:e2e:ui`
|
||||
|
||||
### Development Tooling
|
||||
|
||||
**Backend:**
|
||||
- **uv**: Modern Python package manager (10-100x faster than pip)
|
||||
- **Ruff**: All-in-one linting/formatting (replaces Black, Flake8, isort)
|
||||
- **Pyright**: Static type checking (strict mode)
|
||||
- **pip-audit**: Dependency vulnerability scanning (OSV database)
|
||||
- **detect-secrets**: Hardcoded secrets detection
|
||||
- **pip-licenses**: License compliance checking
|
||||
- **pre-commit**: Git hook framework (Ruff, detect-secrets, standard checks)
|
||||
- **mypy**: Type checking with Pydantic plugin
|
||||
- **Makefile**: `make help` for all commands
|
||||
|
||||
**Frontend:**
|
||||
@@ -222,11 +218,11 @@ NEXT_PUBLIC_API_URL=http://localhost:8000/api/v1
|
||||
### Adding a New API Endpoint
|
||||
|
||||
1. **Define schema** in `backend/app/schemas/`
|
||||
2. **Create repository** in `backend/app/repositories/`
|
||||
2. **Create CRUD operations** in `backend/app/crud/`
|
||||
3. **Implement route** in `backend/app/api/routes/`
|
||||
4. **Register router** in `backend/app/api/main.py`
|
||||
5. **Write tests** in `backend/tests/api/`
|
||||
6. **Generate frontend client**: `bun run generate:api`
|
||||
6. **Generate frontend client**: `npm run generate:api`
|
||||
|
||||
### Database Migrations
|
||||
|
||||
@@ -243,7 +239,7 @@ python migrate.py auto "description" # Generate + apply
|
||||
2. **Follow design system** (see `frontend/docs/design-system/`)
|
||||
3. **Use dependency injection** for auth (`useAuth()` not `useAuthStore`)
|
||||
4. **Write tests** in `frontend/tests/` or `__tests__/`
|
||||
5. **Run type check**: `bun run type-check`
|
||||
5. **Run type check**: `npm run type-check`
|
||||
|
||||
## Security Features
|
||||
|
||||
@@ -253,10 +249,6 @@ python migrate.py auto "description" # Generate + apply
|
||||
- **CSRF protection**: Built into FastAPI
|
||||
- **Session revocation**: Database-backed session tracking
|
||||
- **Comprehensive security tests**: JWT algorithm attacks, session hijacking, privilege escalation
|
||||
- **Dependency vulnerability scanning**: `make dep-audit` (pip-audit against OSV database)
|
||||
- **License compliance**: `make license-check` (blocks GPL-3.0/AGPL)
|
||||
- **Secrets detection**: Pre-commit hook blocks hardcoded secrets
|
||||
- **Unified security pipeline**: `make audit` (all security checks), `make check` (quality + security + tests)
|
||||
|
||||
## Docker Deployment
|
||||
|
||||
@@ -289,7 +281,7 @@ docker-compose exec backend python -c "from app.init_db import init_db; import a
|
||||
- Authentication system (JWT with refresh tokens, OAuth/social login)
|
||||
- **OAuth Provider Mode (MCP-ready)**: Full OAuth 2.0 Authorization Server
|
||||
- Session management (device tracking, revocation)
|
||||
- User management (full lifecycle, password change)
|
||||
- User management (CRUD, password change)
|
||||
- Organization system (multi-tenant with RBAC)
|
||||
- Admin panel (user/org management, bulk operations)
|
||||
- **Internationalization (i18n)** with English and Italian
|
||||
|
||||
97
CLAUDE.md
97
CLAUDE.md
@@ -1,8 +1,71 @@
|
||||
# CLAUDE.md
|
||||
|
||||
Claude Code context for FastAPI + Next.js Full-Stack Template.
|
||||
Claude Code context for **Syndarix** - AI-Powered Software Consulting Agency.
|
||||
|
||||
**See [AGENTS.md](./AGENTS.md) for project context, architecture, and development commands.**
|
||||
**Built on PragmaStack.** See [AGENTS.md](./AGENTS.md) for base template context.
|
||||
|
||||
---
|
||||
|
||||
## Syndarix Project Context
|
||||
|
||||
### Vision
|
||||
Syndarix is an autonomous platform that orchestrates specialized AI agents to deliver complete software solutions with minimal human intervention. It acts as a virtual consulting agency with AI agents playing roles like Product Owner, Architect, Engineers, QA, etc.
|
||||
|
||||
### Repository
|
||||
- **URL:** https://gitea.pragmazest.com/cardosofelipe/syndarix
|
||||
- **Issue Tracker:** Gitea Issues (primary)
|
||||
- **CI/CD:** Gitea Actions
|
||||
|
||||
### Core Concepts
|
||||
|
||||
**Agent Types & Instances:**
|
||||
- Agent Type = Template (base model, failover, expertise, personality)
|
||||
- Agent Instance = Spawned from type, assigned to project
|
||||
- Multiple instances of same type can work together
|
||||
|
||||
**Project Workflow:**
|
||||
1. Requirements discovery with Product Owner agent
|
||||
2. Architecture spike (PO + BA + Architect brainstorm)
|
||||
3. Implementation planning and backlog creation
|
||||
4. Autonomous sprint execution with checkpoints
|
||||
5. Demo and client feedback
|
||||
|
||||
**Autonomy Levels:**
|
||||
- `FULL_CONTROL`: Approve every action
|
||||
- `MILESTONE`: Approve sprint boundaries
|
||||
- `AUTONOMOUS`: Only major decisions
|
||||
|
||||
**MCP-First Architecture:**
|
||||
All integrations via Model Context Protocol servers with explicit scoping:
|
||||
```python
|
||||
# All tools take project_id for scoping
|
||||
search_knowledge(project_id="proj-123", query="auth flow")
|
||||
create_issue(project_id="proj-123", title="Add login")
|
||||
```
|
||||
|
||||
### Syndarix-Specific Directories
|
||||
```
|
||||
docs/
|
||||
├── requirements/ # Requirements documents
|
||||
├── architecture/ # Architecture documentation
|
||||
├── adrs/ # Architecture Decision Records
|
||||
└── spikes/ # Spike research documents
|
||||
```
|
||||
|
||||
### Current Phase
|
||||
**Architecture Spikes** - Validating key decisions before implementation.
|
||||
|
||||
### Key Extensions to Add (from PragmaStack base)
|
||||
- Celery + Redis for agent job queue
|
||||
- WebSocket/SSE for real-time updates
|
||||
- pgvector for RAG knowledge base
|
||||
- MCP server integration layer
|
||||
|
||||
---
|
||||
|
||||
## PragmaStack Development Guidelines
|
||||
|
||||
*The following guidelines are inherited from PragmaStack and remain applicable.*
|
||||
|
||||
## Claude Code-Specific Guidance
|
||||
|
||||
@@ -43,7 +106,7 @@ EOF
|
||||
- Check current state: `python migrate.py current`
|
||||
|
||||
**Frontend API Client Generation:**
|
||||
- Run `bun run generate:api` after backend schema changes
|
||||
- Run `npm run generate:api` after backend schema changes
|
||||
- Client is auto-generated from OpenAPI spec
|
||||
- Located in `frontend/src/lib/api/generated/`
|
||||
- NEVER manually edit generated files
|
||||
@@ -51,16 +114,10 @@ EOF
|
||||
**Testing Commands:**
|
||||
- Backend unit/integration: `IS_TEST=True uv run pytest` (always prefix with `IS_TEST=True`)
|
||||
- Backend E2E (requires Docker): `make test-e2e`
|
||||
- Frontend unit: `bun run test`
|
||||
- Frontend E2E: `bun run test:e2e`
|
||||
- Frontend unit: `npm test`
|
||||
- Frontend E2E: `npm run test:e2e`
|
||||
- Use `make test` or `make test-cov` in backend for convenience
|
||||
|
||||
**Security & Quality Commands (Backend):**
|
||||
- `make validate` — lint + format + type checks
|
||||
- `make audit` — dependency vulnerabilities + license compliance
|
||||
- `make validate-all` — quality + security checks
|
||||
- `make check` — **full pipeline**: quality + security + tests
|
||||
|
||||
**Backend E2E Testing (requires Docker):**
|
||||
- Install deps: `make install-e2e`
|
||||
- Run all E2E tests: `make test-e2e`
|
||||
@@ -148,7 +205,7 @@ async def mock_commit():
|
||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
|
||||
with pytest.raises(OperationalError):
|
||||
await repo_method(session, obj_in=data)
|
||||
await crud_method(session, obj_in=data)
|
||||
mock_rollback.assert_called_once()
|
||||
```
|
||||
|
||||
@@ -163,18 +220,14 @@ with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
- Never skip security headers in production
|
||||
- Rate limiting is configured in route decorators: `@limiter.limit("10/minute")`
|
||||
- Session revocation is database-backed, not just JWT expiry
|
||||
- Run `make audit` to check for dependency vulnerabilities and license compliance
|
||||
- Run `make check` for the full pipeline: quality + security + tests
|
||||
- Pre-commit hooks enforce Ruff lint/format and detect-secrets on every commit
|
||||
- Setup hooks: `cd backend && uv run pre-commit install`
|
||||
|
||||
### Common Workflows Guidance
|
||||
|
||||
**When Adding a New Feature:**
|
||||
1. Start with backend schema and repository
|
||||
1. Start with backend schema and CRUD
|
||||
2. Implement API route with proper authorization
|
||||
3. Write backend tests (aim for >90% coverage)
|
||||
4. Generate frontend API client: `bun run generate:api`
|
||||
4. Generate frontend API client: `npm run generate:api`
|
||||
5. Implement frontend components
|
||||
6. Write frontend unit tests
|
||||
7. Add E2E tests for critical flows
|
||||
@@ -187,8 +240,8 @@ with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
|
||||
**When Debugging:**
|
||||
- Backend: Check `IS_TEST=True` environment variable is set
|
||||
- Frontend: Run `bun run type-check` first
|
||||
- E2E: Use `bun run test:e2e:debug` for step-by-step debugging
|
||||
- Frontend: Run `npm run type-check` first
|
||||
- E2E: Use `npm run test:e2e:debug` for step-by-step debugging
|
||||
- Check logs: Backend has detailed error logging
|
||||
|
||||
**Demo Mode (Frontend-Only Showcase):**
|
||||
@@ -196,7 +249,7 @@ with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
- Uses MSW (Mock Service Worker) to intercept API calls in browser
|
||||
- Zero backend required - perfect for Vercel deployments
|
||||
- **Fully Automated**: MSW handlers auto-generated from OpenAPI spec
|
||||
- Run `bun run generate:api` → updates both API client AND MSW handlers
|
||||
- Run `npm run generate:api` → updates both API client AND MSW handlers
|
||||
- No manual synchronization needed!
|
||||
- Demo credentials (any password ≥8 chars works):
|
||||
- User: `demo@example.com` / `DemoPass123`
|
||||
@@ -224,7 +277,7 @@ with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
No Claude Code Skills installed yet. To create one, invoke the built-in "skill-creator" skill.
|
||||
|
||||
**Potential skill ideas for this project:**
|
||||
- API endpoint generator workflow (schema → repository → route → tests → frontend client)
|
||||
- API endpoint generator workflow (schema → CRUD → route → tests → frontend client)
|
||||
- Component generator with design system compliance
|
||||
- Database migration troubleshooting helper
|
||||
- Test coverage analyzer and improvement suggester
|
||||
|
||||
@@ -91,10 +91,7 @@ Ready to write some code? Awesome!
|
||||
cd backend
|
||||
|
||||
# Install dependencies (uv manages virtual environment automatically)
|
||||
make install-dev
|
||||
|
||||
# Setup pre-commit hooks
|
||||
uv run pre-commit install
|
||||
uv sync
|
||||
|
||||
# Setup environment
|
||||
cp .env.example .env
|
||||
@@ -103,14 +100,8 @@ cp .env.example .env
|
||||
# Run migrations
|
||||
python migrate.py apply
|
||||
|
||||
# Run quality + security checks
|
||||
make validate-all
|
||||
|
||||
# Run tests
|
||||
make test
|
||||
|
||||
# Run full pipeline (quality + security + tests)
|
||||
make check
|
||||
IS_TEST=True uv run pytest
|
||||
|
||||
# Start dev server
|
||||
uvicorn app.main:app --reload
|
||||
@@ -122,20 +113,20 @@ uvicorn app.main:app --reload
|
||||
cd frontend
|
||||
|
||||
# Install dependencies
|
||||
bun install
|
||||
npm install
|
||||
|
||||
# Setup environment
|
||||
cp .env.local.example .env.local
|
||||
|
||||
# Generate API client
|
||||
bun run generate:api
|
||||
npm run generate:api
|
||||
|
||||
# Run tests
|
||||
bun run test
|
||||
bun run test:e2e:ui
|
||||
npm test
|
||||
npm run test:e2e:ui
|
||||
|
||||
# Start dev server
|
||||
bun run dev
|
||||
npm run dev
|
||||
```
|
||||
|
||||
---
|
||||
@@ -204,7 +195,7 @@ export function UserProfile({ userId }: UserProfileProps) {
|
||||
|
||||
### Key Patterns
|
||||
|
||||
- **Backend**: Use repository pattern, keep routes thin, business logic in services
|
||||
- **Backend**: Use CRUD pattern, keep routes thin, business logic in services
|
||||
- **Frontend**: Use React Query for server state, Zustand for client state
|
||||
- **Both**: Handle errors gracefully, log appropriately, write tests
|
||||
|
||||
@@ -325,7 +316,7 @@ Fixed stuff
|
||||
### Before Submitting
|
||||
|
||||
- [ ] Code follows project style guidelines
|
||||
- [ ] `make check` passes (quality + security + tests) in backend
|
||||
- [ ] All tests pass locally
|
||||
- [ ] New tests added for new features
|
||||
- [ ] Documentation updated if needed
|
||||
- [ ] No merge conflicts with `main`
|
||||
|
||||
25
Makefile
25
Makefile
@@ -1,4 +1,4 @@
|
||||
.PHONY: help dev dev-full prod down logs logs-dev clean clean-slate drop-db reset-db push-images deploy scan-images
|
||||
.PHONY: help dev dev-full prod down logs logs-dev clean clean-slate drop-db reset-db push-images deploy
|
||||
|
||||
VERSION ?= latest
|
||||
REGISTRY ?= ghcr.io/cardosofelipe/pragma-stack
|
||||
@@ -21,7 +21,6 @@ help:
|
||||
@echo " make prod - Start production stack"
|
||||
@echo " make deploy - Pull and deploy latest images"
|
||||
@echo " make push-images - Build and push images to registry"
|
||||
@echo " make scan-images - Scan production images for CVEs (requires trivy)"
|
||||
@echo " make logs - Follow production container logs"
|
||||
@echo ""
|
||||
@echo "Cleanup:"
|
||||
@@ -90,28 +89,6 @@ push-images:
|
||||
docker push $(REGISTRY)/backend:$(VERSION)
|
||||
docker push $(REGISTRY)/frontend:$(VERSION)
|
||||
|
||||
scan-images:
|
||||
@docker info > /dev/null 2>&1 || (echo "❌ Docker is not running!"; exit 1)
|
||||
@echo "🐳 Building and scanning production images for CVEs..."
|
||||
docker build -t $(REGISTRY)/backend:scan --target production ./backend
|
||||
docker build -t $(REGISTRY)/frontend:scan --target runner ./frontend
|
||||
@echo ""
|
||||
@echo "=== Backend Image Scan ==="
|
||||
@if command -v trivy > /dev/null 2>&1; then \
|
||||
trivy image --severity HIGH,CRITICAL --exit-code 1 $(REGISTRY)/backend:scan; \
|
||||
else \
|
||||
echo "ℹ️ Trivy not found locally, using Docker to run Trivy..."; \
|
||||
docker run --rm -v /var/run/docker.sock:/var/run/docker.sock aquasec/trivy image --severity HIGH,CRITICAL --exit-code 1 $(REGISTRY)/backend:scan; \
|
||||
fi
|
||||
@echo ""
|
||||
@echo "=== Frontend Image Scan ==="
|
||||
@if command -v trivy > /dev/null 2>&1; then \
|
||||
trivy image --severity HIGH,CRITICAL --exit-code 1 $(REGISTRY)/frontend:scan; \
|
||||
else \
|
||||
docker run --rm -v /var/run/docker.sock:/var/run/docker.sock aquasec/trivy image --severity HIGH,CRITICAL --exit-code 1 $(REGISTRY)/frontend:scan; \
|
||||
fi
|
||||
@echo "✅ No HIGH/CRITICAL CVEs found in production images!"
|
||||
|
||||
# ============================================================================
|
||||
# Cleanup
|
||||
# ============================================================================
|
||||
|
||||
724
README.md
724
README.md
@@ -1,659 +1,175 @@
|
||||
# <img src="frontend/public/logo.svg" alt="PragmaStack" width="32" height="32" style="vertical-align: middle" /> PragmaStack
|
||||
# Syndarix
|
||||
|
||||
> **The Pragmatic Full-Stack Template. Production-ready, security-first, and opinionated.**
|
||||
> **Your AI-Powered Software Consulting Agency**
|
||||
>
|
||||
> An autonomous platform that orchestrates specialized AI agents to deliver complete software solutions with minimal human intervention.
|
||||
|
||||
[](./backend/tests)
|
||||
[](./frontend/tests)
|
||||
[](./frontend/e2e)
|
||||
[](https://gitea.pragmazest.com/cardosofelipe/fast-next-template)
|
||||
[](./LICENSE)
|
||||
[](./CONTRIBUTING.md)
|
||||
|
||||

|
||||
|
||||
---
|
||||
|
||||
## Why PragmaStack?
|
||||
## Vision
|
||||
|
||||
Building a modern full-stack application often leads to "analysis paralysis" or "boilerplate fatigue". You spend weeks setting up authentication, testing, and linting before writing a single line of business logic.
|
||||
Syndarix transforms the software development lifecycle by providing a **virtual consulting team** of AI agents that collaboratively plan, design, implement, test, and deliver complete software solutions.
|
||||
|
||||
**PragmaStack cuts through the noise.**
|
||||
**The Problem:** Even with AI coding assistants, developers spend as much time managing AI as doing the work themselves. Context switching, babysitting, and knowledge fragmentation limit productivity.
|
||||
|
||||
We provide a **pragmatic**, opinionated foundation that prioritizes:
|
||||
- **Speed**: Ship features, not config files.
|
||||
- **Robustness**: Security and testing are not optional.
|
||||
- **Clarity**: Code that is easy to read and maintain.
|
||||
|
||||
Whether you're building a SaaS, an internal tool, or a side project, PragmaStack gives you a solid starting point without the bloat.
|
||||
**The Solution:** A structured, autonomous agency where specialized AI agents handle different roles (Product Owner, Architect, Engineers, QA, etc.) with proper workflows, reviews, and quality gates.
|
||||
|
||||
---
|
||||
|
||||
## ✨ Features
|
||||
## Key Features
|
||||
|
||||
### 🔐 **Authentication & Security**
|
||||
- JWT-based authentication with access + refresh tokens
|
||||
- **OAuth/Social Login** (Google, GitHub) with PKCE support
|
||||
- **OAuth 2.0 Authorization Server** (MCP-ready) for third-party integrations
|
||||
- Session management with device tracking and revocation
|
||||
- Password reset flow (email integration ready)
|
||||
- Secure password hashing (bcrypt)
|
||||
- CSRF protection, rate limiting, and security headers
|
||||
- Comprehensive security tests (JWT algorithm attacks, session hijacking, privilege escalation)
|
||||
### Multi-Agent Orchestration
|
||||
- Configurable agent **types** with base model, failover, expertise, and personality
|
||||
- Spawn multiple **instances** from the same type (e.g., Dave, Ellis, Kate as Software Developers)
|
||||
- Agent-to-agent communication and collaboration
|
||||
- Per-instance customization with domain-specific knowledge
|
||||
|
||||
### 🔌 **OAuth Provider Mode (MCP Integration)**
|
||||
Full OAuth 2.0 Authorization Server for Model Context Protocol (MCP) and third-party clients:
|
||||
- **RFC 7636**: Authorization Code Flow with PKCE (S256 only)
|
||||
- **RFC 8414**: Server metadata discovery at `/.well-known/oauth-authorization-server`
|
||||
- **RFC 7662**: Token introspection endpoint
|
||||
- **RFC 7009**: Token revocation endpoint
|
||||
- **JWT access tokens**: Self-contained, configurable lifetime
|
||||
- **Opaque refresh tokens**: Secure rotation, database-backed revocation
|
||||
- **Consent management**: Users can review and revoke app permissions
|
||||
- **Client management**: Admin endpoints for registering OAuth clients
|
||||
- **Scopes**: `openid`, `profile`, `email`, `read:users`, `write:users`, `admin`
|
||||
### Complete SDLC Support
|
||||
- **Requirements Discovery** → **Architecture Spike** → **Implementation Planning**
|
||||
- **Sprint Management** with automated ceremonies
|
||||
- **Issue Tracking** with Epic/Story/Task hierarchy
|
||||
- **Git Integration** with proper branch/PR workflows
|
||||
- **CI/CD Pipelines** with automated testing
|
||||
|
||||
### 👥 **Multi-Tenancy & Organizations**
|
||||
- Full organization system with role-based access control (Owner, Admin, Member)
|
||||
- Invite/remove members, manage permissions
|
||||
- Organization-scoped data access
|
||||
- User can belong to multiple organizations
|
||||
### Configurable Autonomy
|
||||
- From `FULL_CONTROL` (approve everything) to `AUTONOMOUS` (only major milestones)
|
||||
- Client can intervene at any point
|
||||
- Transparent progress visibility
|
||||
|
||||
### 🛠️ **Admin Panel**
|
||||
- Complete user management (full lifecycle, activate/deactivate, bulk operations)
|
||||
- Organization management (create, edit, delete, member management)
|
||||
- Session monitoring across all users
|
||||
- Real-time statistics dashboard
|
||||
- Admin-only routes with proper authorization
|
||||
### MCP-First Architecture
|
||||
- All integrations via **Model Context Protocol (MCP)** servers
|
||||
- Unified Knowledge Base with project/agent scoping
|
||||
- Git providers (Gitea, GitHub, GitLab) via MCP
|
||||
- Extensible through custom MCP tools
|
||||
|
||||
### 🎨 **Modern Frontend**
|
||||
- Next.js 16 with App Router and React 19
|
||||
- **PragmaStack Design System** built on shadcn/ui + TailwindCSS
|
||||
- Pre-configured theme with dark mode support (coming soon)
|
||||
- Responsive, accessible components (WCAG AA compliant)
|
||||
- Rich marketing landing page with animated components
|
||||
- Live component showcase and documentation at `/dev`
|
||||
|
||||
### 🌍 **Internationalization (i18n)**
|
||||
- Built-in multi-language support with next-intl v4
|
||||
- Locale-based routing (`/en/*`, `/it/*`)
|
||||
- Seamless language switching with LocaleSwitcher component
|
||||
- SEO-friendly URLs and metadata per locale
|
||||
- Translation files for English and Italian (easily extensible)
|
||||
- Type-safe translations throughout the app
|
||||
|
||||
### 🎯 **Content & UX Features**
|
||||
- **Toast notifications** with Sonner for elegant user feedback
|
||||
- **Smooth animations** powered by Framer Motion
|
||||
- **Markdown rendering** with syntax highlighting (GitHub Flavored Markdown)
|
||||
- **Charts and visualizations** ready with Recharts
|
||||
- **SEO optimization** with dynamic sitemap and robots.txt generation
|
||||
- **Session tracking UI** with device information and revocation controls
|
||||
|
||||
### 🧪 **Comprehensive Testing**
|
||||
- **Backend Testing**: ~97% unit test coverage
|
||||
- Unit, integration, and security tests
|
||||
- Async database testing with SQLAlchemy
|
||||
- API endpoint testing with fixtures
|
||||
- Security vulnerability tests (JWT attacks, session hijacking, privilege escalation)
|
||||
- **Frontend Unit Tests**: ~97% coverage with Jest
|
||||
- Component testing
|
||||
- Hook testing
|
||||
- Utility function testing
|
||||
- **End-to-End Tests**: Playwright with zero flaky tests
|
||||
- Complete user flows (auth, navigation, settings)
|
||||
- Parallel execution for speed
|
||||
- Visual regression testing ready
|
||||
|
||||
### 📚 **Developer Experience**
|
||||
- Auto-generated TypeScript API client from OpenAPI spec
|
||||
- Interactive API documentation (Swagger + ReDoc)
|
||||
- Database migrations with Alembic helper script
|
||||
- Hot reload in development for both frontend and backend
|
||||
- Comprehensive code documentation and design system docs
|
||||
- Live component playground at `/dev` with code examples
|
||||
- Docker support for easy deployment
|
||||
- VSCode workspace settings included
|
||||
|
||||
### 📊 **Ready for Production**
|
||||
- Docker + docker-compose setup
|
||||
- Environment-based configuration
|
||||
- Database connection pooling
|
||||
- Error handling and logging
|
||||
- Health check endpoints
|
||||
- Production security headers
|
||||
- Rate limiting on sensitive endpoints
|
||||
- SEO optimization with dynamic sitemaps and robots.txt
|
||||
- Multi-language SEO with locale-specific metadata
|
||||
- Performance monitoring and bundle analysis
|
||||
### Project Complexity Wizard
|
||||
- **Script** → Minimal process, no repo needed
|
||||
- **Simple** → Single sprint, basic backlog
|
||||
- **Medium/Complex** → Full AGILE workflow with multiple sprints
|
||||
|
||||
---
|
||||
|
||||
## 📸 Screenshots
|
||||
## Technology Stack
|
||||
|
||||
<details>
|
||||
<summary>Click to view screenshots</summary>
|
||||
Built on [PragmaStack](https://gitea.pragmazest.com/cardosofelipe/fast-next-template):
|
||||
|
||||
### Landing Page
|
||||

|
||||
| Component | Technology |
|
||||
|-----------|------------|
|
||||
| Backend | FastAPI 0.115+ (Python 3.11+) |
|
||||
| Frontend | Next.js 16 (React 19) |
|
||||
| Database | PostgreSQL 15+ with pgvector |
|
||||
| ORM | SQLAlchemy 2.0 |
|
||||
| State Management | Zustand + TanStack Query |
|
||||
| UI | shadcn/ui + Tailwind 4 |
|
||||
| Auth | JWT dual-token + OAuth 2.0 |
|
||||
| Testing | pytest + Jest + Playwright |
|
||||
|
||||
|
||||
|
||||
### Authentication
|
||||

|
||||
|
||||
|
||||
|
||||
### Admin Dashboard
|
||||

|
||||
|
||||
|
||||
|
||||
### Design System
|
||||

|
||||
|
||||
</details>
|
||||
### Syndarix Extensions
|
||||
| Component | Technology |
|
||||
|-----------|------------|
|
||||
| Task Queue | Celery + Redis |
|
||||
| Real-time | FastAPI WebSocket / SSE |
|
||||
| Vector DB | pgvector (PostgreSQL extension) |
|
||||
| MCP SDK | Anthropic MCP SDK |
|
||||
|
||||
---
|
||||
|
||||
## 🎭 Demo Mode
|
||||
## Project Status
|
||||
|
||||
**Try the frontend without a backend!** Perfect for:
|
||||
- **Free deployment** on Vercel (no backend costs)
|
||||
- **Portfolio showcasing** with live demos
|
||||
- **Client presentations** without infrastructure setup
|
||||
**Phase:** Architecture & Planning
|
||||
|
||||
See [docs/requirements/](./docs/requirements/) for the comprehensive requirements document.
|
||||
|
||||
### Current Milestones
|
||||
- [x] Fork PragmaStack as foundation
|
||||
- [x] Create requirements document
|
||||
- [ ] Execute architecture spikes
|
||||
- [ ] Create ADRs for key decisions
|
||||
- [ ] Begin MVP implementation
|
||||
|
||||
---
|
||||
|
||||
## Documentation
|
||||
|
||||
- [Requirements Document](./docs/requirements/SYNDARIX_REQUIREMENTS.md)
|
||||
- [Architecture Decisions](./docs/adrs/) (coming soon)
|
||||
- [Spike Research](./docs/spikes/) (coming soon)
|
||||
- [Architecture Overview](./docs/architecture/) (coming soon)
|
||||
|
||||
---
|
||||
|
||||
## Getting Started
|
||||
|
||||
### Prerequisites
|
||||
- Docker & Docker Compose
|
||||
- Node.js 20+
|
||||
- Python 3.11+
|
||||
- PostgreSQL 15+ (or use Docker)
|
||||
|
||||
### Quick Start
|
||||
|
||||
```bash
|
||||
cd frontend
|
||||
echo "NEXT_PUBLIC_DEMO_MODE=true" > .env.local
|
||||
bun run dev
|
||||
```
|
||||
|
||||
**Demo Credentials:**
|
||||
- Regular user: `demo@example.com` / `DemoPass123`
|
||||
- Admin user: `admin@example.com` / `AdminPass123`
|
||||
|
||||
Demo mode uses [Mock Service Worker (MSW)](https://mswjs.io/) to intercept API calls in the browser. Your code remains unchanged - the same components work with both real and mocked backends.
|
||||
|
||||
**Key Features:**
|
||||
- ✅ Zero backend required
|
||||
- ✅ All features functional (auth, admin, stats)
|
||||
- ✅ Realistic network delays and errors
|
||||
- ✅ Does NOT interfere with tests (97%+ coverage maintained)
|
||||
- ✅ One-line toggle: `NEXT_PUBLIC_DEMO_MODE=true`
|
||||
|
||||
📖 **[Complete Demo Mode Documentation](./frontend/docs/DEMO_MODE.md)**
|
||||
|
||||
---
|
||||
|
||||
## 🚀 Tech Stack
|
||||
|
||||
### Backend
|
||||
- **[FastAPI](https://fastapi.tiangolo.com/)** - Modern async Python web framework
|
||||
- **[SQLAlchemy 2.0](https://www.sqlalchemy.org/)** - Powerful ORM with async support
|
||||
- **[PostgreSQL](https://www.postgresql.org/)** - Robust relational database
|
||||
- **[Alembic](https://alembic.sqlalchemy.org/)** - Database migrations
|
||||
- **[Pydantic v2](https://docs.pydantic.dev/)** - Data validation with type hints
|
||||
- **[pytest](https://pytest.org/)** - Testing framework with async support
|
||||
|
||||
### Frontend
|
||||
- **[Next.js 16](https://nextjs.org/)** - React framework with App Router
|
||||
- **[React 19](https://react.dev/)** - UI library
|
||||
- **[TypeScript](https://www.typescriptlang.org/)** - Type-safe JavaScript
|
||||
- **[TailwindCSS](https://tailwindcss.com/)** - Utility-first CSS framework
|
||||
- **[shadcn/ui](https://ui.shadcn.com/)** - Beautiful, accessible component library
|
||||
- **[next-intl](https://next-intl.dev/)** - Internationalization (i18n) with type safety
|
||||
- **[TanStack Query](https://tanstack.com/query)** - Powerful data fetching/caching
|
||||
- **[Zustand](https://zustand-demo.pmnd.rs/)** - Lightweight state management
|
||||
- **[Framer Motion](https://www.framer.com/motion/)** - Production-ready animation library
|
||||
- **[Sonner](https://sonner.emilkowal.ski/)** - Beautiful toast notifications
|
||||
- **[Recharts](https://recharts.org/)** - Composable charting library
|
||||
- **[React Markdown](https://github.com/remarkjs/react-markdown)** - Markdown rendering with GFM support
|
||||
- **[Playwright](https://playwright.dev/)** - End-to-end testing
|
||||
|
||||
### DevOps
|
||||
- **[Docker](https://www.docker.com/)** - Containerization
|
||||
- **[docker-compose](https://docs.docker.com/compose/)** - Multi-container orchestration
|
||||
- **GitHub Actions** (coming soon) - CI/CD pipelines
|
||||
|
||||
---
|
||||
|
||||
## 📋 Prerequisites
|
||||
|
||||
- **Docker & Docker Compose** (recommended) - [Install Docker](https://docs.docker.com/get-docker/)
|
||||
- **OR manually:**
|
||||
- Python 3.12+
|
||||
- Node.js 18+ (Node 20+ recommended)
|
||||
- PostgreSQL 15+
|
||||
|
||||
---
|
||||
|
||||
## 🏃 Quick Start (Docker)
|
||||
|
||||
The fastest way to get started is with Docker:
|
||||
|
||||
```bash
|
||||
# Clone the repository
|
||||
git clone https://github.com/cardosofelipe/pragma-stack.git
|
||||
cd fast-next-template
|
||||
git clone https://gitea.pragmazest.com/cardosofelipe/syndarix.git
|
||||
cd syndarix
|
||||
|
||||
# Copy environment file
|
||||
# Copy environment template
|
||||
cp .env.template .env
|
||||
|
||||
# Start all services (backend, frontend, database)
|
||||
docker-compose up
|
||||
# Start development environment
|
||||
docker-compose -f docker-compose.dev.yml up -d
|
||||
|
||||
# In another terminal, run database migrations
|
||||
docker-compose exec backend alembic upgrade head
|
||||
# Run database migrations
|
||||
make migrate
|
||||
|
||||
# Create first superuser (optional)
|
||||
docker-compose exec backend python -c "from app.init_db import init_db; import asyncio; asyncio.run(init_db())"
|
||||
```
|
||||
|
||||
**That's it! 🎉**
|
||||
|
||||
- Frontend: http://localhost:3000
|
||||
- Backend API: http://localhost:8000
|
||||
- API Docs: http://localhost:8000/docs
|
||||
|
||||
Default superuser credentials:
|
||||
- Email: `admin@example.com`
|
||||
- Password: `admin123`
|
||||
|
||||
**⚠️ Change these immediately in production!**
|
||||
|
||||
---
|
||||
|
||||
## 🛠️ Manual Setup (Development)
|
||||
|
||||
### Backend Setup
|
||||
|
||||
```bash
|
||||
cd backend
|
||||
|
||||
# Create virtual environment
|
||||
python -m venv .venv
|
||||
source .venv/bin/activate # On Windows: .venv\Scripts\activate
|
||||
|
||||
# Install dependencies
|
||||
pip install -r requirements.txt
|
||||
|
||||
# Setup environment
|
||||
cp .env.example .env
|
||||
# Edit .env with your database credentials
|
||||
|
||||
# Run migrations
|
||||
alembic upgrade head
|
||||
|
||||
# Initialize database with first superuser
|
||||
python -c "from app.init_db import init_db; import asyncio; asyncio.run(init_db())"
|
||||
|
||||
# Start development server
|
||||
uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
|
||||
```
|
||||
|
||||
### Frontend Setup
|
||||
|
||||
```bash
|
||||
cd frontend
|
||||
|
||||
# Install dependencies
|
||||
bun install
|
||||
|
||||
# Setup environment
|
||||
cp .env.local.example .env.local
|
||||
# Edit .env.local with your backend URL
|
||||
|
||||
# Generate API client
|
||||
bun run generate:api
|
||||
|
||||
# Start development server
|
||||
bun run dev
|
||||
```
|
||||
|
||||
Visit http://localhost:3000 to see your app!
|
||||
|
||||
---
|
||||
|
||||
## 📂 Project Structure
|
||||
|
||||
```
|
||||
├── backend/ # FastAPI backend
|
||||
│ ├── app/
|
||||
│ │ ├── api/ # API routes and dependencies
|
||||
│ │ ├── core/ # Core functionality (auth, config, database)
|
||||
│ │ ├── repositories/ # Repository pattern (database operations)
|
||||
│ │ ├── models/ # SQLAlchemy models
|
||||
│ │ ├── schemas/ # Pydantic schemas
|
||||
│ │ ├── services/ # Business logic
|
||||
│ │ └── utils/ # Utilities
|
||||
│ ├── tests/ # Backend tests (97% coverage)
|
||||
│ ├── alembic/ # Database migrations
|
||||
│ └── docs/ # Backend documentation
|
||||
│
|
||||
├── frontend/ # Next.js frontend
|
||||
│ ├── src/
|
||||
│ │ ├── app/ # Next.js App Router pages
|
||||
│ │ ├── components/ # React components
|
||||
│ │ ├── lib/ # Libraries and utilities
|
||||
│ │ │ ├── api/ # API client (auto-generated)
|
||||
│ │ │ └── stores/ # Zustand stores
|
||||
│ │ └── hooks/ # Custom React hooks
|
||||
│ ├── e2e/ # Playwright E2E tests
|
||||
│ ├── tests/ # Unit tests (Jest)
|
||||
│ └── docs/ # Frontend documentation
|
||||
│ └── design-system/ # Comprehensive design system docs
|
||||
│
|
||||
├── docker-compose.yml # Docker orchestration
|
||||
├── docker-compose.dev.yml # Development with hot reload
|
||||
└── README.md # You are here!
|
||||
# Start the development servers
|
||||
make dev
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🧪 Testing
|
||||
## Architecture Overview
|
||||
|
||||
This template takes testing seriously with comprehensive coverage across all layers:
|
||||
|
||||
### Backend Unit & Integration Tests
|
||||
|
||||
**High coverage (~97%)** across all critical paths including security-focused tests.
|
||||
|
||||
```bash
|
||||
cd backend
|
||||
|
||||
# Run all tests
|
||||
IS_TEST=True pytest
|
||||
|
||||
# Run with coverage report
|
||||
IS_TEST=True pytest --cov=app --cov-report=term-missing
|
||||
|
||||
# Run specific test file
|
||||
IS_TEST=True pytest tests/api/test_auth.py -v
|
||||
|
||||
# Generate HTML coverage report
|
||||
IS_TEST=True pytest --cov=app --cov-report=html
|
||||
open htmlcov/index.html
|
||||
```
|
||||
|
||||
**Test types:**
|
||||
- **Unit tests**: Repository operations, utilities, business logic
|
||||
- **Integration tests**: API endpoints with database
|
||||
- **Security tests**: JWT algorithm attacks, session hijacking, privilege escalation
|
||||
- **Error handling tests**: Database failures, validation errors
|
||||
|
||||
### Frontend Unit Tests
|
||||
|
||||
**High coverage (~97%)** with Jest and React Testing Library.
|
||||
|
||||
```bash
|
||||
cd frontend
|
||||
|
||||
# Run unit tests
|
||||
bun run test
|
||||
|
||||
# Run with coverage
|
||||
bun run test:coverage
|
||||
|
||||
# Watch mode
|
||||
bun run test:watch
|
||||
```
|
||||
|
||||
**Test types:**
|
||||
- Component rendering and interactions
|
||||
- Custom hooks behavior
|
||||
- State management
|
||||
- Utility functions
|
||||
- API integration mocks
|
||||
|
||||
### End-to-End Tests
|
||||
|
||||
**Zero flaky tests** with Playwright covering complete user journeys.
|
||||
|
||||
```bash
|
||||
cd frontend
|
||||
|
||||
# Run E2E tests
|
||||
bun run test:e2e
|
||||
|
||||
# Run E2E tests in UI mode (recommended for development)
|
||||
bun run test:e2e:ui
|
||||
|
||||
# Run specific test file
|
||||
npx playwright test auth-login.spec.ts
|
||||
|
||||
# Generate test report
|
||||
npx playwright show-report
|
||||
```
|
||||
|
||||
**Test coverage:**
|
||||
- Complete authentication flows
|
||||
- Navigation and routing
|
||||
- Form submissions and validation
|
||||
- Settings and profile management
|
||||
- Session management
|
||||
- Admin panel workflows (in progress)
|
||||
|
||||
---
|
||||
|
||||
## 🤖 AI-Friendly Documentation
|
||||
|
||||
This project includes comprehensive documentation designed for AI coding assistants:
|
||||
|
||||
- **[AGENTS.md](./AGENTS.md)** - Framework-agnostic AI assistant context for PragmaStack
|
||||
- **[CLAUDE.md](./CLAUDE.md)** - Claude Code-specific guidance
|
||||
|
||||
These files provide AI assistants with the **PragmaStack** architecture, patterns, and best practices.
|
||||
|
||||
---
|
||||
|
||||
## 🗄️ Database Migrations
|
||||
|
||||
The template uses Alembic for database migrations:
|
||||
|
||||
```bash
|
||||
cd backend
|
||||
|
||||
# Generate migration from model changes
|
||||
python migrate.py generate "description of changes"
|
||||
|
||||
# Apply migrations
|
||||
python migrate.py apply
|
||||
|
||||
# Or do both in one command
|
||||
python migrate.py auto "description"
|
||||
|
||||
# View migration history
|
||||
python migrate.py list
|
||||
|
||||
# Check current revision
|
||||
python migrate.py current
|
||||
+====================================================================+
|
||||
| SYNDARIX CORE |
|
||||
+====================================================================+
|
||||
| +------------------+ +------------------+ +------------------+ |
|
||||
| | Agent Orchestrator| | Project Manager | | Workflow Engine | |
|
||||
| +------------------+ +------------------+ +------------------+ |
|
||||
+====================================================================+
|
||||
|
|
||||
v
|
||||
+====================================================================+
|
||||
| MCP ORCHESTRATION LAYER |
|
||||
| All integrations via unified MCP servers with project scoping |
|
||||
+====================================================================+
|
||||
|
|
||||
+------------------------+------------------------+
|
||||
| | |
|
||||
+----v----+ +----v----+ +----v----+ +----v----+ +----v----+
|
||||
| LLM | | Git | |Knowledge| | File | | Code |
|
||||
| Providers| | MCP | |Base MCP | |Sys. MCP | |Analysis |
|
||||
+---------+ +---------+ +---------+ +---------+ +---------+
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 📖 Documentation
|
||||
## Contributing
|
||||
|
||||
### AI Assistant Documentation
|
||||
|
||||
- **[AGENTS.md](./AGENTS.md)** - Framework-agnostic AI coding assistant context
|
||||
- **[CLAUDE.md](./CLAUDE.md)** - Claude Code-specific guidance and preferences
|
||||
|
||||
### Backend Documentation
|
||||
|
||||
- **[ARCHITECTURE.md](./backend/docs/ARCHITECTURE.md)** - System architecture and design patterns
|
||||
- **[CODING_STANDARDS.md](./backend/docs/CODING_STANDARDS.md)** - Code quality standards
|
||||
- **[COMMON_PITFALLS.md](./backend/docs/COMMON_PITFALLS.md)** - Common mistakes to avoid
|
||||
- **[FEATURE_EXAMPLE.md](./backend/docs/FEATURE_EXAMPLE.md)** - Step-by-step feature guide
|
||||
|
||||
### Frontend Documentation
|
||||
|
||||
- **[PragmaStack Design System](./frontend/docs/design-system/)** - Complete design system guide
|
||||
- Quick start, foundations (colors, typography, spacing)
|
||||
- Component library guide
|
||||
- Layout patterns, spacing philosophy
|
||||
- Forms, accessibility, AI guidelines
|
||||
- **[E2E Testing Guide](./frontend/e2e/README.md)** - E2E testing setup and best practices
|
||||
|
||||
### API Documentation
|
||||
|
||||
When the backend is running:
|
||||
- **Swagger UI**: http://localhost:8000/docs
|
||||
- **ReDoc**: http://localhost:8000/redoc
|
||||
- **OpenAPI JSON**: http://localhost:8000/api/v1/openapi.json
|
||||
See [CONTRIBUTING.md](./CONTRIBUTING.md) for guidelines.
|
||||
|
||||
---
|
||||
|
||||
## 🚢 Deployment
|
||||
## License
|
||||
|
||||
### Docker Production Deployment
|
||||
|
||||
```bash
|
||||
# Build and start all services
|
||||
docker-compose up -d
|
||||
|
||||
# Run migrations
|
||||
docker-compose exec backend alembic upgrade head
|
||||
|
||||
# View logs
|
||||
docker-compose logs -f
|
||||
|
||||
# Stop services
|
||||
docker-compose down
|
||||
```
|
||||
|
||||
### Production Checklist
|
||||
|
||||
- [ ] Change default superuser credentials
|
||||
- [ ] Set strong `SECRET_KEY` in backend `.env`
|
||||
- [ ] Configure production database (PostgreSQL)
|
||||
- [ ] Set `ENVIRONMENT=production` in backend
|
||||
- [ ] Configure CORS origins for your domain
|
||||
- [ ] Setup SSL/TLS certificates
|
||||
- [ ] Configure email service for password resets
|
||||
- [ ] Setup monitoring and logging
|
||||
- [ ] Configure backup strategy
|
||||
- [ ] Review and adjust rate limits
|
||||
- [ ] Test security headers
|
||||
MIT License - see [LICENSE](./LICENSE) for details.
|
||||
|
||||
---
|
||||
|
||||
## 🛣️ Roadmap & Status
|
||||
## Acknowledgments
|
||||
|
||||
### ✅ Completed
|
||||
- [x] Authentication system (JWT, refresh tokens, session management, OAuth)
|
||||
- [x] User management (full lifecycle, profile, password change)
|
||||
- [x] Organization system with RBAC (Owner, Admin, Member)
|
||||
- [x] Admin panel (users, organizations, sessions, statistics)
|
||||
- [x] **Internationalization (i18n)** with next-intl (English + Italian)
|
||||
- [x] Backend testing infrastructure (~97% coverage)
|
||||
- [x] Frontend unit testing infrastructure (~97% coverage)
|
||||
- [x] Frontend E2E testing (Playwright, zero flaky tests)
|
||||
- [x] Design system documentation
|
||||
- [x] **Marketing landing page** with animated components
|
||||
- [x] **`/dev` documentation portal** with live component examples
|
||||
- [x] **Toast notifications** system (Sonner)
|
||||
- [x] **Charts and visualizations** (Recharts)
|
||||
- [x] **Animation system** (Framer Motion)
|
||||
- [x] **Markdown rendering** with syntax highlighting
|
||||
- [x] **SEO optimization** (sitemap, robots.txt, locale-aware metadata)
|
||||
- [x] Database migrations with helper script
|
||||
- [x] Docker deployment
|
||||
- [x] API documentation (OpenAPI/Swagger)
|
||||
|
||||
### 🚧 In Progress
|
||||
- [ ] Email integration (templates ready, SMTP pending)
|
||||
|
||||
### 🔮 Planned
|
||||
- [ ] GitHub Actions CI/CD pipelines
|
||||
- [ ] Dynamic test coverage badges from CI
|
||||
- [ ] E2E test coverage reporting
|
||||
- [ ] OAuth token encryption at rest (security hardening)
|
||||
- [ ] Additional languages (Spanish, French, German, etc.)
|
||||
- [ ] SSO/SAML authentication
|
||||
- [ ] Real-time notifications with WebSockets
|
||||
- [ ] Webhook system
|
||||
- [ ] File upload/storage (S3-compatible)
|
||||
- [ ] Audit logging system
|
||||
- [ ] API versioning example
|
||||
|
||||
|
||||
---
|
||||
|
||||
## 🤝 Contributing
|
||||
|
||||
Contributions are welcome! Whether you're fixing bugs, improving documentation, or proposing new features, we'd love your help.
|
||||
|
||||
### How to Contribute
|
||||
|
||||
1. **Fork the repository**
|
||||
2. **Create a feature branch** (`git checkout -b feature/amazing-feature`)
|
||||
3. **Make your changes**
|
||||
- Follow existing code style
|
||||
- Add tests for new features
|
||||
- Update documentation as needed
|
||||
4. **Run tests** to ensure everything works
|
||||
5. **Commit your changes** (`git commit -m 'Add amazing feature'`)
|
||||
6. **Push to your branch** (`git push origin feature/amazing-feature`)
|
||||
7. **Open a Pull Request**
|
||||
|
||||
### Development Guidelines
|
||||
|
||||
- Write tests for new features (aim for >90% coverage)
|
||||
- Follow the existing architecture patterns
|
||||
- Update documentation when adding features
|
||||
- Keep commits atomic and well-described
|
||||
- Be respectful and constructive in discussions
|
||||
|
||||
### Reporting Issues
|
||||
|
||||
Found a bug? Have a suggestion? [Open an issue](https://github.com/cardosofelipe/pragma-stack/issues)!
|
||||
|
||||
Please include:
|
||||
- Clear description of the issue/suggestion
|
||||
- Steps to reproduce (for bugs)
|
||||
- Expected vs. actual behavior
|
||||
- Environment details (OS, Python/Node version, etc.)
|
||||
|
||||
---
|
||||
|
||||
## 📄 License
|
||||
|
||||
This project is licensed under the **MIT License** - see the [LICENSE](./LICENSE) file for details.
|
||||
|
||||
**TL;DR**: You can use this template for any purpose, commercial or non-commercial. Attribution is appreciated but not required!
|
||||
|
||||
---
|
||||
|
||||
## 🙏 Acknowledgments
|
||||
|
||||
This template is built on the shoulders of giants:
|
||||
|
||||
- [FastAPI](https://fastapi.tiangolo.com/) by Sebastián Ramírez
|
||||
- [Next.js](https://nextjs.org/) by Vercel
|
||||
- [shadcn/ui](https://ui.shadcn.com/) by shadcn
|
||||
- [TanStack Query](https://tanstack.com/query) by Tanner Linsley
|
||||
- [Playwright](https://playwright.dev/) by Microsoft
|
||||
- And countless other open-source projects that make modern development possible
|
||||
|
||||
---
|
||||
|
||||
## 💬 Questions?
|
||||
|
||||
- **Documentation**: Check the `/docs` folders in backend and frontend
|
||||
- **Issues**: [GitHub Issues](https://github.com/cardosofelipe/pragma-stack/issues)
|
||||
- **Discussions**: [GitHub Discussions](https://github.com/cardosofelipe/pragma-stack/discussions)
|
||||
|
||||
---
|
||||
|
||||
## ⭐ Star This Repo
|
||||
|
||||
If this template saves you time, consider giving it a star! It helps others discover the project and motivates continued development.
|
||||
|
||||
**Happy coding! 🚀**
|
||||
|
||||
---
|
||||
|
||||
<div align="center">
|
||||
Made with ❤️ by a developer who got tired of rebuilding the same boilerplate
|
||||
</div>
|
||||
- Built on [PragmaStack](https://gitea.pragmazest.com/cardosofelipe/fast-next-template)
|
||||
- Powered by Claude and the Anthropic API
|
||||
|
||||
@@ -11,7 +11,7 @@ omit =
|
||||
app/utils/auth_test_utils.py
|
||||
|
||||
# Async implementations not yet in use
|
||||
app/repositories/base_async.py
|
||||
app/crud/base_async.py
|
||||
app/core/database_async.py
|
||||
|
||||
# CLI scripts - run manually, not tested
|
||||
@@ -23,7 +23,7 @@ omit =
|
||||
app/api/routes/__init__.py
|
||||
app/api/dependencies/__init__.py
|
||||
app/core/__init__.py
|
||||
app/repositories/__init__.py
|
||||
app/crud/__init__.py
|
||||
app/models/__init__.py
|
||||
app/schemas/__init__.py
|
||||
app/services/__init__.py
|
||||
|
||||
@@ -1,44 +0,0 @@
|
||||
# Pre-commit hooks for backend quality and security checks.
|
||||
#
|
||||
# Install:
|
||||
# cd backend && uv run pre-commit install
|
||||
#
|
||||
# Run manually on all files:
|
||||
# cd backend && uv run pre-commit run --all-files
|
||||
#
|
||||
# Skip hooks temporarily:
|
||||
# git commit --no-verify
|
||||
#
|
||||
repos:
|
||||
# ── Code Quality ──────────────────────────────────────────────────────────
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.14.4
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix, --exit-non-zero-on-fix]
|
||||
- id: ruff-format
|
||||
|
||||
# ── General File Hygiene ──────────────────────────────────────────────────
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v5.0.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
- id: end-of-file-fixer
|
||||
- id: check-yaml
|
||||
- id: check-toml
|
||||
- id: check-merge-conflict
|
||||
- id: check-added-large-files
|
||||
args: [--maxkb=500]
|
||||
- id: debug-statements
|
||||
|
||||
# ── Security ──────────────────────────────────────────────────────────────
|
||||
- repo: https://github.com/Yelp/detect-secrets
|
||||
rev: v1.5.0
|
||||
hooks:
|
||||
- id: detect-secrets
|
||||
args: ['--baseline', '.secrets.baseline']
|
||||
exclude: |
|
||||
(?x)^(
|
||||
.*\.lock$|
|
||||
.*\.svg$
|
||||
)$
|
||||
File diff suppressed because it is too large
Load Diff
@@ -33,11 +33,11 @@ RUN chmod +x /usr/local/bin/entrypoint.sh
|
||||
|
||||
ENTRYPOINT ["/usr/local/bin/entrypoint.sh"]
|
||||
|
||||
# Production stage — Alpine eliminates glibc CVEs (e.g. CVE-2026-0861)
|
||||
FROM python:3.12-alpine AS production
|
||||
# Production stage
|
||||
FROM python:3.12-slim AS production
|
||||
|
||||
# Create non-root user
|
||||
RUN addgroup -S appuser && adduser -S -G appuser appuser
|
||||
RUN groupadd -r appuser && useradd -r -g appuser appuser
|
||||
|
||||
WORKDIR /app
|
||||
ENV PYTHONDONTWRITEBYTECODE=1 \
|
||||
@@ -48,18 +48,18 @@ ENV PYTHONDONTWRITEBYTECODE=1 \
|
||||
UV_NO_CACHE=1
|
||||
|
||||
# Install system dependencies and uv
|
||||
RUN apk add --no-cache postgresql-client curl ca-certificates && \
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends postgresql-client curl ca-certificates && \
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh && \
|
||||
mv /root/.local/bin/uv* /usr/local/bin/
|
||||
mv /root/.local/bin/uv* /usr/local/bin/ && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy dependency files
|
||||
COPY pyproject.toml uv.lock ./
|
||||
|
||||
# Install build dependencies, compile Python packages, then remove build deps
|
||||
RUN apk add --no-cache --virtual .build-deps \
|
||||
gcc g++ musl-dev python3-dev linux-headers libffi-dev openssl-dev && \
|
||||
uv sync --frozen --no-dev && \
|
||||
apk del .build-deps
|
||||
# Install only production dependencies using uv (no dev dependencies)
|
||||
RUN uv sync --frozen --no-dev
|
||||
|
||||
# Copy application code
|
||||
COPY . .
|
||||
|
||||
@@ -1,7 +1,4 @@
|
||||
.PHONY: help lint lint-fix format format-check type-check test test-cov validate clean install-dev sync check-docker install-e2e test-e2e test-e2e-schema test-all dep-audit license-check audit validate-all check benchmark benchmark-check benchmark-save scan-image test-api-security
|
||||
|
||||
# Prevent a stale VIRTUAL_ENV in the caller's shell from confusing uv
|
||||
unexport VIRTUAL_ENV
|
||||
.PHONY: help lint lint-fix format format-check type-check test test-cov validate clean install-dev sync check-docker install-e2e test-e2e test-e2e-schema test-all
|
||||
|
||||
# Default target
|
||||
help:
|
||||
@@ -17,21 +14,8 @@ help:
|
||||
@echo " make lint-fix - Run Ruff linter with auto-fix"
|
||||
@echo " make format - Format code with Ruff"
|
||||
@echo " make format-check - Check if code is formatted"
|
||||
@echo " make type-check - Run pyright type checking"
|
||||
@echo " make validate - Run all checks (lint + format + types + schema fuzz)"
|
||||
@echo ""
|
||||
@echo "Performance:"
|
||||
@echo " make benchmark - Run performance benchmarks"
|
||||
@echo " make benchmark-save - Run benchmarks and save as baseline"
|
||||
@echo " make benchmark-check - Run benchmarks and compare against baseline"
|
||||
@echo ""
|
||||
@echo "Security & Audit:"
|
||||
@echo " make dep-audit - Scan dependencies for known vulnerabilities"
|
||||
@echo " make license-check - Check dependency license compliance"
|
||||
@echo " make audit - Run all security audits (deps + licenses)"
|
||||
@echo " make scan-image - Scan Docker image for CVEs (requires trivy)"
|
||||
@echo " make validate-all - Run all quality + security checks"
|
||||
@echo " make check - Full pipeline: quality + security + tests"
|
||||
@echo " make type-check - Run mypy type checking"
|
||||
@echo " make validate - Run all checks (lint + format + types)"
|
||||
@echo ""
|
||||
@echo "Testing:"
|
||||
@echo " make test - Run pytest (unit/integration, SQLite)"
|
||||
@@ -40,7 +24,6 @@ help:
|
||||
@echo " make test-e2e-schema - Run Schemathesis API schema tests"
|
||||
@echo " make test-all - Run all tests (unit + E2E)"
|
||||
@echo " make check-docker - Check if Docker is available"
|
||||
@echo " make check - Full pipeline: quality + security + tests"
|
||||
@echo ""
|
||||
@echo "Cleanup:"
|
||||
@echo " make clean - Remove cache and build artifacts"
|
||||
@@ -80,52 +63,12 @@ format-check:
|
||||
@uv run ruff format --check app/ tests/
|
||||
|
||||
type-check:
|
||||
@echo "🔎 Running pyright type checking..."
|
||||
@uv run pyright app/
|
||||
@echo "🔎 Running mypy type checking..."
|
||||
@uv run mypy app/
|
||||
|
||||
validate: lint format-check type-check test-api-security
|
||||
validate: lint format-check type-check
|
||||
@echo "✅ All quality checks passed!"
|
||||
|
||||
# API Security Testing (Schemathesis property-based fuzzing)
|
||||
test-api-security: check-docker
|
||||
@echo "🔐 Running Schemathesis API security fuzzing..."
|
||||
@IS_TEST=True PYTHONPATH=. uv run pytest tests/e2e/ -v -m "schemathesis" --tb=short -n 0
|
||||
@echo "✅ API schema security tests passed!"
|
||||
|
||||
# ============================================================================
|
||||
# Security & Audit
|
||||
# ============================================================================
|
||||
|
||||
dep-audit:
|
||||
@echo "🔒 Scanning dependencies for known vulnerabilities..."
|
||||
@uv run pip-audit --desc --skip-editable
|
||||
@echo "✅ No known vulnerabilities found!"
|
||||
|
||||
license-check:
|
||||
@echo "📜 Checking dependency license compliance..."
|
||||
@uv run pip-licenses --fail-on="GPL-3.0-or-later;AGPL-3.0-or-later" --format=plain > /dev/null
|
||||
@echo "✅ All dependency licenses are compliant!"
|
||||
|
||||
audit: dep-audit license-check
|
||||
@echo "✅ All security audits passed!"
|
||||
|
||||
scan-image: check-docker
|
||||
@echo "🐳 Scanning Docker image for OS-level CVEs with Trivy..."
|
||||
@docker build -t pragma-backend:scan -q --target production .
|
||||
@if command -v trivy > /dev/null 2>&1; then \
|
||||
trivy image --severity HIGH,CRITICAL --exit-code 1 pragma-backend:scan; \
|
||||
else \
|
||||
echo "ℹ️ Trivy not found locally, using Docker to run Trivy..."; \
|
||||
docker run --rm -v /var/run/docker.sock:/var/run/docker.sock aquasec/trivy image --severity HIGH,CRITICAL --exit-code 1 pragma-backend:scan; \
|
||||
fi
|
||||
@echo "✅ No HIGH/CRITICAL CVEs found in Docker image!"
|
||||
|
||||
validate-all: validate audit
|
||||
@echo "✅ All quality + security checks passed!"
|
||||
|
||||
check: validate-all test
|
||||
@echo "✅ Full validation pipeline complete!"
|
||||
|
||||
# ============================================================================
|
||||
# Testing
|
||||
# ============================================================================
|
||||
@@ -171,31 +114,6 @@ test-e2e-schema: check-docker
|
||||
@echo "🧪 Running Schemathesis API schema tests..."
|
||||
@IS_TEST=True PYTHONPATH=. uv run pytest tests/e2e/ -v -m "schemathesis" --tb=short -n 0
|
||||
|
||||
# ============================================================================
|
||||
# Performance Benchmarks
|
||||
# ============================================================================
|
||||
|
||||
benchmark:
|
||||
@echo "⏱️ Running performance benchmarks..."
|
||||
@IS_TEST=True PYTHONPATH=. uv run pytest tests/benchmarks/ -v --benchmark-only --benchmark-sort=mean -p no:xdist --override-ini='addopts='
|
||||
|
||||
benchmark-save:
|
||||
@echo "⏱️ Running benchmarks and saving baseline..."
|
||||
@IS_TEST=True PYTHONPATH=. uv run pytest tests/benchmarks/ -v --benchmark-only --benchmark-save=baseline --benchmark-sort=mean -p no:xdist --override-ini='addopts='
|
||||
@echo "✅ Benchmark baseline saved to .benchmarks/"
|
||||
|
||||
benchmark-check:
|
||||
@echo "⏱️ Running benchmarks and comparing against baseline..."
|
||||
@if find .benchmarks -name '*_baseline*' -print -quit 2>/dev/null | grep -q .; then \
|
||||
IS_TEST=True PYTHONPATH=. uv run pytest tests/benchmarks/ -v --benchmark-only --benchmark-compare=0001_baseline --benchmark-sort=mean --benchmark-compare-fail=mean:200% -p no:xdist --override-ini='addopts='; \
|
||||
echo "✅ No performance regressions detected!"; \
|
||||
else \
|
||||
echo "⚠️ No benchmark baseline found. Run 'make benchmark-save' first to create one."; \
|
||||
echo " Running benchmarks without comparison..."; \
|
||||
IS_TEST=True PYTHONPATH=. uv run pytest tests/benchmarks/ -v --benchmark-only --benchmark-save=baseline --benchmark-sort=mean -p no:xdist --override-ini='addopts='; \
|
||||
echo "✅ Benchmark baseline created. Future runs of 'make benchmark-check' will compare against it."; \
|
||||
fi
|
||||
|
||||
test-all:
|
||||
@echo "🧪 Running ALL tests (unit + E2E)..."
|
||||
@$(MAKE) test
|
||||
@@ -209,7 +127,7 @@ clean:
|
||||
@echo "🧹 Cleaning up..."
|
||||
@find . -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true
|
||||
@find . -type d -name ".pytest_cache" -exec rm -rf {} + 2>/dev/null || true
|
||||
@find . -type d -name ".pyright" -exec rm -rf {} + 2>/dev/null || true
|
||||
@find . -type d -name ".mypy_cache" -exec rm -rf {} + 2>/dev/null || true
|
||||
@find . -type d -name ".ruff_cache" -exec rm -rf {} + 2>/dev/null || true
|
||||
@find . -type d -name "*.egg-info" -exec rm -rf {} + 2>/dev/null || true
|
||||
@find . -type d -name "htmlcov" -exec rm -rf {} + 2>/dev/null || true
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# PragmaStack Backend API
|
||||
# Syndarix Backend API
|
||||
|
||||
> The pragmatic, production-ready FastAPI backend for PragmaStack.
|
||||
> The pragmatic, production-ready FastAPI backend for Syndarix.
|
||||
|
||||
## Overview
|
||||
|
||||
@@ -14,9 +14,7 @@ Features:
|
||||
- **Multi-tenancy**: Organization-based access control with roles (Owner/Admin/Member)
|
||||
- **Testing**: 97%+ coverage with security-focused test suite
|
||||
- **Performance**: Async throughout, connection pooling, optimized queries
|
||||
- **Modern Tooling**: uv for dependencies, Ruff for linting/formatting, Pyright for type checking
|
||||
- **Security Auditing**: Automated dependency vulnerability scanning, license compliance, secrets detection
|
||||
- **Pre-commit Hooks**: Ruff, detect-secrets, and standard checks on every commit
|
||||
- **Modern Tooling**: uv for dependencies, Ruff for linting/formatting, mypy for type checking
|
||||
|
||||
## Quick Start
|
||||
|
||||
@@ -151,7 +149,7 @@ uv pip list --outdated
|
||||
# Run any Python command via uv (no activation needed)
|
||||
uv run python script.py
|
||||
uv run pytest
|
||||
uv run pyright app/
|
||||
uv run mypy app/
|
||||
|
||||
# Or activate the virtual environment
|
||||
source .venv/bin/activate
|
||||
@@ -173,22 +171,12 @@ make lint # Run Ruff linter (check only)
|
||||
make lint-fix # Run Ruff with auto-fix
|
||||
make format # Format code with Ruff
|
||||
make format-check # Check if code is formatted
|
||||
make type-check # Run Pyright type checking
|
||||
make type-check # Run mypy type checking
|
||||
make validate # Run all checks (lint + format + types)
|
||||
|
||||
# Security & Audit
|
||||
make dep-audit # Scan dependencies for known vulnerabilities (CVEs)
|
||||
make license-check # Check dependency license compliance
|
||||
make audit # Run all security audits (deps + licenses)
|
||||
make validate-all # Run all quality + security checks
|
||||
make check # Full pipeline: quality + security + tests
|
||||
|
||||
# Testing
|
||||
make test # Run all tests
|
||||
make test-cov # Run tests with coverage report
|
||||
make test-e2e # Run E2E tests (PostgreSQL, requires Docker)
|
||||
make test-e2e-schema # Run Schemathesis API schema tests
|
||||
make test-all # Run all tests (unit + E2E)
|
||||
|
||||
# Utilities
|
||||
make clean # Remove cache and build artifacts
|
||||
@@ -264,7 +252,7 @@ app/
|
||||
│ ├── database.py # Database engine setup
|
||||
│ ├── auth.py # JWT token handling
|
||||
│ └── exceptions.py # Custom exceptions
|
||||
├── repositories/ # Repository pattern (database operations)
|
||||
├── crud/ # Database operations
|
||||
├── models/ # SQLAlchemy ORM models
|
||||
├── schemas/ # Pydantic request/response schemas
|
||||
├── services/ # Business logic layer
|
||||
@@ -364,29 +352,18 @@ open htmlcov/index.html
|
||||
# Using Makefile (recommended)
|
||||
make lint # Ruff linting
|
||||
make format # Ruff formatting
|
||||
make type-check # Pyright type checking
|
||||
make type-check # mypy type checking
|
||||
make validate # All checks at once
|
||||
|
||||
# Security audits
|
||||
make dep-audit # Scan dependencies for CVEs
|
||||
make license-check # Check license compliance
|
||||
make audit # All security audits
|
||||
make validate-all # Quality + security checks
|
||||
make check # Full pipeline: quality + security + tests
|
||||
|
||||
# Using uv directly
|
||||
uv run ruff check app/ tests/
|
||||
uv run ruff format app/ tests/
|
||||
uv run pyright app/
|
||||
uv run mypy app/
|
||||
```
|
||||
|
||||
**Tools:**
|
||||
- **Ruff**: All-in-one linting, formatting, and import sorting (replaces Black, Flake8, isort)
|
||||
- **Pyright**: Static type checking (strict mode)
|
||||
- **pip-audit**: Dependency vulnerability scanning against the OSV database
|
||||
- **pip-licenses**: Dependency license compliance checking
|
||||
- **detect-secrets**: Hardcoded secrets/credentials detection
|
||||
- **pre-commit**: Git hook framework for automated checks on every commit
|
||||
- **mypy**: Static type checking with Pydantic plugin
|
||||
|
||||
All configurations are in `pyproject.toml`.
|
||||
|
||||
@@ -462,7 +439,7 @@ See [docs/FEATURE_EXAMPLE.md](docs/FEATURE_EXAMPLE.md) for step-by-step guide.
|
||||
|
||||
Quick overview:
|
||||
1. Create Pydantic schemas in `app/schemas/`
|
||||
2. Create repository in `app/repositories/`
|
||||
2. Create CRUD operations in `app/crud/`
|
||||
3. Create route in `app/api/routes/`
|
||||
4. Register router in `app/api/main.py`
|
||||
5. Write tests in `tests/api/`
|
||||
@@ -612,42 +589,13 @@ Configured in `app/core/config.py`:
|
||||
- **Security Headers**: CSP, HSTS, X-Frame-Options, etc.
|
||||
- **Input Validation**: Pydantic schemas, SQL injection prevention (ORM)
|
||||
|
||||
### Security Auditing
|
||||
|
||||
Automated, deterministic security checks are built into the development workflow:
|
||||
|
||||
```bash
|
||||
# Scan dependencies for known vulnerabilities (CVEs)
|
||||
make dep-audit
|
||||
|
||||
# Check dependency license compliance (blocks GPL-3.0/AGPL)
|
||||
make license-check
|
||||
|
||||
# Run all security audits
|
||||
make audit
|
||||
|
||||
# Full pipeline: quality + security + tests
|
||||
make check
|
||||
```
|
||||
|
||||
**Pre-commit hooks** automatically run on every commit:
|
||||
- **Ruff** lint + format checks
|
||||
- **detect-secrets** blocks commits containing hardcoded secrets
|
||||
- **Standard checks**: trailing whitespace, YAML/TOML validation, merge conflict detection, large file prevention
|
||||
|
||||
Setup pre-commit hooks:
|
||||
```bash
|
||||
uv run pre-commit install
|
||||
```
|
||||
|
||||
### Security Best Practices
|
||||
|
||||
1. **Never commit secrets**: Use `.env` files (git-ignored), enforced by detect-secrets pre-commit hook
|
||||
1. **Never commit secrets**: Use `.env` files (git-ignored)
|
||||
2. **Strong SECRET_KEY**: Min 32 chars, cryptographically random
|
||||
3. **HTTPS in production**: Required for token security
|
||||
4. **Regular updates**: Keep dependencies current (`uv sync --upgrade`), run `make dep-audit` to check for CVEs
|
||||
4. **Regular updates**: Keep dependencies current (`uv sync --upgrade`)
|
||||
5. **Audit logs**: Monitor authentication events
|
||||
6. **Run `make check` before pushing**: Validates quality, security, and tests in one command
|
||||
|
||||
---
|
||||
|
||||
@@ -697,11 +645,7 @@ logging.basicConfig(level=logging.INFO)
|
||||
**Built with modern Python tooling:**
|
||||
- 🚀 **uv** - 10-100x faster dependency management
|
||||
- ⚡ **Ruff** - 10-100x faster linting & formatting
|
||||
- 🔍 **Pyright** - Static type checking (strict mode)
|
||||
- 🔍 **mypy** - Static type checking
|
||||
- ✅ **pytest** - Comprehensive test suite
|
||||
- 🔒 **pip-audit** - Dependency vulnerability scanning
|
||||
- 🔑 **detect-secrets** - Hardcoded secrets detection
|
||||
- 📜 **pip-licenses** - License compliance checking
|
||||
- 🪝 **pre-commit** - Automated git hooks
|
||||
|
||||
**All configured in a single `pyproject.toml` file!**
|
||||
|
||||
@@ -40,7 +40,6 @@ def include_object(object, name, type_, reflected, compare_to):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
# This line sets up loggers basically.
|
||||
if config.config_file_name is not None:
|
||||
|
||||
@@ -1,446 +1,262 @@
|
||||
"""initial models
|
||||
|
||||
Revision ID: 0001
|
||||
Revises:
|
||||
Revises:
|
||||
Create Date: 2025-11-27 09:08:09.464506
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "0001"
|
||||
down_revision: str | None = None
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
revision: str = '0001'
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"oauth_states",
|
||||
sa.Column("state", sa.String(length=255), nullable=False),
|
||||
sa.Column("code_verifier", sa.String(length=128), nullable=True),
|
||||
sa.Column("nonce", sa.String(length=255), nullable=True),
|
||||
sa.Column("provider", sa.String(length=50), nullable=False),
|
||||
sa.Column("redirect_uri", sa.String(length=500), nullable=True),
|
||||
sa.Column("user_id", sa.UUID(), nullable=True),
|
||||
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
op.create_table('oauth_states',
|
||||
sa.Column('state', sa.String(length=255), nullable=False),
|
||||
sa.Column('code_verifier', sa.String(length=128), nullable=True),
|
||||
sa.Column('nonce', sa.String(length=255), nullable=True),
|
||||
sa.Column('provider', sa.String(length=50), nullable=False),
|
||||
sa.Column('redirect_uri', sa.String(length=500), nullable=True),
|
||||
sa.Column('user_id', sa.UUID(), nullable=True),
|
||||
sa.Column('expires_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_oauth_states_state"), "oauth_states", ["state"], unique=True
|
||||
op.create_index(op.f('ix_oauth_states_state'), 'oauth_states', ['state'], unique=True)
|
||||
op.create_table('organizations',
|
||||
sa.Column('name', sa.String(length=255), nullable=False),
|
||||
sa.Column('slug', sa.String(length=255), nullable=False),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||
sa.Column('settings', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_table(
|
||||
"organizations",
|
||||
sa.Column("name", sa.String(length=255), nullable=False),
|
||||
sa.Column("slug", sa.String(length=255), nullable=False),
|
||||
sa.Column("description", sa.Text(), nullable=True),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=False),
|
||||
sa.Column("settings", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
op.create_index(op.f('ix_organizations_is_active'), 'organizations', ['is_active'], unique=False)
|
||||
op.create_index(op.f('ix_organizations_name'), 'organizations', ['name'], unique=False)
|
||||
op.create_index('ix_organizations_name_active', 'organizations', ['name', 'is_active'], unique=False)
|
||||
op.create_index(op.f('ix_organizations_slug'), 'organizations', ['slug'], unique=True)
|
||||
op.create_index('ix_organizations_slug_active', 'organizations', ['slug', 'is_active'], unique=False)
|
||||
op.create_table('users',
|
||||
sa.Column('email', sa.String(length=255), nullable=False),
|
||||
sa.Column('password_hash', sa.String(length=255), nullable=True),
|
||||
sa.Column('first_name', sa.String(length=100), nullable=False),
|
||||
sa.Column('last_name', sa.String(length=100), nullable=True),
|
||||
sa.Column('phone_number', sa.String(length=20), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||
sa.Column('is_superuser', sa.Boolean(), nullable=False),
|
||||
sa.Column('preferences', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column('locale', sa.String(length=10), nullable=True),
|
||||
sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_organizations_is_active"), "organizations", ["is_active"], unique=False
|
||||
op.create_index(op.f('ix_users_deleted_at'), 'users', ['deleted_at'], unique=False)
|
||||
op.create_index(op.f('ix_users_email'), 'users', ['email'], unique=True)
|
||||
op.create_index(op.f('ix_users_is_active'), 'users', ['is_active'], unique=False)
|
||||
op.create_index(op.f('ix_users_is_superuser'), 'users', ['is_superuser'], unique=False)
|
||||
op.create_index(op.f('ix_users_locale'), 'users', ['locale'], unique=False)
|
||||
op.create_table('oauth_accounts',
|
||||
sa.Column('user_id', sa.UUID(), nullable=False),
|
||||
sa.Column('provider', sa.String(length=50), nullable=False),
|
||||
sa.Column('provider_user_id', sa.String(length=255), nullable=False),
|
||||
sa.Column('provider_email', sa.String(length=255), nullable=True),
|
||||
sa.Column('access_token_encrypted', sa.String(length=2048), nullable=True),
|
||||
sa.Column('refresh_token_encrypted', sa.String(length=2048), nullable=True),
|
||||
sa.Column('token_expires_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('provider', 'provider_user_id', name='uq_oauth_provider_user')
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_organizations_name"), "organizations", ["name"], unique=False
|
||||
op.create_index(op.f('ix_oauth_accounts_provider'), 'oauth_accounts', ['provider'], unique=False)
|
||||
op.create_index(op.f('ix_oauth_accounts_provider_email'), 'oauth_accounts', ['provider_email'], unique=False)
|
||||
op.create_index(op.f('ix_oauth_accounts_user_id'), 'oauth_accounts', ['user_id'], unique=False)
|
||||
op.create_index('ix_oauth_accounts_user_provider', 'oauth_accounts', ['user_id', 'provider'], unique=False)
|
||||
op.create_table('oauth_clients',
|
||||
sa.Column('client_id', sa.String(length=64), nullable=False),
|
||||
sa.Column('client_secret_hash', sa.String(length=255), nullable=True),
|
||||
sa.Column('client_name', sa.String(length=255), nullable=False),
|
||||
sa.Column('client_description', sa.String(length=1000), nullable=True),
|
||||
sa.Column('client_type', sa.String(length=20), nullable=False),
|
||||
sa.Column('redirect_uris', postgresql.JSONB(astext_type=sa.Text()), nullable=False),
|
||||
sa.Column('allowed_scopes', postgresql.JSONB(astext_type=sa.Text()), nullable=False),
|
||||
sa.Column('access_token_lifetime', sa.String(length=10), nullable=False),
|
||||
sa.Column('refresh_token_lifetime', sa.String(length=10), nullable=False),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||
sa.Column('owner_user_id', sa.UUID(), nullable=True),
|
||||
sa.Column('mcp_server_url', sa.String(length=2048), nullable=True),
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(['owner_user_id'], ['users.id'], ondelete='SET NULL'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(
|
||||
"ix_organizations_name_active",
|
||||
"organizations",
|
||||
["name", "is_active"],
|
||||
unique=False,
|
||||
op.create_index(op.f('ix_oauth_clients_client_id'), 'oauth_clients', ['client_id'], unique=True)
|
||||
op.create_index(op.f('ix_oauth_clients_is_active'), 'oauth_clients', ['is_active'], unique=False)
|
||||
op.create_table('user_organizations',
|
||||
sa.Column('user_id', sa.UUID(), nullable=False),
|
||||
sa.Column('organization_id', sa.UUID(), nullable=False),
|
||||
sa.Column('role', sa.Enum('OWNER', 'ADMIN', 'MEMBER', 'GUEST', name='organizationrole'), nullable=False),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||
sa.Column('custom_permissions', sa.String(length=500), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('user_id', 'organization_id')
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_organizations_slug"), "organizations", ["slug"], unique=True
|
||||
op.create_index('ix_user_org_org_active', 'user_organizations', ['organization_id', 'is_active'], unique=False)
|
||||
op.create_index('ix_user_org_role', 'user_organizations', ['role'], unique=False)
|
||||
op.create_index('ix_user_org_user_active', 'user_organizations', ['user_id', 'is_active'], unique=False)
|
||||
op.create_index(op.f('ix_user_organizations_is_active'), 'user_organizations', ['is_active'], unique=False)
|
||||
op.create_table('user_sessions',
|
||||
sa.Column('user_id', sa.UUID(), nullable=False),
|
||||
sa.Column('refresh_token_jti', sa.String(length=255), nullable=False),
|
||||
sa.Column('device_name', sa.String(length=255), nullable=True),
|
||||
sa.Column('device_id', sa.String(length=255), nullable=True),
|
||||
sa.Column('ip_address', sa.String(length=45), nullable=True),
|
||||
sa.Column('user_agent', sa.String(length=500), nullable=True),
|
||||
sa.Column('last_used_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('expires_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||
sa.Column('location_city', sa.String(length=100), nullable=True),
|
||||
sa.Column('location_country', sa.String(length=100), nullable=True),
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(
|
||||
"ix_organizations_slug_active",
|
||||
"organizations",
|
||||
["slug", "is_active"],
|
||||
unique=False,
|
||||
op.create_index(op.f('ix_user_sessions_is_active'), 'user_sessions', ['is_active'], unique=False)
|
||||
op.create_index('ix_user_sessions_jti_active', 'user_sessions', ['refresh_token_jti', 'is_active'], unique=False)
|
||||
op.create_index(op.f('ix_user_sessions_refresh_token_jti'), 'user_sessions', ['refresh_token_jti'], unique=True)
|
||||
op.create_index('ix_user_sessions_user_active', 'user_sessions', ['user_id', 'is_active'], unique=False)
|
||||
op.create_index(op.f('ix_user_sessions_user_id'), 'user_sessions', ['user_id'], unique=False)
|
||||
op.create_table('oauth_authorization_codes',
|
||||
sa.Column('code', sa.String(length=128), nullable=False),
|
||||
sa.Column('client_id', sa.String(length=64), nullable=False),
|
||||
sa.Column('user_id', sa.UUID(), nullable=False),
|
||||
sa.Column('redirect_uri', sa.String(length=2048), nullable=False),
|
||||
sa.Column('scope', sa.String(length=1000), nullable=False),
|
||||
sa.Column('code_challenge', sa.String(length=128), nullable=True),
|
||||
sa.Column('code_challenge_method', sa.String(length=10), nullable=True),
|
||||
sa.Column('state', sa.String(length=256), nullable=True),
|
||||
sa.Column('nonce', sa.String(length=256), nullable=True),
|
||||
sa.Column('expires_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('used', sa.Boolean(), nullable=False),
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(['client_id'], ['oauth_clients.client_id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_table(
|
||||
"users",
|
||||
sa.Column("email", sa.String(length=255), nullable=False),
|
||||
sa.Column("password_hash", sa.String(length=255), nullable=True),
|
||||
sa.Column("first_name", sa.String(length=100), nullable=False),
|
||||
sa.Column("last_name", sa.String(length=100), nullable=True),
|
||||
sa.Column("phone_number", sa.String(length=20), nullable=True),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=False),
|
||||
sa.Column("is_superuser", sa.Boolean(), nullable=False),
|
||||
sa.Column(
|
||||
"preferences", postgresql.JSONB(astext_type=sa.Text()), nullable=True
|
||||
),
|
||||
sa.Column("locale", sa.String(length=10), nullable=True),
|
||||
sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
op.create_index('ix_oauth_authorization_codes_client_user', 'oauth_authorization_codes', ['client_id', 'user_id'], unique=False)
|
||||
op.create_index(op.f('ix_oauth_authorization_codes_code'), 'oauth_authorization_codes', ['code'], unique=True)
|
||||
op.create_index('ix_oauth_authorization_codes_expires_at', 'oauth_authorization_codes', ['expires_at'], unique=False)
|
||||
op.create_table('oauth_consents',
|
||||
sa.Column('user_id', sa.UUID(), nullable=False),
|
||||
sa.Column('client_id', sa.String(length=64), nullable=False),
|
||||
sa.Column('granted_scopes', sa.String(length=1000), nullable=False),
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(['client_id'], ['oauth_clients.client_id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f("ix_users_deleted_at"), "users", ["deleted_at"], unique=False)
|
||||
op.create_index(op.f("ix_users_email"), "users", ["email"], unique=True)
|
||||
op.create_index(op.f("ix_users_is_active"), "users", ["is_active"], unique=False)
|
||||
op.create_index(
|
||||
op.f("ix_users_is_superuser"), "users", ["is_superuser"], unique=False
|
||||
)
|
||||
op.create_index(op.f("ix_users_locale"), "users", ["locale"], unique=False)
|
||||
op.create_table(
|
||||
"oauth_accounts",
|
||||
sa.Column("user_id", sa.UUID(), nullable=False),
|
||||
sa.Column("provider", sa.String(length=50), nullable=False),
|
||||
sa.Column("provider_user_id", sa.String(length=255), nullable=False),
|
||||
sa.Column("provider_email", sa.String(length=255), nullable=True),
|
||||
sa.Column("access_token_encrypted", sa.String(length=2048), nullable=True),
|
||||
sa.Column("refresh_token_encrypted", sa.String(length=2048), nullable=True),
|
||||
sa.Column("token_expires_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint(
|
||||
"provider", "provider_user_id", name="uq_oauth_provider_user"
|
||||
),
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_oauth_accounts_provider"), "oauth_accounts", ["provider"], unique=False
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_oauth_accounts_provider_email"),
|
||||
"oauth_accounts",
|
||||
["provider_email"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_oauth_accounts_user_id"), "oauth_accounts", ["user_id"], unique=False
|
||||
)
|
||||
op.create_index(
|
||||
"ix_oauth_accounts_user_provider",
|
||||
"oauth_accounts",
|
||||
["user_id", "provider"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_table(
|
||||
"oauth_clients",
|
||||
sa.Column("client_id", sa.String(length=64), nullable=False),
|
||||
sa.Column("client_secret_hash", sa.String(length=255), nullable=True),
|
||||
sa.Column("client_name", sa.String(length=255), nullable=False),
|
||||
sa.Column("client_description", sa.String(length=1000), nullable=True),
|
||||
sa.Column("client_type", sa.String(length=20), nullable=False),
|
||||
sa.Column(
|
||||
"redirect_uris", postgresql.JSONB(astext_type=sa.Text()), nullable=False
|
||||
),
|
||||
sa.Column(
|
||||
"allowed_scopes", postgresql.JSONB(astext_type=sa.Text()), nullable=False
|
||||
),
|
||||
sa.Column("access_token_lifetime", sa.String(length=10), nullable=False),
|
||||
sa.Column("refresh_token_lifetime", sa.String(length=10), nullable=False),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=False),
|
||||
sa.Column("owner_user_id", sa.UUID(), nullable=True),
|
||||
sa.Column("mcp_server_url", sa.String(length=2048), nullable=True),
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(["owner_user_id"], ["users.id"], ondelete="SET NULL"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_oauth_clients_client_id"), "oauth_clients", ["client_id"], unique=True
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_oauth_clients_is_active"), "oauth_clients", ["is_active"], unique=False
|
||||
)
|
||||
op.create_table(
|
||||
"user_organizations",
|
||||
sa.Column("user_id", sa.UUID(), nullable=False),
|
||||
sa.Column("organization_id", sa.UUID(), nullable=False),
|
||||
sa.Column(
|
||||
"role",
|
||||
sa.Enum("OWNER", "ADMIN", "MEMBER", "GUEST", name="organizationrole"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=False),
|
||||
sa.Column("custom_permissions", sa.String(length=500), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["organization_id"], ["organizations.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("user_id", "organization_id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_user_org_org_active",
|
||||
"user_organizations",
|
||||
["organization_id", "is_active"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index("ix_user_org_role", "user_organizations", ["role"], unique=False)
|
||||
op.create_index(
|
||||
"ix_user_org_user_active",
|
||||
"user_organizations",
|
||||
["user_id", "is_active"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_user_organizations_is_active"),
|
||||
"user_organizations",
|
||||
["is_active"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_table(
|
||||
"user_sessions",
|
||||
sa.Column("user_id", sa.UUID(), nullable=False),
|
||||
sa.Column("refresh_token_jti", sa.String(length=255), nullable=False),
|
||||
sa.Column("device_name", sa.String(length=255), nullable=True),
|
||||
sa.Column("device_id", sa.String(length=255), nullable=True),
|
||||
sa.Column("ip_address", sa.String(length=45), nullable=True),
|
||||
sa.Column("user_agent", sa.String(length=500), nullable=True),
|
||||
sa.Column("last_used_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=False),
|
||||
sa.Column("location_city", sa.String(length=100), nullable=True),
|
||||
sa.Column("location_country", sa.String(length=100), nullable=True),
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_user_sessions_is_active"), "user_sessions", ["is_active"], unique=False
|
||||
)
|
||||
op.create_index(
|
||||
"ix_user_sessions_jti_active",
|
||||
"user_sessions",
|
||||
["refresh_token_jti", "is_active"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_user_sessions_refresh_token_jti"),
|
||||
"user_sessions",
|
||||
["refresh_token_jti"],
|
||||
unique=True,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_user_sessions_user_active",
|
||||
"user_sessions",
|
||||
["user_id", "is_active"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_user_sessions_user_id"), "user_sessions", ["user_id"], unique=False
|
||||
)
|
||||
op.create_table(
|
||||
"oauth_authorization_codes",
|
||||
sa.Column("code", sa.String(length=128), nullable=False),
|
||||
sa.Column("client_id", sa.String(length=64), nullable=False),
|
||||
sa.Column("user_id", sa.UUID(), nullable=False),
|
||||
sa.Column("redirect_uri", sa.String(length=2048), nullable=False),
|
||||
sa.Column("scope", sa.String(length=1000), nullable=False),
|
||||
sa.Column("code_challenge", sa.String(length=128), nullable=True),
|
||||
sa.Column("code_challenge_method", sa.String(length=10), nullable=True),
|
||||
sa.Column("state", sa.String(length=256), nullable=True),
|
||||
sa.Column("nonce", sa.String(length=256), nullable=True),
|
||||
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("used", sa.Boolean(), nullable=False),
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["client_id"], ["oauth_clients.client_id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_oauth_authorization_codes_client_user",
|
||||
"oauth_authorization_codes",
|
||||
["client_id", "user_id"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_oauth_authorization_codes_code"),
|
||||
"oauth_authorization_codes",
|
||||
["code"],
|
||||
unique=True,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_oauth_authorization_codes_expires_at",
|
||||
"oauth_authorization_codes",
|
||||
["expires_at"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_table(
|
||||
"oauth_consents",
|
||||
sa.Column("user_id", sa.UUID(), nullable=False),
|
||||
sa.Column("client_id", sa.String(length=64), nullable=False),
|
||||
sa.Column("granted_scopes", sa.String(length=1000), nullable=False),
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["client_id"], ["oauth_clients.client_id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_oauth_consents_user_client",
|
||||
"oauth_consents",
|
||||
["user_id", "client_id"],
|
||||
unique=True,
|
||||
)
|
||||
op.create_table(
|
||||
"oauth_provider_refresh_tokens",
|
||||
sa.Column("token_hash", sa.String(length=64), nullable=False),
|
||||
sa.Column("jti", sa.String(length=64), nullable=False),
|
||||
sa.Column("client_id", sa.String(length=64), nullable=False),
|
||||
sa.Column("user_id", sa.UUID(), nullable=False),
|
||||
sa.Column("scope", sa.String(length=1000), nullable=False),
|
||||
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("revoked", sa.Boolean(), nullable=False),
|
||||
sa.Column("last_used_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("device_info", sa.String(length=500), nullable=True),
|
||||
sa.Column("ip_address", sa.String(length=45), nullable=True),
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["client_id"], ["oauth_clients.client_id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_oauth_provider_refresh_tokens_client_user",
|
||||
"oauth_provider_refresh_tokens",
|
||||
["client_id", "user_id"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_oauth_provider_refresh_tokens_expires_at",
|
||||
"oauth_provider_refresh_tokens",
|
||||
["expires_at"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_oauth_provider_refresh_tokens_jti"),
|
||||
"oauth_provider_refresh_tokens",
|
||||
["jti"],
|
||||
unique=True,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_oauth_provider_refresh_tokens_revoked"),
|
||||
"oauth_provider_refresh_tokens",
|
||||
["revoked"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_oauth_provider_refresh_tokens_token_hash"),
|
||||
"oauth_provider_refresh_tokens",
|
||||
["token_hash"],
|
||||
unique=True,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_oauth_provider_refresh_tokens_user_revoked",
|
||||
"oauth_provider_refresh_tokens",
|
||||
["user_id", "revoked"],
|
||||
unique=False,
|
||||
op.create_index('ix_oauth_consents_user_client', 'oauth_consents', ['user_id', 'client_id'], unique=True)
|
||||
op.create_table('oauth_provider_refresh_tokens',
|
||||
sa.Column('token_hash', sa.String(length=64), nullable=False),
|
||||
sa.Column('jti', sa.String(length=64), nullable=False),
|
||||
sa.Column('client_id', sa.String(length=64), nullable=False),
|
||||
sa.Column('user_id', sa.UUID(), nullable=False),
|
||||
sa.Column('scope', sa.String(length=1000), nullable=False),
|
||||
sa.Column('expires_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('revoked', sa.Boolean(), nullable=False),
|
||||
sa.Column('last_used_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('device_info', sa.String(length=500), nullable=True),
|
||||
sa.Column('ip_address', sa.String(length=45), nullable=True),
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(['client_id'], ['oauth_clients.client_id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index('ix_oauth_provider_refresh_tokens_client_user', 'oauth_provider_refresh_tokens', ['client_id', 'user_id'], unique=False)
|
||||
op.create_index('ix_oauth_provider_refresh_tokens_expires_at', 'oauth_provider_refresh_tokens', ['expires_at'], unique=False)
|
||||
op.create_index(op.f('ix_oauth_provider_refresh_tokens_jti'), 'oauth_provider_refresh_tokens', ['jti'], unique=True)
|
||||
op.create_index(op.f('ix_oauth_provider_refresh_tokens_revoked'), 'oauth_provider_refresh_tokens', ['revoked'], unique=False)
|
||||
op.create_index(op.f('ix_oauth_provider_refresh_tokens_token_hash'), 'oauth_provider_refresh_tokens', ['token_hash'], unique=True)
|
||||
op.create_index('ix_oauth_provider_refresh_tokens_user_revoked', 'oauth_provider_refresh_tokens', ['user_id', 'revoked'], unique=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(
|
||||
"ix_oauth_provider_refresh_tokens_user_revoked",
|
||||
table_name="oauth_provider_refresh_tokens",
|
||||
)
|
||||
op.drop_index(
|
||||
op.f("ix_oauth_provider_refresh_tokens_token_hash"),
|
||||
table_name="oauth_provider_refresh_tokens",
|
||||
)
|
||||
op.drop_index(
|
||||
op.f("ix_oauth_provider_refresh_tokens_revoked"),
|
||||
table_name="oauth_provider_refresh_tokens",
|
||||
)
|
||||
op.drop_index(
|
||||
op.f("ix_oauth_provider_refresh_tokens_jti"),
|
||||
table_name="oauth_provider_refresh_tokens",
|
||||
)
|
||||
op.drop_index(
|
||||
"ix_oauth_provider_refresh_tokens_expires_at",
|
||||
table_name="oauth_provider_refresh_tokens",
|
||||
)
|
||||
op.drop_index(
|
||||
"ix_oauth_provider_refresh_tokens_client_user",
|
||||
table_name="oauth_provider_refresh_tokens",
|
||||
)
|
||||
op.drop_table("oauth_provider_refresh_tokens")
|
||||
op.drop_index("ix_oauth_consents_user_client", table_name="oauth_consents")
|
||||
op.drop_table("oauth_consents")
|
||||
op.drop_index(
|
||||
"ix_oauth_authorization_codes_expires_at",
|
||||
table_name="oauth_authorization_codes",
|
||||
)
|
||||
op.drop_index(
|
||||
op.f("ix_oauth_authorization_codes_code"),
|
||||
table_name="oauth_authorization_codes",
|
||||
)
|
||||
op.drop_index(
|
||||
"ix_oauth_authorization_codes_client_user",
|
||||
table_name="oauth_authorization_codes",
|
||||
)
|
||||
op.drop_table("oauth_authorization_codes")
|
||||
op.drop_index(op.f("ix_user_sessions_user_id"), table_name="user_sessions")
|
||||
op.drop_index("ix_user_sessions_user_active", table_name="user_sessions")
|
||||
op.drop_index(
|
||||
op.f("ix_user_sessions_refresh_token_jti"), table_name="user_sessions"
|
||||
)
|
||||
op.drop_index("ix_user_sessions_jti_active", table_name="user_sessions")
|
||||
op.drop_index(op.f("ix_user_sessions_is_active"), table_name="user_sessions")
|
||||
op.drop_table("user_sessions")
|
||||
op.drop_index(
|
||||
op.f("ix_user_organizations_is_active"), table_name="user_organizations"
|
||||
)
|
||||
op.drop_index("ix_user_org_user_active", table_name="user_organizations")
|
||||
op.drop_index("ix_user_org_role", table_name="user_organizations")
|
||||
op.drop_index("ix_user_org_org_active", table_name="user_organizations")
|
||||
op.drop_table("user_organizations")
|
||||
op.drop_index(op.f("ix_oauth_clients_is_active"), table_name="oauth_clients")
|
||||
op.drop_index(op.f("ix_oauth_clients_client_id"), table_name="oauth_clients")
|
||||
op.drop_table("oauth_clients")
|
||||
op.drop_index("ix_oauth_accounts_user_provider", table_name="oauth_accounts")
|
||||
op.drop_index(op.f("ix_oauth_accounts_user_id"), table_name="oauth_accounts")
|
||||
op.drop_index(op.f("ix_oauth_accounts_provider_email"), table_name="oauth_accounts")
|
||||
op.drop_index(op.f("ix_oauth_accounts_provider"), table_name="oauth_accounts")
|
||||
op.drop_table("oauth_accounts")
|
||||
op.drop_index(op.f("ix_users_locale"), table_name="users")
|
||||
op.drop_index(op.f("ix_users_is_superuser"), table_name="users")
|
||||
op.drop_index(op.f("ix_users_is_active"), table_name="users")
|
||||
op.drop_index(op.f("ix_users_email"), table_name="users")
|
||||
op.drop_index(op.f("ix_users_deleted_at"), table_name="users")
|
||||
op.drop_table("users")
|
||||
op.drop_index("ix_organizations_slug_active", table_name="organizations")
|
||||
op.drop_index(op.f("ix_organizations_slug"), table_name="organizations")
|
||||
op.drop_index("ix_organizations_name_active", table_name="organizations")
|
||||
op.drop_index(op.f("ix_organizations_name"), table_name="organizations")
|
||||
op.drop_index(op.f("ix_organizations_is_active"), table_name="organizations")
|
||||
op.drop_table("organizations")
|
||||
op.drop_index(op.f("ix_oauth_states_state"), table_name="oauth_states")
|
||||
op.drop_table("oauth_states")
|
||||
op.drop_index('ix_oauth_provider_refresh_tokens_user_revoked', table_name='oauth_provider_refresh_tokens')
|
||||
op.drop_index(op.f('ix_oauth_provider_refresh_tokens_token_hash'), table_name='oauth_provider_refresh_tokens')
|
||||
op.drop_index(op.f('ix_oauth_provider_refresh_tokens_revoked'), table_name='oauth_provider_refresh_tokens')
|
||||
op.drop_index(op.f('ix_oauth_provider_refresh_tokens_jti'), table_name='oauth_provider_refresh_tokens')
|
||||
op.drop_index('ix_oauth_provider_refresh_tokens_expires_at', table_name='oauth_provider_refresh_tokens')
|
||||
op.drop_index('ix_oauth_provider_refresh_tokens_client_user', table_name='oauth_provider_refresh_tokens')
|
||||
op.drop_table('oauth_provider_refresh_tokens')
|
||||
op.drop_index('ix_oauth_consents_user_client', table_name='oauth_consents')
|
||||
op.drop_table('oauth_consents')
|
||||
op.drop_index('ix_oauth_authorization_codes_expires_at', table_name='oauth_authorization_codes')
|
||||
op.drop_index(op.f('ix_oauth_authorization_codes_code'), table_name='oauth_authorization_codes')
|
||||
op.drop_index('ix_oauth_authorization_codes_client_user', table_name='oauth_authorization_codes')
|
||||
op.drop_table('oauth_authorization_codes')
|
||||
op.drop_index(op.f('ix_user_sessions_user_id'), table_name='user_sessions')
|
||||
op.drop_index('ix_user_sessions_user_active', table_name='user_sessions')
|
||||
op.drop_index(op.f('ix_user_sessions_refresh_token_jti'), table_name='user_sessions')
|
||||
op.drop_index('ix_user_sessions_jti_active', table_name='user_sessions')
|
||||
op.drop_index(op.f('ix_user_sessions_is_active'), table_name='user_sessions')
|
||||
op.drop_table('user_sessions')
|
||||
op.drop_index(op.f('ix_user_organizations_is_active'), table_name='user_organizations')
|
||||
op.drop_index('ix_user_org_user_active', table_name='user_organizations')
|
||||
op.drop_index('ix_user_org_role', table_name='user_organizations')
|
||||
op.drop_index('ix_user_org_org_active', table_name='user_organizations')
|
||||
op.drop_table('user_organizations')
|
||||
op.drop_index(op.f('ix_oauth_clients_is_active'), table_name='oauth_clients')
|
||||
op.drop_index(op.f('ix_oauth_clients_client_id'), table_name='oauth_clients')
|
||||
op.drop_table('oauth_clients')
|
||||
op.drop_index('ix_oauth_accounts_user_provider', table_name='oauth_accounts')
|
||||
op.drop_index(op.f('ix_oauth_accounts_user_id'), table_name='oauth_accounts')
|
||||
op.drop_index(op.f('ix_oauth_accounts_provider_email'), table_name='oauth_accounts')
|
||||
op.drop_index(op.f('ix_oauth_accounts_provider'), table_name='oauth_accounts')
|
||||
op.drop_table('oauth_accounts')
|
||||
op.drop_index(op.f('ix_users_locale'), table_name='users')
|
||||
op.drop_index(op.f('ix_users_is_superuser'), table_name='users')
|
||||
op.drop_index(op.f('ix_users_is_active'), table_name='users')
|
||||
op.drop_index(op.f('ix_users_email'), table_name='users')
|
||||
op.drop_index(op.f('ix_users_deleted_at'), table_name='users')
|
||||
op.drop_table('users')
|
||||
op.drop_index('ix_organizations_slug_active', table_name='organizations')
|
||||
op.drop_index(op.f('ix_organizations_slug'), table_name='organizations')
|
||||
op.drop_index('ix_organizations_name_active', table_name='organizations')
|
||||
op.drop_index(op.f('ix_organizations_name'), table_name='organizations')
|
||||
op.drop_index(op.f('ix_organizations_is_active'), table_name='organizations')
|
||||
op.drop_table('organizations')
|
||||
op.drop_index(op.f('ix_oauth_states_state'), table_name='oauth_states')
|
||||
op.drop_table('oauth_states')
|
||||
# ### end Alembic commands ###
|
||||
|
||||
@@ -114,13 +114,8 @@ def upgrade() -> None:
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop indexes in reverse order
|
||||
op.drop_index(
|
||||
"ix_perf_oauth_auth_codes_expires", table_name="oauth_authorization_codes"
|
||||
)
|
||||
op.drop_index(
|
||||
"ix_perf_oauth_refresh_tokens_expires",
|
||||
table_name="oauth_provider_refresh_tokens",
|
||||
)
|
||||
op.drop_index("ix_perf_oauth_auth_codes_expires", table_name="oauth_authorization_codes")
|
||||
op.drop_index("ix_perf_oauth_refresh_tokens_expires", table_name="oauth_provider_refresh_tokens")
|
||||
op.drop_index("ix_perf_user_sessions_expires", table_name="user_sessions")
|
||||
op.drop_index("ix_perf_organizations_slug_lower", table_name="organizations")
|
||||
op.drop_index("ix_perf_users_active", table_name="users")
|
||||
|
||||
@@ -0,0 +1,66 @@
|
||||
"""Enable pgvector extension
|
||||
|
||||
Revision ID: 0003
|
||||
Revises: 0002
|
||||
Create Date: 2025-12-30
|
||||
|
||||
This migration enables the pgvector extension for PostgreSQL, which provides
|
||||
vector similarity search capabilities required for the RAG (Retrieval-Augmented
|
||||
Generation) knowledge base system.
|
||||
|
||||
Vector Dimension Reference (per ADR-008 and SPIKE-006):
|
||||
---------------------------------------------------------
|
||||
The dimension size depends on the embedding model used:
|
||||
|
||||
| Model | Dimensions | Use Case |
|
||||
|----------------------------|------------|------------------------------|
|
||||
| text-embedding-3-small | 1536 | General docs, conversations |
|
||||
| text-embedding-3-large | 256-3072 | High accuracy (configurable) |
|
||||
| voyage-code-3 | 1024 | Code files (Python, JS, etc) |
|
||||
| voyage-3-large | 1024 | High quality general purpose |
|
||||
| nomic-embed-text (Ollama) | 768 | Local/fallback embedding |
|
||||
|
||||
Recommended defaults for Syndarix:
|
||||
- Documentation/conversations: 1536 (text-embedding-3-small)
|
||||
- Code files: 1024 (voyage-code-3)
|
||||
|
||||
Prerequisites:
|
||||
--------------
|
||||
This migration requires PostgreSQL with the pgvector extension installed.
|
||||
The Docker Compose configuration uses `pgvector/pgvector:pg17` which includes
|
||||
the extension pre-installed.
|
||||
|
||||
References:
|
||||
-----------
|
||||
- ADR-008: Knowledge Base and RAG Architecture
|
||||
- SPIKE-006: Knowledge Base with pgvector for RAG System
|
||||
- https://github.com/pgvector/pgvector
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "0003"
|
||||
down_revision: str | None = "0002"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Enable the pgvector extension.
|
||||
|
||||
The CREATE EXTENSION IF NOT EXISTS statement is idempotent - it will
|
||||
succeed whether the extension already exists or not.
|
||||
"""
|
||||
op.execute("CREATE EXTENSION IF NOT EXISTS vector")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Drop the pgvector extension.
|
||||
|
||||
Note: This will fail if any tables with vector columns exist.
|
||||
Future migrations that create vector columns should be downgraded first.
|
||||
"""
|
||||
op.execute("DROP EXTENSION IF EXISTS vector")
|
||||
@@ -1,35 +0,0 @@
|
||||
"""rename oauth account token fields drop encrypted suffix
|
||||
|
||||
Revision ID: 0003
|
||||
Revises: 0002
|
||||
Create Date: 2026-02-27 01:03:18.869178
|
||||
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "0003"
|
||||
down_revision: str | None = "0002"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.alter_column(
|
||||
"oauth_accounts", "access_token_encrypted", new_column_name="access_token"
|
||||
)
|
||||
op.alter_column(
|
||||
"oauth_accounts", "refresh_token_encrypted", new_column_name="refresh_token"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.alter_column(
|
||||
"oauth_accounts", "access_token", new_column_name="access_token_encrypted"
|
||||
)
|
||||
op.alter_column(
|
||||
"oauth_accounts", "refresh_token", new_column_name="refresh_token_encrypted"
|
||||
)
|
||||
@@ -1,12 +1,12 @@
|
||||
from fastapi import Depends, Header, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from fastapi.security.utils import get_authorization_scheme_param
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.auth import TokenExpiredError, TokenInvalidError, get_token_data
|
||||
from app.core.database import get_db
|
||||
from app.models.user import User
|
||||
from app.repositories.user import user_repo
|
||||
|
||||
# OAuth2 configuration
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||||
@@ -32,8 +32,9 @@ async def get_current_user(
|
||||
# Decode token and get user ID
|
||||
token_data = get_token_data(token)
|
||||
|
||||
# Get user from database via repository
|
||||
user = await user_repo.get(db, id=str(token_data.user_id))
|
||||
# Get user from database
|
||||
result = await db.execute(select(User).where(User.id == token_data.user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
@@ -143,7 +144,8 @@ async def get_optional_current_user(
|
||||
|
||||
try:
|
||||
token_data = get_token_data(token)
|
||||
user = await user_repo.get(db, id=str(token_data.user_id))
|
||||
result = await db.execute(select(User).where(User.id == token_data.user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user or not user.is_active:
|
||||
return None
|
||||
return user
|
||||
|
||||
36
backend/app/api/dependencies/event_bus.py
Normal file
36
backend/app/api/dependencies/event_bus.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""
|
||||
Event bus dependency for FastAPI routes.
|
||||
|
||||
This module provides the FastAPI dependency for injecting the EventBus
|
||||
into route handlers. The event bus is a singleton that maintains
|
||||
Redis pub/sub connections for real-time event streaming.
|
||||
"""
|
||||
|
||||
from app.services.event_bus import (
|
||||
EventBus,
|
||||
get_connected_event_bus as _get_connected_event_bus,
|
||||
)
|
||||
|
||||
|
||||
async def get_event_bus() -> EventBus:
|
||||
"""
|
||||
FastAPI dependency that provides a connected EventBus instance.
|
||||
|
||||
The EventBus is a singleton that maintains Redis pub/sub connections.
|
||||
It's lazily initialized and connected on first access, and should be
|
||||
closed during application shutdown via close_event_bus().
|
||||
|
||||
Usage:
|
||||
@router.get("/events/stream")
|
||||
async def stream_events(
|
||||
event_bus: EventBus = Depends(get_event_bus)
|
||||
):
|
||||
...
|
||||
|
||||
Returns:
|
||||
EventBus: The global connected event bus instance
|
||||
|
||||
Raises:
|
||||
EventBusConnectionError: If connection to Redis fails
|
||||
"""
|
||||
return await _get_connected_event_bus()
|
||||
@@ -15,9 +15,9 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.dependencies.auth import get_current_user
|
||||
from app.core.database import get_db
|
||||
from app.crud.organization import organization as organization_crud
|
||||
from app.models.user import User
|
||||
from app.models.user_organization import OrganizationRole
|
||||
from app.services.organization_service import organization_service
|
||||
|
||||
|
||||
def require_superuser(current_user: User = Depends(get_current_user)) -> User:
|
||||
@@ -81,7 +81,7 @@ class OrganizationPermission:
|
||||
return current_user
|
||||
|
||||
# Get user's role in organization
|
||||
user_role = await organization_service.get_user_role_in_org(
|
||||
user_role = await organization_crud.get_user_role_in_org(
|
||||
db, user_id=current_user.id, organization_id=organization_id
|
||||
)
|
||||
|
||||
@@ -123,7 +123,7 @@ async def require_org_membership(
|
||||
if current_user.is_superuser:
|
||||
return current_user
|
||||
|
||||
user_role = await organization_service.get_user_role_in_org(
|
||||
user_role = await organization_crud.get_user_role_in_org(
|
||||
db, user_id=current_user.id, organization_id=organization_id
|
||||
)
|
||||
|
||||
|
||||
@@ -1,41 +0,0 @@
|
||||
# app/api/dependencies/services.py
|
||||
"""FastAPI dependency functions for service singletons."""
|
||||
|
||||
from app.services import oauth_provider_service
|
||||
from app.services.auth_service import AuthService
|
||||
from app.services.oauth_service import OAuthService
|
||||
from app.services.organization_service import OrganizationService, organization_service
|
||||
from app.services.session_service import SessionService, session_service
|
||||
from app.services.user_service import UserService, user_service
|
||||
|
||||
|
||||
def get_auth_service() -> AuthService:
|
||||
"""Return the AuthService singleton for dependency injection."""
|
||||
from app.services.auth_service import AuthService as _AuthService
|
||||
|
||||
return _AuthService()
|
||||
|
||||
|
||||
def get_user_service() -> UserService:
|
||||
"""Return the UserService singleton for dependency injection."""
|
||||
return user_service
|
||||
|
||||
|
||||
def get_organization_service() -> OrganizationService:
|
||||
"""Return the OrganizationService singleton for dependency injection."""
|
||||
return organization_service
|
||||
|
||||
|
||||
def get_session_service() -> SessionService:
|
||||
"""Return the SessionService singleton for dependency injection."""
|
||||
return session_service
|
||||
|
||||
|
||||
def get_oauth_service() -> OAuthService:
|
||||
"""Return OAuthService for dependency injection."""
|
||||
return OAuthService()
|
||||
|
||||
|
||||
def get_oauth_provider_service():
|
||||
"""Return the oauth_provider_service module for dependency injection."""
|
||||
return oauth_provider_service
|
||||
@@ -3,6 +3,7 @@ from fastapi import APIRouter
|
||||
from app.api.routes import (
|
||||
admin,
|
||||
auth,
|
||||
events,
|
||||
oauth,
|
||||
oauth_provider,
|
||||
organizations,
|
||||
@@ -22,3 +23,5 @@ api_router.include_router(admin.router, prefix="/admin", tags=["Admin"])
|
||||
api_router.include_router(
|
||||
organizations.router, prefix="/organizations", tags=["Organizations"]
|
||||
)
|
||||
# SSE events router - no prefix, routes define full paths
|
||||
api_router.include_router(events.router, tags=["Events"])
|
||||
|
||||
@@ -14,6 +14,7 @@ from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, status
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.dependencies.permissions import require_superuser
|
||||
@@ -24,9 +25,12 @@ from app.core.exceptions import (
|
||||
ErrorCode,
|
||||
NotFoundError,
|
||||
)
|
||||
from app.core.repository_exceptions import DuplicateEntryError
|
||||
from app.crud.organization import organization as organization_crud
|
||||
from app.crud.session import session as session_crud
|
||||
from app.crud.user import user as user_crud
|
||||
from app.models.organization import Organization
|
||||
from app.models.user import User
|
||||
from app.models.user_organization import OrganizationRole
|
||||
from app.models.user_organization import OrganizationRole, UserOrganization
|
||||
from app.schemas.common import (
|
||||
MessageResponse,
|
||||
PaginatedResponse,
|
||||
@@ -42,9 +46,6 @@ from app.schemas.organizations import (
|
||||
)
|
||||
from app.schemas.sessions import AdminSessionResponse
|
||||
from app.schemas.users import UserCreate, UserResponse, UserUpdate
|
||||
from app.services.organization_service import organization_service
|
||||
from app.services.session_service import session_service
|
||||
from app.services.user_service import user_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -65,7 +66,7 @@ class BulkUserAction(BaseModel):
|
||||
|
||||
action: BulkAction = Field(..., description="Action to perform on selected users")
|
||||
user_ids: list[UUID] = Field(
|
||||
..., min_length=1, max_length=100, description="List of user IDs (max 100)"
|
||||
..., min_items=1, max_items=100, description="List of user IDs (max 100)"
|
||||
)
|
||||
|
||||
|
||||
@@ -177,29 +178,38 @@ async def admin_get_stats(
|
||||
"""Get admin dashboard statistics with real data from database."""
|
||||
from app.core.config import settings
|
||||
|
||||
stats = await user_service.get_stats(db)
|
||||
total_users = stats["total_users"]
|
||||
active_count = stats["active_count"]
|
||||
inactive_count = stats["inactive_count"]
|
||||
all_users = stats["all_users"]
|
||||
# Check if we have any data
|
||||
total_users_query = select(func.count()).select_from(User)
|
||||
total_users = (await db.execute(total_users_query)).scalar() or 0
|
||||
|
||||
# If database is essentially empty (only admin user), return demo data
|
||||
if total_users <= 1 and settings.DEMO_MODE: # pragma: no cover
|
||||
logger.info("Returning demo stats data (empty database in demo mode)")
|
||||
return _generate_demo_stats()
|
||||
|
||||
# 1. User Growth (Last 30 days)
|
||||
# 1. User Growth (Last 30 days) - Improved calculation
|
||||
datetime.now(UTC) - timedelta(days=30)
|
||||
|
||||
# Get all users with their creation dates
|
||||
all_users_query = select(User).order_by(User.created_at)
|
||||
result = await db.execute(all_users_query)
|
||||
all_users = result.scalars().all()
|
||||
|
||||
# Build cumulative counts per day
|
||||
user_growth = []
|
||||
for i in range(29, -1, -1):
|
||||
date = datetime.now(UTC) - timedelta(days=i)
|
||||
date_start = date.replace(hour=0, minute=0, second=0, microsecond=0, tzinfo=UTC)
|
||||
date_end = date_start + timedelta(days=1)
|
||||
|
||||
# Count all users created before end of this day
|
||||
# Make comparison timezone-aware
|
||||
total_users_on_date = sum(
|
||||
1
|
||||
for u in all_users
|
||||
if u.created_at and u.created_at.replace(tzinfo=UTC) < date_end
|
||||
)
|
||||
# Count active users created before end of this day
|
||||
active_users_on_date = sum(
|
||||
1
|
||||
for u in all_users
|
||||
@@ -217,16 +227,27 @@ async def admin_get_stats(
|
||||
)
|
||||
|
||||
# 2. Organization Distribution - Top 6 organizations by member count
|
||||
org_rows = await organization_service.get_org_distribution(db, limit=6)
|
||||
org_dist = [OrgDistributionData(name=r["name"], value=r["value"]) for r in org_rows]
|
||||
org_query = (
|
||||
select(Organization.name, func.count(UserOrganization.user_id).label("count"))
|
||||
.join(UserOrganization, Organization.id == UserOrganization.organization_id)
|
||||
.group_by(Organization.name)
|
||||
.order_by(func.count(UserOrganization.user_id).desc())
|
||||
.limit(6)
|
||||
)
|
||||
result = await db.execute(org_query)
|
||||
org_dist = [
|
||||
OrgDistributionData(name=row.name, value=row.count) for row in result.all()
|
||||
]
|
||||
|
||||
# 3. User Registration Activity (Last 14 days)
|
||||
# 3. User Registration Activity (Last 14 days) - NEW
|
||||
registration_activity = []
|
||||
for i in range(13, -1, -1):
|
||||
date = datetime.now(UTC) - timedelta(days=i)
|
||||
date_start = date.replace(hour=0, minute=0, second=0, microsecond=0, tzinfo=UTC)
|
||||
date_end = date_start + timedelta(days=1)
|
||||
|
||||
# Count users created on this specific day
|
||||
# Make comparison timezone-aware
|
||||
day_registrations = sum(
|
||||
1
|
||||
for u in all_users
|
||||
@@ -242,8 +263,16 @@ async def admin_get_stats(
|
||||
)
|
||||
|
||||
# 4. User Status - Active vs Inactive
|
||||
active_query = select(func.count()).select_from(User).where(User.is_active)
|
||||
inactive_query = (
|
||||
select(func.count()).select_from(User).where(User.is_active.is_(False))
|
||||
)
|
||||
|
||||
active_count = (await db.execute(active_query)).scalar() or 0
|
||||
inactive_count = (await db.execute(inactive_query)).scalar() or 0
|
||||
|
||||
logger.info(
|
||||
"User status counts - Active: %s, Inactive: %s", active_count, inactive_count
|
||||
f"User status counts - Active: {active_count}, Inactive: {inactive_count}"
|
||||
)
|
||||
|
||||
user_status = [
|
||||
@@ -292,7 +321,7 @@ async def admin_list_users(
|
||||
filters["is_superuser"] = is_superuser
|
||||
|
||||
# Get users with search
|
||||
users, total = await user_service.list_users(
|
||||
users, total = await user_crud.get_multi_with_total(
|
||||
db,
|
||||
skip=pagination.offset,
|
||||
limit=pagination.limit,
|
||||
@@ -312,7 +341,7 @@ async def admin_list_users(
|
||||
return PaginatedResponse(data=users, pagination=pagination_meta)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error listing users (admin): %s", e)
|
||||
logger.error(f"Error listing users (admin): {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@@ -335,14 +364,14 @@ async def admin_create_user(
|
||||
Allows setting is_superuser and other fields.
|
||||
"""
|
||||
try:
|
||||
user = await user_service.create_user(db, user_in)
|
||||
logger.info("Admin %s created user %s", admin.email, user.email)
|
||||
user = await user_crud.create(db, obj_in=user_in)
|
||||
logger.info(f"Admin {admin.email} created user {user.email}")
|
||||
return user
|
||||
except DuplicateEntryError as e:
|
||||
logger.warning("Failed to create user: %s", e)
|
||||
raise DuplicateError(message=str(e), error_code=ErrorCode.USER_ALREADY_EXISTS)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to create user: {e!s}")
|
||||
raise NotFoundError(message=str(e), error_code=ErrorCode.USER_ALREADY_EXISTS)
|
||||
except Exception as e:
|
||||
logger.exception("Error creating user (admin): %s", e)
|
||||
logger.error(f"Error creating user (admin): {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@@ -359,7 +388,11 @@ async def admin_get_user(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""Get detailed information about a specific user."""
|
||||
user = await user_service.get_user(db, str(user_id))
|
||||
user = await user_crud.get(db, id=user_id)
|
||||
if not user:
|
||||
raise NotFoundError(
|
||||
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
|
||||
)
|
||||
return user
|
||||
|
||||
|
||||
@@ -378,13 +411,20 @@ async def admin_update_user(
|
||||
) -> Any:
|
||||
"""Update user information with admin privileges."""
|
||||
try:
|
||||
user = await user_service.get_user(db, str(user_id))
|
||||
updated_user = await user_service.update_user(db, user=user, obj_in=user_in)
|
||||
logger.info("Admin %s updated user %s", admin.email, updated_user.email)
|
||||
user = await user_crud.get(db, id=user_id)
|
||||
if not user:
|
||||
raise NotFoundError(
|
||||
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
|
||||
)
|
||||
|
||||
updated_user = await user_crud.update(db, db_obj=user, obj_in=user_in)
|
||||
logger.info(f"Admin {admin.email} updated user {updated_user.email}")
|
||||
return updated_user
|
||||
|
||||
except NotFoundError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Error updating user (admin): %s", e)
|
||||
logger.error(f"Error updating user (admin): {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@@ -402,7 +442,11 @@ async def admin_delete_user(
|
||||
) -> Any:
|
||||
"""Soft delete a user (sets deleted_at timestamp)."""
|
||||
try:
|
||||
user = await user_service.get_user(db, str(user_id))
|
||||
user = await user_crud.get(db, id=user_id)
|
||||
if not user:
|
||||
raise NotFoundError(
|
||||
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
|
||||
)
|
||||
|
||||
# Prevent deleting yourself
|
||||
if user.id == admin.id:
|
||||
@@ -412,15 +456,17 @@ async def admin_delete_user(
|
||||
error_code=ErrorCode.OPERATION_FORBIDDEN,
|
||||
)
|
||||
|
||||
await user_service.soft_delete_user(db, str(user_id))
|
||||
logger.info("Admin %s deleted user %s", admin.email, user.email)
|
||||
await user_crud.soft_delete(db, id=user_id)
|
||||
logger.info(f"Admin {admin.email} deleted user {user.email}")
|
||||
|
||||
return MessageResponse(
|
||||
success=True, message=f"User {user.email} has been deleted"
|
||||
)
|
||||
|
||||
except NotFoundError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Error deleting user (admin): %s", e)
|
||||
logger.error(f"Error deleting user (admin): {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@@ -438,16 +484,23 @@ async def admin_activate_user(
|
||||
) -> Any:
|
||||
"""Activate a user account."""
|
||||
try:
|
||||
user = await user_service.get_user(db, str(user_id))
|
||||
await user_service.update_user(db, user=user, obj_in={"is_active": True})
|
||||
logger.info("Admin %s activated user %s", admin.email, user.email)
|
||||
user = await user_crud.get(db, id=user_id)
|
||||
if not user:
|
||||
raise NotFoundError(
|
||||
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
|
||||
)
|
||||
|
||||
await user_crud.update(db, db_obj=user, obj_in={"is_active": True})
|
||||
logger.info(f"Admin {admin.email} activated user {user.email}")
|
||||
|
||||
return MessageResponse(
|
||||
success=True, message=f"User {user.email} has been activated"
|
||||
)
|
||||
|
||||
except NotFoundError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Error activating user (admin): %s", e)
|
||||
logger.error(f"Error activating user (admin): {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@@ -465,7 +518,11 @@ async def admin_deactivate_user(
|
||||
) -> Any:
|
||||
"""Deactivate a user account."""
|
||||
try:
|
||||
user = await user_service.get_user(db, str(user_id))
|
||||
user = await user_crud.get(db, id=user_id)
|
||||
if not user:
|
||||
raise NotFoundError(
|
||||
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
|
||||
)
|
||||
|
||||
# Prevent deactivating yourself
|
||||
if user.id == admin.id:
|
||||
@@ -475,15 +532,17 @@ async def admin_deactivate_user(
|
||||
error_code=ErrorCode.OPERATION_FORBIDDEN,
|
||||
)
|
||||
|
||||
await user_service.update_user(db, user=user, obj_in={"is_active": False})
|
||||
logger.info("Admin %s deactivated user %s", admin.email, user.email)
|
||||
await user_crud.update(db, db_obj=user, obj_in={"is_active": False})
|
||||
logger.info(f"Admin {admin.email} deactivated user {user.email}")
|
||||
|
||||
return MessageResponse(
|
||||
success=True, message=f"User {user.email} has been deactivated"
|
||||
)
|
||||
|
||||
except NotFoundError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Error deactivating user (admin): %s", e)
|
||||
logger.error(f"Error deactivating user (admin): {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@@ -508,16 +567,16 @@ async def admin_bulk_user_action(
|
||||
try:
|
||||
# Use efficient bulk operations instead of loop
|
||||
if bulk_action.action == BulkAction.ACTIVATE:
|
||||
affected_count = await user_service.bulk_update_status(
|
||||
affected_count = await user_crud.bulk_update_status(
|
||||
db, user_ids=bulk_action.user_ids, is_active=True
|
||||
)
|
||||
elif bulk_action.action == BulkAction.DEACTIVATE:
|
||||
affected_count = await user_service.bulk_update_status(
|
||||
affected_count = await user_crud.bulk_update_status(
|
||||
db, user_ids=bulk_action.user_ids, is_active=False
|
||||
)
|
||||
elif bulk_action.action == BulkAction.DELETE:
|
||||
# bulk_soft_delete automatically excludes the admin user
|
||||
affected_count = await user_service.bulk_soft_delete(
|
||||
affected_count = await user_crud.bulk_soft_delete(
|
||||
db, user_ids=bulk_action.user_ids, exclude_user_id=admin.id
|
||||
)
|
||||
else: # pragma: no cover
|
||||
@@ -528,11 +587,8 @@ async def admin_bulk_user_action(
|
||||
failed_count = requested_count - affected_count
|
||||
|
||||
logger.info(
|
||||
"Admin %s performed bulk %s on %s users (%s skipped/failed)",
|
||||
admin.email,
|
||||
bulk_action.action.value,
|
||||
affected_count,
|
||||
failed_count,
|
||||
f"Admin {admin.email} performed bulk {bulk_action.action.value} "
|
||||
f"on {affected_count} users ({failed_count} skipped/failed)"
|
||||
)
|
||||
|
||||
return BulkActionResult(
|
||||
@@ -544,7 +600,7 @@ async def admin_bulk_user_action(
|
||||
)
|
||||
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.exception("Error in bulk user action: %s", e)
|
||||
logger.error(f"Error in bulk user action: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@@ -568,7 +624,7 @@ async def admin_list_organizations(
|
||||
"""List all organizations with filtering and search."""
|
||||
try:
|
||||
# Use optimized method that gets member counts in single query (no N+1)
|
||||
orgs_with_data, total = await organization_service.get_multi_with_member_counts(
|
||||
orgs_with_data, total = await organization_crud.get_multi_with_member_counts(
|
||||
db,
|
||||
skip=pagination.offset,
|
||||
limit=pagination.limit,
|
||||
@@ -605,7 +661,7 @@ async def admin_list_organizations(
|
||||
return PaginatedResponse(data=orgs_with_count, pagination=pagination_meta)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error listing organizations (admin): %s", e)
|
||||
logger.error(f"Error listing organizations (admin): {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@@ -624,8 +680,8 @@ async def admin_create_organization(
|
||||
) -> Any:
|
||||
"""Create a new organization."""
|
||||
try:
|
||||
org = await organization_service.create_organization(db, obj_in=org_in)
|
||||
logger.info("Admin %s created organization %s", admin.email, org.name)
|
||||
org = await organization_crud.create(db, obj_in=org_in)
|
||||
logger.info(f"Admin {admin.email} created organization {org.name}")
|
||||
|
||||
# Add member count
|
||||
org_dict = {
|
||||
@@ -641,11 +697,11 @@ async def admin_create_organization(
|
||||
}
|
||||
return OrganizationResponse(**org_dict)
|
||||
|
||||
except DuplicateEntryError as e:
|
||||
logger.warning("Failed to create organization: %s", e)
|
||||
raise DuplicateError(message=str(e), error_code=ErrorCode.ALREADY_EXISTS)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to create organization: {e!s}")
|
||||
raise NotFoundError(message=str(e), error_code=ErrorCode.ALREADY_EXISTS)
|
||||
except Exception as e:
|
||||
logger.exception("Error creating organization (admin): %s", e)
|
||||
logger.error(f"Error creating organization (admin): {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@@ -662,7 +718,12 @@ async def admin_get_organization(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""Get detailed information about a specific organization."""
|
||||
org = await organization_service.get_organization(db, str(org_id))
|
||||
org = await organization_crud.get(db, id=org_id)
|
||||
if not org:
|
||||
raise NotFoundError(
|
||||
message=f"Organization {org_id} not found", error_code=ErrorCode.NOT_FOUND
|
||||
)
|
||||
|
||||
org_dict = {
|
||||
"id": org.id,
|
||||
"name": org.name,
|
||||
@@ -672,7 +733,7 @@ async def admin_get_organization(
|
||||
"settings": org.settings,
|
||||
"created_at": org.created_at,
|
||||
"updated_at": org.updated_at,
|
||||
"member_count": await organization_service.get_member_count(
|
||||
"member_count": await organization_crud.get_member_count(
|
||||
db, organization_id=org.id
|
||||
),
|
||||
}
|
||||
@@ -694,11 +755,15 @@ async def admin_update_organization(
|
||||
) -> Any:
|
||||
"""Update organization information."""
|
||||
try:
|
||||
org = await organization_service.get_organization(db, str(org_id))
|
||||
updated_org = await organization_service.update_organization(
|
||||
db, org=org, obj_in=org_in
|
||||
)
|
||||
logger.info("Admin %s updated organization %s", admin.email, updated_org.name)
|
||||
org = await organization_crud.get(db, id=org_id)
|
||||
if not org:
|
||||
raise NotFoundError(
|
||||
message=f"Organization {org_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
updated_org = await organization_crud.update(db, db_obj=org, obj_in=org_in)
|
||||
logger.info(f"Admin {admin.email} updated organization {updated_org.name}")
|
||||
|
||||
org_dict = {
|
||||
"id": updated_org.id,
|
||||
@@ -709,14 +774,16 @@ async def admin_update_organization(
|
||||
"settings": updated_org.settings,
|
||||
"created_at": updated_org.created_at,
|
||||
"updated_at": updated_org.updated_at,
|
||||
"member_count": await organization_service.get_member_count(
|
||||
"member_count": await organization_crud.get_member_count(
|
||||
db, organization_id=updated_org.id
|
||||
),
|
||||
}
|
||||
return OrganizationResponse(**org_dict)
|
||||
|
||||
except NotFoundError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Error updating organization (admin): %s", e)
|
||||
logger.error(f"Error updating organization (admin): {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@@ -734,16 +801,24 @@ async def admin_delete_organization(
|
||||
) -> Any:
|
||||
"""Delete an organization and all its relationships."""
|
||||
try:
|
||||
org = await organization_service.get_organization(db, str(org_id))
|
||||
await organization_service.remove_organization(db, str(org_id))
|
||||
logger.info("Admin %s deleted organization %s", admin.email, org.name)
|
||||
org = await organization_crud.get(db, id=org_id)
|
||||
if not org:
|
||||
raise NotFoundError(
|
||||
message=f"Organization {org_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
await organization_crud.remove(db, id=org_id)
|
||||
logger.info(f"Admin {admin.email} deleted organization {org.name}")
|
||||
|
||||
return MessageResponse(
|
||||
success=True, message=f"Organization {org.name} has been deleted"
|
||||
)
|
||||
|
||||
except NotFoundError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Error deleting organization (admin): %s", e)
|
||||
logger.error(f"Error deleting organization (admin): {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@@ -763,8 +838,14 @@ async def admin_list_organization_members(
|
||||
) -> Any:
|
||||
"""List all members of an organization."""
|
||||
try:
|
||||
await organization_service.get_organization(db, str(org_id)) # validates exists
|
||||
members, total = await organization_service.get_organization_members(
|
||||
org = await organization_crud.get(db, id=org_id)
|
||||
if not org:
|
||||
raise NotFoundError(
|
||||
message=f"Organization {org_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
members, total = await organization_crud.get_organization_members(
|
||||
db,
|
||||
organization_id=org_id,
|
||||
skip=pagination.offset,
|
||||
@@ -787,7 +868,9 @@ async def admin_list_organization_members(
|
||||
except NotFoundError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Error listing organization members (admin): %s", e)
|
||||
logger.error(
|
||||
f"Error listing organization members (admin): {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@@ -815,32 +898,45 @@ async def admin_add_organization_member(
|
||||
) -> Any:
|
||||
"""Add a user to an organization."""
|
||||
try:
|
||||
org = await organization_service.get_organization(db, str(org_id))
|
||||
user = await user_service.get_user(db, str(request.user_id))
|
||||
org = await organization_crud.get(db, id=org_id)
|
||||
if not org:
|
||||
raise NotFoundError(
|
||||
message=f"Organization {org_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
await organization_service.add_member(
|
||||
user = await user_crud.get(db, id=request.user_id)
|
||||
if not user:
|
||||
raise NotFoundError(
|
||||
message=f"User {request.user_id} not found",
|
||||
error_code=ErrorCode.USER_NOT_FOUND,
|
||||
)
|
||||
|
||||
await organization_crud.add_user(
|
||||
db, organization_id=org_id, user_id=request.user_id, role=request.role
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Admin %s added user %s to organization %s with role %s",
|
||||
admin.email,
|
||||
user.email,
|
||||
org.name,
|
||||
request.role.value,
|
||||
f"Admin {admin.email} added user {user.email} to organization {org.name} "
|
||||
f"with role {request.role.value}"
|
||||
)
|
||||
|
||||
return MessageResponse(
|
||||
success=True, message=f"User {user.email} added to organization {org.name}"
|
||||
)
|
||||
|
||||
except DuplicateEntryError as e:
|
||||
logger.warning("Failed to add user to organization: %s", e)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to add user to organization: {e!s}")
|
||||
# Use DuplicateError for "already exists" scenarios
|
||||
raise DuplicateError(
|
||||
message=str(e), error_code=ErrorCode.USER_ALREADY_EXISTS, field="user_id"
|
||||
)
|
||||
except NotFoundError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Error adding member to organization (admin): %s", e)
|
||||
logger.error(
|
||||
f"Error adding member to organization (admin): {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@@ -859,10 +955,20 @@ async def admin_remove_organization_member(
|
||||
) -> Any:
|
||||
"""Remove a user from an organization."""
|
||||
try:
|
||||
org = await organization_service.get_organization(db, str(org_id))
|
||||
user = await user_service.get_user(db, str(user_id))
|
||||
org = await organization_crud.get(db, id=org_id)
|
||||
if not org:
|
||||
raise NotFoundError(
|
||||
message=f"Organization {org_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
success = await organization_service.remove_member(
|
||||
user = await user_crud.get(db, id=user_id)
|
||||
if not user:
|
||||
raise NotFoundError(
|
||||
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
|
||||
)
|
||||
|
||||
success = await organization_crud.remove_user(
|
||||
db, organization_id=org_id, user_id=user_id
|
||||
)
|
||||
|
||||
@@ -873,10 +979,7 @@ async def admin_remove_organization_member(
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Admin %s removed user %s from organization %s",
|
||||
admin.email,
|
||||
user.email,
|
||||
org.name,
|
||||
f"Admin {admin.email} removed user {user.email} from organization {org.name}"
|
||||
)
|
||||
|
||||
return MessageResponse(
|
||||
@@ -887,7 +990,9 @@ async def admin_remove_organization_member(
|
||||
except NotFoundError:
|
||||
raise
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.exception("Error removing member from organization (admin): %s", e)
|
||||
logger.error(
|
||||
f"Error removing member from organization (admin): {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@@ -917,7 +1022,7 @@ async def admin_list_sessions(
|
||||
"""List all sessions across all users with filtering and pagination."""
|
||||
try:
|
||||
# Get sessions with user info (eager loaded to prevent N+1)
|
||||
sessions, total = await session_service.get_all_sessions(
|
||||
sessions, total = await session_crud.get_all_sessions(
|
||||
db,
|
||||
skip=pagination.offset,
|
||||
limit=pagination.limit,
|
||||
@@ -956,10 +1061,7 @@ async def admin_list_sessions(
|
||||
session_responses.append(session_response)
|
||||
|
||||
logger.info(
|
||||
"Admin %s listed %s sessions (total: %s)",
|
||||
admin.email,
|
||||
len(session_responses),
|
||||
total,
|
||||
f"Admin {admin.email} listed {len(session_responses)} sessions (total: {total})"
|
||||
)
|
||||
|
||||
pagination_meta = create_pagination_meta(
|
||||
@@ -972,5 +1074,5 @@ async def admin_list_sessions(
|
||||
return PaginatedResponse(data=session_responses, pagination=pagination_meta)
|
||||
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.exception("Error listing sessions (admin): %s", e)
|
||||
logger.error(f"Error listing sessions (admin): {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
@@ -15,14 +15,16 @@ from app.core.auth import (
|
||||
TokenExpiredError,
|
||||
TokenInvalidError,
|
||||
decode_token,
|
||||
get_password_hash,
|
||||
)
|
||||
from app.core.database import get_db
|
||||
from app.core.exceptions import (
|
||||
AuthenticationError as AuthError,
|
||||
DatabaseError,
|
||||
DuplicateError,
|
||||
ErrorCode,
|
||||
)
|
||||
from app.crud.session import session as session_crud
|
||||
from app.crud.user import user as user_crud
|
||||
from app.models.user import User
|
||||
from app.schemas.common import MessageResponse
|
||||
from app.schemas.sessions import LogoutRequest, SessionCreate
|
||||
@@ -37,8 +39,6 @@ from app.schemas.users import (
|
||||
)
|
||||
from app.services.auth_service import AuthenticationError, AuthService
|
||||
from app.services.email_service import email_service
|
||||
from app.services.session_service import session_service
|
||||
from app.services.user_service import user_service
|
||||
from app.utils.device import extract_device_info
|
||||
from app.utils.security import create_password_reset_token, verify_password_reset_token
|
||||
|
||||
@@ -91,18 +91,17 @@ async def _create_login_session(
|
||||
location_country=device_info.location_country,
|
||||
)
|
||||
|
||||
await session_service.create_session(db, obj_in=session_data)
|
||||
await session_crud.create_session(db, obj_in=session_data)
|
||||
|
||||
logger.info(
|
||||
"%s successful: %s from %s (IP: %s)",
|
||||
login_type.capitalize(),
|
||||
user.email,
|
||||
device_info.device_name,
|
||||
device_info.ip_address,
|
||||
f"{login_type.capitalize()} successful: {user.email} from {device_info.device_name} "
|
||||
f"(IP: {device_info.ip_address})"
|
||||
)
|
||||
except Exception as session_err:
|
||||
# Log but don't fail login if session creation fails
|
||||
logger.exception("Failed to create session for %s: %s", user.email, session_err)
|
||||
logger.error(
|
||||
f"Failed to create session for {user.email}: {session_err!s}", exc_info=True
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
@@ -124,21 +123,15 @@ async def register_user(
|
||||
try:
|
||||
user = await AuthService.create_user(db, user_data)
|
||||
return user
|
||||
except DuplicateError:
|
||||
except AuthenticationError as e:
|
||||
# SECURITY: Don't reveal if email exists - generic error message
|
||||
logger.warning("Registration failed: duplicate email %s", user_data.email)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Registration failed. Please check your information and try again.",
|
||||
)
|
||||
except AuthError as e:
|
||||
logger.warning("Registration failed: %s", e)
|
||||
logger.warning(f"Registration failed: {e!s}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Registration failed. Please check your information and try again.",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Unexpected error during registration: %s", e)
|
||||
logger.error(f"Unexpected error during registration: {e!s}", exc_info=True)
|
||||
raise DatabaseError(
|
||||
message="An unexpected error occurred. Please try again later.",
|
||||
error_code=ErrorCode.INTERNAL_ERROR,
|
||||
@@ -166,7 +159,7 @@ async def login(
|
||||
|
||||
# Explicitly check for None result and raise correct exception
|
||||
if user is None:
|
||||
logger.warning("Invalid login attempt for: %s", login_data.email)
|
||||
logger.warning(f"Invalid login attempt for: {login_data.email}")
|
||||
raise AuthError(
|
||||
message="Invalid email or password",
|
||||
error_code=ErrorCode.INVALID_CREDENTIALS,
|
||||
@@ -182,11 +175,14 @@ async def login(
|
||||
|
||||
except AuthenticationError as e:
|
||||
# Handle specific authentication errors like inactive accounts
|
||||
logger.warning("Authentication failed: %s", e)
|
||||
logger.warning(f"Authentication failed: {e!s}")
|
||||
raise AuthError(message=str(e), error_code=ErrorCode.INVALID_CREDENTIALS)
|
||||
except AuthError:
|
||||
# Re-raise custom auth exceptions without modification
|
||||
raise
|
||||
except Exception as e:
|
||||
# Handle unexpected errors
|
||||
logger.exception("Unexpected error during login: %s", e)
|
||||
logger.error(f"Unexpected error during login: {e!s}", exc_info=True)
|
||||
raise DatabaseError(
|
||||
message="An unexpected error occurred. Please try again later.",
|
||||
error_code=ErrorCode.INTERNAL_ERROR,
|
||||
@@ -228,10 +224,13 @@ async def login_oauth(
|
||||
# Return full token response with user data
|
||||
return tokens
|
||||
except AuthenticationError as e:
|
||||
logger.warning("OAuth authentication failed: %s", e)
|
||||
logger.warning(f"OAuth authentication failed: {e!s}")
|
||||
raise AuthError(message=str(e), error_code=ErrorCode.INVALID_CREDENTIALS)
|
||||
except AuthError:
|
||||
# Re-raise custom auth exceptions without modification
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Unexpected error during OAuth login: %s", e)
|
||||
logger.error(f"Unexpected error during OAuth login: {e!s}", exc_info=True)
|
||||
raise DatabaseError(
|
||||
message="An unexpected error occurred. Please try again later.",
|
||||
error_code=ErrorCode.INTERNAL_ERROR,
|
||||
@@ -260,12 +259,11 @@ async def refresh_token(
|
||||
)
|
||||
|
||||
# Check if session exists and is active
|
||||
session = await session_service.get_active_by_jti(db, jti=refresh_payload.jti)
|
||||
session = await session_crud.get_active_by_jti(db, jti=refresh_payload.jti)
|
||||
|
||||
if not session:
|
||||
logger.warning(
|
||||
"Refresh token used for inactive or non-existent session: %s",
|
||||
refresh_payload.jti,
|
||||
f"Refresh token used for inactive or non-existent session: {refresh_payload.jti}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
@@ -281,14 +279,16 @@ async def refresh_token(
|
||||
|
||||
# Update session with new refresh token JTI and expiration
|
||||
try:
|
||||
await session_service.update_refresh_token(
|
||||
await session_crud.update_refresh_token(
|
||||
db,
|
||||
session=session,
|
||||
new_jti=new_refresh_payload.jti,
|
||||
new_expires_at=datetime.fromtimestamp(new_refresh_payload.exp, tz=UTC),
|
||||
)
|
||||
except Exception as session_err:
|
||||
logger.exception("Failed to update session %s: %s", session.id, session_err)
|
||||
logger.error(
|
||||
f"Failed to update session {session.id}: {session_err!s}", exc_info=True
|
||||
)
|
||||
# Continue anyway - tokens are already issued
|
||||
|
||||
return tokens
|
||||
@@ -311,7 +311,7 @@ async def refresh_token(
|
||||
# Re-raise HTTP exceptions (like session revoked)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Unexpected error during token refresh: %s", e)
|
||||
logger.error(f"Unexpected error during token refresh: {e!s}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="An unexpected error occurred. Please try again later.",
|
||||
@@ -347,7 +347,7 @@ async def request_password_reset(
|
||||
"""
|
||||
try:
|
||||
# Look up user by email
|
||||
user = await user_service.get_by_email(db, email=reset_request.email)
|
||||
user = await user_crud.get_by_email(db, email=reset_request.email)
|
||||
|
||||
# Only send email if user exists and is active
|
||||
if user and user.is_active:
|
||||
@@ -358,12 +358,11 @@ async def request_password_reset(
|
||||
await email_service.send_password_reset_email(
|
||||
to_email=user.email, reset_token=reset_token, user_name=user.first_name
|
||||
)
|
||||
logger.info("Password reset requested for %s", user.email)
|
||||
logger.info(f"Password reset requested for {user.email}")
|
||||
else:
|
||||
# Log attempt but don't reveal if email exists
|
||||
logger.warning(
|
||||
"Password reset requested for non-existent or inactive email: %s",
|
||||
reset_request.email,
|
||||
f"Password reset requested for non-existent or inactive email: {reset_request.email}"
|
||||
)
|
||||
|
||||
# Always return success to prevent email enumeration
|
||||
@@ -372,7 +371,7 @@ async def request_password_reset(
|
||||
message="If your email is registered, you will receive a password reset link shortly",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Error processing password reset request: %s", e)
|
||||
logger.error(f"Error processing password reset request: {e!s}", exc_info=True)
|
||||
# Still return success to prevent information leakage
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
@@ -413,34 +412,40 @@ async def confirm_password_reset(
|
||||
detail="Invalid or expired password reset token",
|
||||
)
|
||||
|
||||
# Reset password via service (validates user exists and is active)
|
||||
try:
|
||||
user = await AuthService.reset_password(
|
||||
db, email=email, new_password=reset_confirm.new_password
|
||||
# Look up user
|
||||
user = await user_crud.get_by_email(db, email=email)
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
|
||||
)
|
||||
except AuthenticationError as e:
|
||||
err_msg = str(e)
|
||||
if "inactive" in err_msg.lower():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail=err_msg
|
||||
)
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=err_msg)
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="User account is inactive",
|
||||
)
|
||||
|
||||
# Update password
|
||||
user.password_hash = get_password_hash(reset_confirm.new_password)
|
||||
db.add(user)
|
||||
await db.commit()
|
||||
|
||||
# SECURITY: Invalidate all existing sessions after password reset
|
||||
# This prevents stolen sessions from being used after password change
|
||||
from app.crud.session import session as session_crud
|
||||
|
||||
try:
|
||||
deactivated_count = await session_service.deactivate_all_user_sessions(
|
||||
deactivated_count = await session_crud.deactivate_all_user_sessions(
|
||||
db, user_id=str(user.id)
|
||||
)
|
||||
logger.info(
|
||||
"Password reset successful for %s, invalidated %s sessions",
|
||||
user.email,
|
||||
deactivated_count,
|
||||
f"Password reset successful for {user.email}, invalidated {deactivated_count} sessions"
|
||||
)
|
||||
except Exception as session_error:
|
||||
# Log but don't fail password reset if session invalidation fails
|
||||
logger.error(
|
||||
"Failed to invalidate sessions after password reset: %s", session_error
|
||||
f"Failed to invalidate sessions after password reset: {session_error!s}"
|
||||
)
|
||||
|
||||
return MessageResponse(
|
||||
@@ -451,7 +456,7 @@ async def confirm_password_reset(
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Error confirming password reset: %s", e)
|
||||
logger.error(f"Error confirming password reset: {e!s}", exc_info=True)
|
||||
await db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
@@ -501,21 +506,19 @@ async def logout(
|
||||
)
|
||||
except (TokenExpiredError, TokenInvalidError) as e:
|
||||
# Even if token is expired/invalid, try to deactivate session
|
||||
logger.warning("Logout with invalid/expired token: %s", e)
|
||||
logger.warning(f"Logout with invalid/expired token: {e!s}")
|
||||
# Don't fail - return success anyway
|
||||
return MessageResponse(success=True, message="Logged out successfully")
|
||||
|
||||
# Find the session by JTI
|
||||
session = await session_service.get_by_jti(db, jti=refresh_payload.jti)
|
||||
session = await session_crud.get_by_jti(db, jti=refresh_payload.jti)
|
||||
|
||||
if session:
|
||||
# Verify session belongs to current user (security check)
|
||||
if str(session.user_id) != str(current_user.id):
|
||||
logger.warning(
|
||||
"User %s attempted to logout session %s belonging to user %s",
|
||||
current_user.id,
|
||||
session.id,
|
||||
session.user_id,
|
||||
f"User {current_user.id} attempted to logout session {session.id} "
|
||||
f"belonging to user {session.user_id}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
@@ -523,20 +526,17 @@ async def logout(
|
||||
)
|
||||
|
||||
# Deactivate the session
|
||||
await session_service.deactivate(db, session_id=str(session.id))
|
||||
await session_crud.deactivate(db, session_id=str(session.id))
|
||||
|
||||
logger.info(
|
||||
"User %s logged out from %s (session %s)",
|
||||
current_user.id,
|
||||
session.device_name,
|
||||
session.id,
|
||||
f"User {current_user.id} logged out from {session.device_name} "
|
||||
f"(session {session.id})"
|
||||
)
|
||||
else:
|
||||
# Session not found - maybe already deleted or never existed
|
||||
# Return success anyway (idempotent)
|
||||
logger.info(
|
||||
"Logout requested for non-existent session (JTI: %s)",
|
||||
refresh_payload.jti,
|
||||
f"Logout requested for non-existent session (JTI: {refresh_payload.jti})"
|
||||
)
|
||||
|
||||
return MessageResponse(success=True, message="Logged out successfully")
|
||||
@@ -544,7 +544,9 @@ async def logout(
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Error during logout for user %s: %s", current_user.id, e)
|
||||
logger.error(
|
||||
f"Error during logout for user {current_user.id}: {e!s}", exc_info=True
|
||||
)
|
||||
# Don't expose error details
|
||||
return MessageResponse(success=True, message="Logged out successfully")
|
||||
|
||||
@@ -582,12 +584,12 @@ async def logout_all(
|
||||
"""
|
||||
try:
|
||||
# Deactivate all sessions for this user
|
||||
count = await session_service.deactivate_all_user_sessions(
|
||||
count = await session_crud.deactivate_all_user_sessions(
|
||||
db, user_id=str(current_user.id)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"User %s logged out from all devices (%s sessions)", current_user.id, count
|
||||
f"User {current_user.id} logged out from all devices ({count} sessions)"
|
||||
)
|
||||
|
||||
return MessageResponse(
|
||||
@@ -596,7 +598,9 @@ async def logout_all(
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error during logout-all for user %s: %s", current_user.id, e)
|
||||
logger.error(
|
||||
f"Error during logout-all for user {current_user.id}: {e!s}", exc_info=True
|
||||
)
|
||||
await db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
|
||||
283
backend/app/api/routes/events.py
Normal file
283
backend/app/api/routes/events.py
Normal file
@@ -0,0 +1,283 @@
|
||||
"""
|
||||
SSE endpoint for real-time project event streaming.
|
||||
|
||||
This module provides Server-Sent Events (SSE) endpoints for streaming
|
||||
project events to connected clients. Events are scoped to projects,
|
||||
with authorization checks to ensure clients only receive events
|
||||
for projects they have access to.
|
||||
|
||||
Features:
|
||||
- Real-time event streaming via SSE
|
||||
- Project-scoped authorization
|
||||
- Automatic reconnection support (Last-Event-ID)
|
||||
- Keepalive messages every 30 seconds
|
||||
- Graceful connection cleanup
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, Request
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
||||
from app.api.dependencies.auth import get_current_user
|
||||
from app.api.dependencies.event_bus import get_event_bus
|
||||
from app.core.exceptions import AuthorizationError
|
||||
from app.models.user import User
|
||||
from app.schemas.errors import ErrorCode
|
||||
from app.schemas.events import EventType
|
||||
from app.services.event_bus import EventBus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
# Keepalive interval in seconds
|
||||
KEEPALIVE_INTERVAL = 30
|
||||
|
||||
|
||||
async def check_project_access(
|
||||
project_id: UUID,
|
||||
user: User,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a user has access to a project's events.
|
||||
|
||||
This is a placeholder implementation that will be replaced
|
||||
with actual project authorization logic once the Project model
|
||||
is implemented. Currently allows access for all authenticated users.
|
||||
|
||||
Args:
|
||||
project_id: The project to check access for
|
||||
user: The authenticated user
|
||||
|
||||
Returns:
|
||||
bool: True if user has access, False otherwise
|
||||
|
||||
TODO: Implement actual project authorization
|
||||
- Check if user owns the project
|
||||
- Check if user is a member of the project
|
||||
- Check project visibility settings
|
||||
"""
|
||||
# Placeholder: Allow all authenticated users for now
|
||||
# This will be replaced with actual project ownership/membership check
|
||||
logger.debug(
|
||||
f"Project access check for user {user.id} on project {project_id} "
|
||||
"(placeholder: allowing all authenticated users)"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
async def event_generator(
|
||||
project_id: UUID,
|
||||
event_bus: EventBus,
|
||||
last_event_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
Generate SSE events for a project.
|
||||
|
||||
This async generator yields SSE-formatted events from the event bus,
|
||||
including keepalive comments to maintain the connection.
|
||||
|
||||
Args:
|
||||
project_id: The project to stream events for
|
||||
event_bus: The EventBus instance
|
||||
last_event_id: Optional last received event ID for reconnection
|
||||
|
||||
Yields:
|
||||
dict: SSE event data with 'event', 'data', and optional 'id' fields
|
||||
"""
|
||||
try:
|
||||
async for event_data in event_bus.subscribe_sse(
|
||||
project_id=project_id,
|
||||
last_event_id=last_event_id,
|
||||
keepalive_interval=KEEPALIVE_INTERVAL,
|
||||
):
|
||||
if event_data == "":
|
||||
# Keepalive - yield SSE comment
|
||||
yield {"comment": "keepalive"}
|
||||
else:
|
||||
# Parse event to extract type and id
|
||||
try:
|
||||
event_dict = json.loads(event_data)
|
||||
event_type = event_dict.get("type", "message")
|
||||
event_id = event_dict.get("id")
|
||||
|
||||
yield {
|
||||
"event": event_type,
|
||||
"data": event_data,
|
||||
"id": event_id,
|
||||
}
|
||||
except json.JSONDecodeError:
|
||||
# If we can't parse, send as generic message
|
||||
yield {
|
||||
"event": "message",
|
||||
"data": event_data,
|
||||
}
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"Event stream cancelled for project {project_id}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error in event stream for project {project_id}: {e}")
|
||||
raise
|
||||
|
||||
|
||||
@router.get(
|
||||
"/projects/{project_id}/events/stream",
|
||||
summary="Stream Project Events",
|
||||
description="""
|
||||
Stream real-time events for a project via Server-Sent Events (SSE).
|
||||
|
||||
**Authentication**: Required (Bearer token)
|
||||
**Authorization**: Must have access to the project
|
||||
|
||||
**SSE Event Format**:
|
||||
```
|
||||
event: agent.status_changed
|
||||
id: 550e8400-e29b-41d4-a716-446655440000
|
||||
data: {"id": "...", "type": "agent.status_changed", "project_id": "...", ...}
|
||||
|
||||
: keepalive
|
||||
|
||||
event: issue.created
|
||||
id: 550e8400-e29b-41d4-a716-446655440001
|
||||
data: {...}
|
||||
```
|
||||
|
||||
**Reconnection**: Include the `Last-Event-ID` header with the last received
|
||||
event ID to resume from where you left off.
|
||||
|
||||
**Keepalive**: The server sends a comment (`: keepalive`) every 30 seconds
|
||||
to keep the connection alive.
|
||||
|
||||
**Rate Limit**: 10 connections/minute per IP
|
||||
""",
|
||||
response_class=EventSourceResponse,
|
||||
responses={
|
||||
200: {
|
||||
"description": "SSE stream established",
|
||||
"content": {"text/event-stream": {}},
|
||||
},
|
||||
401: {"description": "Not authenticated"},
|
||||
403: {"description": "Not authorized to access this project"},
|
||||
404: {"description": "Project not found"},
|
||||
},
|
||||
operation_id="stream_project_events",
|
||||
)
|
||||
@limiter.limit("10/minute")
|
||||
async def stream_project_events(
|
||||
request: Request,
|
||||
project_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
event_bus: EventBus = Depends(get_event_bus),
|
||||
last_event_id: str | None = Header(None, alias="Last-Event-ID"),
|
||||
):
|
||||
"""
|
||||
Stream real-time events for a project via SSE.
|
||||
|
||||
This endpoint establishes a persistent SSE connection that streams
|
||||
project events to the client in real-time. The connection includes:
|
||||
|
||||
- Event streaming: All project events (agent updates, issues, etc.)
|
||||
- Keepalive: Comment every 30 seconds to maintain connection
|
||||
- Reconnection: Use Last-Event-ID header to resume after disconnect
|
||||
|
||||
The connection is automatically cleaned up when the client disconnects.
|
||||
"""
|
||||
logger.info(
|
||||
f"SSE connection request for project {project_id} "
|
||||
f"by user {current_user.id} "
|
||||
f"(last_event_id={last_event_id})"
|
||||
)
|
||||
|
||||
# Check project access
|
||||
has_access = await check_project_access(project_id, current_user)
|
||||
if not has_access:
|
||||
raise AuthorizationError(
|
||||
message=f"You don't have access to project {project_id}",
|
||||
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||
)
|
||||
|
||||
# Return SSE response
|
||||
return EventSourceResponse(
|
||||
event_generator(
|
||||
project_id=project_id,
|
||||
event_bus=event_bus,
|
||||
last_event_id=last_event_id,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no", # Disable nginx buffering
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/projects/{project_id}/events/test",
|
||||
summary="Send Test Event (Development Only)",
|
||||
description="""
|
||||
Send a test event to a project's event stream. This endpoint is
|
||||
intended for development and testing purposes.
|
||||
|
||||
**Authentication**: Required (Bearer token)
|
||||
**Authorization**: Must have access to the project
|
||||
|
||||
**Note**: This endpoint should be disabled or restricted in production.
|
||||
""",
|
||||
response_model=dict,
|
||||
responses={
|
||||
200: {"description": "Test event sent"},
|
||||
401: {"description": "Not authenticated"},
|
||||
403: {"description": "Not authorized to access this project"},
|
||||
},
|
||||
operation_id="send_test_event",
|
||||
)
|
||||
async def send_test_event(
|
||||
project_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
event_bus: EventBus = Depends(get_event_bus),
|
||||
):
|
||||
"""
|
||||
Send a test event to the project's event stream.
|
||||
|
||||
This is useful for testing SSE connections during development.
|
||||
"""
|
||||
# Check project access
|
||||
has_access = await check_project_access(project_id, current_user)
|
||||
if not has_access:
|
||||
raise AuthorizationError(
|
||||
message=f"You don't have access to project {project_id}",
|
||||
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||
)
|
||||
|
||||
# Create and publish test event using the Event schema
|
||||
event = EventBus.create_event(
|
||||
event_type=EventType.AGENT_MESSAGE,
|
||||
project_id=project_id,
|
||||
actor_type="user",
|
||||
actor_id=current_user.id,
|
||||
payload={
|
||||
"message": "Test event from SSE endpoint",
|
||||
"message_type": "info",
|
||||
},
|
||||
)
|
||||
|
||||
channel = event_bus.get_project_channel(project_id)
|
||||
await event_bus.publish(channel, event)
|
||||
|
||||
logger.info(f"Test event sent to project {project_id}: {event.id}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"event_id": event.id,
|
||||
"event_type": event.type.value,
|
||||
"message": "Test event sent successfully",
|
||||
}
|
||||
@@ -25,6 +25,8 @@ from app.core.auth import decode_token
|
||||
from app.core.config import settings
|
||||
from app.core.database import get_db
|
||||
from app.core.exceptions import AuthenticationError as AuthError
|
||||
from app.crud import oauth_account
|
||||
from app.crud.session import session as session_crud
|
||||
from app.models.user import User
|
||||
from app.schemas.oauth import (
|
||||
OAuthAccountsListResponse,
|
||||
@@ -36,7 +38,6 @@ from app.schemas.oauth import (
|
||||
from app.schemas.sessions import SessionCreate
|
||||
from app.schemas.users import Token
|
||||
from app.services.oauth_service import OAuthService
|
||||
from app.services.session_service import session_service
|
||||
from app.utils.device import extract_device_info
|
||||
|
||||
router = APIRouter()
|
||||
@@ -81,19 +82,17 @@ async def _create_oauth_login_session(
|
||||
location_country=device_info.location_country,
|
||||
)
|
||||
|
||||
await session_service.create_session(db, obj_in=session_data)
|
||||
await session_crud.create_session(db, obj_in=session_data)
|
||||
|
||||
logger.info(
|
||||
"OAuth login successful: %s via %s from %s (IP: %s)",
|
||||
user.email,
|
||||
provider,
|
||||
device_info.device_name,
|
||||
device_info.ip_address,
|
||||
f"OAuth login successful: {user.email} via {provider} "
|
||||
f"from {device_info.device_name} (IP: {device_info.ip_address})"
|
||||
)
|
||||
except Exception as session_err:
|
||||
# Log but don't fail login if session creation fails
|
||||
logger.exception(
|
||||
"Failed to create session for OAuth login %s: %s", user.email, session_err
|
||||
logger.error(
|
||||
f"Failed to create session for OAuth login {user.email}: {session_err!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -178,13 +177,13 @@ async def get_authorization_url(
|
||||
}
|
||||
|
||||
except AuthError as e:
|
||||
logger.warning("OAuth authorization failed: %s", e)
|
||||
logger.warning(f"OAuth authorization failed: {e!s}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("OAuth authorization error: %s", e)
|
||||
logger.error(f"OAuth authorization error: {e!s}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to create authorization URL",
|
||||
@@ -252,13 +251,13 @@ async def handle_callback(
|
||||
return result
|
||||
|
||||
except AuthError as e:
|
||||
logger.warning("OAuth callback failed: %s", e)
|
||||
logger.warning(f"OAuth callback failed: {e!s}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=str(e),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("OAuth callback error: %s", e)
|
||||
logger.error(f"OAuth callback error: {e!s}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="OAuth authentication failed",
|
||||
@@ -290,7 +289,7 @@ async def list_accounts(
|
||||
Returns:
|
||||
List of linked OAuth accounts
|
||||
"""
|
||||
accounts = await OAuthService.get_user_accounts(db, user_id=current_user.id)
|
||||
accounts = await oauth_account.get_user_accounts(db, user_id=current_user.id)
|
||||
return OAuthAccountsListResponse(accounts=accounts)
|
||||
|
||||
|
||||
@@ -339,13 +338,13 @@ async def unlink_account(
|
||||
)
|
||||
|
||||
except AuthError as e:
|
||||
logger.warning("OAuth unlink failed for %s: %s", current_user.email, e)
|
||||
logger.warning(f"OAuth unlink failed for {current_user.email}: {e!s}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("OAuth unlink error: %s", e)
|
||||
logger.error(f"OAuth unlink error: {e!s}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to unlink OAuth account",
|
||||
@@ -398,7 +397,7 @@ async def start_link(
|
||||
)
|
||||
|
||||
# Check if user already has this provider linked
|
||||
existing = await OAuthService.get_user_account_by_provider(
|
||||
existing = await oauth_account.get_user_account_by_provider(
|
||||
db, user_id=current_user.id, provider=provider
|
||||
)
|
||||
if existing:
|
||||
@@ -421,13 +420,13 @@ async def start_link(
|
||||
}
|
||||
|
||||
except AuthError as e:
|
||||
logger.warning("OAuth link authorization failed: %s", e)
|
||||
logger.warning(f"OAuth link authorization failed: {e!s}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("OAuth link error: %s", e)
|
||||
logger.error(f"OAuth link error: {e!s}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to create authorization URL",
|
||||
|
||||
@@ -34,6 +34,7 @@ from app.api.dependencies.auth import (
|
||||
)
|
||||
from app.core.config import settings
|
||||
from app.core.database import get_db
|
||||
from app.crud import oauth_client as oauth_client_crud
|
||||
from app.models.user import User
|
||||
from app.schemas.oauth import (
|
||||
OAuthClientCreate,
|
||||
@@ -452,7 +453,7 @@ async def token(
|
||||
except Exception as e:
|
||||
# Log malformed Basic auth for security monitoring
|
||||
logger.warning(
|
||||
"Malformed Basic auth header in token request: %s", type(e).__name__
|
||||
f"Malformed Basic auth header in token request: {type(e).__name__}"
|
||||
)
|
||||
# Fall back to form body
|
||||
|
||||
@@ -563,8 +564,7 @@ async def revoke(
|
||||
except Exception as e:
|
||||
# Log malformed Basic auth for security monitoring
|
||||
logger.warning(
|
||||
"Malformed Basic auth header in revoke request: %s",
|
||||
type(e).__name__,
|
||||
f"Malformed Basic auth header in revoke request: {type(e).__name__}"
|
||||
)
|
||||
# Fall back to form body
|
||||
|
||||
@@ -586,7 +586,7 @@ async def revoke(
|
||||
)
|
||||
except Exception as e:
|
||||
# Log but don't expose errors per RFC 7009
|
||||
logger.warning("Token revocation error: %s", e)
|
||||
logger.warning(f"Token revocation error: {e}")
|
||||
|
||||
# Always return 200 OK per RFC 7009
|
||||
return {"status": "ok"}
|
||||
@@ -635,8 +635,7 @@ async def introspect(
|
||||
except Exception as e:
|
||||
# Log malformed Basic auth for security monitoring
|
||||
logger.warning(
|
||||
"Malformed Basic auth header in introspect request: %s",
|
||||
type(e).__name__,
|
||||
f"Malformed Basic auth header in introspect request: {type(e).__name__}"
|
||||
)
|
||||
# Fall back to form body
|
||||
|
||||
@@ -656,8 +655,8 @@ async def introspect(
|
||||
headers={"WWW-Authenticate": "Basic"},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Token introspection error: %s", e)
|
||||
return OAuthTokenIntrospectionResponse(active=False) # pyright: ignore[reportCallIssue]
|
||||
logger.warning(f"Token introspection error: {e}")
|
||||
return OAuthTokenIntrospectionResponse(active=False)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
@@ -713,7 +712,7 @@ async def register_client(
|
||||
client_type=client_type,
|
||||
)
|
||||
|
||||
client, secret = await provider_service.register_client(db, client_data)
|
||||
client, secret = await oauth_client_crud.create_client(db, obj_in=client_data)
|
||||
|
||||
# Update MCP server URL if provided
|
||||
if mcp_server_url:
|
||||
@@ -751,7 +750,7 @@ async def list_clients(
|
||||
current_user: User = Depends(get_current_superuser),
|
||||
) -> list[OAuthClientResponse]:
|
||||
"""List all OAuth clients."""
|
||||
clients = await provider_service.list_clients(db)
|
||||
clients = await oauth_client_crud.get_all_clients(db)
|
||||
return [OAuthClientResponse.model_validate(c) for c in clients]
|
||||
|
||||
|
||||
@@ -777,7 +776,7 @@ async def delete_client(
|
||||
detail="Client not found",
|
||||
)
|
||||
|
||||
await provider_service.delete_client_by_id(db, client_id=client_id)
|
||||
await oauth_client_crud.delete_client(db, client_id=client_id)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
@@ -798,7 +797,30 @@ async def list_my_consents(
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
) -> list[dict]:
|
||||
"""List applications the user has authorized."""
|
||||
return await provider_service.list_user_consents(db, user_id=current_user.id)
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.models.oauth_client import OAuthClient
|
||||
from app.models.oauth_provider_token import OAuthConsent
|
||||
|
||||
result = await db.execute(
|
||||
select(OAuthConsent, OAuthClient)
|
||||
.join(OAuthClient, OAuthConsent.client_id == OAuthClient.client_id)
|
||||
.where(OAuthConsent.user_id == current_user.id)
|
||||
)
|
||||
rows = result.all()
|
||||
|
||||
return [
|
||||
{
|
||||
"client_id": consent.client_id,
|
||||
"client_name": client.client_name,
|
||||
"client_description": client.client_description,
|
||||
"granted_scopes": consent.granted_scopes.split()
|
||||
if consent.granted_scopes
|
||||
else [],
|
||||
"granted_at": consent.created_at.isoformat(),
|
||||
}
|
||||
for consent, client in rows
|
||||
]
|
||||
|
||||
|
||||
@router.delete(
|
||||
|
||||
@@ -15,6 +15,8 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.api.dependencies.auth import get_current_user
|
||||
from app.api.dependencies.permissions import require_org_admin, require_org_membership
|
||||
from app.core.database import get_db
|
||||
from app.core.exceptions import ErrorCode, NotFoundError
|
||||
from app.crud.organization import organization as organization_crud
|
||||
from app.models.user import User
|
||||
from app.schemas.common import (
|
||||
PaginatedResponse,
|
||||
@@ -26,7 +28,6 @@ from app.schemas.organizations import (
|
||||
OrganizationResponse,
|
||||
OrganizationUpdate,
|
||||
)
|
||||
from app.services.organization_service import organization_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -53,7 +54,7 @@ async def get_my_organizations(
|
||||
"""
|
||||
try:
|
||||
# Get all org data in single query with JOIN and subquery
|
||||
orgs_data = await organization_service.get_user_organizations_with_details(
|
||||
orgs_data = await organization_crud.get_user_organizations_with_details(
|
||||
db, user_id=current_user.id, is_active=is_active
|
||||
)
|
||||
|
||||
@@ -77,7 +78,7 @@ async def get_my_organizations(
|
||||
return orgs_with_data
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error getting user organizations: %s", e)
|
||||
logger.error(f"Error getting user organizations: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@@ -99,7 +100,13 @@ async def get_organization(
|
||||
User must be a member of the organization.
|
||||
"""
|
||||
try:
|
||||
org = await organization_service.get_organization(db, str(organization_id))
|
||||
org = await organization_crud.get(db, id=organization_id)
|
||||
if not org: # pragma: no cover - Permission check prevents this (see docs/UNREACHABLE_DEFENSIVE_CODE_ANALYSIS.md)
|
||||
raise NotFoundError(
|
||||
detail=f"Organization {organization_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
org_dict = {
|
||||
"id": org.id,
|
||||
"name": org.name,
|
||||
@@ -109,14 +116,16 @@ async def get_organization(
|
||||
"settings": org.settings,
|
||||
"created_at": org.created_at,
|
||||
"updated_at": org.updated_at,
|
||||
"member_count": await organization_service.get_member_count(
|
||||
"member_count": await organization_crud.get_member_count(
|
||||
db, organization_id=org.id
|
||||
),
|
||||
}
|
||||
return OrganizationResponse(**org_dict)
|
||||
|
||||
except NotFoundError: # pragma: no cover - See above
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Error getting organization: %s", e)
|
||||
logger.error(f"Error getting organization: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@@ -140,7 +149,7 @@ async def get_organization_members(
|
||||
User must be a member of the organization to view members.
|
||||
"""
|
||||
try:
|
||||
members, total = await organization_service.get_organization_members(
|
||||
members, total = await organization_crud.get_organization_members(
|
||||
db,
|
||||
organization_id=organization_id,
|
||||
skip=pagination.offset,
|
||||
@@ -160,7 +169,7 @@ async def get_organization_members(
|
||||
return PaginatedResponse(data=member_responses, pagination=pagination_meta)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error getting organization members: %s", e)
|
||||
logger.error(f"Error getting organization members: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@@ -183,12 +192,16 @@ async def update_organization(
|
||||
Requires owner or admin role in the organization.
|
||||
"""
|
||||
try:
|
||||
org = await organization_service.get_organization(db, str(organization_id))
|
||||
updated_org = await organization_service.update_organization(
|
||||
db, org=org, obj_in=org_in
|
||||
)
|
||||
org = await organization_crud.get(db, id=organization_id)
|
||||
if not org: # pragma: no cover - Permission check prevents this (see docs/UNREACHABLE_DEFENSIVE_CODE_ANALYSIS.md)
|
||||
raise NotFoundError(
|
||||
detail=f"Organization {organization_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
updated_org = await organization_crud.update(db, db_obj=org, obj_in=org_in)
|
||||
logger.info(
|
||||
"User %s updated organization %s", current_user.email, updated_org.name
|
||||
f"User {current_user.email} updated organization {updated_org.name}"
|
||||
)
|
||||
|
||||
org_dict = {
|
||||
@@ -200,12 +213,14 @@ async def update_organization(
|
||||
"settings": updated_org.settings,
|
||||
"created_at": updated_org.created_at,
|
||||
"updated_at": updated_org.updated_at,
|
||||
"member_count": await organization_service.get_member_count(
|
||||
"member_count": await organization_crud.get_member_count(
|
||||
db, organization_id=updated_org.id
|
||||
),
|
||||
}
|
||||
return OrganizationResponse(**org_dict)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error updating organization: %s", e)
|
||||
except NotFoundError: # pragma: no cover - See above
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating organization: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
@@ -17,10 +17,10 @@ from app.api.dependencies.auth import get_current_user
|
||||
from app.core.auth import decode_token
|
||||
from app.core.database import get_db
|
||||
from app.core.exceptions import AuthorizationError, ErrorCode, NotFoundError
|
||||
from app.crud.session import session as session_crud
|
||||
from app.models.user import User
|
||||
from app.schemas.common import MessageResponse
|
||||
from app.schemas.sessions import SessionListResponse, SessionResponse
|
||||
from app.services.session_service import session_service
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -60,7 +60,7 @@ async def list_my_sessions(
|
||||
"""
|
||||
try:
|
||||
# Get all active sessions for user
|
||||
sessions = await session_service.get_user_sessions(
|
||||
sessions = await session_crud.get_user_sessions(
|
||||
db, user_id=str(current_user.id), active_only=True
|
||||
)
|
||||
|
||||
@@ -74,7 +74,9 @@ async def list_my_sessions(
|
||||
# For now, we'll mark current based on most recent activity
|
||||
except Exception as e:
|
||||
# Optional token parsing - silently ignore failures
|
||||
logger.debug("Failed to decode access token for session marking: %s", e)
|
||||
logger.debug(
|
||||
f"Failed to decode access token for session marking: {e!s}"
|
||||
)
|
||||
|
||||
# Convert to response format
|
||||
session_responses = []
|
||||
@@ -96,7 +98,7 @@ async def list_my_sessions(
|
||||
session_responses.append(session_response)
|
||||
|
||||
logger.info(
|
||||
"User %s listed %s active sessions", current_user.id, len(session_responses)
|
||||
f"User {current_user.id} listed {len(session_responses)} active sessions"
|
||||
)
|
||||
|
||||
return SessionListResponse(
|
||||
@@ -104,7 +106,9 @@ async def list_my_sessions(
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error listing sessions for user %s: %s", current_user.id, e)
|
||||
logger.error(
|
||||
f"Error listing sessions for user {current_user.id}: {e!s}", exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to retrieve sessions",
|
||||
@@ -146,7 +150,7 @@ async def revoke_session(
|
||||
"""
|
||||
try:
|
||||
# Get the session
|
||||
session = await session_service.get_session(db, str(session_id))
|
||||
session = await session_crud.get(db, id=str(session_id))
|
||||
|
||||
if not session:
|
||||
raise NotFoundError(
|
||||
@@ -157,10 +161,8 @@ async def revoke_session(
|
||||
# Verify session belongs to current user
|
||||
if str(session.user_id) != str(current_user.id):
|
||||
logger.warning(
|
||||
"User %s attempted to revoke session %s belonging to user %s",
|
||||
current_user.id,
|
||||
session_id,
|
||||
session.user_id,
|
||||
f"User {current_user.id} attempted to revoke session {session_id} "
|
||||
f"belonging to user {session.user_id}"
|
||||
)
|
||||
raise AuthorizationError(
|
||||
message="You can only revoke your own sessions",
|
||||
@@ -168,13 +170,11 @@ async def revoke_session(
|
||||
)
|
||||
|
||||
# Deactivate the session
|
||||
await session_service.deactivate(db, session_id=str(session_id))
|
||||
await session_crud.deactivate(db, session_id=str(session_id))
|
||||
|
||||
logger.info(
|
||||
"User %s revoked session %s (%s)",
|
||||
current_user.id,
|
||||
session_id,
|
||||
session.device_name,
|
||||
f"User {current_user.id} revoked session {session_id} "
|
||||
f"({session.device_name})"
|
||||
)
|
||||
|
||||
return MessageResponse(
|
||||
@@ -185,7 +185,7 @@ async def revoke_session(
|
||||
except (NotFoundError, AuthorizationError):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Error revoking session %s: %s", session_id, e)
|
||||
logger.error(f"Error revoking session {session_id}: {e!s}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to revoke session",
|
||||
@@ -224,12 +224,12 @@ async def cleanup_expired_sessions(
|
||||
"""
|
||||
try:
|
||||
# Use optimized bulk DELETE instead of N individual deletes
|
||||
deleted_count = await session_service.cleanup_expired_for_user(
|
||||
deleted_count = await session_crud.cleanup_expired_for_user(
|
||||
db, user_id=str(current_user.id)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"User %s cleaned up %s expired sessions", current_user.id, deleted_count
|
||||
f"User {current_user.id} cleaned up {deleted_count} expired sessions"
|
||||
)
|
||||
|
||||
return MessageResponse(
|
||||
@@ -237,8 +237,9 @@ async def cleanup_expired_sessions(
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Error cleaning up sessions for user %s: %s", current_user.id, e
|
||||
logger.error(
|
||||
f"Error cleaning up sessions for user {current_user.id}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
await db.rollback()
|
||||
raise HTTPException(
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
User management endpoints for database operations.
|
||||
User management endpoints for CRUD operations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
@@ -13,7 +13,8 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.dependencies.auth import get_current_superuser, get_current_user
|
||||
from app.core.database import get_db
|
||||
from app.core.exceptions import AuthorizationError, ErrorCode
|
||||
from app.core.exceptions import AuthorizationError, ErrorCode, NotFoundError
|
||||
from app.crud.user import user as user_crud
|
||||
from app.models.user import User
|
||||
from app.schemas.common import (
|
||||
MessageResponse,
|
||||
@@ -24,7 +25,6 @@ from app.schemas.common import (
|
||||
)
|
||||
from app.schemas.users import PasswordChange, UserResponse, UserUpdate
|
||||
from app.services.auth_service import AuthenticationError, AuthService
|
||||
from app.services.user_service import user_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -71,7 +71,7 @@ async def list_users(
|
||||
filters["is_superuser"] = is_superuser
|
||||
|
||||
# Get paginated users with total count
|
||||
users, total = await user_service.list_users(
|
||||
users, total = await user_crud.get_multi_with_total(
|
||||
db,
|
||||
skip=pagination.offset,
|
||||
limit=pagination.limit,
|
||||
@@ -90,7 +90,7 @@ async def list_users(
|
||||
|
||||
return PaginatedResponse(data=users, pagination=pagination_meta)
|
||||
except Exception as e:
|
||||
logger.exception("Error listing users: %s", e)
|
||||
logger.error(f"Error listing users: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@@ -107,9 +107,7 @@ async def list_users(
|
||||
""",
|
||||
operation_id="get_current_user_profile",
|
||||
)
|
||||
async def get_current_user_profile(
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> Any:
|
||||
def get_current_user_profile(current_user: User = Depends(get_current_user)) -> Any:
|
||||
"""Get current user's profile."""
|
||||
return current_user
|
||||
|
||||
@@ -140,16 +138,18 @@ async def update_current_user(
|
||||
Users cannot elevate their own permissions (protected by UserUpdate schema validator).
|
||||
"""
|
||||
try:
|
||||
updated_user = await user_service.update_user(
|
||||
db, user=current_user, obj_in=user_update
|
||||
updated_user = await user_crud.update(
|
||||
db, db_obj=current_user, obj_in=user_update
|
||||
)
|
||||
logger.info("User %s updated their profile", current_user.id)
|
||||
logger.info(f"User {current_user.id} updated their profile")
|
||||
return updated_user
|
||||
except ValueError as e:
|
||||
logger.error("Error updating user %s: %s", current_user.id, e)
|
||||
logger.error(f"Error updating user {current_user.id}: {e!s}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Unexpected error updating user %s: %s", current_user.id, e)
|
||||
logger.error(
|
||||
f"Unexpected error updating user {current_user.id}: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@@ -182,9 +182,7 @@ async def get_user_by_id(
|
||||
# Check permissions
|
||||
if str(user_id) != str(current_user.id) and not current_user.is_superuser:
|
||||
logger.warning(
|
||||
"User %s attempted to access user %s without permission",
|
||||
current_user.id,
|
||||
user_id,
|
||||
f"User {current_user.id} attempted to access user {user_id} without permission"
|
||||
)
|
||||
raise AuthorizationError(
|
||||
message="Not enough permissions to view this user",
|
||||
@@ -192,7 +190,13 @@ async def get_user_by_id(
|
||||
)
|
||||
|
||||
# Get user
|
||||
user = await user_service.get_user(db, str(user_id))
|
||||
user = await user_crud.get(db, id=str(user_id))
|
||||
if not user:
|
||||
raise NotFoundError(
|
||||
message=f"User with id {user_id} not found",
|
||||
error_code=ErrorCode.USER_NOT_FOUND,
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
@@ -229,9 +233,7 @@ async def update_user(
|
||||
|
||||
if not is_own_profile and not current_user.is_superuser:
|
||||
logger.warning(
|
||||
"User %s attempted to update user %s without permission",
|
||||
current_user.id,
|
||||
user_id,
|
||||
f"User {current_user.id} attempted to update user {user_id} without permission"
|
||||
)
|
||||
raise AuthorizationError(
|
||||
message="Not enough permissions to update this user",
|
||||
@@ -239,17 +241,22 @@ async def update_user(
|
||||
)
|
||||
|
||||
# Get user
|
||||
user = await user_service.get_user(db, str(user_id))
|
||||
user = await user_crud.get(db, id=str(user_id))
|
||||
if not user:
|
||||
raise NotFoundError(
|
||||
message=f"User with id {user_id} not found",
|
||||
error_code=ErrorCode.USER_NOT_FOUND,
|
||||
)
|
||||
|
||||
try:
|
||||
updated_user = await user_service.update_user(db, user=user, obj_in=user_update)
|
||||
logger.info("User %s updated by %s", user_id, current_user.id)
|
||||
updated_user = await user_crud.update(db, db_obj=user, obj_in=user_update)
|
||||
logger.info(f"User {user_id} updated by {current_user.id}")
|
||||
return updated_user
|
||||
except ValueError as e:
|
||||
logger.error("Error updating user %s: %s", user_id, e)
|
||||
logger.error(f"Error updating user {user_id}: {e!s}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Unexpected error updating user %s: %s", user_id, e)
|
||||
logger.error(f"Unexpected error updating user {user_id}: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@@ -289,19 +296,19 @@ async def change_current_user_password(
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info("User %s changed their password", current_user.id)
|
||||
logger.info(f"User {current_user.id} changed their password")
|
||||
return MessageResponse(
|
||||
success=True, message="Password changed successfully"
|
||||
)
|
||||
except AuthenticationError as e:
|
||||
logger.warning(
|
||||
"Failed password change attempt for user %s: %s", current_user.id, e
|
||||
f"Failed password change attempt for user {current_user.id}: {e!s}"
|
||||
)
|
||||
raise AuthorizationError(
|
||||
message=str(e), error_code=ErrorCode.INVALID_CREDENTIALS
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Error changing password for user %s: %s", current_user.id, e)
|
||||
logger.error(f"Error changing password for user {current_user.id}: {e!s}")
|
||||
raise
|
||||
|
||||
|
||||
@@ -339,19 +346,24 @@ async def delete_user(
|
||||
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||
)
|
||||
|
||||
# Get user (raises NotFoundError if not found)
|
||||
await user_service.get_user(db, str(user_id))
|
||||
# Get user
|
||||
user = await user_crud.get(db, id=str(user_id))
|
||||
if not user:
|
||||
raise NotFoundError(
|
||||
message=f"User with id {user_id} not found",
|
||||
error_code=ErrorCode.USER_NOT_FOUND,
|
||||
)
|
||||
|
||||
try:
|
||||
# Use soft delete instead of hard delete
|
||||
await user_service.soft_delete_user(db, str(user_id))
|
||||
logger.info("User %s soft-deleted by %s", user_id, current_user.id)
|
||||
await user_crud.soft_delete(db, id=str(user_id))
|
||||
logger.info(f"User {user_id} soft-deleted by {current_user.id}")
|
||||
return MessageResponse(
|
||||
success=True, message=f"User {user_id} deleted successfully"
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error("Error deleting user %s: %s", user_id, e)
|
||||
logger.error(f"Error deleting user {user_id}: {e!s}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Unexpected error deleting user %s: %s", user_id, e)
|
||||
logger.error(f"Unexpected error deleting user {user_id}: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
110
backend/app/celery_app.py
Normal file
110
backend/app/celery_app.py
Normal file
@@ -0,0 +1,110 @@
|
||||
# app/celery_app.py
|
||||
"""
|
||||
Celery application configuration for Syndarix.
|
||||
|
||||
This module configures the Celery app for background task processing:
|
||||
- Agent execution tasks (LLM calls, tool execution)
|
||||
- Git operations (clone, commit, push, PR creation)
|
||||
- Issue synchronization with external trackers
|
||||
- Workflow state management
|
||||
- Cost tracking and budget monitoring
|
||||
|
||||
Architecture:
|
||||
- Redis as message broker and result backend
|
||||
- Queue routing for task isolation
|
||||
- JSON serialization for cross-language compatibility
|
||||
- Beat scheduler for periodic tasks
|
||||
"""
|
||||
|
||||
from celery import Celery
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
# Create Celery application instance
|
||||
celery_app = Celery(
|
||||
"syndarix",
|
||||
broker=settings.celery_broker_url,
|
||||
backend=settings.celery_result_backend,
|
||||
)
|
||||
|
||||
# Define task queues with their own exchanges and routing keys
|
||||
TASK_QUEUES = {
|
||||
"agent": {"exchange": "agent", "routing_key": "agent"},
|
||||
"git": {"exchange": "git", "routing_key": "git"},
|
||||
"sync": {"exchange": "sync", "routing_key": "sync"},
|
||||
"default": {"exchange": "default", "routing_key": "default"},
|
||||
}
|
||||
|
||||
# Configure Celery
|
||||
celery_app.conf.update(
|
||||
# Serialization
|
||||
task_serializer="json",
|
||||
accept_content=["json"],
|
||||
result_serializer="json",
|
||||
# Timezone
|
||||
timezone="UTC",
|
||||
enable_utc=True,
|
||||
# Task imports for auto-discovery
|
||||
imports=("app.tasks",),
|
||||
# Default queue
|
||||
task_default_queue="default",
|
||||
# Task queues configuration
|
||||
task_queues=TASK_QUEUES,
|
||||
# Task routing - route tasks to appropriate queues
|
||||
task_routes={
|
||||
"app.tasks.agent.*": {"queue": "agent"},
|
||||
"app.tasks.git.*": {"queue": "git"},
|
||||
"app.tasks.sync.*": {"queue": "sync"},
|
||||
"app.tasks.*": {"queue": "default"},
|
||||
},
|
||||
# Time limits per ADR-003
|
||||
task_soft_time_limit=300, # 5 minutes soft limit
|
||||
task_time_limit=600, # 10 minutes hard limit
|
||||
# Result expiration - 24 hours
|
||||
result_expires=86400,
|
||||
# Broker connection retry
|
||||
broker_connection_retry_on_startup=True,
|
||||
# Beat schedule for periodic tasks
|
||||
beat_schedule={
|
||||
# Cost aggregation every hour per ADR-012
|
||||
"aggregate-daily-costs": {
|
||||
"task": "app.tasks.cost.aggregate_daily_costs",
|
||||
"schedule": 3600.0, # 1 hour in seconds
|
||||
},
|
||||
# Reset daily budget counters at midnight UTC
|
||||
"reset-daily-budget-counters": {
|
||||
"task": "app.tasks.cost.reset_daily_budget_counters",
|
||||
"schedule": 86400.0, # 24 hours in seconds
|
||||
},
|
||||
# Check for stale workflows every 5 minutes
|
||||
"recover-stale-workflows": {
|
||||
"task": "app.tasks.workflow.recover_stale_workflows",
|
||||
"schedule": 300.0, # 5 minutes in seconds
|
||||
},
|
||||
# Incremental issue sync every minute per ADR-011
|
||||
"sync-issues-incremental": {
|
||||
"task": "app.tasks.sync.sync_issues_incremental",
|
||||
"schedule": 60.0, # 1 minute in seconds
|
||||
},
|
||||
# Full issue reconciliation every 15 minutes per ADR-011
|
||||
"sync-issues-full": {
|
||||
"task": "app.tasks.sync.sync_issues_full",
|
||||
"schedule": 900.0, # 15 minutes in seconds
|
||||
},
|
||||
},
|
||||
# Task execution settings
|
||||
task_acks_late=True, # Acknowledge tasks after execution
|
||||
task_reject_on_worker_lost=True, # Reject tasks if worker dies
|
||||
worker_prefetch_multiplier=1, # Fair task distribution
|
||||
)
|
||||
|
||||
# Auto-discover tasks from task modules
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"app.tasks.agent",
|
||||
"app.tasks.git",
|
||||
"app.tasks.sync",
|
||||
"app.tasks.workflow",
|
||||
"app.tasks.cost",
|
||||
]
|
||||
)
|
||||
@@ -1,21 +1,23 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
import bcrypt
|
||||
import jwt
|
||||
from jwt.exceptions import (
|
||||
ExpiredSignatureError,
|
||||
InvalidTokenError,
|
||||
MissingRequiredClaimError,
|
||||
)
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.core.config import settings
|
||||
from app.schemas.users import TokenData, TokenPayload
|
||||
|
||||
# Suppress passlib bcrypt warnings about ident
|
||||
logging.getLogger("passlib").setLevel(logging.ERROR)
|
||||
|
||||
# Password hashing context
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
|
||||
# Custom exceptions for auth
|
||||
class AuthError(Exception):
|
||||
@@ -35,16 +37,13 @@ class TokenMissingClaimError(AuthError):
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""Verify a password against a bcrypt hash."""
|
||||
return bcrypt.checkpw(
|
||||
plain_password.encode("utf-8"), hashed_password.encode("utf-8")
|
||||
)
|
||||
"""Verify a password against a hash."""
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
|
||||
def get_password_hash(password: str) -> str:
|
||||
"""Generate a bcrypt password hash."""
|
||||
salt = bcrypt.gensalt()
|
||||
return bcrypt.hashpw(password.encode("utf-8"), salt).decode("utf-8")
|
||||
"""Generate a password hash."""
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
async def verify_password_async(plain_password: str, hashed_password: str) -> bool:
|
||||
@@ -61,9 +60,9 @@ async def verify_password_async(plain_password: str, hashed_password: str) -> bo
|
||||
Returns:
|
||||
True if password matches, False otherwise
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
None, partial(verify_password, plain_password, hashed_password)
|
||||
None, partial(pwd_context.verify, plain_password, hashed_password)
|
||||
)
|
||||
|
||||
|
||||
@@ -81,8 +80,8 @@ async def get_password_hash_async(password: str) -> str:
|
||||
Returns:
|
||||
Hashed password string
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(None, get_password_hash, password)
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, pwd_context.hash, password)
|
||||
|
||||
|
||||
def create_access_token(
|
||||
@@ -122,7 +121,11 @@ def create_access_token(
|
||||
to_encode.update(claims)
|
||||
|
||||
# Create the JWT
|
||||
return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
||||
encoded_jwt = jwt.encode(
|
||||
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
|
||||
)
|
||||
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
def create_refresh_token(
|
||||
@@ -151,7 +154,11 @@ def create_refresh_token(
|
||||
"type": "refresh",
|
||||
}
|
||||
|
||||
return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
||||
encoded_jwt = jwt.encode(
|
||||
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
|
||||
)
|
||||
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
def decode_token(token: str, verify_type: str | None = None) -> TokenPayload:
|
||||
@@ -191,7 +198,7 @@ def decode_token(token: str, verify_type: str | None = None) -> TokenPayload:
|
||||
|
||||
# Reject weak or unexpected algorithms
|
||||
# NOTE: These are defensive checks that provide defense-in-depth.
|
||||
# PyJWT rejects these tokens BEFORE we reach here,
|
||||
# The python-jose library rejects these tokens BEFORE we reach here,
|
||||
# but we keep these checks in case the library changes or is misconfigured.
|
||||
# Coverage: Marked as pragma since library catches first (see tests/core/test_auth_security.py)
|
||||
if token_algorithm == "NONE": # pragma: no cover
|
||||
@@ -212,11 +219,10 @@ def decode_token(token: str, verify_type: str | None = None) -> TokenPayload:
|
||||
token_data = TokenPayload(**payload)
|
||||
return token_data
|
||||
|
||||
except ExpiredSignatureError:
|
||||
raise TokenExpiredError("Token has expired")
|
||||
except MissingRequiredClaimError as e:
|
||||
raise TokenMissingClaimError(f"Token missing required claim: {e}")
|
||||
except InvalidTokenError:
|
||||
except JWTError as e:
|
||||
# Check if the error is due to an expired token
|
||||
if "expired" in str(e).lower():
|
||||
raise TokenExpiredError("Token has expired")
|
||||
raise TokenInvalidError("Invalid authentication token")
|
||||
except ValidationError:
|
||||
raise TokenInvalidError("Invalid token payload")
|
||||
|
||||
@@ -5,7 +5,7 @@ from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
PROJECT_NAME: str = "PragmaStack"
|
||||
PROJECT_NAME: str = "Syndarix"
|
||||
VERSION: str = "1.0.0"
|
||||
API_V1_STR: str = "/api/v1"
|
||||
|
||||
@@ -39,6 +39,32 @@ class Settings(BaseSettings):
|
||||
db_pool_timeout: int = 30 # Seconds to wait for a connection
|
||||
db_pool_recycle: int = 3600 # Recycle connections after 1 hour
|
||||
|
||||
# Redis configuration (Syndarix: cache, pub/sub, Celery broker)
|
||||
REDIS_URL: str = Field(
|
||||
default="redis://localhost:6379/0",
|
||||
description="Redis URL for cache, pub/sub, and Celery broker",
|
||||
)
|
||||
|
||||
# Celery configuration (Syndarix: background task processing)
|
||||
CELERY_BROKER_URL: str | None = Field(
|
||||
default=None,
|
||||
description="Celery broker URL (defaults to REDIS_URL if not set)",
|
||||
)
|
||||
CELERY_RESULT_BACKEND: str | None = Field(
|
||||
default=None,
|
||||
description="Celery result backend URL (defaults to REDIS_URL if not set)",
|
||||
)
|
||||
|
||||
@property
|
||||
def celery_broker_url(self) -> str:
|
||||
"""Get Celery broker URL, defaulting to Redis."""
|
||||
return self.CELERY_BROKER_URL or self.REDIS_URL
|
||||
|
||||
@property
|
||||
def celery_result_backend(self) -> str:
|
||||
"""Get Celery result backend URL, defaulting to Redis."""
|
||||
return self.CELERY_RESULT_BACKEND or self.REDIS_URL
|
||||
|
||||
# SQL debugging (disable in production)
|
||||
sql_echo: bool = False # Log SQL statements
|
||||
sql_echo_pool: bool = False # Log connection pool events
|
||||
|
||||
@@ -128,8 +128,8 @@ async def async_transaction_scope() -> AsyncGenerator[AsyncSession, None]:
|
||||
|
||||
Usage:
|
||||
async with async_transaction_scope() as db:
|
||||
user = await user_repo.create(db, obj_in=user_create)
|
||||
profile = await profile_repo.create(db, obj_in=profile_create)
|
||||
user = await user_crud.create(db, obj_in=user_create)
|
||||
profile = await profile_crud.create(db, obj_in=profile_create)
|
||||
# Both operations committed together
|
||||
"""
|
||||
async with SessionLocal() as session:
|
||||
@@ -139,7 +139,7 @@ async def async_transaction_scope() -> AsyncGenerator[AsyncSession, None]:
|
||||
logger.debug("Async transaction committed successfully")
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.error("Async transaction failed, rolling back: %s", e)
|
||||
logger.error(f"Async transaction failed, rolling back: {e!s}")
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
@@ -155,7 +155,7 @@ async def check_async_database_health() -> bool:
|
||||
await db.execute(text("SELECT 1"))
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error("Async database health check failed: %s", e)
|
||||
logger.error(f"Async database health check failed: {e!s}")
|
||||
return False
|
||||
|
||||
|
||||
|
||||
@@ -143,11 +143,8 @@ async def api_exception_handler(request: Request, exc: APIException) -> JSONResp
|
||||
Returns a standardized error response with error code and message.
|
||||
"""
|
||||
logger.warning(
|
||||
"API exception: %s - %s (status: %s, path: %s)",
|
||||
exc.error_code,
|
||||
exc.message,
|
||||
exc.status_code,
|
||||
request.url.path,
|
||||
f"API exception: {exc.error_code} - {exc.message} "
|
||||
f"(status: {exc.status_code}, path: {request.url.path})"
|
||||
)
|
||||
|
||||
error_response = ErrorResponse(
|
||||
@@ -189,9 +186,7 @@ async def validation_exception_handler(
|
||||
)
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
"Validation error: %s errors (path: %s)", len(errors), request.url.path
|
||||
)
|
||||
logger.warning(f"Validation error: {len(errors)} errors (path: {request.url.path})")
|
||||
|
||||
error_response = ErrorResponse(errors=errors)
|
||||
|
||||
@@ -223,14 +218,11 @@ async def http_exception_handler(request: Request, exc: HTTPException) -> JSONRe
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
"HTTP exception: %s - %s (path: %s)",
|
||||
exc.status_code,
|
||||
exc.detail,
|
||||
request.url.path,
|
||||
f"HTTP exception: {exc.status_code} - {exc.detail} (path: {request.url.path})"
|
||||
)
|
||||
|
||||
error_response = ErrorResponse(
|
||||
errors=[ErrorDetail(code=error_code, message=str(exc.detail), field=None)]
|
||||
errors=[ErrorDetail(code=error_code, message=str(exc.detail))]
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
@@ -247,11 +239,10 @@ async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONR
|
||||
Logs the full exception and returns a generic error response to avoid
|
||||
leaking sensitive information in production.
|
||||
"""
|
||||
logger.exception(
|
||||
"Unhandled exception: %s - %s (path: %s)",
|
||||
type(exc).__name__,
|
||||
exc,
|
||||
request.url.path,
|
||||
logger.error(
|
||||
f"Unhandled exception: {type(exc).__name__} - {exc!s} "
|
||||
f"(path: {request.url.path})",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# In production, don't expose internal error details
|
||||
@@ -263,7 +254,7 @@ async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONR
|
||||
message = f"{type(exc).__name__}: {exc!s}"
|
||||
|
||||
error_response = ErrorResponse(
|
||||
errors=[ErrorDetail(code=ErrorCode.INTERNAL_ERROR, message=message, field=None)]
|
||||
errors=[ErrorDetail(code=ErrorCode.INTERNAL_ERROR, message=message)]
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
|
||||
476
backend/app/core/redis.py
Normal file
476
backend/app/core/redis.py
Normal file
@@ -0,0 +1,476 @@
|
||||
# app/core/redis.py
|
||||
"""
|
||||
Redis client configuration for caching and pub/sub.
|
||||
|
||||
This module provides async Redis connectivity with connection pooling
|
||||
for FastAPI endpoints and background tasks.
|
||||
|
||||
Features:
|
||||
- Connection pooling for efficient resource usage
|
||||
- Cache operations (get, set, delete, expire)
|
||||
- Pub/sub operations (publish, subscribe)
|
||||
- Health check for monitoring
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
|
||||
from redis.asyncio import ConnectionPool, Redis
|
||||
from redis.asyncio.client import PubSub
|
||||
from redis.exceptions import ConnectionError, RedisError, TimeoutError
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
# Configure logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default TTL for cache entries (1 hour)
|
||||
DEFAULT_CACHE_TTL = 3600
|
||||
|
||||
# Connection pool settings
|
||||
POOL_MAX_CONNECTIONS = 50
|
||||
POOL_TIMEOUT = 10 # seconds
|
||||
|
||||
|
||||
class RedisClient:
|
||||
"""
|
||||
Async Redis client with connection pooling.
|
||||
|
||||
Provides high-level operations for caching and pub/sub
|
||||
with proper error handling and connection management.
|
||||
"""
|
||||
|
||||
def __init__(self, url: str | None = None) -> None:
|
||||
"""
|
||||
Initialize Redis client.
|
||||
|
||||
Args:
|
||||
url: Redis connection URL. Defaults to settings.REDIS_URL.
|
||||
"""
|
||||
self._url = url or settings.REDIS_URL
|
||||
self._pool: ConnectionPool | None = None
|
||||
self._client: Redis | None = None
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def _ensure_pool(self) -> ConnectionPool:
|
||||
"""Ensure connection pool is initialized (thread-safe)."""
|
||||
if self._pool is None:
|
||||
async with self._lock:
|
||||
# Double-check after acquiring lock
|
||||
if self._pool is None:
|
||||
self._pool = ConnectionPool.from_url(
|
||||
self._url,
|
||||
max_connections=POOL_MAX_CONNECTIONS,
|
||||
socket_timeout=POOL_TIMEOUT,
|
||||
socket_connect_timeout=POOL_TIMEOUT,
|
||||
decode_responses=True,
|
||||
health_check_interval=30,
|
||||
)
|
||||
logger.info("Redis connection pool initialized")
|
||||
return self._pool
|
||||
|
||||
async def _get_client(self) -> Redis:
|
||||
"""Get Redis client instance from pool."""
|
||||
pool = await self._ensure_pool()
|
||||
if self._client is None:
|
||||
self._client = Redis(connection_pool=pool)
|
||||
return self._client
|
||||
|
||||
# =========================================================================
|
||||
# Cache Operations
|
||||
# =========================================================================
|
||||
|
||||
async def cache_get(self, key: str) -> str | None:
|
||||
"""
|
||||
Get a value from cache.
|
||||
|
||||
Args:
|
||||
key: Cache key.
|
||||
|
||||
Returns:
|
||||
Cached value or None if not found.
|
||||
"""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
value = await client.get(key)
|
||||
if value is not None:
|
||||
logger.debug(f"Cache hit for key: {key}")
|
||||
else:
|
||||
logger.debug(f"Cache miss for key: {key}")
|
||||
return value
|
||||
except (ConnectionError, TimeoutError) as e:
|
||||
logger.error(f"Redis cache_get failed for key '{key}': {e}")
|
||||
return None
|
||||
except RedisError as e:
|
||||
logger.error(f"Redis error in cache_get for key '{key}': {e}")
|
||||
return None
|
||||
|
||||
async def cache_get_json(self, key: str) -> Any | None:
|
||||
"""
|
||||
Get a JSON-serialized value from cache.
|
||||
|
||||
Args:
|
||||
key: Cache key.
|
||||
|
||||
Returns:
|
||||
Deserialized value or None if not found.
|
||||
"""
|
||||
value = await self.cache_get(key)
|
||||
if value is not None:
|
||||
try:
|
||||
return json.loads(value)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to decode JSON for key '{key}': {e}")
|
||||
return None
|
||||
return None
|
||||
|
||||
async def cache_set(
|
||||
self,
|
||||
key: str,
|
||||
value: str,
|
||||
ttl: int | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Set a value in cache.
|
||||
|
||||
Args:
|
||||
key: Cache key.
|
||||
value: Value to cache.
|
||||
ttl: Time-to-live in seconds. Defaults to DEFAULT_CACHE_TTL.
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise.
|
||||
"""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
ttl = ttl if ttl is not None else DEFAULT_CACHE_TTL
|
||||
await client.set(key, value, ex=ttl)
|
||||
logger.debug(f"Cache set for key: {key} (TTL: {ttl}s)")
|
||||
return True
|
||||
except (ConnectionError, TimeoutError) as e:
|
||||
logger.error(f"Redis cache_set failed for key '{key}': {e}")
|
||||
return False
|
||||
except RedisError as e:
|
||||
logger.error(f"Redis error in cache_set for key '{key}': {e}")
|
||||
return False
|
||||
|
||||
async def cache_set_json(
|
||||
self,
|
||||
key: str,
|
||||
value: Any,
|
||||
ttl: int | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Set a JSON-serialized value in cache.
|
||||
|
||||
Args:
|
||||
key: Cache key.
|
||||
value: Value to serialize and cache.
|
||||
ttl: Time-to-live in seconds.
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise.
|
||||
"""
|
||||
try:
|
||||
serialized = json.dumps(value)
|
||||
return await self.cache_set(key, serialized, ttl)
|
||||
except (TypeError, ValueError) as e:
|
||||
logger.error(f"Failed to serialize value for key '{key}': {e}")
|
||||
return False
|
||||
|
||||
async def cache_delete(self, key: str) -> bool:
|
||||
"""
|
||||
Delete a key from cache.
|
||||
|
||||
Args:
|
||||
key: Cache key to delete.
|
||||
|
||||
Returns:
|
||||
True if key was deleted, False otherwise.
|
||||
"""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
result = await client.delete(key)
|
||||
logger.debug(f"Cache delete for key: {key} (deleted: {result > 0})")
|
||||
return result > 0
|
||||
except (ConnectionError, TimeoutError) as e:
|
||||
logger.error(f"Redis cache_delete failed for key '{key}': {e}")
|
||||
return False
|
||||
except RedisError as e:
|
||||
logger.error(f"Redis error in cache_delete for key '{key}': {e}")
|
||||
return False
|
||||
|
||||
async def cache_delete_pattern(self, pattern: str) -> int:
|
||||
"""
|
||||
Delete all keys matching a pattern.
|
||||
|
||||
Args:
|
||||
pattern: Glob-style pattern (e.g., "user:*").
|
||||
|
||||
Returns:
|
||||
Number of keys deleted.
|
||||
"""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
deleted = 0
|
||||
async for key in client.scan_iter(pattern):
|
||||
await client.delete(key)
|
||||
deleted += 1
|
||||
logger.debug(f"Cache delete pattern '{pattern}': {deleted} keys deleted")
|
||||
return deleted
|
||||
except (ConnectionError, TimeoutError) as e:
|
||||
logger.error(f"Redis cache_delete_pattern failed for '{pattern}': {e}")
|
||||
return 0
|
||||
except RedisError as e:
|
||||
logger.error(f"Redis error in cache_delete_pattern for '{pattern}': {e}")
|
||||
return 0
|
||||
|
||||
async def cache_expire(self, key: str, ttl: int) -> bool:
|
||||
"""
|
||||
Set or update TTL for a key.
|
||||
|
||||
Args:
|
||||
key: Cache key.
|
||||
ttl: New TTL in seconds.
|
||||
|
||||
Returns:
|
||||
True if TTL was set, False if key doesn't exist.
|
||||
"""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
result = await client.expire(key, ttl)
|
||||
logger.debug(f"Cache expire for key: {key} (TTL: {ttl}s, success: {result})")
|
||||
return result
|
||||
except (ConnectionError, TimeoutError) as e:
|
||||
logger.error(f"Redis cache_expire failed for key '{key}': {e}")
|
||||
return False
|
||||
except RedisError as e:
|
||||
logger.error(f"Redis error in cache_expire for key '{key}': {e}")
|
||||
return False
|
||||
|
||||
async def cache_exists(self, key: str) -> bool:
|
||||
"""
|
||||
Check if a key exists in cache.
|
||||
|
||||
Args:
|
||||
key: Cache key.
|
||||
|
||||
Returns:
|
||||
True if key exists, False otherwise.
|
||||
"""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
result = await client.exists(key)
|
||||
return result > 0
|
||||
except (ConnectionError, TimeoutError) as e:
|
||||
logger.error(f"Redis cache_exists failed for key '{key}': {e}")
|
||||
return False
|
||||
except RedisError as e:
|
||||
logger.error(f"Redis error in cache_exists for key '{key}': {e}")
|
||||
return False
|
||||
|
||||
async def cache_ttl(self, key: str) -> int:
|
||||
"""
|
||||
Get remaining TTL for a key.
|
||||
|
||||
Args:
|
||||
key: Cache key.
|
||||
|
||||
Returns:
|
||||
TTL in seconds, -1 if no TTL, -2 if key doesn't exist.
|
||||
"""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
return await client.ttl(key)
|
||||
except (ConnectionError, TimeoutError) as e:
|
||||
logger.error(f"Redis cache_ttl failed for key '{key}': {e}")
|
||||
return -2
|
||||
except RedisError as e:
|
||||
logger.error(f"Redis error in cache_ttl for key '{key}': {e}")
|
||||
return -2
|
||||
|
||||
# =========================================================================
|
||||
# Pub/Sub Operations
|
||||
# =========================================================================
|
||||
|
||||
async def publish(self, channel: str, message: str | dict) -> int:
|
||||
"""
|
||||
Publish a message to a channel.
|
||||
|
||||
Args:
|
||||
channel: Channel name.
|
||||
message: Message to publish (string or dict for JSON serialization).
|
||||
|
||||
Returns:
|
||||
Number of subscribers that received the message.
|
||||
"""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
if isinstance(message, dict):
|
||||
message = json.dumps(message)
|
||||
result = await client.publish(channel, message)
|
||||
logger.debug(f"Published to channel '{channel}': {result} subscribers")
|
||||
return result
|
||||
except (ConnectionError, TimeoutError) as e:
|
||||
logger.error(f"Redis publish failed for channel '{channel}': {e}")
|
||||
return 0
|
||||
except RedisError as e:
|
||||
logger.error(f"Redis error in publish for channel '{channel}': {e}")
|
||||
return 0
|
||||
|
||||
@asynccontextmanager
|
||||
async def subscribe(
|
||||
self, *channels: str
|
||||
) -> AsyncGenerator[PubSub, None]:
|
||||
"""
|
||||
Subscribe to one or more channels.
|
||||
|
||||
Usage:
|
||||
async with redis_client.subscribe("channel1", "channel2") as pubsub:
|
||||
async for message in pubsub.listen():
|
||||
if message["type"] == "message":
|
||||
print(message["data"])
|
||||
|
||||
Args:
|
||||
channels: Channel names to subscribe to.
|
||||
|
||||
Yields:
|
||||
PubSub instance for receiving messages.
|
||||
"""
|
||||
client = await self._get_client()
|
||||
pubsub = client.pubsub()
|
||||
try:
|
||||
await pubsub.subscribe(*channels)
|
||||
logger.debug(f"Subscribed to channels: {channels}")
|
||||
yield pubsub
|
||||
finally:
|
||||
await pubsub.unsubscribe(*channels)
|
||||
await pubsub.close()
|
||||
logger.debug(f"Unsubscribed from channels: {channels}")
|
||||
|
||||
@asynccontextmanager
|
||||
async def psubscribe(
|
||||
self, *patterns: str
|
||||
) -> AsyncGenerator[PubSub, None]:
|
||||
"""
|
||||
Subscribe to channels matching patterns.
|
||||
|
||||
Usage:
|
||||
async with redis_client.psubscribe("user:*") as pubsub:
|
||||
async for message in pubsub.listen():
|
||||
if message["type"] == "pmessage":
|
||||
print(message["pattern"], message["channel"], message["data"])
|
||||
|
||||
Args:
|
||||
patterns: Glob-style patterns to subscribe to.
|
||||
|
||||
Yields:
|
||||
PubSub instance for receiving messages.
|
||||
"""
|
||||
client = await self._get_client()
|
||||
pubsub = client.pubsub()
|
||||
try:
|
||||
await pubsub.psubscribe(*patterns)
|
||||
logger.debug(f"Pattern subscribed: {patterns}")
|
||||
yield pubsub
|
||||
finally:
|
||||
await pubsub.punsubscribe(*patterns)
|
||||
await pubsub.close()
|
||||
logger.debug(f"Pattern unsubscribed: {patterns}")
|
||||
|
||||
# =========================================================================
|
||||
# Health & Connection Management
|
||||
# =========================================================================
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""
|
||||
Check if Redis connection is healthy.
|
||||
|
||||
Returns:
|
||||
True if connection is successful, False otherwise.
|
||||
"""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
result = await client.ping()
|
||||
return result is True
|
||||
except (ConnectionError, TimeoutError) as e:
|
||||
logger.error(f"Redis health check failed: {e}")
|
||||
return False
|
||||
except RedisError as e:
|
||||
logger.error(f"Redis health check error: {e}")
|
||||
return False
|
||||
|
||||
async def close(self) -> None:
|
||||
"""
|
||||
Close Redis connections and cleanup resources.
|
||||
|
||||
Should be called during application shutdown.
|
||||
"""
|
||||
if self._client:
|
||||
await self._client.close()
|
||||
self._client = None
|
||||
logger.debug("Redis client closed")
|
||||
|
||||
if self._pool:
|
||||
await self._pool.disconnect()
|
||||
self._pool = None
|
||||
logger.info("Redis connection pool closed")
|
||||
|
||||
async def get_pool_info(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get connection pool statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with pool information.
|
||||
"""
|
||||
if self._pool is None:
|
||||
return {"status": "not_initialized"}
|
||||
|
||||
return {
|
||||
"status": "active",
|
||||
"max_connections": POOL_MAX_CONNECTIONS,
|
||||
"url": self._url.split("@")[-1] if "@" in self._url else self._url,
|
||||
}
|
||||
|
||||
|
||||
# Global Redis client instance
|
||||
redis_client = RedisClient()
|
||||
|
||||
|
||||
# FastAPI dependency for Redis client
|
||||
async def get_redis() -> AsyncGenerator[RedisClient, None]:
|
||||
"""
|
||||
FastAPI dependency that provides the Redis client.
|
||||
|
||||
Usage:
|
||||
@router.get("/cached-data")
|
||||
async def get_data(redis: RedisClient = Depends(get_redis)):
|
||||
cached = await redis.cache_get("my-key")
|
||||
...
|
||||
"""
|
||||
yield redis_client
|
||||
|
||||
|
||||
# Health check function for use in /health endpoint
|
||||
async def check_redis_health() -> bool:
|
||||
"""
|
||||
Check if Redis connection is healthy.
|
||||
|
||||
Returns:
|
||||
True if connection is successful, False otherwise.
|
||||
"""
|
||||
return await redis_client.health_check()
|
||||
|
||||
|
||||
# Cleanup function for application shutdown
|
||||
async def close_redis() -> None:
|
||||
"""
|
||||
Close Redis connections.
|
||||
|
||||
Should be called during application shutdown.
|
||||
"""
|
||||
await redis_client.close()
|
||||
@@ -1,26 +0,0 @@
|
||||
"""
|
||||
Custom exceptions for the repository layer.
|
||||
|
||||
These exceptions allow services and routes to handle database-level errors
|
||||
with proper semantics, without leaking SQLAlchemy internals.
|
||||
"""
|
||||
|
||||
|
||||
class RepositoryError(Exception):
|
||||
"""Base for all repository-layer errors."""
|
||||
|
||||
|
||||
class DuplicateEntryError(RepositoryError):
|
||||
"""Raised on unique constraint violations. Maps to HTTP 409 Conflict."""
|
||||
|
||||
|
||||
class IntegrityConstraintError(RepositoryError):
|
||||
"""Raised on FK or check constraint violations."""
|
||||
|
||||
|
||||
class RecordNotFoundError(RepositoryError):
|
||||
"""Raised when an expected record doesn't exist."""
|
||||
|
||||
|
||||
class InvalidInputError(RepositoryError):
|
||||
"""Raised on bad pagination params, invalid UUIDs, or other invalid inputs."""
|
||||
14
backend/app/crud/__init__.py
Normal file
14
backend/app/crud/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# app/crud/__init__.py
|
||||
from .oauth import oauth_account, oauth_client, oauth_state
|
||||
from .organization import organization
|
||||
from .session import session as session_crud
|
||||
from .user import user
|
||||
|
||||
__all__ = [
|
||||
"oauth_account",
|
||||
"oauth_client",
|
||||
"oauth_state",
|
||||
"organization",
|
||||
"session_crud",
|
||||
"user",
|
||||
]
|
||||
177
backend/app/repositories/base.py → backend/app/crud/base.py
Normal file → Executable file
177
backend/app/repositories/base.py → backend/app/crud/base.py
Normal file → Executable file
@@ -1,6 +1,6 @@
|
||||
# app/repositories/base.py
|
||||
# app/crud/base_async.py
|
||||
"""
|
||||
Base repository class for async database operations using SQLAlchemy 2.0 async patterns.
|
||||
Async CRUD operations base class using SQLAlchemy 2.0 async patterns.
|
||||
|
||||
Provides reusable create, read, update, and delete operations for all models.
|
||||
"""
|
||||
@@ -18,11 +18,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Load
|
||||
|
||||
from app.core.database import Base
|
||||
from app.core.repository_exceptions import (
|
||||
DuplicateEntryError,
|
||||
IntegrityConstraintError,
|
||||
InvalidInputError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -31,16 +26,16 @@ CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
|
||||
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
|
||||
|
||||
|
||||
class BaseRepository[
|
||||
class CRUDBase[
|
||||
ModelType: Base,
|
||||
CreateSchemaType: BaseModel,
|
||||
UpdateSchemaType: BaseModel,
|
||||
]:
|
||||
"""Async repository operations for a model."""
|
||||
"""Async CRUD operations for a model."""
|
||||
|
||||
def __init__(self, model: type[ModelType]):
|
||||
"""
|
||||
Repository object with default async methods to Create, Read, Update, Delete.
|
||||
CRUD object with default async methods to Create, Read, Update, Delete.
|
||||
|
||||
Parameters:
|
||||
model: A SQLAlchemy model class
|
||||
@@ -61,19 +56,26 @@ class BaseRepository[
|
||||
|
||||
Returns:
|
||||
Model instance or None if not found
|
||||
|
||||
Example:
|
||||
# Eager load user relationship
|
||||
from sqlalchemy.orm import joinedload
|
||||
session = await session_crud.get(db, id=session_id, options=[joinedload(UserSession.user)])
|
||||
"""
|
||||
# Validate UUID format and convert to UUID object if string
|
||||
try:
|
||||
if isinstance(id, uuid.UUID):
|
||||
uuid_obj = id
|
||||
else:
|
||||
uuid_obj = uuid.UUID(str(id))
|
||||
except (ValueError, AttributeError, TypeError) as e:
|
||||
logger.warning("Invalid UUID format: %s - %s", id, e)
|
||||
logger.warning(f"Invalid UUID format: {id} - {e!s}")
|
||||
return None
|
||||
|
||||
try:
|
||||
query = select(self.model).where(self.model.id == uuid_obj)
|
||||
|
||||
# Apply eager loading options if provided
|
||||
if options:
|
||||
for option in options:
|
||||
query = query.options(option)
|
||||
@@ -81,9 +83,7 @@ class BaseRepository[
|
||||
result = await db.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error retrieving %s with id %s: %s", self.model.__name__, id, e
|
||||
)
|
||||
logger.error(f"Error retrieving {self.model.__name__} with id {id}: {e!s}")
|
||||
raise
|
||||
|
||||
async def get_multi(
|
||||
@@ -96,17 +96,28 @@ class BaseRepository[
|
||||
) -> list[ModelType]:
|
||||
"""
|
||||
Get multiple records with pagination validation and optional eager loading.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
skip: Number of records to skip
|
||||
limit: Maximum number of records to return
|
||||
options: Optional list of SQLAlchemy load options for eager loading
|
||||
|
||||
Returns:
|
||||
List of model instances
|
||||
"""
|
||||
# Validate pagination parameters
|
||||
if skip < 0:
|
||||
raise InvalidInputError("skip must be non-negative")
|
||||
raise ValueError("skip must be non-negative")
|
||||
if limit < 0:
|
||||
raise InvalidInputError("limit must be non-negative")
|
||||
raise ValueError("limit must be non-negative")
|
||||
if limit > 1000:
|
||||
raise InvalidInputError("Maximum limit is 1000")
|
||||
raise ValueError("Maximum limit is 1000")
|
||||
|
||||
try:
|
||||
query = select(self.model).order_by(self.model.id).offset(skip).limit(limit)
|
||||
query = select(self.model).offset(skip).limit(limit)
|
||||
|
||||
# Apply eager loading options if provided
|
||||
if options:
|
||||
for option in options:
|
||||
query = query.options(option)
|
||||
@@ -115,7 +126,7 @@ class BaseRepository[
|
||||
return list(result.scalars().all())
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error retrieving multiple %s records: %s", self.model.__name__, e
|
||||
f"Error retrieving multiple {self.model.__name__} records: {e!s}"
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -125,8 +136,9 @@ class BaseRepository[
|
||||
"""Create a new record with error handling.
|
||||
|
||||
NOTE: This method is defensive code that's never called in practice.
|
||||
All repository subclasses override this method with their own implementations.
|
||||
Marked as pragma: no cover to avoid false coverage gaps.
|
||||
All CRUD subclasses (CRUDUser, CRUDOrganization, CRUDSession) override this method
|
||||
with their own implementations, so the base implementation and its exception handlers
|
||||
are never executed. Marked as pragma: no cover to avoid false coverage gaps.
|
||||
"""
|
||||
try: # pragma: no cover
|
||||
obj_in_data = jsonable_encoder(obj_in)
|
||||
@@ -140,24 +152,22 @@ class BaseRepository[
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
|
||||
logger.warning(
|
||||
"Duplicate entry attempted for %s: %s",
|
||||
self.model.__name__,
|
||||
error_msg,
|
||||
f"Duplicate entry attempted for {self.model.__name__}: {error_msg}"
|
||||
)
|
||||
raise DuplicateEntryError(
|
||||
raise ValueError(
|
||||
f"A {self.model.__name__} with this data already exists"
|
||||
)
|
||||
logger.error(
|
||||
"Integrity error creating %s: %s", self.model.__name__, error_msg
|
||||
)
|
||||
raise IntegrityConstraintError(f"Database integrity error: {error_msg}")
|
||||
logger.error(f"Integrity error creating {self.model.__name__}: {error_msg}")
|
||||
raise ValueError(f"Database integrity error: {error_msg}")
|
||||
except (OperationalError, DataError) as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error("Database error creating %s: %s", self.model.__name__, e)
|
||||
raise IntegrityConstraintError(f"Database operation failed: {e!s}")
|
||||
logger.error(f"Database error creating {self.model.__name__}: {e!s}")
|
||||
raise ValueError(f"Database operation failed: {e!s}")
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.exception("Unexpected error creating %s: %s", self.model.__name__, e)
|
||||
logger.error(
|
||||
f"Unexpected error creating {self.model.__name__}: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def update(
|
||||
@@ -188,35 +198,34 @@ class BaseRepository[
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
|
||||
logger.warning(
|
||||
"Duplicate entry attempted for %s: %s",
|
||||
self.model.__name__,
|
||||
error_msg,
|
||||
f"Duplicate entry attempted for {self.model.__name__}: {error_msg}"
|
||||
)
|
||||
raise DuplicateEntryError(
|
||||
raise ValueError(
|
||||
f"A {self.model.__name__} with this data already exists"
|
||||
)
|
||||
logger.error(
|
||||
"Integrity error updating %s: %s", self.model.__name__, error_msg
|
||||
)
|
||||
raise IntegrityConstraintError(f"Database integrity error: {error_msg}")
|
||||
logger.error(f"Integrity error updating {self.model.__name__}: {error_msg}")
|
||||
raise ValueError(f"Database integrity error: {error_msg}")
|
||||
except (OperationalError, DataError) as e:
|
||||
await db.rollback()
|
||||
logger.error("Database error updating %s: %s", self.model.__name__, e)
|
||||
raise IntegrityConstraintError(f"Database operation failed: {e!s}")
|
||||
logger.error(f"Database error updating {self.model.__name__}: {e!s}")
|
||||
raise ValueError(f"Database operation failed: {e!s}")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.exception("Unexpected error updating %s: %s", self.model.__name__, e)
|
||||
logger.error(
|
||||
f"Unexpected error updating {self.model.__name__}: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def remove(self, db: AsyncSession, *, id: str) -> ModelType | None:
|
||||
"""Delete a record with error handling and null check."""
|
||||
# Validate UUID format and convert to UUID object if string
|
||||
try:
|
||||
if isinstance(id, uuid.UUID):
|
||||
uuid_obj = id
|
||||
else:
|
||||
uuid_obj = uuid.UUID(str(id))
|
||||
except (ValueError, AttributeError, TypeError) as e:
|
||||
logger.warning("Invalid UUID format for deletion: %s - %s", id, e)
|
||||
logger.warning(f"Invalid UUID format for deletion: {id} - {e!s}")
|
||||
return None
|
||||
|
||||
try:
|
||||
@@ -227,7 +236,7 @@ class BaseRepository[
|
||||
|
||||
if obj is None:
|
||||
logger.warning(
|
||||
"%s with id %s not found for deletion", self.model.__name__, id
|
||||
f"{self.model.__name__} with id {id} not found for deletion"
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -237,16 +246,15 @@ class BaseRepository[
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
logger.error(
|
||||
"Integrity error deleting %s: %s", self.model.__name__, error_msg
|
||||
)
|
||||
raise IntegrityConstraintError(
|
||||
logger.error(f"Integrity error deleting {self.model.__name__}: {error_msg}")
|
||||
raise ValueError(
|
||||
f"Cannot delete {self.model.__name__}: referenced by other records"
|
||||
)
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.exception(
|
||||
"Error deleting %s with id %s: %s", self.model.__name__, id, e
|
||||
logger.error(
|
||||
f"Error deleting {self.model.__name__} with id {id}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -264,40 +272,57 @@ class BaseRepository[
|
||||
Get multiple records with total count, filtering, and sorting.
|
||||
|
||||
NOTE: This method is defensive code that's never called in practice.
|
||||
All repository subclasses override this method with their own implementations.
|
||||
All CRUD subclasses (CRUDUser, CRUDOrganization, CRUDSession) override this method
|
||||
with their own implementations that include additional parameters like search.
|
||||
Marked as pragma: no cover to avoid false coverage gaps.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
skip: Number of records to skip
|
||||
limit: Maximum number of records to return
|
||||
sort_by: Field name to sort by (must be a valid model attribute)
|
||||
sort_order: Sort order ("asc" or "desc")
|
||||
filters: Dictionary of filters (field_name: value)
|
||||
|
||||
Returns:
|
||||
Tuple of (items, total_count)
|
||||
"""
|
||||
# Validate pagination parameters
|
||||
if skip < 0:
|
||||
raise InvalidInputError("skip must be non-negative")
|
||||
raise ValueError("skip must be non-negative")
|
||||
if limit < 0:
|
||||
raise InvalidInputError("limit must be non-negative")
|
||||
raise ValueError("limit must be non-negative")
|
||||
if limit > 1000:
|
||||
raise InvalidInputError("Maximum limit is 1000")
|
||||
raise ValueError("Maximum limit is 1000")
|
||||
|
||||
try:
|
||||
# Build base query
|
||||
query = select(self.model)
|
||||
|
||||
# Exclude soft-deleted records by default
|
||||
if hasattr(self.model, "deleted_at"):
|
||||
query = query.where(self.model.deleted_at.is_(None))
|
||||
|
||||
# Apply filters
|
||||
if filters:
|
||||
for field, value in filters.items():
|
||||
if hasattr(self.model, field) and value is not None:
|
||||
query = query.where(getattr(self.model, field) == value)
|
||||
|
||||
# Get total count (before pagination)
|
||||
count_query = select(func.count()).select_from(query.alias())
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# Apply sorting
|
||||
if sort_by and hasattr(self.model, sort_by):
|
||||
sort_column = getattr(self.model, sort_by)
|
||||
if sort_order.lower() == "desc":
|
||||
query = query.order_by(sort_column.desc())
|
||||
else:
|
||||
query = query.order_by(sort_column.asc())
|
||||
else:
|
||||
query = query.order_by(self.model.id)
|
||||
|
||||
# Apply pagination
|
||||
query = query.offset(skip).limit(limit)
|
||||
items_result = await db.execute(query)
|
||||
items = list(items_result.scalars().all())
|
||||
@@ -305,7 +330,7 @@ class BaseRepository[
|
||||
return items, total
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error(
|
||||
"Error retrieving paginated %s records: %s", self.model.__name__, e
|
||||
f"Error retrieving paginated {self.model.__name__} records: {e!s}"
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -315,7 +340,7 @@ class BaseRepository[
|
||||
result = await db.execute(select(func.count(self.model.id)))
|
||||
return result.scalar_one()
|
||||
except Exception as e:
|
||||
logger.error("Error counting %s records: %s", self.model.__name__, e)
|
||||
logger.error(f"Error counting {self.model.__name__} records: {e!s}")
|
||||
raise
|
||||
|
||||
async def exists(self, db: AsyncSession, id: str) -> bool:
|
||||
@@ -331,13 +356,14 @@ class BaseRepository[
|
||||
"""
|
||||
from datetime import datetime
|
||||
|
||||
# Validate UUID format and convert to UUID object if string
|
||||
try:
|
||||
if isinstance(id, uuid.UUID):
|
||||
uuid_obj = id
|
||||
else:
|
||||
uuid_obj = uuid.UUID(str(id))
|
||||
except (ValueError, AttributeError, TypeError) as e:
|
||||
logger.warning("Invalid UUID format for soft deletion: %s - %s", id, e)
|
||||
logger.warning(f"Invalid UUID format for soft deletion: {id} - {e!s}")
|
||||
return None
|
||||
|
||||
try:
|
||||
@@ -348,16 +374,18 @@ class BaseRepository[
|
||||
|
||||
if obj is None:
|
||||
logger.warning(
|
||||
"%s with id %s not found for soft deletion", self.model.__name__, id
|
||||
f"{self.model.__name__} with id {id} not found for soft deletion"
|
||||
)
|
||||
return None
|
||||
|
||||
# Check if model supports soft deletes
|
||||
if not hasattr(self.model, "deleted_at"):
|
||||
logger.error("%s does not support soft deletes", self.model.__name__)
|
||||
raise InvalidInputError(
|
||||
logger.error(f"{self.model.__name__} does not support soft deletes")
|
||||
raise ValueError(
|
||||
f"{self.model.__name__} does not have a deleted_at column"
|
||||
)
|
||||
|
||||
# Set deleted_at timestamp
|
||||
obj.deleted_at = datetime.now(UTC)
|
||||
db.add(obj)
|
||||
await db.commit()
|
||||
@@ -365,8 +393,9 @@ class BaseRepository[
|
||||
return obj
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.exception(
|
||||
"Error soft deleting %s with id %s: %s", self.model.__name__, id, e
|
||||
logger.error(
|
||||
f"Error soft deleting {self.model.__name__} with id {id}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -376,16 +405,18 @@ class BaseRepository[
|
||||
|
||||
Only works if the model has a 'deleted_at' column.
|
||||
"""
|
||||
# Validate UUID format
|
||||
try:
|
||||
if isinstance(id, uuid.UUID):
|
||||
uuid_obj = id
|
||||
else:
|
||||
uuid_obj = uuid.UUID(str(id))
|
||||
except (ValueError, AttributeError, TypeError) as e:
|
||||
logger.warning("Invalid UUID format for restoration: %s - %s", id, e)
|
||||
logger.warning(f"Invalid UUID format for restoration: {id} - {e!s}")
|
||||
return None
|
||||
|
||||
try:
|
||||
# Find the soft-deleted record
|
||||
if hasattr(self.model, "deleted_at"):
|
||||
result = await db.execute(
|
||||
select(self.model).where(
|
||||
@@ -394,19 +425,18 @@ class BaseRepository[
|
||||
)
|
||||
obj = result.scalar_one_or_none()
|
||||
else:
|
||||
logger.error("%s does not support soft deletes", self.model.__name__)
|
||||
raise InvalidInputError(
|
||||
logger.error(f"{self.model.__name__} does not support soft deletes")
|
||||
raise ValueError(
|
||||
f"{self.model.__name__} does not have a deleted_at column"
|
||||
)
|
||||
|
||||
if obj is None:
|
||||
logger.warning(
|
||||
"Soft-deleted %s with id %s not found for restoration",
|
||||
self.model.__name__,
|
||||
id,
|
||||
f"Soft-deleted {self.model.__name__} with id {id} not found for restoration"
|
||||
)
|
||||
return None
|
||||
|
||||
# Clear deleted_at timestamp
|
||||
obj.deleted_at = None
|
||||
db.add(obj)
|
||||
await db.commit()
|
||||
@@ -414,7 +444,8 @@ class BaseRepository[
|
||||
return obj
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.exception(
|
||||
"Error restoring %s with id %s: %s", self.model.__name__, id, e
|
||||
logger.error(
|
||||
f"Error restoring {self.model.__name__} with id {id}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
718
backend/app/crud/oauth.py
Executable file
718
backend/app/crud/oauth.py
Executable file
@@ -0,0 +1,718 @@
|
||||
"""
|
||||
Async CRUD operations for OAuth models using SQLAlchemy 2.0 patterns.
|
||||
|
||||
Provides operations for:
|
||||
- OAuthAccount: Managing linked OAuth provider accounts
|
||||
- OAuthState: CSRF protection state during OAuth flows
|
||||
- OAuthClient: Registered OAuth clients (provider mode skeleton)
|
||||
"""
|
||||
|
||||
import logging
|
||||
import secrets
|
||||
from datetime import UTC, datetime
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import and_, delete, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
from app.crud.base import CRUDBase
|
||||
from app.models.oauth_account import OAuthAccount
|
||||
from app.models.oauth_client import OAuthClient
|
||||
from app.models.oauth_state import OAuthState
|
||||
from app.schemas.oauth import OAuthAccountCreate, OAuthClientCreate, OAuthStateCreate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# OAuth Account CRUD
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class EmptySchema(BaseModel):
|
||||
"""Placeholder schema for CRUD operations that don't need update schemas."""
|
||||
|
||||
|
||||
class CRUDOAuthAccount(CRUDBase[OAuthAccount, OAuthAccountCreate, EmptySchema]):
|
||||
"""CRUD operations for OAuth account links."""
|
||||
|
||||
async def get_by_provider_id(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
provider: str,
|
||||
provider_user_id: str,
|
||||
) -> OAuthAccount | None:
|
||||
"""
|
||||
Get OAuth account by provider and provider user ID.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
provider: OAuth provider name (google, github)
|
||||
provider_user_id: User ID from the OAuth provider
|
||||
|
||||
Returns:
|
||||
OAuthAccount if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(OAuthAccount)
|
||||
.where(
|
||||
and_(
|
||||
OAuthAccount.provider == provider,
|
||||
OAuthAccount.provider_user_id == provider_user_id,
|
||||
)
|
||||
)
|
||||
.options(joinedload(OAuthAccount.user))
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e: # pragma: no cover # pragma: no cover
|
||||
logger.error(
|
||||
f"Error getting OAuth account for {provider}:{provider_user_id}: {e!s}"
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_by_provider_email(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
provider: str,
|
||||
email: str,
|
||||
) -> OAuthAccount | None:
|
||||
"""
|
||||
Get OAuth account by provider and email.
|
||||
|
||||
Used for auto-linking existing accounts by email.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
provider: OAuth provider name
|
||||
email: Email address from the OAuth provider
|
||||
|
||||
Returns:
|
||||
OAuthAccount if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(OAuthAccount)
|
||||
.where(
|
||||
and_(
|
||||
OAuthAccount.provider == provider,
|
||||
OAuthAccount.provider_email == email,
|
||||
)
|
||||
)
|
||||
.options(joinedload(OAuthAccount.user))
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e: # pragma: no cover # pragma: no cover
|
||||
logger.error(
|
||||
f"Error getting OAuth account for {provider} email {email}: {e!s}"
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_user_accounts(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: str | UUID,
|
||||
) -> list[OAuthAccount]:
|
||||
"""
|
||||
Get all OAuth accounts linked to a user.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
List of OAuthAccount objects
|
||||
"""
|
||||
try:
|
||||
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
|
||||
|
||||
result = await db.execute(
|
||||
select(OAuthAccount)
|
||||
.where(OAuthAccount.user_id == user_uuid)
|
||||
.order_by(OAuthAccount.created_at.desc())
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error(f"Error getting OAuth accounts for user {user_id}: {e!s}")
|
||||
raise
|
||||
|
||||
async def get_user_account_by_provider(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: str | UUID,
|
||||
provider: str,
|
||||
) -> OAuthAccount | None:
|
||||
"""
|
||||
Get a specific OAuth account for a user and provider.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_id: User ID
|
||||
provider: OAuth provider name
|
||||
|
||||
Returns:
|
||||
OAuthAccount if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
|
||||
|
||||
result = await db.execute(
|
||||
select(OAuthAccount).where(
|
||||
and_(
|
||||
OAuthAccount.user_id == user_uuid,
|
||||
OAuthAccount.provider == provider,
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error(
|
||||
f"Error getting OAuth account for user {user_id}, provider {provider}: {e!s}"
|
||||
)
|
||||
raise
|
||||
|
||||
async def create_account(
|
||||
self, db: AsyncSession, *, obj_in: OAuthAccountCreate
|
||||
) -> OAuthAccount:
|
||||
"""
|
||||
Create a new OAuth account link.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
obj_in: OAuth account creation data
|
||||
|
||||
Returns:
|
||||
Created OAuthAccount
|
||||
|
||||
Raises:
|
||||
ValueError: If account already exists or creation fails
|
||||
"""
|
||||
try:
|
||||
db_obj = OAuthAccount(
|
||||
user_id=obj_in.user_id,
|
||||
provider=obj_in.provider,
|
||||
provider_user_id=obj_in.provider_user_id,
|
||||
provider_email=obj_in.provider_email,
|
||||
access_token_encrypted=obj_in.access_token_encrypted,
|
||||
refresh_token_encrypted=obj_in.refresh_token_encrypted,
|
||||
token_expires_at=obj_in.token_expires_at,
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
|
||||
logger.info(
|
||||
f"OAuth account created: {obj_in.provider} linked to user {obj_in.user_id}"
|
||||
)
|
||||
return db_obj
|
||||
except IntegrityError as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
if "uq_oauth_provider_user" in error_msg.lower():
|
||||
logger.warning(
|
||||
f"OAuth account already exists: {obj_in.provider}:{obj_in.provider_user_id}"
|
||||
)
|
||||
raise ValueError(
|
||||
f"This {obj_in.provider} account is already linked to another user"
|
||||
)
|
||||
logger.error(f"Integrity error creating OAuth account: {error_msg}")
|
||||
raise ValueError(f"Failed to create OAuth account: {error_msg}")
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error(f"Error creating OAuth account: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def delete_account(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: str | UUID,
|
||||
provider: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Delete an OAuth account link.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_id: User ID
|
||||
provider: OAuth provider name
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
try:
|
||||
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
|
||||
|
||||
result = await db.execute(
|
||||
delete(OAuthAccount).where(
|
||||
and_(
|
||||
OAuthAccount.user_id == user_uuid,
|
||||
OAuthAccount.provider == provider,
|
||||
)
|
||||
)
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
deleted = result.rowcount > 0
|
||||
if deleted:
|
||||
logger.info(
|
||||
f"OAuth account deleted: {provider} unlinked from user {user_id}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"OAuth account not found for deletion: {provider} for user {user_id}"
|
||||
)
|
||||
|
||||
return deleted
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
f"Error deleting OAuth account {provider} for user {user_id}: {e!s}"
|
||||
)
|
||||
raise
|
||||
|
||||
async def update_tokens(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
account: OAuthAccount,
|
||||
access_token_encrypted: str | None = None,
|
||||
refresh_token_encrypted: str | None = None,
|
||||
token_expires_at: datetime | None = None,
|
||||
) -> OAuthAccount:
|
||||
"""
|
||||
Update OAuth tokens for an account.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
account: OAuthAccount to update
|
||||
access_token_encrypted: New encrypted access token
|
||||
refresh_token_encrypted: New encrypted refresh token
|
||||
token_expires_at: New token expiration time
|
||||
|
||||
Returns:
|
||||
Updated OAuthAccount
|
||||
"""
|
||||
try:
|
||||
if access_token_encrypted is not None:
|
||||
account.access_token_encrypted = access_token_encrypted
|
||||
if refresh_token_encrypted is not None:
|
||||
account.refresh_token_encrypted = refresh_token_encrypted
|
||||
if token_expires_at is not None:
|
||||
account.token_expires_at = token_expires_at
|
||||
|
||||
db.add(account)
|
||||
await db.commit()
|
||||
await db.refresh(account)
|
||||
|
||||
return account
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error(f"Error updating OAuth tokens: {e!s}")
|
||||
raise
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# OAuth State CRUD
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class CRUDOAuthState(CRUDBase[OAuthState, OAuthStateCreate, EmptySchema]):
|
||||
"""CRUD operations for OAuth state (CSRF protection)."""
|
||||
|
||||
async def create_state(
|
||||
self, db: AsyncSession, *, obj_in: OAuthStateCreate
|
||||
) -> OAuthState:
|
||||
"""
|
||||
Create a new OAuth state for CSRF protection.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
obj_in: OAuth state creation data
|
||||
|
||||
Returns:
|
||||
Created OAuthState
|
||||
"""
|
||||
try:
|
||||
db_obj = OAuthState(
|
||||
state=obj_in.state,
|
||||
code_verifier=obj_in.code_verifier,
|
||||
nonce=obj_in.nonce,
|
||||
provider=obj_in.provider,
|
||||
redirect_uri=obj_in.redirect_uri,
|
||||
user_id=obj_in.user_id,
|
||||
expires_at=obj_in.expires_at,
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
|
||||
logger.debug(f"OAuth state created for {obj_in.provider}")
|
||||
return db_obj
|
||||
except IntegrityError as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
# State collision (extremely rare with cryptographic random)
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
logger.error(f"OAuth state collision: {error_msg}")
|
||||
raise ValueError("Failed to create OAuth state, please retry")
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error(f"Error creating OAuth state: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_and_consume_state(
|
||||
self, db: AsyncSession, *, state: str
|
||||
) -> OAuthState | None:
|
||||
"""
|
||||
Get and delete OAuth state (consume it).
|
||||
|
||||
This ensures each state can only be used once (replay protection).
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
state: State string to look up
|
||||
|
||||
Returns:
|
||||
OAuthState if found and valid, None otherwise
|
||||
"""
|
||||
try:
|
||||
# Get the state
|
||||
result = await db.execute(
|
||||
select(OAuthState).where(OAuthState.state == state)
|
||||
)
|
||||
db_obj = result.scalar_one_or_none()
|
||||
|
||||
if db_obj is None:
|
||||
logger.warning(f"OAuth state not found: {state[:8]}...")
|
||||
return None
|
||||
|
||||
# Check expiration
|
||||
# Handle both timezone-aware and timezone-naive datetimes
|
||||
now = datetime.now(UTC)
|
||||
expires_at = db_obj.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
# SQLite returns naive datetimes, assume UTC
|
||||
expires_at = expires_at.replace(tzinfo=UTC)
|
||||
|
||||
if expires_at < now:
|
||||
logger.warning(f"OAuth state expired: {state[:8]}...")
|
||||
await db.delete(db_obj)
|
||||
await db.commit()
|
||||
return None
|
||||
|
||||
# Delete it (consume)
|
||||
await db.delete(db_obj)
|
||||
await db.commit()
|
||||
|
||||
logger.debug(f"OAuth state consumed: {state[:8]}...")
|
||||
return db_obj
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error(f"Error consuming OAuth state: {e!s}")
|
||||
raise
|
||||
|
||||
async def cleanup_expired(self, db: AsyncSession) -> int:
|
||||
"""
|
||||
Clean up expired OAuth states.
|
||||
|
||||
Should be called periodically to remove stale states.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Number of states deleted
|
||||
"""
|
||||
try:
|
||||
now = datetime.now(UTC)
|
||||
|
||||
stmt = delete(OAuthState).where(OAuthState.expires_at < now)
|
||||
result = await db.execute(stmt)
|
||||
await db.commit()
|
||||
|
||||
count = result.rowcount
|
||||
if count > 0:
|
||||
logger.info(f"Cleaned up {count} expired OAuth states")
|
||||
|
||||
return count
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error(f"Error cleaning up expired OAuth states: {e!s}")
|
||||
raise
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# OAuth Client CRUD (Provider Mode - Skeleton)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class CRUDOAuthClient(CRUDBase[OAuthClient, OAuthClientCreate, EmptySchema]):
|
||||
"""
|
||||
CRUD operations for OAuth clients (provider mode).
|
||||
|
||||
This is a skeleton implementation for MCP client registration.
|
||||
Full implementation can be expanded when needed.
|
||||
"""
|
||||
|
||||
async def get_by_client_id(
|
||||
self, db: AsyncSession, *, client_id: str
|
||||
) -> OAuthClient | None:
|
||||
"""
|
||||
Get OAuth client by client_id.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
client_id: OAuth client ID
|
||||
|
||||
Returns:
|
||||
OAuthClient if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(OAuthClient).where(
|
||||
and_(
|
||||
OAuthClient.client_id == client_id,
|
||||
OAuthClient.is_active == True, # noqa: E712
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error(f"Error getting OAuth client {client_id}: {e!s}")
|
||||
raise
|
||||
|
||||
async def create_client(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
obj_in: OAuthClientCreate,
|
||||
owner_user_id: UUID | None = None,
|
||||
) -> tuple[OAuthClient, str | None]:
|
||||
"""
|
||||
Create a new OAuth client.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
obj_in: OAuth client creation data
|
||||
owner_user_id: Optional owner user ID
|
||||
|
||||
Returns:
|
||||
Tuple of (created OAuthClient, client_secret or None for public clients)
|
||||
"""
|
||||
try:
|
||||
# Generate client_id
|
||||
client_id = secrets.token_urlsafe(32)
|
||||
|
||||
# Generate client_secret for confidential clients
|
||||
client_secret = None
|
||||
client_secret_hash = None
|
||||
if obj_in.client_type == "confidential":
|
||||
client_secret = secrets.token_urlsafe(48)
|
||||
# SECURITY: Use bcrypt for secret storage (not SHA-256)
|
||||
# bcrypt is computationally expensive, making brute-force attacks infeasible
|
||||
from app.core.auth import get_password_hash
|
||||
|
||||
client_secret_hash = get_password_hash(client_secret)
|
||||
|
||||
db_obj = OAuthClient(
|
||||
client_id=client_id,
|
||||
client_secret_hash=client_secret_hash,
|
||||
client_name=obj_in.client_name,
|
||||
client_description=obj_in.client_description,
|
||||
client_type=obj_in.client_type,
|
||||
redirect_uris=obj_in.redirect_uris,
|
||||
allowed_scopes=obj_in.allowed_scopes,
|
||||
owner_user_id=owner_user_id,
|
||||
is_active=True,
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
|
||||
logger.info(
|
||||
f"OAuth client created: {obj_in.client_name} ({client_id[:8]}...)"
|
||||
)
|
||||
return db_obj, client_secret
|
||||
except IntegrityError as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
logger.error(f"Error creating OAuth client: {error_msg}")
|
||||
raise ValueError(f"Failed to create OAuth client: {error_msg}")
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error(f"Error creating OAuth client: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def deactivate_client(
|
||||
self, db: AsyncSession, *, client_id: str
|
||||
) -> OAuthClient | None:
|
||||
"""
|
||||
Deactivate an OAuth client.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
client_id: OAuth client ID
|
||||
|
||||
Returns:
|
||||
Deactivated OAuthClient if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
client = await self.get_by_client_id(db, client_id=client_id)
|
||||
if client is None:
|
||||
return None
|
||||
|
||||
client.is_active = False
|
||||
db.add(client)
|
||||
await db.commit()
|
||||
await db.refresh(client)
|
||||
|
||||
logger.info(f"OAuth client deactivated: {client.client_name}")
|
||||
return client
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error(f"Error deactivating OAuth client {client_id}: {e!s}")
|
||||
raise
|
||||
|
||||
async def validate_redirect_uri(
|
||||
self, db: AsyncSession, *, client_id: str, redirect_uri: str
|
||||
) -> bool:
|
||||
"""
|
||||
Validate that a redirect URI is allowed for a client.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
client_id: OAuth client ID
|
||||
redirect_uri: Redirect URI to validate
|
||||
|
||||
Returns:
|
||||
True if valid, False otherwise
|
||||
"""
|
||||
try:
|
||||
client = await self.get_by_client_id(db, client_id=client_id)
|
||||
if client is None:
|
||||
return False
|
||||
|
||||
return redirect_uri in (client.redirect_uris or [])
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error(f"Error validating redirect URI: {e!s}")
|
||||
return False
|
||||
|
||||
async def verify_client_secret(
|
||||
self, db: AsyncSession, *, client_id: str, client_secret: str
|
||||
) -> bool:
|
||||
"""
|
||||
Verify client credentials.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
client_id: OAuth client ID
|
||||
client_secret: Client secret to verify
|
||||
|
||||
Returns:
|
||||
True if valid, False otherwise
|
||||
"""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(OAuthClient).where(
|
||||
and_(
|
||||
OAuthClient.client_id == client_id,
|
||||
OAuthClient.is_active == True, # noqa: E712
|
||||
)
|
||||
)
|
||||
)
|
||||
client = result.scalar_one_or_none()
|
||||
|
||||
if client is None or client.client_secret_hash is None:
|
||||
return False
|
||||
|
||||
# SECURITY: Verify secret using bcrypt (not SHA-256)
|
||||
# This supports both old SHA-256 hashes (for migration) and new bcrypt hashes
|
||||
from app.core.auth import verify_password
|
||||
|
||||
stored_hash: str = str(client.client_secret_hash)
|
||||
|
||||
# Check if it's a bcrypt hash (starts with $2b$) or legacy SHA-256
|
||||
if stored_hash.startswith("$2"):
|
||||
# New bcrypt format
|
||||
return verify_password(client_secret, stored_hash)
|
||||
else:
|
||||
# Legacy SHA-256 format - still support for migration
|
||||
import hashlib
|
||||
|
||||
secret_hash = hashlib.sha256(client_secret.encode()).hexdigest()
|
||||
return secrets.compare_digest(stored_hash, secret_hash)
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error(f"Error verifying client secret: {e!s}")
|
||||
return False
|
||||
|
||||
async def get_all_clients(
|
||||
self, db: AsyncSession, *, include_inactive: bool = False
|
||||
) -> list[OAuthClient]:
|
||||
"""
|
||||
Get all OAuth clients.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
include_inactive: Whether to include inactive clients
|
||||
|
||||
Returns:
|
||||
List of OAuthClient objects
|
||||
"""
|
||||
try:
|
||||
query = select(OAuthClient).order_by(OAuthClient.created_at.desc())
|
||||
if not include_inactive:
|
||||
query = query.where(OAuthClient.is_active == True) # noqa: E712
|
||||
|
||||
result = await db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error(f"Error getting all OAuth clients: {e!s}")
|
||||
raise
|
||||
|
||||
async def delete_client(self, db: AsyncSession, *, client_id: str) -> bool:
|
||||
"""
|
||||
Delete an OAuth client permanently.
|
||||
|
||||
Note: This will cascade delete related records (tokens, consents, etc.)
|
||||
due to foreign key constraints.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
client_id: OAuth client ID
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
try:
|
||||
result = await db.execute(
|
||||
delete(OAuthClient).where(OAuthClient.client_id == client_id)
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
deleted = result.rowcount > 0
|
||||
if deleted:
|
||||
logger.info(f"OAuth client deleted: {client_id}")
|
||||
else:
|
||||
logger.warning(f"OAuth client not found for deletion: {client_id}")
|
||||
|
||||
return deleted
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error(f"Error deleting OAuth client {client_id}: {e!s}")
|
||||
raise
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Singleton instances
|
||||
# ============================================================================
|
||||
|
||||
oauth_account = CRUDOAuthAccount(OAuthAccount)
|
||||
oauth_state = CRUDOAuthState(OAuthState)
|
||||
oauth_client = CRUDOAuthClient(OAuthClient)
|
||||
128
backend/app/repositories/organization.py → backend/app/crud/organization.py
Normal file → Executable file
128
backend/app/repositories/organization.py → backend/app/crud/organization.py
Normal file → Executable file
@@ -1,5 +1,5 @@
|
||||
# app/repositories/organization.py
|
||||
"""Repository for Organization model async database operations using SQLAlchemy 2.0 patterns."""
|
||||
# app/crud/organization_async.py
|
||||
"""Async CRUD operations for Organization model using SQLAlchemy 2.0 patterns."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
@@ -9,11 +9,10 @@ from sqlalchemy import and_, case, func, or_, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.repository_exceptions import DuplicateEntryError, IntegrityConstraintError
|
||||
from app.crud.base import CRUDBase
|
||||
from app.models.organization import Organization
|
||||
from app.models.user import User
|
||||
from app.models.user_organization import OrganizationRole, UserOrganization
|
||||
from app.repositories.base import BaseRepository
|
||||
from app.schemas.organizations import (
|
||||
OrganizationCreate,
|
||||
OrganizationUpdate,
|
||||
@@ -22,10 +21,8 @@ from app.schemas.organizations import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OrganizationRepository(
|
||||
BaseRepository[Organization, OrganizationCreate, OrganizationUpdate]
|
||||
):
|
||||
"""Repository for Organization model."""
|
||||
class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUpdate]):
|
||||
"""Async CRUD operations for Organization model."""
|
||||
|
||||
async def get_by_slug(self, db: AsyncSession, *, slug: str) -> Organization | None:
|
||||
"""Get organization by slug."""
|
||||
@@ -35,7 +32,7 @@ class OrganizationRepository(
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error("Error getting organization by slug %s: %s", slug, e)
|
||||
logger.error(f"Error getting organization by slug {slug}: {e!s}")
|
||||
raise
|
||||
|
||||
async def create(
|
||||
@@ -57,20 +54,18 @@ class OrganizationRepository(
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
if (
|
||||
"slug" in error_msg.lower()
|
||||
or "unique" in error_msg.lower()
|
||||
or "duplicate" in error_msg.lower()
|
||||
):
|
||||
logger.warning("Duplicate slug attempted: %s", obj_in.slug)
|
||||
raise DuplicateEntryError(
|
||||
if "slug" in error_msg.lower():
|
||||
logger.warning(f"Duplicate slug attempted: {obj_in.slug}")
|
||||
raise ValueError(
|
||||
f"Organization with slug '{obj_in.slug}' already exists"
|
||||
)
|
||||
logger.error("Integrity error creating organization: %s", error_msg)
|
||||
raise IntegrityConstraintError(f"Database integrity error: {error_msg}")
|
||||
logger.error(f"Integrity error creating organization: {error_msg}")
|
||||
raise ValueError(f"Database integrity error: {error_msg}")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.exception("Unexpected error creating organization: %s", e)
|
||||
logger.error(
|
||||
f"Unexpected error creating organization: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_multi_with_filters(
|
||||
@@ -84,10 +79,16 @@ class OrganizationRepository(
|
||||
sort_by: str = "created_at",
|
||||
sort_order: str = "desc",
|
||||
) -> tuple[list[Organization], int]:
|
||||
"""Get multiple organizations with filtering, searching, and sorting."""
|
||||
"""
|
||||
Get multiple organizations with filtering, searching, and sorting.
|
||||
|
||||
Returns:
|
||||
Tuple of (organizations list, total count)
|
||||
"""
|
||||
try:
|
||||
query = select(Organization)
|
||||
|
||||
# Apply filters
|
||||
if is_active is not None:
|
||||
query = query.where(Organization.is_active == is_active)
|
||||
|
||||
@@ -99,23 +100,26 @@ class OrganizationRepository(
|
||||
)
|
||||
query = query.where(search_filter)
|
||||
|
||||
# Get total count before pagination
|
||||
count_query = select(func.count()).select_from(query.alias())
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# Apply sorting
|
||||
sort_column = getattr(Organization, sort_by, Organization.created_at)
|
||||
if sort_order == "desc":
|
||||
query = query.order_by(sort_column.desc())
|
||||
else:
|
||||
query = query.order_by(sort_column.asc())
|
||||
|
||||
# Apply pagination
|
||||
query = query.offset(skip).limit(limit)
|
||||
result = await db.execute(query)
|
||||
organizations = list(result.scalars().all())
|
||||
|
||||
return organizations, total
|
||||
except Exception as e:
|
||||
logger.error("Error getting organizations with filters: %s", e)
|
||||
logger.error(f"Error getting organizations with filters: {e!s}")
|
||||
raise
|
||||
|
||||
async def get_member_count(self, db: AsyncSession, *, organization_id: UUID) -> int:
|
||||
@@ -132,7 +136,7 @@ class OrganizationRepository(
|
||||
return result.scalar_one() or 0
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error getting member count for organization %s: %s", organization_id, e
|
||||
f"Error getting member count for organization {organization_id}: {e!s}"
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -145,8 +149,16 @@ class OrganizationRepository(
|
||||
is_active: bool | None = None,
|
||||
search: str | None = None,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""Get organizations with member counts in a SINGLE QUERY using JOIN and GROUP BY."""
|
||||
"""
|
||||
Get organizations with member counts in a SINGLE QUERY using JOIN and GROUP BY.
|
||||
This eliminates the N+1 query problem.
|
||||
|
||||
Returns:
|
||||
Tuple of (list of dicts with org and member_count, total count)
|
||||
"""
|
||||
try:
|
||||
# Build base query with LEFT JOIN and GROUP BY
|
||||
# Use CASE statement to count only active members
|
||||
query = (
|
||||
select(
|
||||
Organization,
|
||||
@@ -169,10 +181,10 @@ class OrganizationRepository(
|
||||
.group_by(Organization.id)
|
||||
)
|
||||
|
||||
# Apply filters
|
||||
if is_active is not None:
|
||||
query = query.where(Organization.is_active == is_active)
|
||||
|
||||
search_filter = None
|
||||
if search:
|
||||
search_filter = or_(
|
||||
Organization.name.ilike(f"%{search}%"),
|
||||
@@ -181,15 +193,17 @@ class OrganizationRepository(
|
||||
)
|
||||
query = query.where(search_filter)
|
||||
|
||||
# Get total count
|
||||
count_query = select(func.count(Organization.id))
|
||||
if is_active is not None:
|
||||
count_query = count_query.where(Organization.is_active == is_active)
|
||||
if search_filter is not None:
|
||||
if search:
|
||||
count_query = count_query.where(search_filter)
|
||||
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# Apply pagination and ordering
|
||||
query = (
|
||||
query.order_by(Organization.created_at.desc()).offset(skip).limit(limit)
|
||||
)
|
||||
@@ -197,6 +211,7 @@ class OrganizationRepository(
|
||||
result = await db.execute(query)
|
||||
rows = result.all()
|
||||
|
||||
# Convert to list of dicts
|
||||
orgs_with_counts = [
|
||||
{"organization": org, "member_count": member_count}
|
||||
for org, member_count in rows
|
||||
@@ -205,7 +220,9 @@ class OrganizationRepository(
|
||||
return orgs_with_counts, total
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error getting organizations with member counts: %s", e)
|
||||
logger.error(
|
||||
f"Error getting organizations with member counts: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def add_user(
|
||||
@@ -219,6 +236,7 @@ class OrganizationRepository(
|
||||
) -> UserOrganization:
|
||||
"""Add a user to an organization with a specific role."""
|
||||
try:
|
||||
# Check if relationship already exists
|
||||
result = await db.execute(
|
||||
select(UserOrganization).where(
|
||||
and_(
|
||||
@@ -230,6 +248,7 @@ class OrganizationRepository(
|
||||
existing = result.scalar_one_or_none()
|
||||
|
||||
if existing:
|
||||
# Reactivate if inactive, or raise error if already active
|
||||
if not existing.is_active:
|
||||
existing.is_active = True
|
||||
existing.role = role
|
||||
@@ -238,10 +257,9 @@ class OrganizationRepository(
|
||||
await db.refresh(existing)
|
||||
return existing
|
||||
else:
|
||||
raise DuplicateEntryError(
|
||||
"User is already a member of this organization"
|
||||
)
|
||||
raise ValueError("User is already a member of this organization")
|
||||
|
||||
# Create new relationship
|
||||
user_org = UserOrganization(
|
||||
user_id=user_id,
|
||||
organization_id=organization_id,
|
||||
@@ -255,11 +273,11 @@ class OrganizationRepository(
|
||||
return user_org
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
logger.error("Integrity error adding user to organization: %s", e)
|
||||
raise IntegrityConstraintError("Failed to add user to organization")
|
||||
logger.error(f"Integrity error adding user to organization: {e!s}")
|
||||
raise ValueError("Failed to add user to organization")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.exception("Error adding user to organization: %s", e)
|
||||
logger.error(f"Error adding user to organization: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def remove_user(
|
||||
@@ -285,7 +303,7 @@ class OrganizationRepository(
|
||||
return True
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.exception("Error removing user from organization: %s", e)
|
||||
logger.error(f"Error removing user from organization: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def update_user_role(
|
||||
@@ -320,7 +338,7 @@ class OrganizationRepository(
|
||||
return user_org
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.exception("Error updating user role: %s", e)
|
||||
logger.error(f"Error updating user role: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_organization_members(
|
||||
@@ -330,10 +348,16 @@ class OrganizationRepository(
|
||||
organization_id: UUID,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
is_active: bool | None = True,
|
||||
is_active: bool = True,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""Get members of an organization with user details."""
|
||||
"""
|
||||
Get members of an organization with user details.
|
||||
|
||||
Returns:
|
||||
Tuple of (members list with user details, total count)
|
||||
"""
|
||||
try:
|
||||
# Build query with join
|
||||
query = (
|
||||
select(UserOrganization, User)
|
||||
.join(User, UserOrganization.user_id == User.id)
|
||||
@@ -343,6 +367,7 @@ class OrganizationRepository(
|
||||
if is_active is not None:
|
||||
query = query.where(UserOrganization.is_active == is_active)
|
||||
|
||||
# Get total count
|
||||
count_query = select(func.count()).select_from(
|
||||
select(UserOrganization)
|
||||
.where(UserOrganization.organization_id == organization_id)
|
||||
@@ -356,6 +381,7 @@ class OrganizationRepository(
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# Apply ordering and pagination
|
||||
query = (
|
||||
query.order_by(UserOrganization.created_at.desc())
|
||||
.offset(skip)
|
||||
@@ -380,11 +406,11 @@ class OrganizationRepository(
|
||||
|
||||
return members, total
|
||||
except Exception as e:
|
||||
logger.error("Error getting organization members: %s", e)
|
||||
logger.error(f"Error getting organization members: {e!s}")
|
||||
raise
|
||||
|
||||
async def get_user_organizations(
|
||||
self, db: AsyncSession, *, user_id: UUID, is_active: bool | None = True
|
||||
self, db: AsyncSession, *, user_id: UUID, is_active: bool = True
|
||||
) -> list[Organization]:
|
||||
"""Get all organizations a user belongs to."""
|
||||
try:
|
||||
@@ -403,14 +429,21 @@ class OrganizationRepository(
|
||||
result = await db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e:
|
||||
logger.error("Error getting user organizations: %s", e)
|
||||
logger.error(f"Error getting user organizations: {e!s}")
|
||||
raise
|
||||
|
||||
async def get_user_organizations_with_details(
|
||||
self, db: AsyncSession, *, user_id: UUID, is_active: bool | None = True
|
||||
self, db: AsyncSession, *, user_id: UUID, is_active: bool = True
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get user's organizations with role and member count in SINGLE QUERY."""
|
||||
"""
|
||||
Get user's organizations with role and member count in SINGLE QUERY.
|
||||
Eliminates N+1 problem by using subquery for member counts.
|
||||
|
||||
Returns:
|
||||
List of dicts with organization, role, and member_count
|
||||
"""
|
||||
try:
|
||||
# Subquery to get member counts for each organization
|
||||
member_count_subq = (
|
||||
select(
|
||||
UserOrganization.organization_id,
|
||||
@@ -421,6 +454,7 @@ class OrganizationRepository(
|
||||
.subquery()
|
||||
)
|
||||
|
||||
# Main query with JOIN to get org, role, and member count
|
||||
query = (
|
||||
select(
|
||||
Organization,
|
||||
@@ -452,7 +486,9 @@ class OrganizationRepository(
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error getting user organizations with details: %s", e)
|
||||
logger.error(
|
||||
f"Error getting user organizations with details: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_user_role_in_org(
|
||||
@@ -471,9 +507,9 @@ class OrganizationRepository(
|
||||
)
|
||||
user_org = result.scalar_one_or_none()
|
||||
|
||||
return user_org.role if user_org else None # pyright: ignore[reportReturnType]
|
||||
return user_org.role if user_org else None
|
||||
except Exception as e:
|
||||
logger.error("Error getting user role in org: %s", e)
|
||||
logger.error(f"Error getting user role in org: {e!s}")
|
||||
raise
|
||||
|
||||
async def is_user_org_owner(
|
||||
@@ -495,5 +531,5 @@ class OrganizationRepository(
|
||||
return role in [OrganizationRole.OWNER, OrganizationRole.ADMIN]
|
||||
|
||||
|
||||
# Singleton instance
|
||||
organization_repo = OrganizationRepository(Organization)
|
||||
# Create a singleton instance for use across the application
|
||||
organization = CRUDOrganization(Organization)
|
||||
231
backend/app/repositories/session.py → backend/app/crud/session.py
Normal file → Executable file
231
backend/app/repositories/session.py → backend/app/crud/session.py
Normal file → Executable file
@@ -1,5 +1,6 @@
|
||||
# app/repositories/session.py
|
||||
"""Repository for UserSession model async database operations using SQLAlchemy 2.0 patterns."""
|
||||
"""
|
||||
Async CRUD operations for user sessions using SQLAlchemy 2.0 patterns.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
@@ -10,32 +11,49 @@ from sqlalchemy import and_, delete, func, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
from app.core.repository_exceptions import IntegrityConstraintError, InvalidInputError
|
||||
from app.crud.base import CRUDBase
|
||||
from app.models.user_session import UserSession
|
||||
from app.repositories.base import BaseRepository
|
||||
from app.schemas.sessions import SessionCreate, SessionUpdate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SessionRepository(BaseRepository[UserSession, SessionCreate, SessionUpdate]):
|
||||
"""Repository for UserSession model."""
|
||||
class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
"""Async CRUD operations for user sessions."""
|
||||
|
||||
async def get_by_jti(self, db: AsyncSession, *, jti: str) -> UserSession | None:
|
||||
"""Get session by refresh token JTI."""
|
||||
"""
|
||||
Get session by refresh token JTI.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
jti: Refresh token JWT ID
|
||||
|
||||
Returns:
|
||||
UserSession if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(UserSession).where(UserSession.refresh_token_jti == jti)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error("Error getting session by JTI %s: %s", jti, e)
|
||||
logger.error(f"Error getting session by JTI {jti}: {e!s}")
|
||||
raise
|
||||
|
||||
async def get_active_by_jti(
|
||||
self, db: AsyncSession, *, jti: str
|
||||
) -> UserSession | None:
|
||||
"""Get active session by refresh token JTI."""
|
||||
"""
|
||||
Get active session by refresh token JTI.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
jti: Refresh token JWT ID
|
||||
|
||||
Returns:
|
||||
Active UserSession if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(UserSession).where(
|
||||
@@ -47,7 +65,7 @@ class SessionRepository(BaseRepository[UserSession, SessionCreate, SessionUpdate
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error("Error getting active session by JTI %s: %s", jti, e)
|
||||
logger.error(f"Error getting active session by JTI {jti}: {e!s}")
|
||||
raise
|
||||
|
||||
async def get_user_sessions(
|
||||
@@ -58,12 +76,25 @@ class SessionRepository(BaseRepository[UserSession, SessionCreate, SessionUpdate
|
||||
active_only: bool = True,
|
||||
with_user: bool = False,
|
||||
) -> list[UserSession]:
|
||||
"""Get all sessions for a user with optional eager loading."""
|
||||
"""
|
||||
Get all sessions for a user with optional eager loading.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_id: User ID
|
||||
active_only: If True, return only active sessions
|
||||
with_user: If True, eager load user relationship to prevent N+1
|
||||
|
||||
Returns:
|
||||
List of UserSession objects
|
||||
"""
|
||||
try:
|
||||
# Convert user_id string to UUID if needed
|
||||
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
|
||||
|
||||
query = select(UserSession).where(UserSession.user_id == user_uuid)
|
||||
|
||||
# Add eager loading if requested to prevent N+1 queries
|
||||
if with_user:
|
||||
query = query.options(joinedload(UserSession.user))
|
||||
|
||||
@@ -74,13 +105,25 @@ class SessionRepository(BaseRepository[UserSession, SessionCreate, SessionUpdate
|
||||
result = await db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e:
|
||||
logger.error("Error getting sessions for user %s: %s", user_id, e)
|
||||
logger.error(f"Error getting sessions for user {user_id}: {e!s}")
|
||||
raise
|
||||
|
||||
async def create_session(
|
||||
self, db: AsyncSession, *, obj_in: SessionCreate
|
||||
) -> UserSession:
|
||||
"""Create a new user session."""
|
||||
"""
|
||||
Create a new user session.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
obj_in: SessionCreate schema with session data
|
||||
|
||||
Returns:
|
||||
Created UserSession
|
||||
|
||||
Raises:
|
||||
ValueError: If session creation fails
|
||||
"""
|
||||
try:
|
||||
db_obj = UserSession(
|
||||
user_id=obj_in.user_id,
|
||||
@@ -100,26 +143,33 @@ class SessionRepository(BaseRepository[UserSession, SessionCreate, SessionUpdate
|
||||
await db.refresh(db_obj)
|
||||
|
||||
logger.info(
|
||||
"Session created for user %s from %s (IP: %s)",
|
||||
obj_in.user_id,
|
||||
obj_in.device_name,
|
||||
obj_in.ip_address,
|
||||
f"Session created for user {obj_in.user_id} from {obj_in.device_name} "
|
||||
f"(IP: {obj_in.ip_address})"
|
||||
)
|
||||
|
||||
return db_obj
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.exception("Error creating session: %s", e)
|
||||
raise IntegrityConstraintError(f"Failed to create session: {e!s}")
|
||||
logger.error(f"Error creating session: {e!s}", exc_info=True)
|
||||
raise ValueError(f"Failed to create session: {e!s}")
|
||||
|
||||
async def deactivate(
|
||||
self, db: AsyncSession, *, session_id: str
|
||||
) -> UserSession | None:
|
||||
"""Deactivate a session (logout from device)."""
|
||||
"""
|
||||
Deactivate a session (logout from device).
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
session_id: Session UUID
|
||||
|
||||
Returns:
|
||||
Deactivated UserSession if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
session = await self.get(db, id=session_id)
|
||||
if not session:
|
||||
logger.warning("Session %s not found for deactivation", session_id)
|
||||
logger.warning(f"Session {session_id} not found for deactivation")
|
||||
return None
|
||||
|
||||
session.is_active = False
|
||||
@@ -128,23 +178,31 @@ class SessionRepository(BaseRepository[UserSession, SessionCreate, SessionUpdate
|
||||
await db.refresh(session)
|
||||
|
||||
logger.info(
|
||||
"Session %s deactivated for user %s (%s)",
|
||||
session_id,
|
||||
session.user_id,
|
||||
session.device_name,
|
||||
f"Session {session_id} deactivated for user {session.user_id} "
|
||||
f"({session.device_name})"
|
||||
)
|
||||
|
||||
return session
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error("Error deactivating session %s: %s", session_id, e)
|
||||
logger.error(f"Error deactivating session {session_id}: {e!s}")
|
||||
raise
|
||||
|
||||
async def deactivate_all_user_sessions(
|
||||
self, db: AsyncSession, *, user_id: str
|
||||
) -> int:
|
||||
"""Deactivate all active sessions for a user (logout from all devices)."""
|
||||
"""
|
||||
Deactivate all active sessions for a user (logout from all devices).
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
Number of sessions deactivated
|
||||
"""
|
||||
try:
|
||||
# Convert user_id string to UUID if needed
|
||||
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
|
||||
|
||||
stmt = (
|
||||
@@ -158,18 +216,27 @@ class SessionRepository(BaseRepository[UserSession, SessionCreate, SessionUpdate
|
||||
|
||||
count = result.rowcount
|
||||
|
||||
logger.info("Deactivated %s sessions for user %s", count, user_id)
|
||||
logger.info(f"Deactivated {count} sessions for user {user_id}")
|
||||
|
||||
return count
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error("Error deactivating all sessions for user %s: %s", user_id, e)
|
||||
logger.error(f"Error deactivating all sessions for user {user_id}: {e!s}")
|
||||
raise
|
||||
|
||||
async def update_last_used(
|
||||
self, db: AsyncSession, *, session: UserSession
|
||||
) -> UserSession:
|
||||
"""Update the last_used_at timestamp for a session."""
|
||||
"""
|
||||
Update the last_used_at timestamp for a session.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
session: UserSession object
|
||||
|
||||
Returns:
|
||||
Updated UserSession
|
||||
"""
|
||||
try:
|
||||
session.last_used_at = datetime.now(UTC)
|
||||
db.add(session)
|
||||
@@ -178,7 +245,7 @@ class SessionRepository(BaseRepository[UserSession, SessionCreate, SessionUpdate
|
||||
return session
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error("Error updating last_used for session %s: %s", session.id, e)
|
||||
logger.error(f"Error updating last_used for session {session.id}: {e!s}")
|
||||
raise
|
||||
|
||||
async def update_refresh_token(
|
||||
@@ -189,7 +256,20 @@ class SessionRepository(BaseRepository[UserSession, SessionCreate, SessionUpdate
|
||||
new_jti: str,
|
||||
new_expires_at: datetime,
|
||||
) -> UserSession:
|
||||
"""Update session with new refresh token JTI and expiration."""
|
||||
"""
|
||||
Update session with new refresh token JTI and expiration.
|
||||
|
||||
Called during token refresh.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
session: UserSession object
|
||||
new_jti: New refresh token JTI
|
||||
new_expires_at: New expiration datetime
|
||||
|
||||
Returns:
|
||||
Updated UserSession
|
||||
"""
|
||||
try:
|
||||
session.refresh_token_jti = new_jti
|
||||
session.expires_at = new_expires_at
|
||||
@@ -201,16 +281,32 @@ class SessionRepository(BaseRepository[UserSession, SessionCreate, SessionUpdate
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
"Error updating refresh token for session %s: %s", session.id, e
|
||||
f"Error updating refresh token for session {session.id}: {e!s}"
|
||||
)
|
||||
raise
|
||||
|
||||
async def cleanup_expired(self, db: AsyncSession, *, keep_days: int = 30) -> int:
|
||||
"""Clean up expired sessions using optimized bulk DELETE."""
|
||||
"""
|
||||
Clean up expired sessions using optimized bulk DELETE.
|
||||
|
||||
Deletes sessions that are:
|
||||
- Expired AND inactive
|
||||
- Older than keep_days
|
||||
|
||||
Uses single DELETE query instead of N individual deletes for efficiency.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
keep_days: Keep inactive sessions for this many days (for audit)
|
||||
|
||||
Returns:
|
||||
Number of sessions deleted
|
||||
"""
|
||||
try:
|
||||
cutoff_date = datetime.now(UTC) - timedelta(days=keep_days)
|
||||
now = datetime.now(UTC)
|
||||
|
||||
# Use bulk DELETE with WHERE clause - single query
|
||||
stmt = delete(UserSession).where(
|
||||
and_(
|
||||
UserSession.is_active == False, # noqa: E712
|
||||
@@ -225,25 +321,38 @@ class SessionRepository(BaseRepository[UserSession, SessionCreate, SessionUpdate
|
||||
count = result.rowcount
|
||||
|
||||
if count > 0:
|
||||
logger.info("Cleaned up %s expired sessions using bulk DELETE", count)
|
||||
logger.info(f"Cleaned up {count} expired sessions using bulk DELETE")
|
||||
|
||||
return count
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error("Error cleaning up expired sessions: %s", e)
|
||||
logger.error(f"Error cleaning up expired sessions: {e!s}")
|
||||
raise
|
||||
|
||||
async def cleanup_expired_for_user(self, db: AsyncSession, *, user_id: str) -> int:
|
||||
"""Clean up expired and inactive sessions for a specific user."""
|
||||
"""
|
||||
Clean up expired and inactive sessions for a specific user.
|
||||
|
||||
Uses single bulk DELETE query for efficiency instead of N individual deletes.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_id: User ID to cleanup sessions for
|
||||
|
||||
Returns:
|
||||
Number of sessions deleted
|
||||
"""
|
||||
try:
|
||||
# Validate UUID
|
||||
try:
|
||||
uuid_obj = uuid.UUID(user_id)
|
||||
except (ValueError, AttributeError):
|
||||
logger.error("Invalid UUID format: %s", user_id)
|
||||
raise InvalidInputError(f"Invalid user ID format: {user_id}")
|
||||
logger.error(f"Invalid UUID format: {user_id}")
|
||||
raise ValueError(f"Invalid user ID format: {user_id}")
|
||||
|
||||
now = datetime.now(UTC)
|
||||
|
||||
# Use bulk DELETE with WHERE clause - single query
|
||||
stmt = delete(UserSession).where(
|
||||
and_(
|
||||
UserSession.user_id == uuid_obj,
|
||||
@@ -259,22 +368,30 @@ class SessionRepository(BaseRepository[UserSession, SessionCreate, SessionUpdate
|
||||
|
||||
if count > 0:
|
||||
logger.info(
|
||||
"Cleaned up %s expired sessions for user %s using bulk DELETE",
|
||||
count,
|
||||
user_id,
|
||||
f"Cleaned up {count} expired sessions for user {user_id} using bulk DELETE"
|
||||
)
|
||||
|
||||
return count
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
"Error cleaning up expired sessions for user %s: %s", user_id, e
|
||||
f"Error cleaning up expired sessions for user {user_id}: {e!s}"
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_user_session_count(self, db: AsyncSession, *, user_id: str) -> int:
|
||||
"""Get count of active sessions for a user."""
|
||||
"""
|
||||
Get count of active sessions for a user.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
Number of active sessions
|
||||
"""
|
||||
try:
|
||||
# Convert user_id string to UUID if needed
|
||||
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
|
||||
|
||||
result = await db.execute(
|
||||
@@ -284,7 +401,7 @@ class SessionRepository(BaseRepository[UserSession, SessionCreate, SessionUpdate
|
||||
)
|
||||
return result.scalar_one()
|
||||
except Exception as e:
|
||||
logger.error("Error counting sessions for user %s: %s", user_id, e)
|
||||
logger.error(f"Error counting sessions for user {user_id}: {e!s}")
|
||||
raise
|
||||
|
||||
async def get_all_sessions(
|
||||
@@ -296,16 +413,31 @@ class SessionRepository(BaseRepository[UserSession, SessionCreate, SessionUpdate
|
||||
active_only: bool = True,
|
||||
with_user: bool = True,
|
||||
) -> tuple[list[UserSession], int]:
|
||||
"""Get all sessions across all users with pagination (admin only)."""
|
||||
"""
|
||||
Get all sessions across all users with pagination (admin only).
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
skip: Number of records to skip
|
||||
limit: Maximum number of records to return
|
||||
active_only: If True, return only active sessions
|
||||
with_user: If True, eager load user relationship to prevent N+1
|
||||
|
||||
Returns:
|
||||
Tuple of (list of UserSession objects, total count)
|
||||
"""
|
||||
try:
|
||||
# Build query
|
||||
query = select(UserSession)
|
||||
|
||||
# Add eager loading if requested to prevent N+1 queries
|
||||
if with_user:
|
||||
query = query.options(joinedload(UserSession.user))
|
||||
|
||||
if active_only:
|
||||
query = query.where(UserSession.is_active)
|
||||
|
||||
# Get total count
|
||||
count_query = select(func.count(UserSession.id))
|
||||
if active_only:
|
||||
count_query = count_query.where(UserSession.is_active)
|
||||
@@ -313,6 +445,7 @@ class SessionRepository(BaseRepository[UserSession, SessionCreate, SessionUpdate
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# Apply pagination and ordering
|
||||
query = (
|
||||
query.order_by(UserSession.last_used_at.desc())
|
||||
.offset(skip)
|
||||
@@ -325,9 +458,9 @@ class SessionRepository(BaseRepository[UserSession, SessionCreate, SessionUpdate
|
||||
return sessions, total
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error getting all sessions: %s", e)
|
||||
logger.error(f"Error getting all sessions: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
# Singleton instance
|
||||
session_repo = SessionRepository(UserSession)
|
||||
# Create singleton instance
|
||||
session = CRUDSession(UserSession)
|
||||
20
backend/app/crud/syndarix/__init__.py
Normal file
20
backend/app/crud/syndarix/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# app/crud/syndarix/__init__.py
|
||||
"""
|
||||
Syndarix CRUD operations.
|
||||
|
||||
This package contains CRUD operations for all Syndarix domain entities.
|
||||
"""
|
||||
|
||||
from .agent_instance import agent_instance
|
||||
from .agent_type import agent_type
|
||||
from .issue import issue
|
||||
from .project import project
|
||||
from .sprint import sprint
|
||||
|
||||
__all__ = [
|
||||
"agent_instance",
|
||||
"agent_type",
|
||||
"issue",
|
||||
"project",
|
||||
"sprint",
|
||||
]
|
||||
346
backend/app/crud/syndarix/agent_instance.py
Normal file
346
backend/app/crud/syndarix/agent_instance.py
Normal file
@@ -0,0 +1,346 @@
|
||||
# app/crud/syndarix/agent_instance.py
|
||||
"""Async CRUD operations for AgentInstance model using SQLAlchemy 2.0 patterns."""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from decimal import Decimal
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func, select, update
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
from app.crud.base import CRUDBase
|
||||
from app.models.syndarix import AgentInstance, Issue
|
||||
from app.models.syndarix.enums import AgentStatus
|
||||
from app.schemas.syndarix import AgentInstanceCreate, AgentInstanceUpdate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CRUDAgentInstance(CRUDBase[AgentInstance, AgentInstanceCreate, AgentInstanceUpdate]):
|
||||
"""Async CRUD operations for AgentInstance model."""
|
||||
|
||||
async def create(
|
||||
self, db: AsyncSession, *, obj_in: AgentInstanceCreate
|
||||
) -> AgentInstance:
|
||||
"""Create a new agent instance with error handling."""
|
||||
try:
|
||||
db_obj = AgentInstance(
|
||||
agent_type_id=obj_in.agent_type_id,
|
||||
project_id=obj_in.project_id,
|
||||
status=obj_in.status,
|
||||
current_task=obj_in.current_task,
|
||||
short_term_memory=obj_in.short_term_memory,
|
||||
long_term_memory_ref=obj_in.long_term_memory_ref,
|
||||
session_id=obj_in.session_id,
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
logger.error(f"Integrity error creating agent instance: {error_msg}")
|
||||
raise ValueError(f"Database integrity error: {error_msg}")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
f"Unexpected error creating agent instance: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_with_details(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
instance_id: UUID,
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
Get an agent instance with full details including related entities.
|
||||
|
||||
Returns:
|
||||
Dictionary with instance and related entity details
|
||||
"""
|
||||
try:
|
||||
# Get instance with joined relationships
|
||||
result = await db.execute(
|
||||
select(AgentInstance)
|
||||
.options(
|
||||
joinedload(AgentInstance.agent_type),
|
||||
joinedload(AgentInstance.project),
|
||||
)
|
||||
.where(AgentInstance.id == instance_id)
|
||||
)
|
||||
instance = result.scalar_one_or_none()
|
||||
|
||||
if not instance:
|
||||
return None
|
||||
|
||||
# Get assigned issues count
|
||||
issues_count_result = await db.execute(
|
||||
select(func.count(Issue.id)).where(
|
||||
Issue.assigned_agent_id == instance_id
|
||||
)
|
||||
)
|
||||
assigned_issues_count = issues_count_result.scalar_one()
|
||||
|
||||
return {
|
||||
"instance": instance,
|
||||
"agent_type_name": instance.agent_type.name if instance.agent_type else None,
|
||||
"agent_type_slug": instance.agent_type.slug if instance.agent_type else None,
|
||||
"project_name": instance.project.name if instance.project else None,
|
||||
"project_slug": instance.project.slug if instance.project else None,
|
||||
"assigned_issues_count": assigned_issues_count,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting agent instance with details {instance_id}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_by_project(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
project_id: UUID,
|
||||
status: AgentStatus | None = None,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
) -> tuple[list[AgentInstance], int]:
|
||||
"""Get agent instances for a specific project."""
|
||||
try:
|
||||
query = select(AgentInstance).where(
|
||||
AgentInstance.project_id == project_id
|
||||
)
|
||||
|
||||
if status is not None:
|
||||
query = query.where(AgentInstance.status == status)
|
||||
|
||||
# Get total count
|
||||
count_query = select(func.count()).select_from(query.alias())
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# Apply pagination
|
||||
query = query.order_by(AgentInstance.created_at.desc())
|
||||
query = query.offset(skip).limit(limit)
|
||||
result = await db.execute(query)
|
||||
instances = list(result.scalars().all())
|
||||
|
||||
return instances, total
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting instances by project {project_id}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_by_agent_type(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
agent_type_id: UUID,
|
||||
status: AgentStatus | None = None,
|
||||
) -> list[AgentInstance]:
|
||||
"""Get all instances of a specific agent type."""
|
||||
try:
|
||||
query = select(AgentInstance).where(
|
||||
AgentInstance.agent_type_id == agent_type_id
|
||||
)
|
||||
|
||||
if status is not None:
|
||||
query = query.where(AgentInstance.status == status)
|
||||
|
||||
query = query.order_by(AgentInstance.created_at.desc())
|
||||
result = await db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting instances by agent type {agent_type_id}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
async def update_status(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
instance_id: UUID,
|
||||
status: AgentStatus,
|
||||
current_task: str | None = None,
|
||||
) -> AgentInstance | None:
|
||||
"""Update the status of an agent instance."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(AgentInstance).where(AgentInstance.id == instance_id)
|
||||
)
|
||||
instance = result.scalar_one_or_none()
|
||||
|
||||
if not instance:
|
||||
return None
|
||||
|
||||
instance.status = status
|
||||
instance.last_activity_at = datetime.now(UTC)
|
||||
if current_task is not None:
|
||||
instance.current_task = current_task
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(instance)
|
||||
return instance
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
f"Error updating instance status {instance_id}: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def terminate(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
instance_id: UUID,
|
||||
) -> AgentInstance | None:
|
||||
"""Terminate an agent instance."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(AgentInstance).where(AgentInstance.id == instance_id)
|
||||
)
|
||||
instance = result.scalar_one_or_none()
|
||||
|
||||
if not instance:
|
||||
return None
|
||||
|
||||
instance.status = AgentStatus.TERMINATED
|
||||
instance.terminated_at = datetime.now(UTC)
|
||||
instance.current_task = None
|
||||
instance.session_id = None
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(instance)
|
||||
return instance
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
f"Error terminating instance {instance_id}: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def record_task_completion(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
instance_id: UUID,
|
||||
tokens_used: int,
|
||||
cost_incurred: Decimal,
|
||||
) -> AgentInstance | None:
|
||||
"""Record a completed task and update metrics."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(AgentInstance).where(AgentInstance.id == instance_id)
|
||||
)
|
||||
instance = result.scalar_one_or_none()
|
||||
|
||||
if not instance:
|
||||
return None
|
||||
|
||||
instance.tasks_completed += 1
|
||||
instance.tokens_used += tokens_used
|
||||
instance.cost_incurred += cost_incurred
|
||||
instance.last_activity_at = datetime.now(UTC)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(instance)
|
||||
return instance
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
f"Error recording task completion {instance_id}: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_project_metrics(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
project_id: UUID,
|
||||
) -> dict[str, Any]:
|
||||
"""Get aggregated metrics for all agents in a project."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(
|
||||
func.count(AgentInstance.id).label("total_instances"),
|
||||
func.count(AgentInstance.id)
|
||||
.filter(AgentInstance.status == AgentStatus.WORKING)
|
||||
.label("active_instances"),
|
||||
func.count(AgentInstance.id)
|
||||
.filter(AgentInstance.status == AgentStatus.IDLE)
|
||||
.label("idle_instances"),
|
||||
func.sum(AgentInstance.tasks_completed).label("total_tasks"),
|
||||
func.sum(AgentInstance.tokens_used).label("total_tokens"),
|
||||
func.sum(AgentInstance.cost_incurred).label("total_cost"),
|
||||
).where(AgentInstance.project_id == project_id)
|
||||
)
|
||||
row = result.one()
|
||||
|
||||
return {
|
||||
"total_instances": row.total_instances or 0,
|
||||
"active_instances": row.active_instances or 0,
|
||||
"idle_instances": row.idle_instances or 0,
|
||||
"total_tasks_completed": row.total_tasks or 0,
|
||||
"total_tokens_used": row.total_tokens or 0,
|
||||
"total_cost_incurred": row.total_cost or Decimal("0.0000"),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting project metrics {project_id}: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def bulk_terminate_by_project(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
project_id: UUID,
|
||||
) -> int:
|
||||
"""Terminate all active instances in a project."""
|
||||
try:
|
||||
now = datetime.now(UTC)
|
||||
stmt = (
|
||||
update(AgentInstance)
|
||||
.where(
|
||||
AgentInstance.project_id == project_id,
|
||||
AgentInstance.status != AgentStatus.TERMINATED,
|
||||
)
|
||||
.values(
|
||||
status=AgentStatus.TERMINATED,
|
||||
terminated_at=now,
|
||||
current_task=None,
|
||||
session_id=None,
|
||||
updated_at=now,
|
||||
)
|
||||
)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
await db.commit()
|
||||
|
||||
terminated_count = result.rowcount
|
||||
logger.info(
|
||||
f"Bulk terminated {terminated_count} instances in project {project_id}"
|
||||
)
|
||||
return terminated_count
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
f"Error bulk terminating instances for project {project_id}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
# Create a singleton instance for use across the application
|
||||
agent_instance = CRUDAgentInstance(AgentInstance)
|
||||
275
backend/app/crud/syndarix/agent_type.py
Normal file
275
backend/app/crud/syndarix/agent_type.py
Normal file
@@ -0,0 +1,275 @@
|
||||
# app/crud/syndarix/agent_type.py
|
||||
"""Async CRUD operations for AgentType model using SQLAlchemy 2.0 patterns."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func, or_, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.crud.base import CRUDBase
|
||||
from app.models.syndarix import AgentInstance, AgentType
|
||||
from app.schemas.syndarix import AgentTypeCreate, AgentTypeUpdate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CRUDAgentType(CRUDBase[AgentType, AgentTypeCreate, AgentTypeUpdate]):
|
||||
"""Async CRUD operations for AgentType model."""
|
||||
|
||||
async def get_by_slug(self, db: AsyncSession, *, slug: str) -> AgentType | None:
|
||||
"""Get agent type by slug."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(AgentType).where(AgentType.slug == slug)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting agent type by slug {slug}: {e!s}")
|
||||
raise
|
||||
|
||||
async def create(
|
||||
self, db: AsyncSession, *, obj_in: AgentTypeCreate
|
||||
) -> AgentType:
|
||||
"""Create a new agent type with error handling."""
|
||||
try:
|
||||
db_obj = AgentType(
|
||||
name=obj_in.name,
|
||||
slug=obj_in.slug,
|
||||
description=obj_in.description,
|
||||
expertise=obj_in.expertise,
|
||||
personality_prompt=obj_in.personality_prompt,
|
||||
primary_model=obj_in.primary_model,
|
||||
fallback_models=obj_in.fallback_models,
|
||||
model_params=obj_in.model_params,
|
||||
mcp_servers=obj_in.mcp_servers,
|
||||
tool_permissions=obj_in.tool_permissions,
|
||||
is_active=obj_in.is_active,
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
if "slug" in error_msg.lower():
|
||||
logger.warning(f"Duplicate slug attempted: {obj_in.slug}")
|
||||
raise ValueError(
|
||||
f"Agent type with slug '{obj_in.slug}' already exists"
|
||||
)
|
||||
logger.error(f"Integrity error creating agent type: {error_msg}")
|
||||
raise ValueError(f"Database integrity error: {error_msg}")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
f"Unexpected error creating agent type: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_multi_with_filters(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
is_active: bool | None = None,
|
||||
search: str | None = None,
|
||||
sort_by: str = "created_at",
|
||||
sort_order: str = "desc",
|
||||
) -> tuple[list[AgentType], int]:
|
||||
"""
|
||||
Get multiple agent types with filtering, searching, and sorting.
|
||||
|
||||
Returns:
|
||||
Tuple of (agent types list, total count)
|
||||
"""
|
||||
try:
|
||||
query = select(AgentType)
|
||||
|
||||
# Apply filters
|
||||
if is_active is not None:
|
||||
query = query.where(AgentType.is_active == is_active)
|
||||
|
||||
if search:
|
||||
search_filter = or_(
|
||||
AgentType.name.ilike(f"%{search}%"),
|
||||
AgentType.slug.ilike(f"%{search}%"),
|
||||
AgentType.description.ilike(f"%{search}%"),
|
||||
)
|
||||
query = query.where(search_filter)
|
||||
|
||||
# Get total count before pagination
|
||||
count_query = select(func.count()).select_from(query.alias())
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# Apply sorting
|
||||
sort_column = getattr(AgentType, sort_by, AgentType.created_at)
|
||||
if sort_order == "desc":
|
||||
query = query.order_by(sort_column.desc())
|
||||
else:
|
||||
query = query.order_by(sort_column.asc())
|
||||
|
||||
# Apply pagination
|
||||
query = query.offset(skip).limit(limit)
|
||||
result = await db.execute(query)
|
||||
agent_types = list(result.scalars().all())
|
||||
|
||||
return agent_types, total
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting agent types with filters: {e!s}")
|
||||
raise
|
||||
|
||||
async def get_with_instance_count(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
agent_type_id: UUID,
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
Get a single agent type with its instance count.
|
||||
|
||||
Returns:
|
||||
Dictionary with agent_type and instance_count
|
||||
"""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(AgentType).where(AgentType.id == agent_type_id)
|
||||
)
|
||||
agent_type = result.scalar_one_or_none()
|
||||
|
||||
if not agent_type:
|
||||
return None
|
||||
|
||||
# Get instance count
|
||||
count_result = await db.execute(
|
||||
select(func.count(AgentInstance.id)).where(
|
||||
AgentInstance.agent_type_id == agent_type_id
|
||||
)
|
||||
)
|
||||
instance_count = count_result.scalar_one()
|
||||
|
||||
return {
|
||||
"agent_type": agent_type,
|
||||
"instance_count": instance_count,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting agent type with count {agent_type_id}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_multi_with_instance_counts(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
is_active: bool | None = None,
|
||||
search: str | None = None,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""
|
||||
Get agent types with instance counts in optimized queries.
|
||||
|
||||
Returns:
|
||||
Tuple of (list of dicts with agent_type and instance_count, total count)
|
||||
"""
|
||||
try:
|
||||
# Get filtered agent types
|
||||
agent_types, total = await self.get_multi_with_filters(
|
||||
db,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
is_active=is_active,
|
||||
search=search,
|
||||
)
|
||||
|
||||
if not agent_types:
|
||||
return [], 0
|
||||
|
||||
agent_type_ids = [at.id for at in agent_types]
|
||||
|
||||
# Get instance counts in bulk
|
||||
counts_result = await db.execute(
|
||||
select(
|
||||
AgentInstance.agent_type_id,
|
||||
func.count(AgentInstance.id).label("count"),
|
||||
)
|
||||
.where(AgentInstance.agent_type_id.in_(agent_type_ids))
|
||||
.group_by(AgentInstance.agent_type_id)
|
||||
)
|
||||
counts = {row.agent_type_id: row.count for row in counts_result}
|
||||
|
||||
# Combine results
|
||||
results = [
|
||||
{
|
||||
"agent_type": agent_type,
|
||||
"instance_count": counts.get(agent_type.id, 0),
|
||||
}
|
||||
for agent_type in agent_types
|
||||
]
|
||||
|
||||
return results, total
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting agent types with counts: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_by_expertise(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
expertise: str,
|
||||
is_active: bool = True,
|
||||
) -> list[AgentType]:
|
||||
"""Get agent types that have a specific expertise."""
|
||||
try:
|
||||
# Use PostgreSQL JSONB contains operator
|
||||
query = select(AgentType).where(
|
||||
AgentType.expertise.contains([expertise.lower()]),
|
||||
AgentType.is_active == is_active,
|
||||
)
|
||||
result = await db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting agent types by expertise {expertise}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
async def deactivate(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
agent_type_id: UUID,
|
||||
) -> AgentType | None:
|
||||
"""Deactivate an agent type (soft delete)."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(AgentType).where(AgentType.id == agent_type_id)
|
||||
)
|
||||
agent_type = result.scalar_one_or_none()
|
||||
|
||||
if not agent_type:
|
||||
return None
|
||||
|
||||
agent_type.is_active = False
|
||||
await db.commit()
|
||||
await db.refresh(agent_type)
|
||||
return agent_type
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
f"Error deactivating agent type {agent_type_id}: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
# Create a singleton instance for use across the application
|
||||
agent_type = CRUDAgentType(AgentType)
|
||||
437
backend/app/crud/syndarix/issue.py
Normal file
437
backend/app/crud/syndarix/issue.py
Normal file
@@ -0,0 +1,437 @@
|
||||
# app/crud/syndarix/issue.py
|
||||
"""Async CRUD operations for Issue model using SQLAlchemy 2.0 patterns."""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func, or_, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
from app.crud.base import CRUDBase
|
||||
from app.models.syndarix import AgentInstance, Issue
|
||||
from app.models.syndarix.enums import IssuePriority, IssueStatus, SyncStatus
|
||||
from app.schemas.syndarix import IssueCreate, IssueUpdate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CRUDIssue(CRUDBase[Issue, IssueCreate, IssueUpdate]):
|
||||
"""Async CRUD operations for Issue model."""
|
||||
|
||||
async def create(self, db: AsyncSession, *, obj_in: IssueCreate) -> Issue:
|
||||
"""Create a new issue with error handling."""
|
||||
try:
|
||||
db_obj = Issue(
|
||||
project_id=obj_in.project_id,
|
||||
title=obj_in.title,
|
||||
body=obj_in.body,
|
||||
status=obj_in.status,
|
||||
priority=obj_in.priority,
|
||||
labels=obj_in.labels,
|
||||
assigned_agent_id=obj_in.assigned_agent_id,
|
||||
human_assignee=obj_in.human_assignee,
|
||||
sprint_id=obj_in.sprint_id,
|
||||
story_points=obj_in.story_points,
|
||||
external_tracker=obj_in.external_tracker,
|
||||
external_id=obj_in.external_id,
|
||||
external_url=obj_in.external_url,
|
||||
external_number=obj_in.external_number,
|
||||
sync_status=SyncStatus.SYNCED,
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
logger.error(f"Integrity error creating issue: {error_msg}")
|
||||
raise ValueError(f"Database integrity error: {error_msg}")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Unexpected error creating issue: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_with_details(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
issue_id: UUID,
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
Get an issue with full details including related entity names.
|
||||
|
||||
Returns:
|
||||
Dictionary with issue and related entity details
|
||||
"""
|
||||
try:
|
||||
# Get issue with joined relationships
|
||||
result = await db.execute(
|
||||
select(Issue)
|
||||
.options(
|
||||
joinedload(Issue.project),
|
||||
joinedload(Issue.sprint),
|
||||
joinedload(Issue.assigned_agent).joinedload(AgentInstance.agent_type),
|
||||
)
|
||||
.where(Issue.id == issue_id)
|
||||
)
|
||||
issue = result.scalar_one_or_none()
|
||||
|
||||
if not issue:
|
||||
return None
|
||||
|
||||
return {
|
||||
"issue": issue,
|
||||
"project_name": issue.project.name if issue.project else None,
|
||||
"project_slug": issue.project.slug if issue.project else None,
|
||||
"sprint_name": issue.sprint.name if issue.sprint else None,
|
||||
"assigned_agent_type_name": (
|
||||
issue.assigned_agent.agent_type.name
|
||||
if issue.assigned_agent and issue.assigned_agent.agent_type
|
||||
else None
|
||||
),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting issue with details {issue_id}: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_by_project(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
project_id: UUID,
|
||||
status: IssueStatus | None = None,
|
||||
priority: IssuePriority | None = None,
|
||||
sprint_id: UUID | None = None,
|
||||
assigned_agent_id: UUID | None = None,
|
||||
labels: list[str] | None = None,
|
||||
search: str | None = None,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
sort_by: str = "created_at",
|
||||
sort_order: str = "desc",
|
||||
) -> tuple[list[Issue], int]:
|
||||
"""Get issues for a specific project with filters."""
|
||||
try:
|
||||
query = select(Issue).where(Issue.project_id == project_id)
|
||||
|
||||
# Apply filters
|
||||
if status is not None:
|
||||
query = query.where(Issue.status == status)
|
||||
|
||||
if priority is not None:
|
||||
query = query.where(Issue.priority == priority)
|
||||
|
||||
if sprint_id is not None:
|
||||
query = query.where(Issue.sprint_id == sprint_id)
|
||||
|
||||
if assigned_agent_id is not None:
|
||||
query = query.where(Issue.assigned_agent_id == assigned_agent_id)
|
||||
|
||||
if labels:
|
||||
# Match any of the provided labels
|
||||
for label in labels:
|
||||
query = query.where(Issue.labels.contains([label.lower()]))
|
||||
|
||||
if search:
|
||||
search_filter = or_(
|
||||
Issue.title.ilike(f"%{search}%"),
|
||||
Issue.body.ilike(f"%{search}%"),
|
||||
)
|
||||
query = query.where(search_filter)
|
||||
|
||||
# Get total count
|
||||
count_query = select(func.count()).select_from(query.alias())
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# Apply sorting
|
||||
sort_column = getattr(Issue, sort_by, Issue.created_at)
|
||||
if sort_order == "desc":
|
||||
query = query.order_by(sort_column.desc())
|
||||
else:
|
||||
query = query.order_by(sort_column.asc())
|
||||
|
||||
# Apply pagination
|
||||
query = query.offset(skip).limit(limit)
|
||||
result = await db.execute(query)
|
||||
issues = list(result.scalars().all())
|
||||
|
||||
return issues, total
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting issues by project {project_id}: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_by_sprint(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
sprint_id: UUID,
|
||||
status: IssueStatus | None = None,
|
||||
) -> list[Issue]:
|
||||
"""Get all issues in a sprint."""
|
||||
try:
|
||||
query = select(Issue).where(Issue.sprint_id == sprint_id)
|
||||
|
||||
if status is not None:
|
||||
query = query.where(Issue.status == status)
|
||||
|
||||
query = query.order_by(Issue.priority.desc(), Issue.created_at.asc())
|
||||
result = await db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting issues by sprint {sprint_id}: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def assign_to_agent(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
issue_id: UUID,
|
||||
agent_id: UUID | None,
|
||||
) -> Issue | None:
|
||||
"""Assign an issue to an agent (or unassign if agent_id is None)."""
|
||||
try:
|
||||
result = await db.execute(select(Issue).where(Issue.id == issue_id))
|
||||
issue = result.scalar_one_or_none()
|
||||
|
||||
if not issue:
|
||||
return None
|
||||
|
||||
issue.assigned_agent_id = agent_id
|
||||
issue.human_assignee = None # Clear human assignee when assigning to agent
|
||||
await db.commit()
|
||||
await db.refresh(issue)
|
||||
return issue
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
f"Error assigning issue {issue_id} to agent {agent_id}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
async def assign_to_human(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
issue_id: UUID,
|
||||
human_assignee: str | None,
|
||||
) -> Issue | None:
|
||||
"""Assign an issue to a human (or unassign if human_assignee is None)."""
|
||||
try:
|
||||
result = await db.execute(select(Issue).where(Issue.id == issue_id))
|
||||
issue = result.scalar_one_or_none()
|
||||
|
||||
if not issue:
|
||||
return None
|
||||
|
||||
issue.human_assignee = human_assignee
|
||||
issue.assigned_agent_id = None # Clear agent when assigning to human
|
||||
await db.commit()
|
||||
await db.refresh(issue)
|
||||
return issue
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
f"Error assigning issue {issue_id} to human {human_assignee}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
async def close_issue(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
issue_id: UUID,
|
||||
) -> Issue | None:
|
||||
"""Close an issue by setting status and closed_at timestamp."""
|
||||
try:
|
||||
result = await db.execute(select(Issue).where(Issue.id == issue_id))
|
||||
issue = result.scalar_one_or_none()
|
||||
|
||||
if not issue:
|
||||
return None
|
||||
|
||||
issue.status = IssueStatus.CLOSED
|
||||
issue.closed_at = datetime.now(UTC)
|
||||
await db.commit()
|
||||
await db.refresh(issue)
|
||||
return issue
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error closing issue {issue_id}: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def reopen_issue(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
issue_id: UUID,
|
||||
) -> Issue | None:
|
||||
"""Reopen a closed issue."""
|
||||
try:
|
||||
result = await db.execute(select(Issue).where(Issue.id == issue_id))
|
||||
issue = result.scalar_one_or_none()
|
||||
|
||||
if not issue:
|
||||
return None
|
||||
|
||||
issue.status = IssueStatus.OPEN
|
||||
issue.closed_at = None
|
||||
await db.commit()
|
||||
await db.refresh(issue)
|
||||
return issue
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error reopening issue {issue_id}: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def update_sync_status(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
issue_id: UUID,
|
||||
sync_status: SyncStatus,
|
||||
last_synced_at: datetime | None = None,
|
||||
external_updated_at: datetime | None = None,
|
||||
) -> Issue | None:
|
||||
"""Update the sync status of an issue."""
|
||||
try:
|
||||
result = await db.execute(select(Issue).where(Issue.id == issue_id))
|
||||
issue = result.scalar_one_or_none()
|
||||
|
||||
if not issue:
|
||||
return None
|
||||
|
||||
issue.sync_status = sync_status
|
||||
if last_synced_at:
|
||||
issue.last_synced_at = last_synced_at
|
||||
if external_updated_at:
|
||||
issue.external_updated_at = external_updated_at
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(issue)
|
||||
return issue
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
f"Error updating sync status for issue {issue_id}: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_project_stats(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
project_id: UUID,
|
||||
) -> dict[str, Any]:
|
||||
"""Get issue statistics for a project."""
|
||||
try:
|
||||
# Get counts by status
|
||||
status_counts = await db.execute(
|
||||
select(Issue.status, func.count(Issue.id).label("count"))
|
||||
.where(Issue.project_id == project_id)
|
||||
.group_by(Issue.status)
|
||||
)
|
||||
by_status = {row.status.value: row.count for row in status_counts}
|
||||
|
||||
# Get counts by priority
|
||||
priority_counts = await db.execute(
|
||||
select(Issue.priority, func.count(Issue.id).label("count"))
|
||||
.where(Issue.project_id == project_id)
|
||||
.group_by(Issue.priority)
|
||||
)
|
||||
by_priority = {row.priority.value: row.count for row in priority_counts}
|
||||
|
||||
# Get story points
|
||||
points_result = await db.execute(
|
||||
select(
|
||||
func.sum(Issue.story_points).label("total"),
|
||||
func.sum(Issue.story_points)
|
||||
.filter(Issue.status == IssueStatus.CLOSED)
|
||||
.label("completed"),
|
||||
).where(Issue.project_id == project_id)
|
||||
)
|
||||
points_row = points_result.one()
|
||||
|
||||
total_issues = sum(by_status.values())
|
||||
|
||||
return {
|
||||
"total": total_issues,
|
||||
"open": by_status.get("open", 0),
|
||||
"in_progress": by_status.get("in_progress", 0),
|
||||
"in_review": by_status.get("in_review", 0),
|
||||
"blocked": by_status.get("blocked", 0),
|
||||
"closed": by_status.get("closed", 0),
|
||||
"by_priority": by_priority,
|
||||
"total_story_points": points_row.total,
|
||||
"completed_story_points": points_row.completed,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting issue stats for project {project_id}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_by_external_id(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
external_tracker: str,
|
||||
external_id: str,
|
||||
) -> Issue | None:
|
||||
"""Get an issue by its external tracker ID."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(Issue).where(
|
||||
Issue.external_tracker == external_tracker,
|
||||
Issue.external_id == external_id,
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting issue by external ID {external_tracker}:{external_id}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_pending_sync(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
project_id: UUID | None = None,
|
||||
limit: int = 100,
|
||||
) -> list[Issue]:
|
||||
"""Get issues that need to be synced with external tracker."""
|
||||
try:
|
||||
query = select(Issue).where(
|
||||
Issue.external_tracker.isnot(None),
|
||||
Issue.sync_status.in_([SyncStatus.PENDING, SyncStatus.ERROR]),
|
||||
)
|
||||
|
||||
if project_id:
|
||||
query = query.where(Issue.project_id == project_id)
|
||||
|
||||
query = query.order_by(Issue.updated_at.asc()).limit(limit)
|
||||
result = await db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting pending sync issues: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
# Create a singleton instance for use across the application
|
||||
issue = CRUDIssue(Issue)
|
||||
309
backend/app/crud/syndarix/project.py
Normal file
309
backend/app/crud/syndarix/project.py
Normal file
@@ -0,0 +1,309 @@
|
||||
# app/crud/syndarix/project.py
|
||||
"""Async CRUD operations for Project model using SQLAlchemy 2.0 patterns."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func, or_, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.crud.base import CRUDBase
|
||||
from app.models.syndarix import AgentInstance, Issue, Project, Sprint
|
||||
from app.models.syndarix.enums import ProjectStatus, SprintStatus
|
||||
from app.schemas.syndarix import ProjectCreate, ProjectUpdate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CRUDProject(CRUDBase[Project, ProjectCreate, ProjectUpdate]):
|
||||
"""Async CRUD operations for Project model."""
|
||||
|
||||
async def get_by_slug(self, db: AsyncSession, *, slug: str) -> Project | None:
|
||||
"""Get project by slug."""
|
||||
try:
|
||||
result = await db.execute(select(Project).where(Project.slug == slug))
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting project by slug {slug}: {e!s}")
|
||||
raise
|
||||
|
||||
async def create(self, db: AsyncSession, *, obj_in: ProjectCreate) -> Project:
|
||||
"""Create a new project with error handling."""
|
||||
try:
|
||||
db_obj = Project(
|
||||
name=obj_in.name,
|
||||
slug=obj_in.slug,
|
||||
description=obj_in.description,
|
||||
autonomy_level=obj_in.autonomy_level,
|
||||
status=obj_in.status,
|
||||
settings=obj_in.settings or {},
|
||||
owner_id=obj_in.owner_id,
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
if "slug" in error_msg.lower():
|
||||
logger.warning(f"Duplicate slug attempted: {obj_in.slug}")
|
||||
raise ValueError(f"Project with slug '{obj_in.slug}' already exists")
|
||||
logger.error(f"Integrity error creating project: {error_msg}")
|
||||
raise ValueError(f"Database integrity error: {error_msg}")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Unexpected error creating project: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_multi_with_filters(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
status: ProjectStatus | None = None,
|
||||
owner_id: UUID | None = None,
|
||||
search: str | None = None,
|
||||
sort_by: str = "created_at",
|
||||
sort_order: str = "desc",
|
||||
) -> tuple[list[Project], int]:
|
||||
"""
|
||||
Get multiple projects with filtering, searching, and sorting.
|
||||
|
||||
Returns:
|
||||
Tuple of (projects list, total count)
|
||||
"""
|
||||
try:
|
||||
query = select(Project)
|
||||
|
||||
# Apply filters
|
||||
if status is not None:
|
||||
query = query.where(Project.status == status)
|
||||
|
||||
if owner_id is not None:
|
||||
query = query.where(Project.owner_id == owner_id)
|
||||
|
||||
if search:
|
||||
search_filter = or_(
|
||||
Project.name.ilike(f"%{search}%"),
|
||||
Project.slug.ilike(f"%{search}%"),
|
||||
Project.description.ilike(f"%{search}%"),
|
||||
)
|
||||
query = query.where(search_filter)
|
||||
|
||||
# Get total count before pagination
|
||||
count_query = select(func.count()).select_from(query.alias())
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# Apply sorting
|
||||
sort_column = getattr(Project, sort_by, Project.created_at)
|
||||
if sort_order == "desc":
|
||||
query = query.order_by(sort_column.desc())
|
||||
else:
|
||||
query = query.order_by(sort_column.asc())
|
||||
|
||||
# Apply pagination
|
||||
query = query.offset(skip).limit(limit)
|
||||
result = await db.execute(query)
|
||||
projects = list(result.scalars().all())
|
||||
|
||||
return projects, total
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting projects with filters: {e!s}")
|
||||
raise
|
||||
|
||||
async def get_with_counts(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
project_id: UUID,
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
Get a single project with agent and issue counts.
|
||||
|
||||
Returns:
|
||||
Dictionary with project, agent_count, issue_count, active_sprint_name
|
||||
"""
|
||||
try:
|
||||
# Get project
|
||||
result = await db.execute(select(Project).where(Project.id == project_id))
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
return None
|
||||
|
||||
# Get agent count
|
||||
agent_count_result = await db.execute(
|
||||
select(func.count(AgentInstance.id)).where(
|
||||
AgentInstance.project_id == project_id
|
||||
)
|
||||
)
|
||||
agent_count = agent_count_result.scalar_one()
|
||||
|
||||
# Get issue count
|
||||
issue_count_result = await db.execute(
|
||||
select(func.count(Issue.id)).where(Issue.project_id == project_id)
|
||||
)
|
||||
issue_count = issue_count_result.scalar_one()
|
||||
|
||||
# Get active sprint name
|
||||
active_sprint_result = await db.execute(
|
||||
select(Sprint.name).where(
|
||||
Sprint.project_id == project_id,
|
||||
Sprint.status == SprintStatus.ACTIVE,
|
||||
)
|
||||
)
|
||||
active_sprint_name = active_sprint_result.scalar_one_or_none()
|
||||
|
||||
return {
|
||||
"project": project,
|
||||
"agent_count": agent_count,
|
||||
"issue_count": issue_count,
|
||||
"active_sprint_name": active_sprint_name,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting project with counts {project_id}: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_multi_with_counts(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
status: ProjectStatus | None = None,
|
||||
owner_id: UUID | None = None,
|
||||
search: str | None = None,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""
|
||||
Get projects with agent/issue counts in optimized queries.
|
||||
|
||||
Returns:
|
||||
Tuple of (list of dicts with project and counts, total count)
|
||||
"""
|
||||
try:
|
||||
# Get filtered projects
|
||||
projects, total = await self.get_multi_with_filters(
|
||||
db,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
status=status,
|
||||
owner_id=owner_id,
|
||||
search=search,
|
||||
)
|
||||
|
||||
if not projects:
|
||||
return [], 0
|
||||
|
||||
project_ids = [p.id for p in projects]
|
||||
|
||||
# Get agent counts in bulk
|
||||
agent_counts_result = await db.execute(
|
||||
select(
|
||||
AgentInstance.project_id,
|
||||
func.count(AgentInstance.id).label("count"),
|
||||
)
|
||||
.where(AgentInstance.project_id.in_(project_ids))
|
||||
.group_by(AgentInstance.project_id)
|
||||
)
|
||||
agent_counts = {row.project_id: row.count for row in agent_counts_result}
|
||||
|
||||
# Get issue counts in bulk
|
||||
issue_counts_result = await db.execute(
|
||||
select(
|
||||
Issue.project_id,
|
||||
func.count(Issue.id).label("count"),
|
||||
)
|
||||
.where(Issue.project_id.in_(project_ids))
|
||||
.group_by(Issue.project_id)
|
||||
)
|
||||
issue_counts = {row.project_id: row.count for row in issue_counts_result}
|
||||
|
||||
# Get active sprint names
|
||||
active_sprints_result = await db.execute(
|
||||
select(Sprint.project_id, Sprint.name).where(
|
||||
Sprint.project_id.in_(project_ids),
|
||||
Sprint.status == SprintStatus.ACTIVE,
|
||||
)
|
||||
)
|
||||
active_sprints = {
|
||||
row.project_id: row.name for row in active_sprints_result
|
||||
}
|
||||
|
||||
# Combine results
|
||||
results = [
|
||||
{
|
||||
"project": project,
|
||||
"agent_count": agent_counts.get(project.id, 0),
|
||||
"issue_count": issue_counts.get(project.id, 0),
|
||||
"active_sprint_name": active_sprints.get(project.id),
|
||||
}
|
||||
for project in projects
|
||||
]
|
||||
|
||||
return results, total
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting projects with counts: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_projects_by_owner(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
owner_id: UUID,
|
||||
status: ProjectStatus | None = None,
|
||||
) -> list[Project]:
|
||||
"""Get all projects owned by a specific user."""
|
||||
try:
|
||||
query = select(Project).where(Project.owner_id == owner_id)
|
||||
|
||||
if status is not None:
|
||||
query = query.where(Project.status == status)
|
||||
|
||||
query = query.order_by(Project.created_at.desc())
|
||||
result = await db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting projects by owner {owner_id}: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def archive_project(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
project_id: UUID,
|
||||
) -> Project | None:
|
||||
"""Archive a project by setting status to ARCHIVED."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
return None
|
||||
|
||||
project.status = ProjectStatus.ARCHIVED
|
||||
await db.commit()
|
||||
await db.refresh(project)
|
||||
return project
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
f"Error archiving project {project_id}: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
# Create a singleton instance for use across the application
|
||||
project = CRUDProject(Project)
|
||||
406
backend/app/crud/syndarix/sprint.py
Normal file
406
backend/app/crud/syndarix/sprint.py
Normal file
@@ -0,0 +1,406 @@
|
||||
# app/crud/syndarix/sprint.py
|
||||
"""Async CRUD operations for Sprint model using SQLAlchemy 2.0 patterns."""
|
||||
|
||||
import logging
|
||||
from datetime import date
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
from app.crud.base import CRUDBase
|
||||
from app.models.syndarix import Issue, Sprint
|
||||
from app.models.syndarix.enums import IssueStatus, SprintStatus
|
||||
from app.schemas.syndarix import SprintCreate, SprintUpdate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CRUDSprint(CRUDBase[Sprint, SprintCreate, SprintUpdate]):
|
||||
"""Async CRUD operations for Sprint model."""
|
||||
|
||||
async def create(self, db: AsyncSession, *, obj_in: SprintCreate) -> Sprint:
|
||||
"""Create a new sprint with error handling."""
|
||||
try:
|
||||
db_obj = Sprint(
|
||||
project_id=obj_in.project_id,
|
||||
name=obj_in.name,
|
||||
number=obj_in.number,
|
||||
goal=obj_in.goal,
|
||||
start_date=obj_in.start_date,
|
||||
end_date=obj_in.end_date,
|
||||
status=obj_in.status,
|
||||
planned_points=obj_in.planned_points,
|
||||
completed_points=obj_in.completed_points,
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
logger.error(f"Integrity error creating sprint: {error_msg}")
|
||||
raise ValueError(f"Database integrity error: {error_msg}")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Unexpected error creating sprint: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_with_details(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
sprint_id: UUID,
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
Get a sprint with full details including issue counts.
|
||||
|
||||
Returns:
|
||||
Dictionary with sprint and related details
|
||||
"""
|
||||
try:
|
||||
# Get sprint with joined project
|
||||
result = await db.execute(
|
||||
select(Sprint)
|
||||
.options(joinedload(Sprint.project))
|
||||
.where(Sprint.id == sprint_id)
|
||||
)
|
||||
sprint = result.scalar_one_or_none()
|
||||
|
||||
if not sprint:
|
||||
return None
|
||||
|
||||
# Get issue counts
|
||||
issue_counts = await db.execute(
|
||||
select(
|
||||
func.count(Issue.id).label("total"),
|
||||
func.count(Issue.id)
|
||||
.filter(Issue.status == IssueStatus.OPEN)
|
||||
.label("open"),
|
||||
func.count(Issue.id)
|
||||
.filter(Issue.status == IssueStatus.CLOSED)
|
||||
.label("completed"),
|
||||
).where(Issue.sprint_id == sprint_id)
|
||||
)
|
||||
counts = issue_counts.one()
|
||||
|
||||
return {
|
||||
"sprint": sprint,
|
||||
"project_name": sprint.project.name if sprint.project else None,
|
||||
"project_slug": sprint.project.slug if sprint.project else None,
|
||||
"issue_count": counts.total,
|
||||
"open_issues": counts.open,
|
||||
"completed_issues": counts.completed,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting sprint with details {sprint_id}: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_by_project(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
project_id: UUID,
|
||||
status: SprintStatus | None = None,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
) -> tuple[list[Sprint], int]:
|
||||
"""Get sprints for a specific project."""
|
||||
try:
|
||||
query = select(Sprint).where(Sprint.project_id == project_id)
|
||||
|
||||
if status is not None:
|
||||
query = query.where(Sprint.status == status)
|
||||
|
||||
# Get total count
|
||||
count_query = select(func.count()).select_from(query.alias())
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# Apply sorting (by number descending - newest first)
|
||||
query = query.order_by(Sprint.number.desc())
|
||||
query = query.offset(skip).limit(limit)
|
||||
result = await db.execute(query)
|
||||
sprints = list(result.scalars().all())
|
||||
|
||||
return sprints, total
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting sprints by project {project_id}: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_active_sprint(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
project_id: UUID,
|
||||
) -> Sprint | None:
|
||||
"""Get the currently active sprint for a project."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(Sprint).where(
|
||||
Sprint.project_id == project_id,
|
||||
Sprint.status == SprintStatus.ACTIVE,
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting active sprint for project {project_id}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_next_sprint_number(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
project_id: UUID,
|
||||
) -> int:
|
||||
"""Get the next sprint number for a project."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(func.max(Sprint.number)).where(Sprint.project_id == project_id)
|
||||
)
|
||||
max_number = result.scalar_one_or_none()
|
||||
return (max_number or 0) + 1
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting next sprint number for project {project_id}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
async def start_sprint(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
sprint_id: UUID,
|
||||
start_date: date | None = None,
|
||||
) -> Sprint | None:
|
||||
"""Start a planned sprint."""
|
||||
try:
|
||||
result = await db.execute(select(Sprint).where(Sprint.id == sprint_id))
|
||||
sprint = result.scalar_one_or_none()
|
||||
|
||||
if not sprint:
|
||||
return None
|
||||
|
||||
if sprint.status != SprintStatus.PLANNED:
|
||||
raise ValueError(
|
||||
f"Cannot start sprint with status {sprint.status.value}"
|
||||
)
|
||||
|
||||
# Check for existing active sprint in project
|
||||
active_sprint = await self.get_active_sprint(db, project_id=sprint.project_id)
|
||||
if active_sprint:
|
||||
raise ValueError(
|
||||
f"Project already has an active sprint: {active_sprint.name}"
|
||||
)
|
||||
|
||||
sprint.status = SprintStatus.ACTIVE
|
||||
if start_date:
|
||||
sprint.start_date = start_date
|
||||
|
||||
# Calculate planned points from issues
|
||||
points_result = await db.execute(
|
||||
select(func.sum(Issue.story_points)).where(Issue.sprint_id == sprint_id)
|
||||
)
|
||||
sprint.planned_points = points_result.scalar_one_or_none() or 0
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(sprint)
|
||||
return sprint
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error starting sprint {sprint_id}: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def complete_sprint(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
sprint_id: UUID,
|
||||
) -> Sprint | None:
|
||||
"""Complete an active sprint and calculate completed points."""
|
||||
try:
|
||||
result = await db.execute(select(Sprint).where(Sprint.id == sprint_id))
|
||||
sprint = result.scalar_one_or_none()
|
||||
|
||||
if not sprint:
|
||||
return None
|
||||
|
||||
if sprint.status != SprintStatus.ACTIVE:
|
||||
raise ValueError(
|
||||
f"Cannot complete sprint with status {sprint.status.value}"
|
||||
)
|
||||
|
||||
sprint.status = SprintStatus.COMPLETED
|
||||
|
||||
# Calculate completed points from closed issues
|
||||
points_result = await db.execute(
|
||||
select(func.sum(Issue.story_points)).where(
|
||||
Issue.sprint_id == sprint_id,
|
||||
Issue.status == IssueStatus.CLOSED,
|
||||
)
|
||||
)
|
||||
sprint.completed_points = points_result.scalar_one_or_none() or 0
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(sprint)
|
||||
return sprint
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error completing sprint {sprint_id}: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def cancel_sprint(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
sprint_id: UUID,
|
||||
) -> Sprint | None:
|
||||
"""Cancel a sprint (only PLANNED or ACTIVE sprints can be cancelled)."""
|
||||
try:
|
||||
result = await db.execute(select(Sprint).where(Sprint.id == sprint_id))
|
||||
sprint = result.scalar_one_or_none()
|
||||
|
||||
if not sprint:
|
||||
return None
|
||||
|
||||
if sprint.status not in [SprintStatus.PLANNED, SprintStatus.ACTIVE]:
|
||||
raise ValueError(
|
||||
f"Cannot cancel sprint with status {sprint.status.value}"
|
||||
)
|
||||
|
||||
sprint.status = SprintStatus.CANCELLED
|
||||
await db.commit()
|
||||
await db.refresh(sprint)
|
||||
return sprint
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error cancelling sprint {sprint_id}: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_velocity(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
project_id: UUID,
|
||||
limit: int = 5,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get velocity data for completed sprints."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(Sprint)
|
||||
.where(
|
||||
Sprint.project_id == project_id,
|
||||
Sprint.status == SprintStatus.COMPLETED,
|
||||
)
|
||||
.order_by(Sprint.number.desc())
|
||||
.limit(limit)
|
||||
)
|
||||
sprints = list(result.scalars().all())
|
||||
|
||||
velocity_data = []
|
||||
for sprint in reversed(sprints): # Return in chronological order
|
||||
velocity = None
|
||||
if sprint.planned_points and sprint.planned_points > 0:
|
||||
velocity = (sprint.completed_points or 0) / sprint.planned_points
|
||||
velocity_data.append(
|
||||
{
|
||||
"sprint_number": sprint.number,
|
||||
"sprint_name": sprint.name,
|
||||
"planned_points": sprint.planned_points,
|
||||
"completed_points": sprint.completed_points,
|
||||
"velocity": velocity,
|
||||
}
|
||||
)
|
||||
|
||||
return velocity_data
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting velocity for project {project_id}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_sprints_with_issue_counts(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
project_id: UUID,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""Get sprints with issue counts in optimized queries."""
|
||||
try:
|
||||
# Get sprints
|
||||
sprints, total = await self.get_by_project(
|
||||
db, project_id=project_id, skip=skip, limit=limit
|
||||
)
|
||||
|
||||
if not sprints:
|
||||
return [], 0
|
||||
|
||||
sprint_ids = [s.id for s in sprints]
|
||||
|
||||
# Get issue counts in bulk
|
||||
issue_counts = await db.execute(
|
||||
select(
|
||||
Issue.sprint_id,
|
||||
func.count(Issue.id).label("total"),
|
||||
func.count(Issue.id)
|
||||
.filter(Issue.status == IssueStatus.OPEN)
|
||||
.label("open"),
|
||||
func.count(Issue.id)
|
||||
.filter(Issue.status == IssueStatus.CLOSED)
|
||||
.label("completed"),
|
||||
)
|
||||
.where(Issue.sprint_id.in_(sprint_ids))
|
||||
.group_by(Issue.sprint_id)
|
||||
)
|
||||
counts_map = {
|
||||
row.sprint_id: {
|
||||
"issue_count": row.total,
|
||||
"open_issues": row.open,
|
||||
"completed_issues": row.completed,
|
||||
}
|
||||
for row in issue_counts
|
||||
}
|
||||
|
||||
# Combine results
|
||||
results = [
|
||||
{
|
||||
"sprint": sprint,
|
||||
**counts_map.get(
|
||||
sprint.id, {"issue_count": 0, "open_issues": 0, "completed_issues": 0}
|
||||
),
|
||||
}
|
||||
for sprint in sprints
|
||||
]
|
||||
|
||||
return results, total
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting sprints with counts for project {project_id}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
# Create a singleton instance for use across the application
|
||||
sprint = CRUDSprint(Sprint)
|
||||
155
backend/app/repositories/user.py → backend/app/crud/user.py
Normal file → Executable file
155
backend/app/repositories/user.py → backend/app/crud/user.py
Normal file → Executable file
@@ -1,5 +1,5 @@
|
||||
# app/repositories/user.py
|
||||
"""Repository for User model async database operations using SQLAlchemy 2.0 patterns."""
|
||||
# app/crud/user_async.py
|
||||
"""Async CRUD operations for User model using SQLAlchemy 2.0 patterns."""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
@@ -11,16 +11,15 @@ from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.auth import get_password_hash_async
|
||||
from app.core.repository_exceptions import DuplicateEntryError, InvalidInputError
|
||||
from app.crud.base import CRUDBase
|
||||
from app.models.user import User
|
||||
from app.repositories.base import BaseRepository
|
||||
from app.schemas.users import UserCreate, UserUpdate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UserRepository(BaseRepository[User, UserCreate, UserUpdate]):
|
||||
"""Repository for User model."""
|
||||
class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
"""Async CRUD operations for User model."""
|
||||
|
||||
async def get_by_email(self, db: AsyncSession, *, email: str) -> User | None:
|
||||
"""Get user by email address."""
|
||||
@@ -28,12 +27,13 @@ class UserRepository(BaseRepository[User, UserCreate, UserUpdate]):
|
||||
result = await db.execute(select(User).where(User.email == email))
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error("Error getting user by email %s: %s", email, e)
|
||||
logger.error(f"Error getting user by email {email}: {e!s}")
|
||||
raise
|
||||
|
||||
async def create(self, db: AsyncSession, *, obj_in: UserCreate) -> User:
|
||||
"""Create a new user with async password hashing and error handling."""
|
||||
try:
|
||||
# Hash password asynchronously to avoid blocking event loop
|
||||
password_hash = await get_password_hash_async(obj_in.password)
|
||||
|
||||
db_obj = User(
|
||||
@@ -57,49 +57,13 @@ class UserRepository(BaseRepository[User, UserCreate, UserUpdate]):
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
if "email" in error_msg.lower():
|
||||
logger.warning("Duplicate email attempted: %s", obj_in.email)
|
||||
raise DuplicateEntryError(
|
||||
f"User with email {obj_in.email} already exists"
|
||||
)
|
||||
logger.error("Integrity error creating user: %s", error_msg)
|
||||
raise DuplicateEntryError(f"Database integrity error: {error_msg}")
|
||||
logger.warning(f"Duplicate email attempted: {obj_in.email}")
|
||||
raise ValueError(f"User with email {obj_in.email} already exists")
|
||||
logger.error(f"Integrity error creating user: {error_msg}")
|
||||
raise ValueError(f"Database integrity error: {error_msg}")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.exception("Unexpected error creating user: %s", e)
|
||||
raise
|
||||
|
||||
async def create_oauth_user(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
email: str,
|
||||
first_name: str = "User",
|
||||
last_name: str | None = None,
|
||||
) -> User:
|
||||
"""Create a new passwordless user for OAuth sign-in."""
|
||||
try:
|
||||
db_obj = User(
|
||||
email=email,
|
||||
password_hash=None, # OAuth-only user
|
||||
first_name=first_name,
|
||||
last_name=last_name,
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.flush() # Get user.id without committing
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
if "email" in error_msg.lower():
|
||||
logger.warning("Duplicate email attempted: %s", email)
|
||||
raise DuplicateEntryError(f"User with email {email} already exists")
|
||||
logger.error("Integrity error creating OAuth user: %s", error_msg)
|
||||
raise DuplicateEntryError(f"Database integrity error: {error_msg}")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.exception("Unexpected error creating OAuth user: %s", e)
|
||||
logger.error(f"Unexpected error creating user: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def update(
|
||||
@@ -111,6 +75,8 @@ class UserRepository(BaseRepository[User, UserCreate, UserUpdate]):
|
||||
else:
|
||||
update_data = obj_in.model_dump(exclude_unset=True)
|
||||
|
||||
# Handle password separately if it exists in update data
|
||||
# Hash password asynchronously to avoid blocking event loop
|
||||
if "password" in update_data:
|
||||
update_data["password_hash"] = await get_password_hash_async(
|
||||
update_data["password"]
|
||||
@@ -119,15 +85,6 @@ class UserRepository(BaseRepository[User, UserCreate, UserUpdate]):
|
||||
|
||||
return await super().update(db, db_obj=db_obj, obj_in=update_data)
|
||||
|
||||
async def update_password(
|
||||
self, db: AsyncSession, *, user: User, password_hash: str
|
||||
) -> User:
|
||||
"""Set a new password hash on a user and commit."""
|
||||
user.password_hash = password_hash
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
return user
|
||||
|
||||
async def get_multi_with_total(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
@@ -139,23 +96,43 @@ class UserRepository(BaseRepository[User, UserCreate, UserUpdate]):
|
||||
filters: dict[str, Any] | None = None,
|
||||
search: str | None = None,
|
||||
) -> tuple[list[User], int]:
|
||||
"""Get multiple users with total count, filtering, sorting, and search."""
|
||||
"""
|
||||
Get multiple users with total count, filtering, sorting, and search.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
skip: Number of records to skip
|
||||
limit: Maximum number of records to return
|
||||
sort_by: Field name to sort by
|
||||
sort_order: Sort order ("asc" or "desc")
|
||||
filters: Dictionary of filters (field_name: value)
|
||||
search: Search term to match against email, first_name, last_name
|
||||
|
||||
Returns:
|
||||
Tuple of (users list, total count)
|
||||
"""
|
||||
# Validate pagination
|
||||
if skip < 0:
|
||||
raise InvalidInputError("skip must be non-negative")
|
||||
raise ValueError("skip must be non-negative")
|
||||
if limit < 0:
|
||||
raise InvalidInputError("limit must be non-negative")
|
||||
raise ValueError("limit must be non-negative")
|
||||
if limit > 1000:
|
||||
raise InvalidInputError("Maximum limit is 1000")
|
||||
raise ValueError("Maximum limit is 1000")
|
||||
|
||||
try:
|
||||
# Build base query
|
||||
query = select(User)
|
||||
|
||||
# Exclude soft-deleted users
|
||||
query = query.where(User.deleted_at.is_(None))
|
||||
|
||||
# Apply filters
|
||||
if filters:
|
||||
for field, value in filters.items():
|
||||
if hasattr(User, field) and value is not None:
|
||||
query = query.where(getattr(User, field) == value)
|
||||
|
||||
# Apply search
|
||||
if search:
|
||||
search_filter = or_(
|
||||
User.email.ilike(f"%{search}%"),
|
||||
@@ -164,12 +141,14 @@ class UserRepository(BaseRepository[User, UserCreate, UserUpdate]):
|
||||
)
|
||||
query = query.where(search_filter)
|
||||
|
||||
# Get total count
|
||||
from sqlalchemy import func
|
||||
|
||||
count_query = select(func.count()).select_from(query.alias())
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# Apply sorting
|
||||
if sort_by and hasattr(User, sort_by):
|
||||
sort_column = getattr(User, sort_by)
|
||||
if sort_order.lower() == "desc":
|
||||
@@ -177,6 +156,7 @@ class UserRepository(BaseRepository[User, UserCreate, UserUpdate]):
|
||||
else:
|
||||
query = query.order_by(sort_column.asc())
|
||||
|
||||
# Apply pagination
|
||||
query = query.offset(skip).limit(limit)
|
||||
result = await db.execute(query)
|
||||
users = list(result.scalars().all())
|
||||
@@ -184,21 +164,32 @@ class UserRepository(BaseRepository[User, UserCreate, UserUpdate]):
|
||||
return users, total
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error retrieving paginated users: %s", e)
|
||||
logger.error(f"Error retrieving paginated users: {e!s}")
|
||||
raise
|
||||
|
||||
async def bulk_update_status(
|
||||
self, db: AsyncSession, *, user_ids: list[UUID], is_active: bool
|
||||
) -> int:
|
||||
"""Bulk update is_active status for multiple users."""
|
||||
"""
|
||||
Bulk update is_active status for multiple users.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_ids: List of user IDs to update
|
||||
is_active: New active status
|
||||
|
||||
Returns:
|
||||
Number of users updated
|
||||
"""
|
||||
try:
|
||||
if not user_ids:
|
||||
return 0
|
||||
|
||||
# Use UPDATE with WHERE IN for efficiency
|
||||
stmt = (
|
||||
update(User)
|
||||
.where(User.id.in_(user_ids))
|
||||
.where(User.deleted_at.is_(None))
|
||||
.where(User.deleted_at.is_(None)) # Don't update deleted users
|
||||
.values(is_active=is_active, updated_at=datetime.now(UTC))
|
||||
)
|
||||
|
||||
@@ -206,14 +197,12 @@ class UserRepository(BaseRepository[User, UserCreate, UserUpdate]):
|
||||
await db.commit()
|
||||
|
||||
updated_count = result.rowcount
|
||||
logger.info(
|
||||
"Bulk updated %s users to is_active=%s", updated_count, is_active
|
||||
)
|
||||
logger.info(f"Bulk updated {updated_count} users to is_active={is_active}")
|
||||
return updated_count
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.exception("Error bulk updating user status: %s", e)
|
||||
logger.error(f"Error bulk updating user status: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def bulk_soft_delete(
|
||||
@@ -223,20 +212,34 @@ class UserRepository(BaseRepository[User, UserCreate, UserUpdate]):
|
||||
user_ids: list[UUID],
|
||||
exclude_user_id: UUID | None = None,
|
||||
) -> int:
|
||||
"""Bulk soft delete multiple users."""
|
||||
"""
|
||||
Bulk soft delete multiple users.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_ids: List of user IDs to delete
|
||||
exclude_user_id: Optional user ID to exclude (e.g., the admin performing the action)
|
||||
|
||||
Returns:
|
||||
Number of users deleted
|
||||
"""
|
||||
try:
|
||||
if not user_ids:
|
||||
return 0
|
||||
|
||||
# Remove excluded user from list
|
||||
filtered_ids = [uid for uid in user_ids if uid != exclude_user_id]
|
||||
|
||||
if not filtered_ids:
|
||||
return 0
|
||||
|
||||
# Use UPDATE with WHERE IN for efficiency
|
||||
stmt = (
|
||||
update(User)
|
||||
.where(User.id.in_(filtered_ids))
|
||||
.where(User.deleted_at.is_(None))
|
||||
.where(
|
||||
User.deleted_at.is_(None)
|
||||
) # Don't re-delete already deleted users
|
||||
.values(
|
||||
deleted_at=datetime.now(UTC),
|
||||
is_active=False,
|
||||
@@ -248,22 +251,22 @@ class UserRepository(BaseRepository[User, UserCreate, UserUpdate]):
|
||||
await db.commit()
|
||||
|
||||
deleted_count = result.rowcount
|
||||
logger.info("Bulk soft deleted %s users", deleted_count)
|
||||
logger.info(f"Bulk soft deleted {deleted_count} users")
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.exception("Error bulk deleting users: %s", e)
|
||||
logger.error(f"Error bulk deleting users: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
def is_active(self, user: User) -> bool:
|
||||
"""Check if user is active."""
|
||||
return bool(user.is_active)
|
||||
return user.is_active
|
||||
|
||||
def is_superuser(self, user: User) -> bool:
|
||||
"""Check if user is a superuser."""
|
||||
return bool(user.is_superuser)
|
||||
return user.is_superuser
|
||||
|
||||
|
||||
# Singleton instance
|
||||
user_repo = UserRepository(User)
|
||||
# Create a singleton instance for use across the application
|
||||
user = CRUDUser(User)
|
||||
@@ -16,10 +16,10 @@ from sqlalchemy import select, text
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.database import SessionLocal, engine
|
||||
from app.crud.user import user as user_crud
|
||||
from app.models.organization import Organization
|
||||
from app.models.user import User
|
||||
from app.models.user_organization import UserOrganization
|
||||
from app.repositories.user import user_repo as user_repo
|
||||
from app.schemas.users import UserCreate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -44,17 +44,16 @@ async def init_db() -> User | None:
|
||||
if not settings.FIRST_SUPERUSER_EMAIL or not settings.FIRST_SUPERUSER_PASSWORD:
|
||||
logger.warning(
|
||||
"First superuser credentials not configured in settings. "
|
||||
"Using defaults: %s",
|
||||
superuser_email,
|
||||
f"Using defaults: {superuser_email}"
|
||||
)
|
||||
|
||||
async with SessionLocal() as session:
|
||||
try:
|
||||
# Check if superuser already exists
|
||||
existing_user = await user_repo.get_by_email(session, email=superuser_email)
|
||||
existing_user = await user_crud.get_by_email(session, email=superuser_email)
|
||||
|
||||
if existing_user:
|
||||
logger.info("Superuser already exists: %s", existing_user.email)
|
||||
logger.info(f"Superuser already exists: {existing_user.email}")
|
||||
return existing_user
|
||||
|
||||
# Create superuser if doesn't exist
|
||||
@@ -66,11 +65,11 @@ async def init_db() -> User | None:
|
||||
is_superuser=True,
|
||||
)
|
||||
|
||||
user = await user_repo.create(session, obj_in=user_in)
|
||||
user = await user_crud.create(session, obj_in=user_in)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
|
||||
logger.info("Created first superuser: %s", user.email)
|
||||
logger.info(f"Created first superuser: {user.email}")
|
||||
|
||||
# Create demo data if in demo mode
|
||||
if settings.DEMO_MODE:
|
||||
@@ -80,7 +79,7 @@ async def init_db() -> User | None:
|
||||
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.error("Error initializing database: %s", e)
|
||||
logger.error(f"Error initializing database: {e}")
|
||||
raise
|
||||
|
||||
|
||||
@@ -93,7 +92,7 @@ async def load_demo_data(session):
|
||||
"""Load demo data from JSON file."""
|
||||
demo_data_path = Path(__file__).parent / "core" / "demo_data.json"
|
||||
if not demo_data_path.exists():
|
||||
logger.warning("Demo data file not found: %s", demo_data_path)
|
||||
logger.warning(f"Demo data file not found: {demo_data_path}")
|
||||
return
|
||||
|
||||
try:
|
||||
@@ -120,7 +119,7 @@ async def load_demo_data(session):
|
||||
session.add(org)
|
||||
await session.flush() # Flush to get ID
|
||||
org_map[org.slug] = org
|
||||
logger.info("Created demo organization: %s", org.name)
|
||||
logger.info(f"Created demo organization: {org.name}")
|
||||
else:
|
||||
# We can't easily get the ORM object from raw SQL result for map without querying again or mapping
|
||||
# So let's just query it properly if we need it for relationships
|
||||
@@ -136,7 +135,7 @@ async def load_demo_data(session):
|
||||
|
||||
# Create Users
|
||||
for user_data in data.get("users", []):
|
||||
existing_user = await user_repo.get_by_email(
|
||||
existing_user = await user_crud.get_by_email(
|
||||
session, email=user_data["email"]
|
||||
)
|
||||
if not existing_user:
|
||||
@@ -149,7 +148,7 @@ async def load_demo_data(session):
|
||||
is_superuser=user_data["is_superuser"],
|
||||
is_active=user_data.get("is_active", True),
|
||||
)
|
||||
user = await user_repo.create(session, obj_in=user_in)
|
||||
user = await user_crud.create(session, obj_in=user_in)
|
||||
|
||||
# Randomize created_at for demo data (last 30 days)
|
||||
# This makes the charts look more realistic
|
||||
@@ -175,10 +174,7 @@ async def load_demo_data(session):
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Created demo user: %s (created %s days ago, active=%s)",
|
||||
user.email,
|
||||
days_ago,
|
||||
user_data.get("is_active", True),
|
||||
f"Created demo user: {user.email} (created {days_ago} days ago, active={user_data.get('is_active', True)})"
|
||||
)
|
||||
|
||||
# Add to organization if specified
|
||||
@@ -191,15 +187,15 @@ async def load_demo_data(session):
|
||||
user_id=user.id, organization_id=org.id, role=role
|
||||
)
|
||||
session.add(member)
|
||||
logger.info("Added %s to %s as %s", user.email, org.name, role)
|
||||
logger.info(f"Added {user.email} to {org.name} as {role}")
|
||||
else:
|
||||
logger.info("Demo user already exists: %s", existing_user.email)
|
||||
logger.info(f"Demo user already exists: {existing_user.email}")
|
||||
|
||||
await session.commit()
|
||||
logger.info("Demo data loaded successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error loading demo data: %s", e)
|
||||
logger.error(f"Error loading demo data: {e}")
|
||||
raise
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import UTC, datetime
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
@@ -16,7 +16,7 @@ from slowapi.util import get_remote_address
|
||||
from app.api.main import api_router
|
||||
from app.api.routes.oauth_provider import wellknown_router as oauth_wellknown_router
|
||||
from app.core.config import settings
|
||||
from app.core.database import check_database_health, close_async_db
|
||||
from app.core.database import check_database_health
|
||||
from app.core.exceptions import (
|
||||
APIException,
|
||||
api_exception_handler,
|
||||
@@ -72,7 +72,6 @@ async def lifespan(app: FastAPI):
|
||||
if os.getenv("IS_TEST", "False") != "True":
|
||||
scheduler.shutdown()
|
||||
logger.info("Scheduled jobs stopped")
|
||||
await close_async_db()
|
||||
|
||||
|
||||
logger.info("Starting app!!!")
|
||||
@@ -295,7 +294,7 @@ async def health_check() -> JSONResponse:
|
||||
"""
|
||||
health_status: dict[str, Any] = {
|
||||
"status": "healthy",
|
||||
"timestamp": datetime.now(UTC).isoformat().replace("+00:00", "Z"),
|
||||
"timestamp": datetime.utcnow().isoformat() + "Z",
|
||||
"version": settings.VERSION,
|
||||
"environment": settings.ENVIRONMENT,
|
||||
"checks": {},
|
||||
@@ -320,7 +319,7 @@ async def health_check() -> JSONResponse:
|
||||
"message": f"Database connection failed: {e!s}",
|
||||
}
|
||||
response_status = status.HTTP_503_SERVICE_UNAVAILABLE
|
||||
logger.error("Health check failed - database error: %s", e)
|
||||
logger.error(f"Health check failed - database error: {e}")
|
||||
|
||||
return JSONResponse(status_code=response_status, content=health_status)
|
||||
|
||||
|
||||
@@ -23,6 +23,15 @@ from .user import User
|
||||
from .user_organization import OrganizationRole, UserOrganization
|
||||
from .user_session import UserSession
|
||||
|
||||
# Syndarix domain models
|
||||
from .syndarix import (
|
||||
AgentInstance,
|
||||
AgentType,
|
||||
Issue,
|
||||
Project,
|
||||
Sprint,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Base",
|
||||
"OAuthAccount",
|
||||
@@ -38,4 +47,10 @@ __all__ = [
|
||||
"User",
|
||||
"UserOrganization",
|
||||
"UserSession",
|
||||
# Syndarix models
|
||||
"AgentInstance",
|
||||
"AgentType",
|
||||
"Issue",
|
||||
"Project",
|
||||
"Sprint",
|
||||
]
|
||||
|
||||
@@ -36,9 +36,9 @@ class OAuthAccount(Base, UUIDMixin, TimestampMixin):
|
||||
) # Email from provider (for reference)
|
||||
|
||||
# Optional: store provider tokens for API access
|
||||
# TODO: Encrypt these at rest in production (requires key management infrastructure)
|
||||
access_token = Column(String(2048), nullable=True)
|
||||
refresh_token = Column(String(2048), nullable=True)
|
||||
# These should be encrypted at rest in production
|
||||
access_token_encrypted = Column(String(2048), nullable=True)
|
||||
refresh_token_encrypted = Column(String(2048), nullable=True)
|
||||
token_expires_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Relationship
|
||||
|
||||
@@ -92,7 +92,7 @@ class OAuthAuthorizationCode(Base, UUIDMixin, TimestampMixin):
|
||||
# Handle both timezone-aware and naive datetimes from DB
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=UTC)
|
||||
return bool(now > expires_at)
|
||||
return now > expires_at
|
||||
|
||||
@property
|
||||
def is_valid(self) -> bool:
|
||||
|
||||
@@ -99,7 +99,7 @@ class OAuthProviderRefreshToken(Base, UUIDMixin, TimestampMixin):
|
||||
# Handle both timezone-aware and naive datetimes from DB
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=UTC)
|
||||
return bool(now > expires_at)
|
||||
return now > expires_at
|
||||
|
||||
@property
|
||||
def is_valid(self) -> bool:
|
||||
|
||||
41
backend/app/models/syndarix/__init__.py
Normal file
41
backend/app/models/syndarix/__init__.py
Normal file
@@ -0,0 +1,41 @@
|
||||
# app/models/syndarix/__init__.py
|
||||
"""
|
||||
Syndarix domain models.
|
||||
|
||||
This package contains all the core entities for the Syndarix AI consulting platform:
|
||||
- Project: Client engagements with autonomy settings
|
||||
- AgentType: Templates for AI agent capabilities
|
||||
- AgentInstance: Spawned agents working on projects
|
||||
- Issue: Units of work with external tracker sync
|
||||
- Sprint: Time-boxed iterations for organizing work
|
||||
"""
|
||||
|
||||
from .agent_instance import AgentInstance
|
||||
from .agent_type import AgentType
|
||||
from .enums import (
|
||||
AgentStatus,
|
||||
AutonomyLevel,
|
||||
IssuePriority,
|
||||
IssueStatus,
|
||||
ProjectStatus,
|
||||
SprintStatus,
|
||||
SyncStatus,
|
||||
)
|
||||
from .issue import Issue
|
||||
from .project import Project
|
||||
from .sprint import Sprint
|
||||
|
||||
__all__ = [
|
||||
"AgentInstance",
|
||||
"AgentStatus",
|
||||
"AgentType",
|
||||
"AutonomyLevel",
|
||||
"Issue",
|
||||
"IssuePriority",
|
||||
"IssueStatus",
|
||||
"Project",
|
||||
"ProjectStatus",
|
||||
"Sprint",
|
||||
"SprintStatus",
|
||||
"SyncStatus",
|
||||
]
|
||||
108
backend/app/models/syndarix/agent_instance.py
Normal file
108
backend/app/models/syndarix/agent_instance.py
Normal file
@@ -0,0 +1,108 @@
|
||||
# app/models/syndarix/agent_instance.py
|
||||
"""
|
||||
AgentInstance model for Syndarix AI consulting platform.
|
||||
|
||||
An AgentInstance is a spawned instance of an AgentType, assigned to a
|
||||
specific project to perform work.
|
||||
"""
|
||||
|
||||
from sqlalchemy import (
|
||||
BigInteger,
|
||||
Column,
|
||||
DateTime,
|
||||
Enum,
|
||||
ForeignKey,
|
||||
Index,
|
||||
Integer,
|
||||
Numeric,
|
||||
String,
|
||||
Text,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import (
|
||||
JSONB,
|
||||
UUID as PGUUID,
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.models.base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
from .enums import AgentStatus
|
||||
|
||||
|
||||
class AgentInstance(Base, UUIDMixin, TimestampMixin):
|
||||
"""
|
||||
AgentInstance model representing a spawned agent working on a project.
|
||||
|
||||
Tracks:
|
||||
- Current status and task
|
||||
- Memory (short-term in DB, long-term reference to vector store)
|
||||
- Session information for MCP connections
|
||||
- Usage metrics (tasks completed, tokens, cost)
|
||||
"""
|
||||
|
||||
__tablename__ = "agent_instances"
|
||||
|
||||
# Foreign keys
|
||||
agent_type_id = Column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("agent_types.id", ondelete="RESTRICT"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
project_id = Column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("projects.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Status tracking
|
||||
status: Column[AgentStatus] = Column(
|
||||
Enum(AgentStatus),
|
||||
default=AgentStatus.IDLE,
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Current task description (brief summary of what agent is doing)
|
||||
current_task = Column(Text, nullable=True)
|
||||
|
||||
# Short-term memory stored in database (conversation context, recent decisions)
|
||||
short_term_memory = Column(JSONB, default=dict, nullable=False)
|
||||
|
||||
# Reference to long-term memory in vector store (e.g., "project-123/agent-456")
|
||||
long_term_memory_ref = Column(String(500), nullable=True)
|
||||
|
||||
# Session ID for active MCP connections
|
||||
session_id = Column(String(255), nullable=True, index=True)
|
||||
|
||||
# Activity tracking
|
||||
last_activity_at = Column(DateTime(timezone=True), nullable=True, index=True)
|
||||
terminated_at = Column(DateTime(timezone=True), nullable=True, index=True)
|
||||
|
||||
# Usage metrics
|
||||
tasks_completed = Column(Integer, default=0, nullable=False)
|
||||
tokens_used = Column(BigInteger, default=0, nullable=False)
|
||||
cost_incurred = Column(Numeric(precision=10, scale=4), default=0, nullable=False)
|
||||
|
||||
# Relationships
|
||||
agent_type = relationship("AgentType", back_populates="instances")
|
||||
project = relationship("Project", back_populates="agent_instances")
|
||||
assigned_issues = relationship(
|
||||
"Issue",
|
||||
back_populates="assigned_agent",
|
||||
foreign_keys="Issue.assigned_agent_id",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_agent_instances_project_status", "project_id", "status"),
|
||||
Index("ix_agent_instances_type_status", "agent_type_id", "status"),
|
||||
Index("ix_agent_instances_project_type", "project_id", "agent_type_id"),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<AgentInstance {self.id} type={self.agent_type_id} "
|
||||
f"project={self.project_id} status={self.status.value}>"
|
||||
)
|
||||
72
backend/app/models/syndarix/agent_type.py
Normal file
72
backend/app/models/syndarix/agent_type.py
Normal file
@@ -0,0 +1,72 @@
|
||||
# app/models/syndarix/agent_type.py
|
||||
"""
|
||||
AgentType model for Syndarix AI consulting platform.
|
||||
|
||||
An AgentType is a template that defines the capabilities, personality,
|
||||
and model configuration for agent instances.
|
||||
"""
|
||||
|
||||
from sqlalchemy import Boolean, Column, Index, String, Text
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.models.base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
|
||||
class AgentType(Base, UUIDMixin, TimestampMixin):
|
||||
"""
|
||||
AgentType model representing a template for agent instances.
|
||||
|
||||
Each agent type defines:
|
||||
- Expertise areas and personality prompt
|
||||
- Model configuration (primary, fallback, parameters)
|
||||
- MCP server access and tool permissions
|
||||
|
||||
Examples: ProductOwner, Architect, BackendEngineer, QAEngineer
|
||||
"""
|
||||
|
||||
__tablename__ = "agent_types"
|
||||
|
||||
name = Column(String(255), nullable=False, index=True)
|
||||
slug = Column(String(255), unique=True, nullable=False, index=True)
|
||||
description = Column(Text, nullable=True)
|
||||
|
||||
# Areas of expertise for this agent type (e.g., ["python", "fastapi", "databases"])
|
||||
expertise = Column(JSONB, default=list, nullable=False)
|
||||
|
||||
# System prompt defining the agent's personality and behavior
|
||||
personality_prompt = Column(Text, nullable=False)
|
||||
|
||||
# Primary LLM model to use (e.g., "claude-opus-4-5-20251101")
|
||||
primary_model = Column(String(100), nullable=False)
|
||||
|
||||
# Fallback models in order of preference
|
||||
fallback_models = Column(JSONB, default=list, nullable=False)
|
||||
|
||||
# Model parameters (temperature, max_tokens, etc.)
|
||||
model_params = Column(JSONB, default=dict, nullable=False)
|
||||
|
||||
# List of MCP servers this agent can connect to
|
||||
mcp_servers = Column(JSONB, default=list, nullable=False)
|
||||
|
||||
# Tool permissions configuration
|
||||
# Structure: {"allowed": ["*"], "denied": [], "require_approval": ["gitea:create_pr"]}
|
||||
tool_permissions = Column(JSONB, default=dict, nullable=False)
|
||||
|
||||
# Whether this agent type is available for new instances
|
||||
is_active = Column(Boolean, default=True, nullable=False, index=True)
|
||||
|
||||
# Relationships
|
||||
instances = relationship(
|
||||
"AgentInstance",
|
||||
back_populates="agent_type",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_agent_types_slug_active", "slug", "is_active"),
|
||||
Index("ix_agent_types_name_active", "name", "is_active"),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<AgentType {self.name} ({self.slug}) active={self.is_active}>"
|
||||
123
backend/app/models/syndarix/enums.py
Normal file
123
backend/app/models/syndarix/enums.py
Normal file
@@ -0,0 +1,123 @@
|
||||
# app/models/syndarix/enums.py
|
||||
"""
|
||||
Enums for Syndarix domain models.
|
||||
|
||||
These enums represent the core state machines and categorizations
|
||||
used throughout the Syndarix AI consulting platform.
|
||||
"""
|
||||
|
||||
from enum import Enum as PyEnum
|
||||
|
||||
|
||||
class AutonomyLevel(str, PyEnum):
|
||||
"""
|
||||
Defines how much control the human has over agent actions.
|
||||
|
||||
FULL_CONTROL: Human must approve every agent action
|
||||
MILESTONE: Human approves at sprint boundaries and major decisions
|
||||
AUTONOMOUS: Agents work independently, only escalating critical issues
|
||||
"""
|
||||
|
||||
FULL_CONTROL = "full_control"
|
||||
MILESTONE = "milestone"
|
||||
AUTONOMOUS = "autonomous"
|
||||
|
||||
|
||||
class ProjectStatus(str, PyEnum):
|
||||
"""
|
||||
Project lifecycle status.
|
||||
|
||||
ACTIVE: Project is actively being worked on
|
||||
PAUSED: Project is temporarily on hold
|
||||
COMPLETED: Project has been delivered successfully
|
||||
ARCHIVED: Project is no longer accessible for work
|
||||
"""
|
||||
|
||||
ACTIVE = "active"
|
||||
PAUSED = "paused"
|
||||
COMPLETED = "completed"
|
||||
ARCHIVED = "archived"
|
||||
|
||||
|
||||
class AgentStatus(str, PyEnum):
|
||||
"""
|
||||
Current operational status of an agent instance.
|
||||
|
||||
IDLE: Agent is available but not currently working
|
||||
WORKING: Agent is actively processing a task
|
||||
WAITING: Agent is waiting for external input or approval
|
||||
PAUSED: Agent has been manually paused
|
||||
TERMINATED: Agent instance has been shut down
|
||||
"""
|
||||
|
||||
IDLE = "idle"
|
||||
WORKING = "working"
|
||||
WAITING = "waiting"
|
||||
PAUSED = "paused"
|
||||
TERMINATED = "terminated"
|
||||
|
||||
|
||||
class IssueStatus(str, PyEnum):
|
||||
"""
|
||||
Issue workflow status.
|
||||
|
||||
OPEN: Issue is ready to be worked on
|
||||
IN_PROGRESS: Agent or human is actively working on the issue
|
||||
IN_REVIEW: Work is complete, awaiting review
|
||||
BLOCKED: Issue cannot proceed due to dependencies or blockers
|
||||
CLOSED: Issue has been completed or cancelled
|
||||
"""
|
||||
|
||||
OPEN = "open"
|
||||
IN_PROGRESS = "in_progress"
|
||||
IN_REVIEW = "in_review"
|
||||
BLOCKED = "blocked"
|
||||
CLOSED = "closed"
|
||||
|
||||
|
||||
class IssuePriority(str, PyEnum):
|
||||
"""
|
||||
Issue priority levels.
|
||||
|
||||
LOW: Nice to have, can be deferred
|
||||
MEDIUM: Standard priority, should be done
|
||||
HIGH: Important, should be prioritized
|
||||
CRITICAL: Must be done immediately, blocking other work
|
||||
"""
|
||||
|
||||
LOW = "low"
|
||||
MEDIUM = "medium"
|
||||
HIGH = "high"
|
||||
CRITICAL = "critical"
|
||||
|
||||
|
||||
class SyncStatus(str, PyEnum):
|
||||
"""
|
||||
External issue tracker synchronization status.
|
||||
|
||||
SYNCED: Local and remote are in sync
|
||||
PENDING: Local changes waiting to be pushed
|
||||
CONFLICT: Merge conflict between local and remote
|
||||
ERROR: Synchronization failed due to an error
|
||||
"""
|
||||
|
||||
SYNCED = "synced"
|
||||
PENDING = "pending"
|
||||
CONFLICT = "conflict"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class SprintStatus(str, PyEnum):
|
||||
"""
|
||||
Sprint lifecycle status.
|
||||
|
||||
PLANNED: Sprint has been created but not started
|
||||
ACTIVE: Sprint is currently in progress
|
||||
COMPLETED: Sprint has been finished successfully
|
||||
CANCELLED: Sprint was cancelled before completion
|
||||
"""
|
||||
|
||||
PLANNED = "planned"
|
||||
ACTIVE = "active"
|
||||
COMPLETED = "completed"
|
||||
CANCELLED = "cancelled"
|
||||
133
backend/app/models/syndarix/issue.py
Normal file
133
backend/app/models/syndarix/issue.py
Normal file
@@ -0,0 +1,133 @@
|
||||
# app/models/syndarix/issue.py
|
||||
"""
|
||||
Issue model for Syndarix AI consulting platform.
|
||||
|
||||
An Issue represents a unit of work that can be assigned to agents or humans,
|
||||
with optional synchronization to external issue trackers (Gitea, GitHub, GitLab).
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, DateTime, Enum, ForeignKey, Index, Integer, String, Text
|
||||
from sqlalchemy.dialects.postgresql import (
|
||||
JSONB,
|
||||
UUID as PGUUID,
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.models.base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
from .enums import IssuePriority, IssueStatus, SyncStatus
|
||||
|
||||
|
||||
class Issue(Base, UUIDMixin, TimestampMixin):
|
||||
"""
|
||||
Issue model representing a unit of work in a project.
|
||||
|
||||
Features:
|
||||
- Standard issue fields (title, body, status, priority)
|
||||
- Assignment to agent instances or human assignees
|
||||
- Sprint association for backlog management
|
||||
- External tracker synchronization (Gitea, GitHub, GitLab)
|
||||
"""
|
||||
|
||||
__tablename__ = "issues"
|
||||
|
||||
# Foreign key to project
|
||||
project_id = Column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("projects.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Issue content
|
||||
title = Column(String(500), nullable=False)
|
||||
body = Column(Text, nullable=False, default="")
|
||||
|
||||
# Status and priority
|
||||
status: Column[IssueStatus] = Column(
|
||||
Enum(IssueStatus),
|
||||
default=IssueStatus.OPEN,
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
priority: Column[IssuePriority] = Column(
|
||||
Enum(IssuePriority),
|
||||
default=IssuePriority.MEDIUM,
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Labels for categorization (e.g., ["bug", "frontend", "urgent"])
|
||||
labels = Column(JSONB, default=list, nullable=False)
|
||||
|
||||
# Assignment - either to an agent or a human (mutually exclusive)
|
||||
assigned_agent_id = Column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("agent_instances.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Human assignee (username or email, not a FK to allow external users)
|
||||
human_assignee = Column(String(255), nullable=True, index=True)
|
||||
|
||||
# Sprint association
|
||||
sprint_id = Column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("sprints.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Story points for estimation
|
||||
story_points = Column(Integer, nullable=True)
|
||||
|
||||
# External tracker integration
|
||||
external_tracker = Column(
|
||||
String(50),
|
||||
nullable=True,
|
||||
index=True,
|
||||
) # 'gitea', 'github', 'gitlab'
|
||||
|
||||
external_id = Column(String(255), nullable=True) # External system's ID
|
||||
external_url = Column(String(1000), nullable=True) # Link to external issue
|
||||
external_number = Column(Integer, nullable=True) # Issue number (e.g., #123)
|
||||
|
||||
# Sync status with external tracker
|
||||
sync_status: Column[SyncStatus] = Column(
|
||||
Enum(SyncStatus),
|
||||
default=SyncStatus.SYNCED,
|
||||
nullable=False,
|
||||
# Note: Index defined in __table_args__ as ix_issues_sync_status
|
||||
)
|
||||
|
||||
last_synced_at = Column(DateTime(timezone=True), nullable=True)
|
||||
external_updated_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Lifecycle timestamp
|
||||
closed_at = Column(DateTime(timezone=True), nullable=True, index=True)
|
||||
|
||||
# Relationships
|
||||
project = relationship("Project", back_populates="issues")
|
||||
assigned_agent = relationship(
|
||||
"AgentInstance",
|
||||
back_populates="assigned_issues",
|
||||
foreign_keys=[assigned_agent_id],
|
||||
)
|
||||
sprint = relationship("Sprint", back_populates="issues")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_issues_project_status", "project_id", "status"),
|
||||
Index("ix_issues_project_priority", "project_id", "priority"),
|
||||
Index("ix_issues_project_sprint", "project_id", "sprint_id"),
|
||||
Index("ix_issues_external_tracker_id", "external_tracker", "external_id"),
|
||||
Index("ix_issues_sync_status", "sync_status"),
|
||||
Index("ix_issues_project_agent", "project_id", "assigned_agent_id"),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<Issue {self.id} title='{self.title[:30]}...' "
|
||||
f"status={self.status.value} priority={self.priority.value}>"
|
||||
)
|
||||
88
backend/app/models/syndarix/project.py
Normal file
88
backend/app/models/syndarix/project.py
Normal file
@@ -0,0 +1,88 @@
|
||||
# app/models/syndarix/project.py
|
||||
"""
|
||||
Project model for Syndarix AI consulting platform.
|
||||
|
||||
A Project represents a client engagement where AI agents collaborate
|
||||
to deliver software solutions.
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, Enum, ForeignKey, Index, String, Text
|
||||
from sqlalchemy.dialects.postgresql import (
|
||||
JSONB,
|
||||
UUID as PGUUID,
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.models.base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
from .enums import AutonomyLevel, ProjectStatus
|
||||
|
||||
|
||||
class Project(Base, UUIDMixin, TimestampMixin):
|
||||
"""
|
||||
Project model representing a client engagement.
|
||||
|
||||
A project contains:
|
||||
- Configuration for how autonomous agents should operate
|
||||
- Settings for MCP server integrations
|
||||
- Relationship to assigned agents, issues, and sprints
|
||||
"""
|
||||
|
||||
__tablename__ = "projects"
|
||||
|
||||
name = Column(String(255), nullable=False, index=True)
|
||||
slug = Column(String(255), unique=True, nullable=False, index=True)
|
||||
description = Column(Text, nullable=True)
|
||||
|
||||
autonomy_level: Column[AutonomyLevel] = Column(
|
||||
Enum(AutonomyLevel),
|
||||
default=AutonomyLevel.MILESTONE,
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
status: Column[ProjectStatus] = Column(
|
||||
Enum(ProjectStatus),
|
||||
default=ProjectStatus.ACTIVE,
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# JSON field for flexible project configuration
|
||||
# Can include: mcp_servers, webhook_urls, notification_settings, etc.
|
||||
settings = Column(JSONB, default=dict, nullable=False)
|
||||
|
||||
# Foreign key to the User who owns this project
|
||||
owner_id = Column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
owner = relationship("User", foreign_keys=[owner_id])
|
||||
agent_instances = relationship(
|
||||
"AgentInstance",
|
||||
back_populates="project",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
issues = relationship(
|
||||
"Issue",
|
||||
back_populates="project",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
sprints = relationship(
|
||||
"Sprint",
|
||||
back_populates="project",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_projects_slug_status", "slug", "status"),
|
||||
Index("ix_projects_owner_status", "owner_id", "status"),
|
||||
Index("ix_projects_autonomy_status", "autonomy_level", "status"),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Project {self.name} ({self.slug}) status={self.status.value}>"
|
||||
74
backend/app/models/syndarix/sprint.py
Normal file
74
backend/app/models/syndarix/sprint.py
Normal file
@@ -0,0 +1,74 @@
|
||||
# app/models/syndarix/sprint.py
|
||||
"""
|
||||
Sprint model for Syndarix AI consulting platform.
|
||||
|
||||
A Sprint represents a time-boxed iteration for organizing and delivering work.
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, Date, Enum, ForeignKey, Index, Integer, String, Text
|
||||
from sqlalchemy.dialects.postgresql import UUID as PGUUID
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.models.base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
from .enums import SprintStatus
|
||||
|
||||
|
||||
class Sprint(Base, UUIDMixin, TimestampMixin):
|
||||
"""
|
||||
Sprint model representing a time-boxed iteration.
|
||||
|
||||
Tracks:
|
||||
- Sprint metadata (name, number, goal)
|
||||
- Date range (start/end)
|
||||
- Progress metrics (planned vs completed points)
|
||||
"""
|
||||
|
||||
__tablename__ = "sprints"
|
||||
|
||||
# Foreign key to project
|
||||
project_id = Column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("projects.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Sprint identification
|
||||
name = Column(String(255), nullable=False)
|
||||
number = Column(Integer, nullable=False) # Sprint number within project
|
||||
|
||||
# Sprint goal (what we aim to achieve)
|
||||
goal = Column(Text, nullable=True)
|
||||
|
||||
# Date range
|
||||
start_date = Column(Date, nullable=False, index=True)
|
||||
end_date = Column(Date, nullable=False, index=True)
|
||||
|
||||
# Status
|
||||
status: Column[SprintStatus] = Column(
|
||||
Enum(SprintStatus),
|
||||
default=SprintStatus.PLANNED,
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Progress metrics
|
||||
planned_points = Column(Integer, nullable=True) # Sum of story points at start
|
||||
completed_points = Column(Integer, nullable=True) # Sum of completed story points
|
||||
|
||||
# Relationships
|
||||
project = relationship("Project", back_populates="sprints")
|
||||
issues = relationship("Issue", back_populates="sprint")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_sprints_project_status", "project_id", "status"),
|
||||
Index("ix_sprints_project_number", "project_id", "number"),
|
||||
Index("ix_sprints_date_range", "start_date", "end_date"),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<Sprint {self.name} (#{self.number}) "
|
||||
f"project={self.project_id} status={self.status.value}>"
|
||||
)
|
||||
@@ -76,11 +76,7 @@ class UserSession(Base, UUIDMixin, TimestampMixin):
|
||||
"""Check if session has expired."""
|
||||
from datetime import datetime
|
||||
|
||||
now = datetime.now(UTC)
|
||||
expires_at = self.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=UTC)
|
||||
return bool(expires_at < now)
|
||||
return self.expires_at < datetime.now(UTC)
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert session to dictionary for serialization."""
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
# app/repositories/__init__.py
|
||||
"""Repository layer — all database access goes through these classes."""
|
||||
|
||||
from app.repositories.oauth_account import OAuthAccountRepository, oauth_account_repo
|
||||
from app.repositories.oauth_authorization_code import (
|
||||
OAuthAuthorizationCodeRepository,
|
||||
oauth_authorization_code_repo,
|
||||
)
|
||||
from app.repositories.oauth_client import OAuthClientRepository, oauth_client_repo
|
||||
from app.repositories.oauth_consent import OAuthConsentRepository, oauth_consent_repo
|
||||
from app.repositories.oauth_provider_token import (
|
||||
OAuthProviderTokenRepository,
|
||||
oauth_provider_token_repo,
|
||||
)
|
||||
from app.repositories.oauth_state import OAuthStateRepository, oauth_state_repo
|
||||
from app.repositories.organization import OrganizationRepository, organization_repo
|
||||
from app.repositories.session import SessionRepository, session_repo
|
||||
from app.repositories.user import UserRepository, user_repo
|
||||
|
||||
__all__ = [
|
||||
"OAuthAccountRepository",
|
||||
"OAuthAuthorizationCodeRepository",
|
||||
"OAuthClientRepository",
|
||||
"OAuthConsentRepository",
|
||||
"OAuthProviderTokenRepository",
|
||||
"OAuthStateRepository",
|
||||
"OrganizationRepository",
|
||||
"SessionRepository",
|
||||
"UserRepository",
|
||||
"oauth_account_repo",
|
||||
"oauth_authorization_code_repo",
|
||||
"oauth_client_repo",
|
||||
"oauth_consent_repo",
|
||||
"oauth_provider_token_repo",
|
||||
"oauth_state_repo",
|
||||
"organization_repo",
|
||||
"session_repo",
|
||||
"user_repo",
|
||||
]
|
||||
@@ -1,249 +0,0 @@
|
||||
# app/repositories/oauth_account.py
|
||||
"""Repository for OAuthAccount model async database operations."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import and_, delete, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
from app.core.repository_exceptions import DuplicateEntryError
|
||||
from app.models.oauth_account import OAuthAccount
|
||||
from app.repositories.base import BaseRepository
|
||||
from app.schemas.oauth import OAuthAccountCreate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmptySchema(BaseModel):
|
||||
"""Placeholder schema for repository operations that don't need update schemas."""
|
||||
|
||||
|
||||
class OAuthAccountRepository(
|
||||
BaseRepository[OAuthAccount, OAuthAccountCreate, EmptySchema]
|
||||
):
|
||||
"""Repository for OAuth account links."""
|
||||
|
||||
async def get_by_provider_id(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
provider: str,
|
||||
provider_user_id: str,
|
||||
) -> OAuthAccount | None:
|
||||
"""Get OAuth account by provider and provider user ID."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(OAuthAccount)
|
||||
.where(
|
||||
and_(
|
||||
OAuthAccount.provider == provider,
|
||||
OAuthAccount.provider_user_id == provider_user_id,
|
||||
)
|
||||
)
|
||||
.options(joinedload(OAuthAccount.user))
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error(
|
||||
"Error getting OAuth account for %s:%s: %s",
|
||||
provider,
|
||||
provider_user_id,
|
||||
e,
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_by_provider_email(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
provider: str,
|
||||
email: str,
|
||||
) -> OAuthAccount | None:
|
||||
"""Get OAuth account by provider and email."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(OAuthAccount)
|
||||
.where(
|
||||
and_(
|
||||
OAuthAccount.provider == provider,
|
||||
OAuthAccount.provider_email == email,
|
||||
)
|
||||
)
|
||||
.options(joinedload(OAuthAccount.user))
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error(
|
||||
"Error getting OAuth account for %s email %s: %s", provider, email, e
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_user_accounts(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: str | UUID,
|
||||
) -> list[OAuthAccount]:
|
||||
"""Get all OAuth accounts linked to a user."""
|
||||
try:
|
||||
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
|
||||
|
||||
result = await db.execute(
|
||||
select(OAuthAccount)
|
||||
.where(OAuthAccount.user_id == user_uuid)
|
||||
.order_by(OAuthAccount.created_at.desc())
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error("Error getting OAuth accounts for user %s: %s", user_id, e)
|
||||
raise
|
||||
|
||||
async def get_user_account_by_provider(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: str | UUID,
|
||||
provider: str,
|
||||
) -> OAuthAccount | None:
|
||||
"""Get a specific OAuth account for a user and provider."""
|
||||
try:
|
||||
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
|
||||
|
||||
result = await db.execute(
|
||||
select(OAuthAccount).where(
|
||||
and_(
|
||||
OAuthAccount.user_id == user_uuid,
|
||||
OAuthAccount.provider == provider,
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error(
|
||||
"Error getting OAuth account for user %s, provider %s: %s",
|
||||
user_id,
|
||||
provider,
|
||||
e,
|
||||
)
|
||||
raise
|
||||
|
||||
async def create_account(
|
||||
self, db: AsyncSession, *, obj_in: OAuthAccountCreate
|
||||
) -> OAuthAccount:
|
||||
"""Create a new OAuth account link."""
|
||||
try:
|
||||
db_obj = OAuthAccount(
|
||||
user_id=obj_in.user_id,
|
||||
provider=obj_in.provider,
|
||||
provider_user_id=obj_in.provider_user_id,
|
||||
provider_email=obj_in.provider_email,
|
||||
access_token=obj_in.access_token,
|
||||
refresh_token=obj_in.refresh_token,
|
||||
token_expires_at=obj_in.token_expires_at,
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
|
||||
logger.info(
|
||||
"OAuth account created: %s linked to user %s",
|
||||
obj_in.provider,
|
||||
obj_in.user_id,
|
||||
)
|
||||
return db_obj
|
||||
except IntegrityError as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
if "uq_oauth_provider_user" in error_msg.lower():
|
||||
logger.warning(
|
||||
"OAuth account already exists: %s:%s",
|
||||
obj_in.provider,
|
||||
obj_in.provider_user_id,
|
||||
)
|
||||
raise DuplicateEntryError(
|
||||
f"This {obj_in.provider} account is already linked to another user"
|
||||
)
|
||||
logger.error("Integrity error creating OAuth account: %s", error_msg)
|
||||
raise DuplicateEntryError(f"Failed to create OAuth account: {error_msg}")
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.exception("Error creating OAuth account: %s", e)
|
||||
raise
|
||||
|
||||
async def delete_account(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: str | UUID,
|
||||
provider: str,
|
||||
) -> bool:
|
||||
"""Delete an OAuth account link."""
|
||||
try:
|
||||
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
|
||||
|
||||
result = await db.execute(
|
||||
delete(OAuthAccount).where(
|
||||
and_(
|
||||
OAuthAccount.user_id == user_uuid,
|
||||
OAuthAccount.provider == provider,
|
||||
)
|
||||
)
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
deleted = result.rowcount > 0
|
||||
if deleted:
|
||||
logger.info(
|
||||
"OAuth account deleted: %s unlinked from user %s", provider, user_id
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"OAuth account not found for deletion: %s for user %s",
|
||||
provider,
|
||||
user_id,
|
||||
)
|
||||
|
||||
return deleted
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
"Error deleting OAuth account %s for user %s: %s", provider, user_id, e
|
||||
)
|
||||
raise
|
||||
|
||||
async def update_tokens(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
account: OAuthAccount,
|
||||
access_token: str | None = None,
|
||||
refresh_token: str | None = None,
|
||||
token_expires_at: datetime | None = None,
|
||||
) -> OAuthAccount:
|
||||
"""Update OAuth tokens for an account."""
|
||||
try:
|
||||
if access_token is not None:
|
||||
account.access_token = access_token
|
||||
if refresh_token is not None:
|
||||
account.refresh_token = refresh_token
|
||||
if token_expires_at is not None:
|
||||
account.token_expires_at = token_expires_at
|
||||
|
||||
db.add(account)
|
||||
await db.commit()
|
||||
await db.refresh(account)
|
||||
|
||||
return account
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error("Error updating OAuth tokens: %s", e)
|
||||
raise
|
||||
|
||||
|
||||
# Singleton instance
|
||||
oauth_account_repo = OAuthAccountRepository(OAuthAccount)
|
||||
@@ -1,108 +0,0 @@
|
||||
# app/repositories/oauth_authorization_code.py
|
||||
"""Repository for OAuthAuthorizationCode model."""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_, delete, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.oauth_authorization_code import OAuthAuthorizationCode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAuthAuthorizationCodeRepository:
|
||||
"""Repository for OAuth 2.0 authorization codes."""
|
||||
|
||||
async def create_code(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
code: str,
|
||||
client_id: str,
|
||||
user_id: UUID,
|
||||
redirect_uri: str,
|
||||
scope: str,
|
||||
expires_at: datetime,
|
||||
code_challenge: str | None = None,
|
||||
code_challenge_method: str | None = None,
|
||||
state: str | None = None,
|
||||
nonce: str | None = None,
|
||||
) -> OAuthAuthorizationCode:
|
||||
"""Create and persist a new authorization code."""
|
||||
auth_code = OAuthAuthorizationCode(
|
||||
code=code,
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
redirect_uri=redirect_uri,
|
||||
scope=scope,
|
||||
code_challenge=code_challenge,
|
||||
code_challenge_method=code_challenge_method,
|
||||
state=state,
|
||||
nonce=nonce,
|
||||
expires_at=expires_at,
|
||||
used=False,
|
||||
)
|
||||
db.add(auth_code)
|
||||
await db.commit()
|
||||
return auth_code
|
||||
|
||||
async def consume_code_atomically(
|
||||
self, db: AsyncSession, *, code: str
|
||||
) -> UUID | None:
|
||||
"""
|
||||
Atomically mark a code as used and return its UUID.
|
||||
|
||||
Returns the UUID if the code was found and not yet used, None otherwise.
|
||||
This prevents race conditions per RFC 6749 Section 4.1.2.
|
||||
"""
|
||||
stmt = (
|
||||
update(OAuthAuthorizationCode)
|
||||
.where(
|
||||
and_(
|
||||
OAuthAuthorizationCode.code == code,
|
||||
OAuthAuthorizationCode.used == False, # noqa: E712
|
||||
)
|
||||
)
|
||||
.values(used=True)
|
||||
.returning(OAuthAuthorizationCode.id)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
row_id = result.scalar_one_or_none()
|
||||
if row_id is not None:
|
||||
await db.commit()
|
||||
return row_id
|
||||
|
||||
async def get_by_id(
|
||||
self, db: AsyncSession, *, code_id: UUID
|
||||
) -> OAuthAuthorizationCode | None:
|
||||
"""Get authorization code by its UUID primary key."""
|
||||
result = await db.execute(
|
||||
select(OAuthAuthorizationCode).where(OAuthAuthorizationCode.id == code_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_by_code(
|
||||
self, db: AsyncSession, *, code: str
|
||||
) -> OAuthAuthorizationCode | None:
|
||||
"""Get authorization code by the code string value."""
|
||||
result = await db.execute(
|
||||
select(OAuthAuthorizationCode).where(OAuthAuthorizationCode.code == code)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def cleanup_expired(self, db: AsyncSession) -> int:
|
||||
"""Delete all expired authorization codes. Returns count deleted."""
|
||||
result = await db.execute(
|
||||
delete(OAuthAuthorizationCode).where(
|
||||
OAuthAuthorizationCode.expires_at < datetime.now(UTC)
|
||||
)
|
||||
)
|
||||
await db.commit()
|
||||
return result.rowcount # type: ignore[attr-defined]
|
||||
|
||||
|
||||
# Singleton instance
|
||||
oauth_authorization_code_repo = OAuthAuthorizationCodeRepository()
|
||||
@@ -1,201 +0,0 @@
|
||||
# app/repositories/oauth_client.py
|
||||
"""Repository for OAuthClient model async database operations."""
|
||||
|
||||
import logging
|
||||
import secrets
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import and_, delete, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.repository_exceptions import DuplicateEntryError
|
||||
from app.models.oauth_client import OAuthClient
|
||||
from app.repositories.base import BaseRepository
|
||||
from app.schemas.oauth import OAuthClientCreate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmptySchema(BaseModel):
|
||||
"""Placeholder schema for repository operations that don't need update schemas."""
|
||||
|
||||
|
||||
class OAuthClientRepository(
|
||||
BaseRepository[OAuthClient, OAuthClientCreate, EmptySchema]
|
||||
):
|
||||
"""Repository for OAuth clients (provider mode)."""
|
||||
|
||||
async def get_by_client_id(
|
||||
self, db: AsyncSession, *, client_id: str
|
||||
) -> OAuthClient | None:
|
||||
"""Get OAuth client by client_id."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(OAuthClient).where(
|
||||
and_(
|
||||
OAuthClient.client_id == client_id,
|
||||
OAuthClient.is_active == True, # noqa: E712
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error("Error getting OAuth client %s: %s", client_id, e)
|
||||
raise
|
||||
|
||||
async def create_client(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
obj_in: OAuthClientCreate,
|
||||
owner_user_id: UUID | None = None,
|
||||
) -> tuple[OAuthClient, str | None]:
|
||||
"""Create a new OAuth client."""
|
||||
try:
|
||||
client_id = secrets.token_urlsafe(32)
|
||||
|
||||
client_secret = None
|
||||
client_secret_hash = None
|
||||
if obj_in.client_type == "confidential":
|
||||
client_secret = secrets.token_urlsafe(48)
|
||||
from app.core.auth import get_password_hash
|
||||
|
||||
client_secret_hash = get_password_hash(client_secret)
|
||||
|
||||
db_obj = OAuthClient(
|
||||
client_id=client_id,
|
||||
client_secret_hash=client_secret_hash,
|
||||
client_name=obj_in.client_name,
|
||||
client_description=obj_in.client_description,
|
||||
client_type=obj_in.client_type,
|
||||
redirect_uris=obj_in.redirect_uris,
|
||||
allowed_scopes=obj_in.allowed_scopes,
|
||||
owner_user_id=owner_user_id,
|
||||
is_active=True,
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
|
||||
logger.info(
|
||||
"OAuth client created: %s (%s...)", obj_in.client_name, client_id[:8]
|
||||
)
|
||||
return db_obj, client_secret
|
||||
except IntegrityError as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
logger.error("Error creating OAuth client: %s", error_msg)
|
||||
raise DuplicateEntryError(f"Failed to create OAuth client: {error_msg}")
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.exception("Error creating OAuth client: %s", e)
|
||||
raise
|
||||
|
||||
async def deactivate_client(
|
||||
self, db: AsyncSession, *, client_id: str
|
||||
) -> OAuthClient | None:
|
||||
"""Deactivate an OAuth client."""
|
||||
try:
|
||||
client = await self.get_by_client_id(db, client_id=client_id)
|
||||
if client is None:
|
||||
return None
|
||||
|
||||
client.is_active = False
|
||||
db.add(client)
|
||||
await db.commit()
|
||||
await db.refresh(client)
|
||||
|
||||
logger.info("OAuth client deactivated: %s", client.client_name)
|
||||
return client
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error("Error deactivating OAuth client %s: %s", client_id, e)
|
||||
raise
|
||||
|
||||
async def validate_redirect_uri(
|
||||
self, db: AsyncSession, *, client_id: str, redirect_uri: str
|
||||
) -> bool:
|
||||
"""Validate that a redirect URI is allowed for a client."""
|
||||
try:
|
||||
client = await self.get_by_client_id(db, client_id=client_id)
|
||||
if client is None:
|
||||
return False
|
||||
|
||||
return redirect_uri in (client.redirect_uris or [])
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error("Error validating redirect URI: %s", e)
|
||||
return False
|
||||
|
||||
async def verify_client_secret(
|
||||
self, db: AsyncSession, *, client_id: str, client_secret: str
|
||||
) -> bool:
|
||||
"""Verify client credentials."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(OAuthClient).where(
|
||||
and_(
|
||||
OAuthClient.client_id == client_id,
|
||||
OAuthClient.is_active == True, # noqa: E712
|
||||
)
|
||||
)
|
||||
)
|
||||
client = result.scalar_one_or_none()
|
||||
|
||||
if client is None or client.client_secret_hash is None:
|
||||
return False
|
||||
|
||||
from app.core.auth import verify_password
|
||||
|
||||
stored_hash: str = str(client.client_secret_hash)
|
||||
|
||||
if stored_hash.startswith("$2"):
|
||||
return verify_password(client_secret, stored_hash)
|
||||
else:
|
||||
import hashlib
|
||||
|
||||
secret_hash = hashlib.sha256(client_secret.encode()).hexdigest()
|
||||
return secrets.compare_digest(stored_hash, secret_hash)
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error("Error verifying client secret: %s", e)
|
||||
return False
|
||||
|
||||
async def get_all_clients(
|
||||
self, db: AsyncSession, *, include_inactive: bool = False
|
||||
) -> list[OAuthClient]:
|
||||
"""Get all OAuth clients."""
|
||||
try:
|
||||
query = select(OAuthClient).order_by(OAuthClient.created_at.desc())
|
||||
if not include_inactive:
|
||||
query = query.where(OAuthClient.is_active == True) # noqa: E712
|
||||
|
||||
result = await db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error("Error getting all OAuth clients: %s", e)
|
||||
raise
|
||||
|
||||
async def delete_client(self, db: AsyncSession, *, client_id: str) -> bool:
|
||||
"""Delete an OAuth client permanently."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
delete(OAuthClient).where(OAuthClient.client_id == client_id)
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
deleted = result.rowcount > 0
|
||||
if deleted:
|
||||
logger.info("OAuth client deleted: %s", client_id)
|
||||
else:
|
||||
logger.warning("OAuth client not found for deletion: %s", client_id)
|
||||
|
||||
return deleted
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error("Error deleting OAuth client %s: %s", client_id, e)
|
||||
raise
|
||||
|
||||
|
||||
# Singleton instance
|
||||
oauth_client_repo = OAuthClientRepository(OAuthClient)
|
||||
@@ -1,113 +0,0 @@
|
||||
# app/repositories/oauth_consent.py
|
||||
"""Repository for OAuthConsent model."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_, delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.oauth_client import OAuthClient
|
||||
from app.models.oauth_provider_token import OAuthConsent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAuthConsentRepository:
|
||||
"""Repository for OAuth consent records (user grants to clients)."""
|
||||
|
||||
async def get_consent(
|
||||
self, db: AsyncSession, *, user_id: UUID, client_id: str
|
||||
) -> OAuthConsent | None:
|
||||
"""Get the consent record for a user-client pair, or None if not found."""
|
||||
result = await db.execute(
|
||||
select(OAuthConsent).where(
|
||||
and_(
|
||||
OAuthConsent.user_id == user_id,
|
||||
OAuthConsent.client_id == client_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def grant_consent(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: UUID,
|
||||
client_id: str,
|
||||
scopes: list[str],
|
||||
) -> OAuthConsent:
|
||||
"""
|
||||
Create or update consent for a user-client pair.
|
||||
|
||||
If consent already exists, the new scopes are merged with existing ones.
|
||||
Returns the created or updated consent record.
|
||||
"""
|
||||
consent = await self.get_consent(db, user_id=user_id, client_id=client_id)
|
||||
|
||||
if consent:
|
||||
existing = (
|
||||
set(consent.granted_scopes.split()) if consent.granted_scopes else set()
|
||||
)
|
||||
merged = existing | set(scopes)
|
||||
consent.granted_scopes = " ".join(sorted(merged)) # type: ignore[assignment]
|
||||
else:
|
||||
consent = OAuthConsent(
|
||||
user_id=user_id,
|
||||
client_id=client_id,
|
||||
granted_scopes=" ".join(sorted(set(scopes))),
|
||||
)
|
||||
db.add(consent)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(consent)
|
||||
return consent
|
||||
|
||||
async def get_user_consents_with_clients(
|
||||
self, db: AsyncSession, *, user_id: UUID
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get all consent records for a user joined with client details."""
|
||||
result = await db.execute(
|
||||
select(OAuthConsent, OAuthClient)
|
||||
.join(OAuthClient, OAuthConsent.client_id == OAuthClient.client_id)
|
||||
.where(OAuthConsent.user_id == user_id)
|
||||
)
|
||||
rows = result.all()
|
||||
return [
|
||||
{
|
||||
"client_id": consent.client_id,
|
||||
"client_name": client.client_name,
|
||||
"client_description": client.client_description,
|
||||
"granted_scopes": consent.granted_scopes.split()
|
||||
if consent.granted_scopes
|
||||
else [],
|
||||
"granted_at": consent.created_at.isoformat(),
|
||||
}
|
||||
for consent, client in rows
|
||||
]
|
||||
|
||||
async def revoke_consent(
|
||||
self, db: AsyncSession, *, user_id: UUID, client_id: str
|
||||
) -> bool:
|
||||
"""
|
||||
Delete the consent record for a user-client pair.
|
||||
|
||||
Returns True if a record was found and deleted.
|
||||
Note: Callers are responsible for also revoking associated tokens.
|
||||
"""
|
||||
result = await db.execute(
|
||||
delete(OAuthConsent).where(
|
||||
and_(
|
||||
OAuthConsent.user_id == user_id,
|
||||
OAuthConsent.client_id == client_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
await db.commit()
|
||||
return result.rowcount > 0 # type: ignore[attr-defined]
|
||||
|
||||
|
||||
# Singleton instance
|
||||
oauth_consent_repo = OAuthConsentRepository()
|
||||
@@ -1,142 +0,0 @@
|
||||
# app/repositories/oauth_provider_token.py
|
||||
"""Repository for OAuthProviderRefreshToken model."""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_, delete, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.oauth_provider_token import OAuthProviderRefreshToken
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAuthProviderTokenRepository:
|
||||
"""Repository for OAuth provider refresh tokens."""
|
||||
|
||||
async def create_token(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
token_hash: str,
|
||||
jti: str,
|
||||
client_id: str,
|
||||
user_id: UUID,
|
||||
scope: str,
|
||||
expires_at: datetime,
|
||||
device_info: str | None = None,
|
||||
ip_address: str | None = None,
|
||||
) -> OAuthProviderRefreshToken:
|
||||
"""Create and persist a new refresh token record."""
|
||||
token = OAuthProviderRefreshToken(
|
||||
token_hash=token_hash,
|
||||
jti=jti,
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
scope=scope,
|
||||
expires_at=expires_at,
|
||||
device_info=device_info,
|
||||
ip_address=ip_address,
|
||||
)
|
||||
db.add(token)
|
||||
await db.commit()
|
||||
return token
|
||||
|
||||
async def get_by_token_hash(
|
||||
self, db: AsyncSession, *, token_hash: str
|
||||
) -> OAuthProviderRefreshToken | None:
|
||||
"""Get refresh token record by SHA-256 token hash."""
|
||||
result = await db.execute(
|
||||
select(OAuthProviderRefreshToken).where(
|
||||
OAuthProviderRefreshToken.token_hash == token_hash
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_by_jti(
|
||||
self, db: AsyncSession, *, jti: str
|
||||
) -> OAuthProviderRefreshToken | None:
|
||||
"""Get refresh token record by JWT ID (JTI)."""
|
||||
result = await db.execute(
|
||||
select(OAuthProviderRefreshToken).where(
|
||||
OAuthProviderRefreshToken.jti == jti
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def revoke(
|
||||
self, db: AsyncSession, *, token: OAuthProviderRefreshToken
|
||||
) -> None:
|
||||
"""Mark a specific token record as revoked."""
|
||||
token.revoked = True # type: ignore[assignment]
|
||||
token.last_used_at = datetime.now(UTC) # type: ignore[assignment]
|
||||
await db.commit()
|
||||
|
||||
async def revoke_all_for_user_client(
|
||||
self, db: AsyncSession, *, user_id: UUID, client_id: str
|
||||
) -> int:
|
||||
"""
|
||||
Revoke all active tokens for a specific user-client pair.
|
||||
|
||||
Used when security incidents are detected (e.g., authorization code reuse).
|
||||
Returns the number of tokens revoked.
|
||||
"""
|
||||
result = await db.execute(
|
||||
update(OAuthProviderRefreshToken)
|
||||
.where(
|
||||
and_(
|
||||
OAuthProviderRefreshToken.user_id == user_id,
|
||||
OAuthProviderRefreshToken.client_id == client_id,
|
||||
OAuthProviderRefreshToken.revoked == False, # noqa: E712
|
||||
)
|
||||
)
|
||||
.values(revoked=True)
|
||||
)
|
||||
count = result.rowcount # type: ignore[attr-defined]
|
||||
if count > 0:
|
||||
await db.commit()
|
||||
return count
|
||||
|
||||
async def revoke_all_for_user(self, db: AsyncSession, *, user_id: UUID) -> int:
|
||||
"""
|
||||
Revoke all active tokens for a user across all clients.
|
||||
|
||||
Used when user changes password or logs out everywhere.
|
||||
Returns the number of tokens revoked.
|
||||
"""
|
||||
result = await db.execute(
|
||||
update(OAuthProviderRefreshToken)
|
||||
.where(
|
||||
and_(
|
||||
OAuthProviderRefreshToken.user_id == user_id,
|
||||
OAuthProviderRefreshToken.revoked == False, # noqa: E712
|
||||
)
|
||||
)
|
||||
.values(revoked=True)
|
||||
)
|
||||
count = result.rowcount # type: ignore[attr-defined]
|
||||
if count > 0:
|
||||
await db.commit()
|
||||
return count
|
||||
|
||||
async def cleanup_expired(self, db: AsyncSession, *, cutoff_days: int = 7) -> int:
|
||||
"""
|
||||
Delete expired refresh tokens older than cutoff_days.
|
||||
|
||||
Should be called periodically (e.g., daily).
|
||||
Returns the number of tokens deleted.
|
||||
"""
|
||||
cutoff = datetime.now(UTC) - timedelta(days=cutoff_days)
|
||||
result = await db.execute(
|
||||
delete(OAuthProviderRefreshToken).where(
|
||||
OAuthProviderRefreshToken.expires_at < cutoff
|
||||
)
|
||||
)
|
||||
await db.commit()
|
||||
return result.rowcount # type: ignore[attr-defined]
|
||||
|
||||
|
||||
# Singleton instance
|
||||
oauth_provider_token_repo = OAuthProviderTokenRepository()
|
||||
@@ -1,113 +0,0 @@
|
||||
# app/repositories/oauth_state.py
|
||||
"""Repository for OAuthState model async database operations."""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.repository_exceptions import DuplicateEntryError
|
||||
from app.models.oauth_state import OAuthState
|
||||
from app.repositories.base import BaseRepository
|
||||
from app.schemas.oauth import OAuthStateCreate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmptySchema(BaseModel):
|
||||
"""Placeholder schema for repository operations that don't need update schemas."""
|
||||
|
||||
|
||||
class OAuthStateRepository(BaseRepository[OAuthState, OAuthStateCreate, EmptySchema]):
|
||||
"""Repository for OAuth state (CSRF protection)."""
|
||||
|
||||
async def create_state(
|
||||
self, db: AsyncSession, *, obj_in: OAuthStateCreate
|
||||
) -> OAuthState:
|
||||
"""Create a new OAuth state for CSRF protection."""
|
||||
try:
|
||||
db_obj = OAuthState(
|
||||
state=obj_in.state,
|
||||
code_verifier=obj_in.code_verifier,
|
||||
nonce=obj_in.nonce,
|
||||
provider=obj_in.provider,
|
||||
redirect_uri=obj_in.redirect_uri,
|
||||
user_id=obj_in.user_id,
|
||||
expires_at=obj_in.expires_at,
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
|
||||
logger.debug("OAuth state created for %s", obj_in.provider)
|
||||
return db_obj
|
||||
except IntegrityError as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
logger.error("OAuth state collision: %s", error_msg)
|
||||
raise DuplicateEntryError("Failed to create OAuth state, please retry")
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.exception("Error creating OAuth state: %s", e)
|
||||
raise
|
||||
|
||||
async def get_and_consume_state(
|
||||
self, db: AsyncSession, *, state: str
|
||||
) -> OAuthState | None:
|
||||
"""Get and delete OAuth state (consume it)."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(OAuthState).where(OAuthState.state == state)
|
||||
)
|
||||
db_obj = result.scalar_one_or_none()
|
||||
|
||||
if db_obj is None:
|
||||
logger.warning("OAuth state not found: %s...", state[:8])
|
||||
return None
|
||||
|
||||
now = datetime.now(UTC)
|
||||
expires_at = db_obj.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=UTC)
|
||||
|
||||
if expires_at < now:
|
||||
logger.warning("OAuth state expired: %s...", state[:8])
|
||||
await db.delete(db_obj)
|
||||
await db.commit()
|
||||
return None
|
||||
|
||||
await db.delete(db_obj)
|
||||
await db.commit()
|
||||
|
||||
logger.debug("OAuth state consumed: %s...", state[:8])
|
||||
return db_obj
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error("Error consuming OAuth state: %s", e)
|
||||
raise
|
||||
|
||||
async def cleanup_expired(self, db: AsyncSession) -> int:
|
||||
"""Clean up expired OAuth states."""
|
||||
try:
|
||||
now = datetime.now(UTC)
|
||||
|
||||
stmt = delete(OAuthState).where(OAuthState.expires_at < now)
|
||||
result = await db.execute(stmt)
|
||||
await db.commit()
|
||||
|
||||
count = result.rowcount
|
||||
if count > 0:
|
||||
logger.info("Cleaned up %s expired OAuth states", count)
|
||||
|
||||
return count
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error("Error cleaning up expired OAuth states: %s", e)
|
||||
raise
|
||||
|
||||
|
||||
# Singleton instance
|
||||
oauth_state_repo = OAuthStateRepository(OAuthState)
|
||||
275
backend/app/schemas/events.py
Normal file
275
backend/app/schemas/events.py
Normal file
@@ -0,0 +1,275 @@
|
||||
"""
|
||||
Event schemas for the Syndarix EventBus (Redis Pub/Sub).
|
||||
|
||||
This module defines event types and payload schemas for real-time communication
|
||||
between services, agents, and the frontend.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class EventType(str, Enum):
|
||||
"""
|
||||
Event types for the EventBus.
|
||||
|
||||
Naming convention: {domain}.{action}
|
||||
"""
|
||||
|
||||
# Agent Events
|
||||
AGENT_SPAWNED = "agent.spawned"
|
||||
AGENT_STATUS_CHANGED = "agent.status_changed"
|
||||
AGENT_MESSAGE = "agent.message"
|
||||
AGENT_TERMINATED = "agent.terminated"
|
||||
|
||||
# Issue Events
|
||||
ISSUE_CREATED = "issue.created"
|
||||
ISSUE_UPDATED = "issue.updated"
|
||||
ISSUE_ASSIGNED = "issue.assigned"
|
||||
ISSUE_CLOSED = "issue.closed"
|
||||
|
||||
# Sprint Events
|
||||
SPRINT_STARTED = "sprint.started"
|
||||
SPRINT_COMPLETED = "sprint.completed"
|
||||
|
||||
# Approval Events
|
||||
APPROVAL_REQUESTED = "approval.requested"
|
||||
APPROVAL_GRANTED = "approval.granted"
|
||||
APPROVAL_DENIED = "approval.denied"
|
||||
|
||||
# Project Events
|
||||
PROJECT_CREATED = "project.created"
|
||||
PROJECT_UPDATED = "project.updated"
|
||||
PROJECT_ARCHIVED = "project.archived"
|
||||
|
||||
# Workflow Events
|
||||
WORKFLOW_STARTED = "workflow.started"
|
||||
WORKFLOW_STEP_COMPLETED = "workflow.step_completed"
|
||||
WORKFLOW_COMPLETED = "workflow.completed"
|
||||
WORKFLOW_FAILED = "workflow.failed"
|
||||
|
||||
|
||||
ActorType = Literal["agent", "user", "system"]
|
||||
|
||||
|
||||
class Event(BaseModel):
|
||||
"""
|
||||
Base event schema for the EventBus.
|
||||
|
||||
All events published to the EventBus must conform to this schema.
|
||||
"""
|
||||
|
||||
id: str = Field(
|
||||
...,
|
||||
description="Unique event identifier (UUID string)",
|
||||
examples=["550e8400-e29b-41d4-a716-446655440000"],
|
||||
)
|
||||
type: EventType = Field(
|
||||
...,
|
||||
description="Event type enum value",
|
||||
examples=[EventType.AGENT_MESSAGE],
|
||||
)
|
||||
timestamp: datetime = Field(
|
||||
...,
|
||||
description="When the event occurred (UTC)",
|
||||
examples=["2024-01-15T10:30:00Z"],
|
||||
)
|
||||
project_id: UUID = Field(
|
||||
...,
|
||||
description="Project this event belongs to",
|
||||
examples=["550e8400-e29b-41d4-a716-446655440001"],
|
||||
)
|
||||
actor_id: UUID | None = Field(
|
||||
default=None,
|
||||
description="ID of the agent or user who triggered the event",
|
||||
examples=["550e8400-e29b-41d4-a716-446655440002"],
|
||||
)
|
||||
actor_type: ActorType = Field(
|
||||
...,
|
||||
description="Type of actor: 'agent', 'user', or 'system'",
|
||||
examples=["agent"],
|
||||
)
|
||||
payload: dict = Field(
|
||||
default_factory=dict,
|
||||
description="Event-specific payload data",
|
||||
)
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
"type": "agent.message",
|
||||
"timestamp": "2024-01-15T10:30:00Z",
|
||||
"project_id": "550e8400-e29b-41d4-a716-446655440001",
|
||||
"actor_id": "550e8400-e29b-41d4-a716-446655440002",
|
||||
"actor_type": "agent",
|
||||
"payload": {"message": "Processing task...", "progress": 50},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# Specific payload schemas for type safety
|
||||
|
||||
|
||||
class AgentSpawnedPayload(BaseModel):
|
||||
"""Payload for AGENT_SPAWNED events."""
|
||||
|
||||
agent_instance_id: UUID = Field(..., description="ID of the spawned agent instance")
|
||||
agent_type_id: UUID = Field(..., description="ID of the agent type")
|
||||
agent_name: str = Field(..., description="Human-readable name of the agent")
|
||||
role: str = Field(..., description="Agent role (e.g., 'product_owner', 'engineer')")
|
||||
|
||||
|
||||
class AgentStatusChangedPayload(BaseModel):
|
||||
"""Payload for AGENT_STATUS_CHANGED events."""
|
||||
|
||||
agent_instance_id: UUID = Field(..., description="ID of the agent instance")
|
||||
previous_status: str = Field(..., description="Previous status")
|
||||
new_status: str = Field(..., description="New status")
|
||||
reason: str | None = Field(default=None, description="Reason for status change")
|
||||
|
||||
|
||||
class AgentMessagePayload(BaseModel):
|
||||
"""Payload for AGENT_MESSAGE events."""
|
||||
|
||||
agent_instance_id: UUID = Field(..., description="ID of the agent instance")
|
||||
message: str = Field(..., description="Message content")
|
||||
message_type: str = Field(
|
||||
default="info",
|
||||
description="Message type: 'info', 'warning', 'error', 'debug'",
|
||||
)
|
||||
metadata: dict = Field(
|
||||
default_factory=dict,
|
||||
description="Additional metadata (e.g., token usage, model info)",
|
||||
)
|
||||
|
||||
|
||||
class AgentTerminatedPayload(BaseModel):
|
||||
"""Payload for AGENT_TERMINATED events."""
|
||||
|
||||
agent_instance_id: UUID = Field(..., description="ID of the agent instance")
|
||||
termination_reason: str = Field(..., description="Reason for termination")
|
||||
final_status: str = Field(..., description="Final status at termination")
|
||||
|
||||
|
||||
class IssueCreatedPayload(BaseModel):
|
||||
"""Payload for ISSUE_CREATED events."""
|
||||
|
||||
issue_id: str = Field(..., description="Issue ID (from external tracker)")
|
||||
title: str = Field(..., description="Issue title")
|
||||
priority: str | None = Field(default=None, description="Issue priority")
|
||||
labels: list[str] = Field(default_factory=list, description="Issue labels")
|
||||
|
||||
|
||||
class IssueUpdatedPayload(BaseModel):
|
||||
"""Payload for ISSUE_UPDATED events."""
|
||||
|
||||
issue_id: str = Field(..., description="Issue ID (from external tracker)")
|
||||
changes: dict = Field(..., description="Dictionary of field changes")
|
||||
|
||||
|
||||
class IssueAssignedPayload(BaseModel):
|
||||
"""Payload for ISSUE_ASSIGNED events."""
|
||||
|
||||
issue_id: str = Field(..., description="Issue ID (from external tracker)")
|
||||
assignee_id: UUID | None = Field(
|
||||
default=None, description="Agent or user assigned to"
|
||||
)
|
||||
assignee_name: str | None = Field(default=None, description="Assignee name")
|
||||
|
||||
|
||||
class IssueClosedPayload(BaseModel):
|
||||
"""Payload for ISSUE_CLOSED events."""
|
||||
|
||||
issue_id: str = Field(..., description="Issue ID (from external tracker)")
|
||||
resolution: str = Field(..., description="Resolution status")
|
||||
|
||||
|
||||
class SprintStartedPayload(BaseModel):
|
||||
"""Payload for SPRINT_STARTED events."""
|
||||
|
||||
sprint_id: UUID = Field(..., description="Sprint ID")
|
||||
sprint_name: str = Field(..., description="Sprint name")
|
||||
goal: str | None = Field(default=None, description="Sprint goal")
|
||||
issue_count: int = Field(default=0, description="Number of issues in sprint")
|
||||
|
||||
|
||||
class SprintCompletedPayload(BaseModel):
|
||||
"""Payload for SPRINT_COMPLETED events."""
|
||||
|
||||
sprint_id: UUID = Field(..., description="Sprint ID")
|
||||
sprint_name: str = Field(..., description="Sprint name")
|
||||
completed_issues: int = Field(default=0, description="Number of completed issues")
|
||||
incomplete_issues: int = Field(
|
||||
default=0, description="Number of incomplete issues"
|
||||
)
|
||||
|
||||
|
||||
class ApprovalRequestedPayload(BaseModel):
|
||||
"""Payload for APPROVAL_REQUESTED events."""
|
||||
|
||||
approval_id: UUID = Field(..., description="Approval request ID")
|
||||
approval_type: str = Field(..., description="Type of approval needed")
|
||||
description: str = Field(..., description="Description of what needs approval")
|
||||
requested_by: UUID | None = Field(
|
||||
default=None, description="Agent/user requesting approval"
|
||||
)
|
||||
timeout_minutes: int | None = Field(
|
||||
default=None, description="Minutes before auto-escalation"
|
||||
)
|
||||
|
||||
|
||||
class ApprovalGrantedPayload(BaseModel):
|
||||
"""Payload for APPROVAL_GRANTED events."""
|
||||
|
||||
approval_id: UUID = Field(..., description="Approval request ID")
|
||||
approved_by: UUID = Field(..., description="User who granted approval")
|
||||
comments: str | None = Field(default=None, description="Approval comments")
|
||||
|
||||
|
||||
class ApprovalDeniedPayload(BaseModel):
|
||||
"""Payload for APPROVAL_DENIED events."""
|
||||
|
||||
approval_id: UUID = Field(..., description="Approval request ID")
|
||||
denied_by: UUID = Field(..., description="User who denied approval")
|
||||
reason: str = Field(..., description="Reason for denial")
|
||||
|
||||
|
||||
class WorkflowStartedPayload(BaseModel):
|
||||
"""Payload for WORKFLOW_STARTED events."""
|
||||
|
||||
workflow_id: UUID = Field(..., description="Workflow execution ID")
|
||||
workflow_type: str = Field(..., description="Type of workflow")
|
||||
total_steps: int = Field(default=0, description="Total number of steps")
|
||||
|
||||
|
||||
class WorkflowStepCompletedPayload(BaseModel):
|
||||
"""Payload for WORKFLOW_STEP_COMPLETED events."""
|
||||
|
||||
workflow_id: UUID = Field(..., description="Workflow execution ID")
|
||||
step_name: str = Field(..., description="Name of completed step")
|
||||
step_number: int = Field(..., description="Step number (1-indexed)")
|
||||
total_steps: int = Field(..., description="Total number of steps")
|
||||
result: dict = Field(default_factory=dict, description="Step result data")
|
||||
|
||||
|
||||
class WorkflowCompletedPayload(BaseModel):
|
||||
"""Payload for WORKFLOW_COMPLETED events."""
|
||||
|
||||
workflow_id: UUID = Field(..., description="Workflow execution ID")
|
||||
duration_seconds: float = Field(..., description="Total execution duration")
|
||||
result: dict = Field(default_factory=dict, description="Workflow result data")
|
||||
|
||||
|
||||
class WorkflowFailedPayload(BaseModel):
|
||||
"""Payload for WORKFLOW_FAILED events."""
|
||||
|
||||
workflow_id: UUID = Field(..., description="Workflow execution ID")
|
||||
error_message: str = Field(..., description="Error message")
|
||||
failed_step: str | None = Field(default=None, description="Step that failed")
|
||||
recoverable: bool = Field(default=False, description="Whether error is recoverable")
|
||||
@@ -60,8 +60,8 @@ class OAuthAccountCreate(OAuthAccountBase):
|
||||
|
||||
user_id: UUID
|
||||
provider_user_id: str = Field(..., max_length=255)
|
||||
access_token: str | None = None
|
||||
refresh_token: str | None = None
|
||||
access_token_encrypted: str | None = None
|
||||
refresh_token_encrypted: str | None = None
|
||||
token_expires_at: datetime | None = None
|
||||
|
||||
|
||||
|
||||
@@ -48,7 +48,7 @@ class OrganizationCreate(OrganizationBase):
|
||||
"""Schema for creating a new organization."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
slug: str = Field(..., min_length=1, max_length=255) # pyright: ignore[reportIncompatibleVariableOverride]
|
||||
slug: str = Field(..., min_length=1, max_length=255)
|
||||
|
||||
|
||||
class OrganizationUpdate(BaseModel):
|
||||
|
||||
113
backend/app/schemas/syndarix/__init__.py
Normal file
113
backend/app/schemas/syndarix/__init__.py
Normal file
@@ -0,0 +1,113 @@
|
||||
# app/schemas/syndarix/__init__.py
|
||||
"""
|
||||
Syndarix domain schemas.
|
||||
|
||||
This package contains Pydantic schemas for validating and serializing
|
||||
Syndarix domain entities.
|
||||
"""
|
||||
|
||||
from .agent_instance import (
|
||||
AgentInstanceCreate,
|
||||
AgentInstanceInDB,
|
||||
AgentInstanceListResponse,
|
||||
AgentInstanceMetrics,
|
||||
AgentInstanceResponse,
|
||||
AgentInstanceTerminate,
|
||||
AgentInstanceUpdate,
|
||||
)
|
||||
from .agent_type import (
|
||||
AgentTypeCreate,
|
||||
AgentTypeInDB,
|
||||
AgentTypeListResponse,
|
||||
AgentTypeResponse,
|
||||
AgentTypeUpdate,
|
||||
)
|
||||
from .enums import (
|
||||
AgentStatus,
|
||||
AutonomyLevel,
|
||||
IssuePriority,
|
||||
IssueStatus,
|
||||
ProjectStatus,
|
||||
SprintStatus,
|
||||
SyncStatus,
|
||||
)
|
||||
from .issue import (
|
||||
IssueAssign,
|
||||
IssueClose,
|
||||
IssueCreate,
|
||||
IssueInDB,
|
||||
IssueListResponse,
|
||||
IssueResponse,
|
||||
IssueStats,
|
||||
IssueSyncUpdate,
|
||||
IssueUpdate,
|
||||
)
|
||||
from .project import (
|
||||
ProjectCreate,
|
||||
ProjectInDB,
|
||||
ProjectListResponse,
|
||||
ProjectResponse,
|
||||
ProjectUpdate,
|
||||
)
|
||||
from .sprint import (
|
||||
SprintBurndown,
|
||||
SprintComplete,
|
||||
SprintCreate,
|
||||
SprintInDB,
|
||||
SprintListResponse,
|
||||
SprintResponse,
|
||||
SprintStart,
|
||||
SprintUpdate,
|
||||
SprintVelocity,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# AgentInstance schemas
|
||||
"AgentInstanceCreate",
|
||||
"AgentInstanceInDB",
|
||||
"AgentInstanceListResponse",
|
||||
"AgentInstanceMetrics",
|
||||
"AgentInstanceResponse",
|
||||
"AgentInstanceTerminate",
|
||||
"AgentInstanceUpdate",
|
||||
# Enums
|
||||
"AgentStatus",
|
||||
# AgentType schemas
|
||||
"AgentTypeCreate",
|
||||
"AgentTypeInDB",
|
||||
"AgentTypeListResponse",
|
||||
"AgentTypeResponse",
|
||||
"AgentTypeUpdate",
|
||||
"AutonomyLevel",
|
||||
# Issue schemas
|
||||
"IssueAssign",
|
||||
"IssueClose",
|
||||
"IssueCreate",
|
||||
"IssueInDB",
|
||||
"IssueListResponse",
|
||||
"IssuePriority",
|
||||
"IssueResponse",
|
||||
"IssueStats",
|
||||
"IssueStatus",
|
||||
"IssueSyncUpdate",
|
||||
"IssueUpdate",
|
||||
# Project schemas
|
||||
"ProjectCreate",
|
||||
"ProjectInDB",
|
||||
"ProjectListResponse",
|
||||
"ProjectResponse",
|
||||
"ProjectStatus",
|
||||
"ProjectUpdate",
|
||||
# Sprint schemas
|
||||
"SprintBurndown",
|
||||
"SprintComplete",
|
||||
"SprintCreate",
|
||||
"SprintInDB",
|
||||
"SprintListResponse",
|
||||
"SprintResponse",
|
||||
"SprintStart",
|
||||
"SprintStatus",
|
||||
"SprintUpdate",
|
||||
"SprintVelocity",
|
||||
"SyncStatus",
|
||||
]
|
||||
122
backend/app/schemas/syndarix/agent_instance.py
Normal file
122
backend/app/schemas/syndarix/agent_instance.py
Normal file
@@ -0,0 +1,122 @@
|
||||
# app/schemas/syndarix/agent_instance.py
|
||||
"""
|
||||
Pydantic schemas for AgentInstance entity.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from .enums import AgentStatus
|
||||
|
||||
|
||||
class AgentInstanceBase(BaseModel):
|
||||
"""Base agent instance schema with common fields."""
|
||||
|
||||
agent_type_id: UUID
|
||||
project_id: UUID
|
||||
status: AgentStatus = AgentStatus.IDLE
|
||||
current_task: str | None = None
|
||||
short_term_memory: dict[str, Any] = Field(default_factory=dict)
|
||||
long_term_memory_ref: str | None = Field(None, max_length=500)
|
||||
session_id: str | None = Field(None, max_length=255)
|
||||
|
||||
|
||||
class AgentInstanceCreate(BaseModel):
|
||||
"""Schema for creating a new agent instance."""
|
||||
|
||||
agent_type_id: UUID
|
||||
project_id: UUID
|
||||
status: AgentStatus = AgentStatus.IDLE
|
||||
current_task: str | None = None
|
||||
short_term_memory: dict[str, Any] = Field(default_factory=dict)
|
||||
long_term_memory_ref: str | None = Field(None, max_length=500)
|
||||
session_id: str | None = Field(None, max_length=255)
|
||||
|
||||
|
||||
class AgentInstanceUpdate(BaseModel):
|
||||
"""Schema for updating an agent instance."""
|
||||
|
||||
status: AgentStatus | None = None
|
||||
current_task: str | None = None
|
||||
short_term_memory: dict[str, Any] | None = None
|
||||
long_term_memory_ref: str | None = None
|
||||
session_id: str | None = None
|
||||
last_activity_at: datetime | None = None
|
||||
tasks_completed: int | None = Field(None, ge=0)
|
||||
tokens_used: int | None = Field(None, ge=0)
|
||||
cost_incurred: Decimal | None = Field(None, ge=0)
|
||||
|
||||
|
||||
class AgentInstanceTerminate(BaseModel):
|
||||
"""Schema for terminating an agent instance."""
|
||||
|
||||
reason: str | None = None
|
||||
|
||||
|
||||
class AgentInstanceInDB(AgentInstanceBase):
|
||||
"""Schema for agent instance in database."""
|
||||
|
||||
id: UUID
|
||||
last_activity_at: datetime | None = None
|
||||
terminated_at: datetime | None = None
|
||||
tasks_completed: int = 0
|
||||
tokens_used: int = 0
|
||||
cost_incurred: Decimal = Decimal("0.0000")
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class AgentInstanceResponse(BaseModel):
|
||||
"""Schema for agent instance API responses."""
|
||||
|
||||
id: UUID
|
||||
agent_type_id: UUID
|
||||
project_id: UUID
|
||||
status: AgentStatus
|
||||
current_task: str | None = None
|
||||
short_term_memory: dict[str, Any] = Field(default_factory=dict)
|
||||
long_term_memory_ref: str | None = None
|
||||
session_id: str | None = None
|
||||
last_activity_at: datetime | None = None
|
||||
terminated_at: datetime | None = None
|
||||
tasks_completed: int = 0
|
||||
tokens_used: int = 0
|
||||
cost_incurred: Decimal = Decimal("0.0000")
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
# Expanded fields from relationships
|
||||
agent_type_name: str | None = None
|
||||
agent_type_slug: str | None = None
|
||||
project_name: str | None = None
|
||||
project_slug: str | None = None
|
||||
assigned_issues_count: int | None = 0
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class AgentInstanceListResponse(BaseModel):
|
||||
"""Schema for paginated agent instance list responses."""
|
||||
|
||||
agent_instances: list[AgentInstanceResponse]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
pages: int
|
||||
|
||||
|
||||
class AgentInstanceMetrics(BaseModel):
|
||||
"""Schema for agent instance metrics summary."""
|
||||
|
||||
total_instances: int
|
||||
active_instances: int
|
||||
idle_instances: int
|
||||
total_tasks_completed: int
|
||||
total_tokens_used: int
|
||||
total_cost_incurred: Decimal
|
||||
151
backend/app/schemas/syndarix/agent_type.py
Normal file
151
backend/app/schemas/syndarix/agent_type.py
Normal file
@@ -0,0 +1,151 @@
|
||||
# app/schemas/syndarix/agent_type.py
|
||||
"""
|
||||
Pydantic schemas for AgentType entity.
|
||||
"""
|
||||
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
|
||||
class AgentTypeBase(BaseModel):
|
||||
"""Base agent type schema with common fields."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
slug: str | None = Field(None, min_length=1, max_length=255)
|
||||
description: str | None = None
|
||||
expertise: list[str] = Field(default_factory=list)
|
||||
personality_prompt: str = Field(..., min_length=1)
|
||||
primary_model: str = Field(..., min_length=1, max_length=100)
|
||||
fallback_models: list[str] = Field(default_factory=list)
|
||||
model_params: dict[str, Any] = Field(default_factory=dict)
|
||||
mcp_servers: list[str] = Field(default_factory=list)
|
||||
tool_permissions: dict[str, Any] = Field(default_factory=dict)
|
||||
is_active: bool = True
|
||||
|
||||
@field_validator("slug")
|
||||
@classmethod
|
||||
def validate_slug(cls, v: str | None) -> str | None:
|
||||
"""Validate slug format: lowercase, alphanumeric, hyphens only."""
|
||||
if v is None:
|
||||
return v
|
||||
if not re.match(r"^[a-z0-9-]+$", v):
|
||||
raise ValueError(
|
||||
"Slug must contain only lowercase letters, numbers, and hyphens"
|
||||
)
|
||||
if v.startswith("-") or v.endswith("-"):
|
||||
raise ValueError("Slug cannot start or end with a hyphen")
|
||||
if "--" in v:
|
||||
raise ValueError("Slug cannot contain consecutive hyphens")
|
||||
return v
|
||||
|
||||
@field_validator("name")
|
||||
@classmethod
|
||||
def validate_name(cls, v: str) -> str:
|
||||
"""Validate agent type name."""
|
||||
if not v or v.strip() == "":
|
||||
raise ValueError("Agent type name cannot be empty")
|
||||
return v.strip()
|
||||
|
||||
@field_validator("expertise")
|
||||
@classmethod
|
||||
def validate_expertise(cls, v: list[str]) -> list[str]:
|
||||
"""Validate and normalize expertise list."""
|
||||
return [e.strip().lower() for e in v if e.strip()]
|
||||
|
||||
@field_validator("mcp_servers")
|
||||
@classmethod
|
||||
def validate_mcp_servers(cls, v: list[str]) -> list[str]:
|
||||
"""Validate MCP server list."""
|
||||
return [s.strip() for s in v if s.strip()]
|
||||
|
||||
|
||||
class AgentTypeCreate(AgentTypeBase):
|
||||
"""Schema for creating a new agent type."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
slug: str = Field(..., min_length=1, max_length=255)
|
||||
personality_prompt: str = Field(..., min_length=1)
|
||||
primary_model: str = Field(..., min_length=1, max_length=100)
|
||||
|
||||
|
||||
class AgentTypeUpdate(BaseModel):
|
||||
"""Schema for updating an agent type."""
|
||||
|
||||
name: str | None = Field(None, min_length=1, max_length=255)
|
||||
slug: str | None = Field(None, min_length=1, max_length=255)
|
||||
description: str | None = None
|
||||
expertise: list[str] | None = None
|
||||
personality_prompt: str | None = None
|
||||
primary_model: str | None = Field(None, min_length=1, max_length=100)
|
||||
fallback_models: list[str] | None = None
|
||||
model_params: dict[str, Any] | None = None
|
||||
mcp_servers: list[str] | None = None
|
||||
tool_permissions: dict[str, Any] | None = None
|
||||
is_active: bool | None = None
|
||||
|
||||
@field_validator("slug")
|
||||
@classmethod
|
||||
def validate_slug(cls, v: str | None) -> str | None:
|
||||
"""Validate slug format."""
|
||||
if v is None:
|
||||
return v
|
||||
if not re.match(r"^[a-z0-9-]+$", v):
|
||||
raise ValueError(
|
||||
"Slug must contain only lowercase letters, numbers, and hyphens"
|
||||
)
|
||||
if v.startswith("-") or v.endswith("-"):
|
||||
raise ValueError("Slug cannot start or end with a hyphen")
|
||||
if "--" in v:
|
||||
raise ValueError("Slug cannot contain consecutive hyphens")
|
||||
return v
|
||||
|
||||
@field_validator("name")
|
||||
@classmethod
|
||||
def validate_name(cls, v: str | None) -> str | None:
|
||||
"""Validate agent type name."""
|
||||
if v is not None and (not v or v.strip() == ""):
|
||||
raise ValueError("Agent type name cannot be empty")
|
||||
return v.strip() if v else v
|
||||
|
||||
@field_validator("expertise")
|
||||
@classmethod
|
||||
def validate_expertise(cls, v: list[str] | None) -> list[str] | None:
|
||||
"""Validate and normalize expertise list."""
|
||||
if v is None:
|
||||
return v
|
||||
return [e.strip().lower() for e in v if e.strip()]
|
||||
|
||||
|
||||
class AgentTypeInDB(AgentTypeBase):
|
||||
"""Schema for agent type in database."""
|
||||
|
||||
id: UUID
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class AgentTypeResponse(AgentTypeBase):
|
||||
"""Schema for agent type API responses."""
|
||||
|
||||
id: UUID
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
instance_count: int | None = 0
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class AgentTypeListResponse(BaseModel):
|
||||
"""Schema for paginated agent type list responses."""
|
||||
|
||||
agent_types: list[AgentTypeResponse]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
pages: int
|
||||
26
backend/app/schemas/syndarix/enums.py
Normal file
26
backend/app/schemas/syndarix/enums.py
Normal file
@@ -0,0 +1,26 @@
|
||||
# app/schemas/syndarix/enums.py
|
||||
"""
|
||||
Re-export enums from models for use in schemas.
|
||||
|
||||
This allows schemas to import enums without depending on SQLAlchemy models directly.
|
||||
"""
|
||||
|
||||
from app.models.syndarix.enums import (
|
||||
AgentStatus,
|
||||
AutonomyLevel,
|
||||
IssuePriority,
|
||||
IssueStatus,
|
||||
ProjectStatus,
|
||||
SprintStatus,
|
||||
SyncStatus,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AgentStatus",
|
||||
"AutonomyLevel",
|
||||
"IssuePriority",
|
||||
"IssueStatus",
|
||||
"ProjectStatus",
|
||||
"SprintStatus",
|
||||
"SyncStatus",
|
||||
]
|
||||
193
backend/app/schemas/syndarix/issue.py
Normal file
193
backend/app/schemas/syndarix/issue.py
Normal file
@@ -0,0 +1,193 @@
|
||||
# app/schemas/syndarix/issue.py
|
||||
"""
|
||||
Pydantic schemas for Issue entity.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
|
||||
from .enums import IssuePriority, IssueStatus, SyncStatus
|
||||
|
||||
|
||||
class IssueBase(BaseModel):
|
||||
"""Base issue schema with common fields."""
|
||||
|
||||
title: str = Field(..., min_length=1, max_length=500)
|
||||
body: str = ""
|
||||
status: IssueStatus = IssueStatus.OPEN
|
||||
priority: IssuePriority = IssuePriority.MEDIUM
|
||||
labels: list[str] = Field(default_factory=list)
|
||||
story_points: int | None = Field(None, ge=0, le=100)
|
||||
|
||||
@field_validator("title")
|
||||
@classmethod
|
||||
def validate_title(cls, v: str) -> str:
|
||||
"""Validate issue title."""
|
||||
if not v or v.strip() == "":
|
||||
raise ValueError("Issue title cannot be empty")
|
||||
return v.strip()
|
||||
|
||||
@field_validator("labels")
|
||||
@classmethod
|
||||
def validate_labels(cls, v: list[str]) -> list[str]:
|
||||
"""Validate and normalize labels."""
|
||||
return [label.strip().lower() for label in v if label.strip()]
|
||||
|
||||
|
||||
class IssueCreate(IssueBase):
|
||||
"""Schema for creating a new issue."""
|
||||
|
||||
project_id: UUID
|
||||
assigned_agent_id: UUID | None = None
|
||||
human_assignee: str | None = Field(None, max_length=255)
|
||||
sprint_id: UUID | None = None
|
||||
|
||||
# External tracker fields (optional, for importing from external systems)
|
||||
external_tracker: Literal["gitea", "github", "gitlab"] | None = None
|
||||
external_id: str | None = Field(None, max_length=255)
|
||||
external_url: str | None = Field(None, max_length=1000)
|
||||
external_number: int | None = None
|
||||
|
||||
|
||||
class IssueUpdate(BaseModel):
|
||||
"""Schema for updating an issue."""
|
||||
|
||||
title: str | None = Field(None, min_length=1, max_length=500)
|
||||
body: str | None = None
|
||||
status: IssueStatus | None = None
|
||||
priority: IssuePriority | None = None
|
||||
labels: list[str] | None = None
|
||||
assigned_agent_id: UUID | None = None
|
||||
human_assignee: str | None = Field(None, max_length=255)
|
||||
sprint_id: UUID | None = None
|
||||
story_points: int | None = Field(None, ge=0, le=100)
|
||||
sync_status: SyncStatus | None = None
|
||||
|
||||
@field_validator("title")
|
||||
@classmethod
|
||||
def validate_title(cls, v: str | None) -> str | None:
|
||||
"""Validate issue title."""
|
||||
if v is not None and (not v or v.strip() == ""):
|
||||
raise ValueError("Issue title cannot be empty")
|
||||
return v.strip() if v else v
|
||||
|
||||
@field_validator("labels")
|
||||
@classmethod
|
||||
def validate_labels(cls, v: list[str] | None) -> list[str] | None:
|
||||
"""Validate and normalize labels."""
|
||||
if v is None:
|
||||
return v
|
||||
return [label.strip().lower() for label in v if label.strip()]
|
||||
|
||||
|
||||
class IssueClose(BaseModel):
|
||||
"""Schema for closing an issue."""
|
||||
|
||||
resolution: str | None = None # Optional resolution note
|
||||
|
||||
|
||||
class IssueAssign(BaseModel):
|
||||
"""Schema for assigning an issue."""
|
||||
|
||||
assigned_agent_id: UUID | None = None
|
||||
human_assignee: str | None = Field(None, max_length=255)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_assignment(self) -> "IssueAssign":
|
||||
"""Ensure only one type of assignee is set."""
|
||||
if self.assigned_agent_id and self.human_assignee:
|
||||
raise ValueError(
|
||||
"Cannot assign to both an agent and a human. Choose one."
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class IssueSyncUpdate(BaseModel):
|
||||
"""Schema for updating sync-related fields."""
|
||||
|
||||
sync_status: SyncStatus
|
||||
last_synced_at: datetime | None = None
|
||||
external_updated_at: datetime | None = None
|
||||
|
||||
|
||||
class IssueInDB(IssueBase):
|
||||
"""Schema for issue in database."""
|
||||
|
||||
id: UUID
|
||||
project_id: UUID
|
||||
assigned_agent_id: UUID | None = None
|
||||
human_assignee: str | None = None
|
||||
sprint_id: UUID | None = None
|
||||
external_tracker: str | None = None
|
||||
external_id: str | None = None
|
||||
external_url: str | None = None
|
||||
external_number: int | None = None
|
||||
sync_status: SyncStatus = SyncStatus.SYNCED
|
||||
last_synced_at: datetime | None = None
|
||||
external_updated_at: datetime | None = None
|
||||
closed_at: datetime | None = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class IssueResponse(BaseModel):
|
||||
"""Schema for issue API responses."""
|
||||
|
||||
id: UUID
|
||||
project_id: UUID
|
||||
title: str
|
||||
body: str
|
||||
status: IssueStatus
|
||||
priority: IssuePriority
|
||||
labels: list[str] = Field(default_factory=list)
|
||||
assigned_agent_id: UUID | None = None
|
||||
human_assignee: str | None = None
|
||||
sprint_id: UUID | None = None
|
||||
story_points: int | None = None
|
||||
external_tracker: str | None = None
|
||||
external_id: str | None = None
|
||||
external_url: str | None = None
|
||||
external_number: int | None = None
|
||||
sync_status: SyncStatus = SyncStatus.SYNCED
|
||||
last_synced_at: datetime | None = None
|
||||
external_updated_at: datetime | None = None
|
||||
closed_at: datetime | None = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
# Expanded fields from relationships
|
||||
project_name: str | None = None
|
||||
project_slug: str | None = None
|
||||
sprint_name: str | None = None
|
||||
assigned_agent_type_name: str | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class IssueListResponse(BaseModel):
|
||||
"""Schema for paginated issue list responses."""
|
||||
|
||||
issues: list[IssueResponse]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
pages: int
|
||||
|
||||
|
||||
class IssueStats(BaseModel):
|
||||
"""Schema for issue statistics."""
|
||||
|
||||
total: int
|
||||
open: int
|
||||
in_progress: int
|
||||
in_review: int
|
||||
blocked: int
|
||||
closed: int
|
||||
by_priority: dict[str, int]
|
||||
total_story_points: int | None = None
|
||||
completed_story_points: int | None = None
|
||||
127
backend/app/schemas/syndarix/project.py
Normal file
127
backend/app/schemas/syndarix/project.py
Normal file
@@ -0,0 +1,127 @@
|
||||
# app/schemas/syndarix/project.py
|
||||
"""
|
||||
Pydantic schemas for Project entity.
|
||||
"""
|
||||
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from .enums import AutonomyLevel, ProjectStatus
|
||||
|
||||
|
||||
class ProjectBase(BaseModel):
|
||||
"""Base project schema with common fields."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
slug: str | None = Field(None, min_length=1, max_length=255)
|
||||
description: str | None = None
|
||||
autonomy_level: AutonomyLevel = AutonomyLevel.MILESTONE
|
||||
status: ProjectStatus = ProjectStatus.ACTIVE
|
||||
settings: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
@field_validator("slug")
|
||||
@classmethod
|
||||
def validate_slug(cls, v: str | None) -> str | None:
|
||||
"""Validate slug format: lowercase, alphanumeric, hyphens only."""
|
||||
if v is None:
|
||||
return v
|
||||
if not re.match(r"^[a-z0-9-]+$", v):
|
||||
raise ValueError(
|
||||
"Slug must contain only lowercase letters, numbers, and hyphens"
|
||||
)
|
||||
if v.startswith("-") or v.endswith("-"):
|
||||
raise ValueError("Slug cannot start or end with a hyphen")
|
||||
if "--" in v:
|
||||
raise ValueError("Slug cannot contain consecutive hyphens")
|
||||
return v
|
||||
|
||||
@field_validator("name")
|
||||
@classmethod
|
||||
def validate_name(cls, v: str) -> str:
|
||||
"""Validate project name."""
|
||||
if not v or v.strip() == "":
|
||||
raise ValueError("Project name cannot be empty")
|
||||
return v.strip()
|
||||
|
||||
|
||||
class ProjectCreate(ProjectBase):
|
||||
"""Schema for creating a new project."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
slug: str = Field(..., min_length=1, max_length=255)
|
||||
owner_id: UUID | None = None
|
||||
|
||||
|
||||
class ProjectUpdate(BaseModel):
|
||||
"""Schema for updating a project."""
|
||||
|
||||
name: str | None = Field(None, min_length=1, max_length=255)
|
||||
slug: str | None = Field(None, min_length=1, max_length=255)
|
||||
description: str | None = None
|
||||
autonomy_level: AutonomyLevel | None = None
|
||||
status: ProjectStatus | None = None
|
||||
settings: dict[str, Any] | None = None
|
||||
owner_id: UUID | None = None
|
||||
|
||||
@field_validator("slug")
|
||||
@classmethod
|
||||
def validate_slug(cls, v: str | None) -> str | None:
|
||||
"""Validate slug format."""
|
||||
if v is None:
|
||||
return v
|
||||
if not re.match(r"^[a-z0-9-]+$", v):
|
||||
raise ValueError(
|
||||
"Slug must contain only lowercase letters, numbers, and hyphens"
|
||||
)
|
||||
if v.startswith("-") or v.endswith("-"):
|
||||
raise ValueError("Slug cannot start or end with a hyphen")
|
||||
if "--" in v:
|
||||
raise ValueError("Slug cannot contain consecutive hyphens")
|
||||
return v
|
||||
|
||||
@field_validator("name")
|
||||
@classmethod
|
||||
def validate_name(cls, v: str | None) -> str | None:
|
||||
"""Validate project name."""
|
||||
if v is not None and (not v or v.strip() == ""):
|
||||
raise ValueError("Project name cannot be empty")
|
||||
return v.strip() if v else v
|
||||
|
||||
|
||||
class ProjectInDB(ProjectBase):
|
||||
"""Schema for project in database."""
|
||||
|
||||
id: UUID
|
||||
owner_id: UUID | None = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class ProjectResponse(ProjectBase):
|
||||
"""Schema for project API responses."""
|
||||
|
||||
id: UUID
|
||||
owner_id: UUID | None = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
agent_count: int | None = 0
|
||||
issue_count: int | None = 0
|
||||
active_sprint_name: str | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class ProjectListResponse(BaseModel):
|
||||
"""Schema for paginated project list responses."""
|
||||
|
||||
projects: list[ProjectResponse]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
pages: int
|
||||
135
backend/app/schemas/syndarix/sprint.py
Normal file
135
backend/app/schemas/syndarix/sprint.py
Normal file
@@ -0,0 +1,135 @@
|
||||
# app/schemas/syndarix/sprint.py
|
||||
"""
|
||||
Pydantic schemas for Sprint entity.
|
||||
"""
|
||||
|
||||
from datetime import date, datetime
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
|
||||
from .enums import SprintStatus
|
||||
|
||||
|
||||
class SprintBase(BaseModel):
|
||||
"""Base sprint schema with common fields."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
number: int = Field(..., ge=1)
|
||||
goal: str | None = None
|
||||
start_date: date
|
||||
end_date: date
|
||||
status: SprintStatus = SprintStatus.PLANNED
|
||||
planned_points: int | None = Field(None, ge=0)
|
||||
completed_points: int | None = Field(None, ge=0)
|
||||
|
||||
@field_validator("name")
|
||||
@classmethod
|
||||
def validate_name(cls, v: str) -> str:
|
||||
"""Validate sprint name."""
|
||||
if not v or v.strip() == "":
|
||||
raise ValueError("Sprint name cannot be empty")
|
||||
return v.strip()
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_dates(self) -> "SprintBase":
|
||||
"""Validate that end_date is after start_date."""
|
||||
if self.end_date < self.start_date:
|
||||
raise ValueError("End date must be after or equal to start date")
|
||||
return self
|
||||
|
||||
|
||||
class SprintCreate(SprintBase):
|
||||
"""Schema for creating a new sprint."""
|
||||
|
||||
project_id: UUID
|
||||
|
||||
|
||||
class SprintUpdate(BaseModel):
|
||||
"""Schema for updating a sprint."""
|
||||
|
||||
name: str | None = Field(None, min_length=1, max_length=255)
|
||||
goal: str | None = None
|
||||
start_date: date | None = None
|
||||
end_date: date | None = None
|
||||
status: SprintStatus | None = None
|
||||
planned_points: int | None = Field(None, ge=0)
|
||||
completed_points: int | None = Field(None, ge=0)
|
||||
|
||||
@field_validator("name")
|
||||
@classmethod
|
||||
def validate_name(cls, v: str | None) -> str | None:
|
||||
"""Validate sprint name."""
|
||||
if v is not None and (not v or v.strip() == ""):
|
||||
raise ValueError("Sprint name cannot be empty")
|
||||
return v.strip() if v else v
|
||||
|
||||
|
||||
class SprintStart(BaseModel):
|
||||
"""Schema for starting a sprint."""
|
||||
|
||||
start_date: date | None = None # Optionally override start date
|
||||
|
||||
|
||||
class SprintComplete(BaseModel):
|
||||
"""Schema for completing a sprint."""
|
||||
|
||||
completed_points: int | None = Field(None, ge=0)
|
||||
notes: str | None = None
|
||||
|
||||
|
||||
class SprintInDB(SprintBase):
|
||||
"""Schema for sprint in database."""
|
||||
|
||||
id: UUID
|
||||
project_id: UUID
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class SprintResponse(SprintBase):
|
||||
"""Schema for sprint API responses."""
|
||||
|
||||
id: UUID
|
||||
project_id: UUID
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
# Expanded fields from relationships
|
||||
project_name: str | None = None
|
||||
project_slug: str | None = None
|
||||
issue_count: int | None = 0
|
||||
open_issues: int | None = 0
|
||||
completed_issues: int | None = 0
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class SprintListResponse(BaseModel):
|
||||
"""Schema for paginated sprint list responses."""
|
||||
|
||||
sprints: list[SprintResponse]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
pages: int
|
||||
|
||||
|
||||
class SprintVelocity(BaseModel):
|
||||
"""Schema for sprint velocity metrics."""
|
||||
|
||||
sprint_number: int
|
||||
sprint_name: str
|
||||
planned_points: int | None
|
||||
completed_points: int | None
|
||||
velocity: float | None # completed/planned ratio
|
||||
|
||||
|
||||
class SprintBurndown(BaseModel):
|
||||
"""Schema for sprint burndown data point."""
|
||||
|
||||
date: date
|
||||
remaining_points: int
|
||||
ideal_remaining: float
|
||||
@@ -1,19 +1,5 @@
|
||||
# app/services/__init__.py
|
||||
from . import oauth_provider_service
|
||||
from .auth_service import AuthService
|
||||
from .oauth_service import OAuthService
|
||||
from .organization_service import OrganizationService, organization_service
|
||||
from .session_service import SessionService, session_service
|
||||
from .user_service import UserService, user_service
|
||||
|
||||
__all__ = [
|
||||
"AuthService",
|
||||
"OAuthService",
|
||||
"OrganizationService",
|
||||
"SessionService",
|
||||
"UserService",
|
||||
"oauth_provider_service",
|
||||
"organization_service",
|
||||
"session_service",
|
||||
"user_service",
|
||||
]
|
||||
__all__ = ["AuthService", "OAuthService"]
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
import logging
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.auth import (
|
||||
@@ -13,18 +14,12 @@ from app.core.auth import (
|
||||
verify_password_async,
|
||||
)
|
||||
from app.core.config import settings
|
||||
from app.core.exceptions import AuthenticationError, DuplicateError
|
||||
from app.core.repository_exceptions import DuplicateEntryError
|
||||
from app.core.exceptions import AuthenticationError
|
||||
from app.models.user import User
|
||||
from app.repositories.user import user_repo
|
||||
from app.schemas.users import Token, UserCreate, UserResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Pre-computed bcrypt hash used for constant-time comparison when user is not found,
|
||||
# preventing timing attacks that could enumerate valid email addresses.
|
||||
_DUMMY_HASH = "$2b$12$EixZaYVK1fsbw1ZfbX3OXePaWxn96p36zLFbnJHfxPSEFBzXKiHia"
|
||||
|
||||
|
||||
class AuthService:
|
||||
"""Service for handling authentication operations"""
|
||||
@@ -44,12 +39,10 @@ class AuthService:
|
||||
Returns:
|
||||
User if authenticated, None otherwise
|
||||
"""
|
||||
user = await user_repo.get_by_email(db, email=email)
|
||||
result = await db.execute(select(User).where(User.email == email))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user:
|
||||
# Perform a dummy verification to match timing of a real bcrypt check,
|
||||
# preventing email enumeration via response-time differences.
|
||||
await verify_password_async(password, _DUMMY_HASH)
|
||||
return None
|
||||
|
||||
# Verify password asynchronously to avoid blocking event loop
|
||||
@@ -78,23 +71,40 @@ class AuthService:
|
||||
"""
|
||||
try:
|
||||
# Check if user already exists
|
||||
existing_user = await user_repo.get_by_email(db, email=user_data.email)
|
||||
result = await db.execute(select(User).where(User.email == user_data.email))
|
||||
existing_user = result.scalar_one_or_none()
|
||||
if existing_user:
|
||||
raise DuplicateError("User with this email already exists")
|
||||
raise AuthenticationError("User with this email already exists")
|
||||
|
||||
# Delegate creation (hashing + commit) to the repository
|
||||
user = await user_repo.create(db, obj_in=user_data)
|
||||
# Create new user with async password hashing
|
||||
# Hash password asynchronously to avoid blocking event loop
|
||||
hashed_password = await get_password_hash_async(user_data.password)
|
||||
|
||||
logger.info("User created successfully: %s", user.email)
|
||||
# Create user object from model
|
||||
user = User(
|
||||
email=user_data.email,
|
||||
password_hash=hashed_password,
|
||||
first_name=user_data.first_name,
|
||||
last_name=user_data.last_name,
|
||||
phone_number=user_data.phone_number,
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
)
|
||||
|
||||
db.add(user)
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
|
||||
logger.info(f"User created successfully: {user.email}")
|
||||
return user
|
||||
|
||||
except (AuthenticationError, DuplicateError):
|
||||
# Re-raise API exceptions without rollback
|
||||
except AuthenticationError:
|
||||
# Re-raise authentication errors without rollback
|
||||
raise
|
||||
except DuplicateEntryError as e:
|
||||
raise DuplicateError(str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Error creating user: %s", e)
|
||||
# Rollback on any database errors
|
||||
await db.rollback()
|
||||
logger.error(f"Error creating user: {e!s}", exc_info=True)
|
||||
raise AuthenticationError(f"Failed to create user: {e!s}")
|
||||
|
||||
@staticmethod
|
||||
@@ -158,7 +168,8 @@ class AuthService:
|
||||
user_id = token_data.user_id
|
||||
|
||||
# Get user from database
|
||||
user = await user_repo.get(db, id=str(user_id))
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user or not user.is_active:
|
||||
raise TokenInvalidError("Invalid user or inactive account")
|
||||
|
||||
@@ -166,7 +177,7 @@ class AuthService:
|
||||
return AuthService.create_tokens(user)
|
||||
|
||||
except (TokenExpiredError, TokenInvalidError) as e:
|
||||
logger.warning("Token refresh failed: %s", e)
|
||||
logger.warning(f"Token refresh failed: {e!s}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
@@ -189,7 +200,8 @@ class AuthService:
|
||||
AuthenticationError: If current password is incorrect or update fails
|
||||
"""
|
||||
try:
|
||||
user = await user_repo.get(db, id=str(user_id))
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
raise AuthenticationError("User not found")
|
||||
|
||||
@@ -198,10 +210,10 @@ class AuthService:
|
||||
raise AuthenticationError("Current password is incorrect")
|
||||
|
||||
# Hash new password asynchronously to avoid blocking event loop
|
||||
new_hash = await get_password_hash_async(new_password)
|
||||
await user_repo.update_password(db, user=user, password_hash=new_hash)
|
||||
user.password_hash = await get_password_hash_async(new_password)
|
||||
await db.commit()
|
||||
|
||||
logger.info("Password changed successfully for user %s", user_id)
|
||||
logger.info(f"Password changed successfully for user {user_id}")
|
||||
return True
|
||||
|
||||
except AuthenticationError:
|
||||
@@ -210,34 +222,7 @@ class AuthService:
|
||||
except Exception as e:
|
||||
# Rollback on any database errors
|
||||
await db.rollback()
|
||||
logger.exception("Error changing password for user %s: %s", user_id, e)
|
||||
logger.error(
|
||||
f"Error changing password for user {user_id}: {e!s}", exc_info=True
|
||||
)
|
||||
raise AuthenticationError(f"Failed to change password: {e!s}")
|
||||
|
||||
@staticmethod
|
||||
async def reset_password(
|
||||
db: AsyncSession, *, email: str, new_password: str
|
||||
) -> User:
|
||||
"""
|
||||
Reset a user's password without requiring the current password.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
email: User email address
|
||||
new_password: New password to set
|
||||
|
||||
Returns:
|
||||
Updated user
|
||||
|
||||
Raises:
|
||||
AuthenticationError: If user not found or inactive
|
||||
"""
|
||||
user = await user_repo.get_by_email(db, email=email)
|
||||
if not user:
|
||||
raise AuthenticationError("User not found")
|
||||
if not user.is_active:
|
||||
raise AuthenticationError("User account is inactive")
|
||||
|
||||
new_hash = await get_password_hash_async(new_password)
|
||||
user = await user_repo.update_password(db, user=user, password_hash=new_hash)
|
||||
logger.info("Password reset successfully for %s", email)
|
||||
return user
|
||||
|
||||
@@ -58,8 +58,8 @@ class ConsoleEmailBackend(EmailBackend):
|
||||
logger.info("=" * 80)
|
||||
logger.info("EMAIL SENT (Console Backend)")
|
||||
logger.info("=" * 80)
|
||||
logger.info("To: %s", ", ".join(to))
|
||||
logger.info("Subject: %s", subject)
|
||||
logger.info(f"To: {', '.join(to)}")
|
||||
logger.info(f"Subject: {subject}")
|
||||
logger.info("-" * 80)
|
||||
if text_content:
|
||||
logger.info("Plain Text Content:")
|
||||
@@ -199,7 +199,7 @@ The {settings.PROJECT_NAME} Team
|
||||
text_content=text_content,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to send password reset email to %s: %s", to_email, e)
|
||||
logger.error(f"Failed to send password reset email to {to_email}: {e!s}")
|
||||
return False
|
||||
|
||||
async def send_email_verification(
|
||||
@@ -287,7 +287,7 @@ The {settings.PROJECT_NAME} Team
|
||||
text_content=text_content,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to send verification email to %s: %s", to_email, e)
|
||||
logger.error(f"Failed to send verification email to {to_email}: {e!s}")
|
||||
return False
|
||||
|
||||
|
||||
|
||||
622
backend/app/services/event_bus.py
Normal file
622
backend/app/services/event_bus.py
Normal file
@@ -0,0 +1,622 @@
|
||||
"""
|
||||
EventBus service for Redis Pub/Sub communication.
|
||||
|
||||
This module provides a centralized event bus for publishing and subscribing to
|
||||
events across the Syndarix platform. It uses Redis Pub/Sub for real-time
|
||||
message delivery between services, agents, and the frontend.
|
||||
|
||||
Architecture:
|
||||
- Publishers emit events to project/agent-specific Redis channels
|
||||
- SSE endpoints subscribe to channels and stream events to clients
|
||||
- Events include metadata for reconnection support (Last-Event-ID)
|
||||
- Events are typed with the EventType enum for consistency
|
||||
|
||||
Usage:
|
||||
# Publishing events
|
||||
event_bus = EventBus()
|
||||
await event_bus.connect()
|
||||
|
||||
event = event_bus.create_event(
|
||||
event_type=EventType.AGENT_MESSAGE,
|
||||
project_id=project_id,
|
||||
actor_type="agent",
|
||||
payload={"message": "Processing..."}
|
||||
)
|
||||
await event_bus.publish(event_bus.get_project_channel(project_id), event)
|
||||
|
||||
# Subscribing to events
|
||||
async for event in event_bus.subscribe(["project:123", "agent:456"]):
|
||||
handle_event(event)
|
||||
|
||||
# Cleanup
|
||||
await event_bus.disconnect()
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import redis.asyncio as redis
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.core.config import settings
|
||||
from app.schemas.events import ActorType, Event, EventType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EventBusError(Exception):
|
||||
"""Base exception for EventBus errors."""
|
||||
|
||||
|
||||
|
||||
class EventBusConnectionError(EventBusError):
|
||||
"""Raised when connection to Redis fails."""
|
||||
|
||||
|
||||
|
||||
class EventBusPublishError(EventBusError):
|
||||
"""Raised when publishing an event fails."""
|
||||
|
||||
|
||||
|
||||
class EventBusSubscriptionError(EventBusError):
|
||||
"""Raised when subscribing to channels fails."""
|
||||
|
||||
|
||||
|
||||
class EventBus:
|
||||
"""
|
||||
EventBus for Redis Pub/Sub communication.
|
||||
|
||||
Provides methods to publish events to channels and subscribe to events
|
||||
from multiple channels. Handles connection management, serialization,
|
||||
and error recovery.
|
||||
|
||||
This class provides:
|
||||
- Event publishing to project/agent-specific channels
|
||||
- Subscription management for SSE endpoints
|
||||
- Reconnection support via event IDs and sequence numbers
|
||||
- Keepalive messages for connection health
|
||||
- Type-safe event creation with the Event schema
|
||||
|
||||
Attributes:
|
||||
redis_url: Redis connection URL
|
||||
redis_client: Async Redis client instance
|
||||
pubsub: Redis PubSub instance for subscriptions
|
||||
"""
|
||||
|
||||
# Channel prefixes for different entity types
|
||||
PROJECT_CHANNEL_PREFIX = "project"
|
||||
AGENT_CHANNEL_PREFIX = "agent"
|
||||
USER_CHANNEL_PREFIX = "user"
|
||||
GLOBAL_CHANNEL = "syndarix:global"
|
||||
|
||||
def __init__(self, redis_url: str | None = None) -> None:
|
||||
"""
|
||||
Initialize the EventBus.
|
||||
|
||||
Args:
|
||||
redis_url: Redis connection URL. Defaults to settings.REDIS_URL.
|
||||
"""
|
||||
self.redis_url = redis_url or settings.REDIS_URL
|
||||
self._redis_client: redis.Redis | None = None
|
||||
self._pubsub: redis.client.PubSub | None = None
|
||||
self._connected = False
|
||||
self._sequence_counters: dict[str, int] = {}
|
||||
|
||||
@property
|
||||
def redis_client(self) -> redis.Redis:
|
||||
"""Get the Redis client, raising if not connected."""
|
||||
if self._redis_client is None:
|
||||
raise EventBusConnectionError(
|
||||
"EventBus not connected. Call connect() first."
|
||||
)
|
||||
return self._redis_client
|
||||
|
||||
@property
|
||||
def pubsub(self) -> redis.client.PubSub:
|
||||
"""Get the PubSub instance, raising if not connected."""
|
||||
if self._pubsub is None:
|
||||
raise EventBusConnectionError(
|
||||
"EventBus not connected. Call connect() first."
|
||||
)
|
||||
return self._pubsub
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if the EventBus is connected to Redis."""
|
||||
return self._connected and self._redis_client is not None
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""
|
||||
Connect to Redis and initialize the PubSub client.
|
||||
|
||||
Raises:
|
||||
EventBusConnectionError: If connection to Redis fails.
|
||||
"""
|
||||
if self._connected:
|
||||
logger.debug("EventBus already connected")
|
||||
return
|
||||
|
||||
try:
|
||||
self._redis_client = redis.from_url(
|
||||
self.redis_url,
|
||||
encoding="utf-8",
|
||||
decode_responses=True,
|
||||
)
|
||||
# Test connection - ping() returns a coroutine for async Redis
|
||||
ping_result = self._redis_client.ping()
|
||||
if hasattr(ping_result, "__await__"):
|
||||
await ping_result
|
||||
self._pubsub = self._redis_client.pubsub()
|
||||
self._connected = True
|
||||
logger.info("EventBus connected to Redis")
|
||||
except redis.ConnectionError as e:
|
||||
logger.error(f"Failed to connect to Redis: {e}", exc_info=True)
|
||||
raise EventBusConnectionError(f"Failed to connect to Redis: {e}") from e
|
||||
except redis.RedisError as e:
|
||||
logger.error(f"Redis error during connection: {e}", exc_info=True)
|
||||
raise EventBusConnectionError(f"Redis error: {e}") from e
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""
|
||||
Disconnect from Redis and cleanup resources.
|
||||
"""
|
||||
if self._pubsub:
|
||||
try:
|
||||
await self._pubsub.unsubscribe()
|
||||
await self._pubsub.close()
|
||||
except redis.RedisError as e:
|
||||
logger.warning(f"Error closing PubSub: {e}")
|
||||
finally:
|
||||
self._pubsub = None
|
||||
|
||||
if self._redis_client:
|
||||
try:
|
||||
await self._redis_client.aclose()
|
||||
except redis.RedisError as e:
|
||||
logger.warning(f"Error closing Redis client: {e}")
|
||||
finally:
|
||||
self._redis_client = None
|
||||
|
||||
self._connected = False
|
||||
logger.info("EventBus disconnected from Redis")
|
||||
|
||||
@asynccontextmanager
|
||||
async def connection(self) -> AsyncIterator["EventBus"]:
|
||||
"""
|
||||
Context manager for automatic connection handling.
|
||||
|
||||
Usage:
|
||||
async with event_bus.connection() as bus:
|
||||
await bus.publish(channel, event)
|
||||
"""
|
||||
await self.connect()
|
||||
try:
|
||||
yield self
|
||||
finally:
|
||||
await self.disconnect()
|
||||
|
||||
def get_project_channel(self, project_id: UUID | str) -> str:
|
||||
"""
|
||||
Get the channel name for a project.
|
||||
|
||||
Args:
|
||||
project_id: The project UUID or string
|
||||
|
||||
Returns:
|
||||
Channel name string in format "project:{uuid}"
|
||||
"""
|
||||
return f"{self.PROJECT_CHANNEL_PREFIX}:{project_id}"
|
||||
|
||||
def get_agent_channel(self, agent_id: UUID | str) -> str:
|
||||
"""
|
||||
Get the channel name for an agent instance.
|
||||
|
||||
Args:
|
||||
agent_id: The agent instance UUID or string
|
||||
|
||||
Returns:
|
||||
Channel name string in format "agent:{uuid}"
|
||||
"""
|
||||
return f"{self.AGENT_CHANNEL_PREFIX}:{agent_id}"
|
||||
|
||||
def get_user_channel(self, user_id: UUID | str) -> str:
|
||||
"""
|
||||
Get the channel name for a user (personal notifications).
|
||||
|
||||
Args:
|
||||
user_id: The user UUID or string
|
||||
|
||||
Returns:
|
||||
Channel name string in format "user:{uuid}"
|
||||
"""
|
||||
return f"{self.USER_CHANNEL_PREFIX}:{user_id}"
|
||||
|
||||
def _get_next_sequence(self, channel: str) -> int:
|
||||
"""Get the next sequence number for a channel's events."""
|
||||
current = self._sequence_counters.get(channel, 0)
|
||||
self._sequence_counters[channel] = current + 1
|
||||
return current + 1
|
||||
|
||||
@staticmethod
|
||||
def create_event(
|
||||
event_type: EventType,
|
||||
project_id: UUID,
|
||||
actor_type: ActorType,
|
||||
payload: dict | None = None,
|
||||
actor_id: UUID | None = None,
|
||||
event_id: str | None = None,
|
||||
timestamp: datetime | None = None,
|
||||
) -> Event:
|
||||
"""
|
||||
Factory method to create a new Event.
|
||||
|
||||
Args:
|
||||
event_type: The type of event
|
||||
project_id: The project this event belongs to
|
||||
actor_type: Type of actor ('agent', 'user', or 'system')
|
||||
payload: Event-specific payload data
|
||||
actor_id: ID of the agent or user who triggered the event
|
||||
event_id: Optional custom event ID (UUID string)
|
||||
timestamp: Optional custom timestamp (defaults to now UTC)
|
||||
|
||||
Returns:
|
||||
A new Event instance
|
||||
"""
|
||||
return Event(
|
||||
id=event_id or str(uuid4()),
|
||||
type=event_type,
|
||||
timestamp=timestamp or datetime.now(UTC),
|
||||
project_id=project_id,
|
||||
actor_id=actor_id,
|
||||
actor_type=actor_type,
|
||||
payload=payload or {},
|
||||
)
|
||||
|
||||
def _serialize_event(self, event: Event) -> str:
|
||||
"""
|
||||
Serialize an event to JSON string.
|
||||
|
||||
Args:
|
||||
event: The Event to serialize
|
||||
|
||||
Returns:
|
||||
JSON string representation of the event
|
||||
"""
|
||||
return event.model_dump_json()
|
||||
|
||||
def _deserialize_event(self, data: str) -> Event:
|
||||
"""
|
||||
Deserialize a JSON string to an Event.
|
||||
|
||||
Args:
|
||||
data: JSON string to deserialize
|
||||
|
||||
Returns:
|
||||
Deserialized Event instance
|
||||
|
||||
Raises:
|
||||
ValidationError: If the data doesn't match the Event schema
|
||||
"""
|
||||
return Event.model_validate_json(data)
|
||||
|
||||
async def publish(self, channel: str, event: Event) -> int:
|
||||
"""
|
||||
Publish an event to a channel.
|
||||
|
||||
Args:
|
||||
channel: The channel name to publish to
|
||||
event: The Event to publish
|
||||
|
||||
Returns:
|
||||
Number of subscribers that received the message
|
||||
|
||||
Raises:
|
||||
EventBusConnectionError: If not connected to Redis
|
||||
EventBusPublishError: If publishing fails
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise EventBusConnectionError("EventBus not connected")
|
||||
|
||||
try:
|
||||
message = self._serialize_event(event)
|
||||
subscriber_count = await self.redis_client.publish(channel, message)
|
||||
logger.debug(
|
||||
f"Published event {event.type} to {channel} "
|
||||
f"(received by {subscriber_count} subscribers)"
|
||||
)
|
||||
return subscriber_count
|
||||
except redis.RedisError as e:
|
||||
logger.error(f"Failed to publish event to {channel}: {e}", exc_info=True)
|
||||
raise EventBusPublishError(f"Failed to publish event: {e}") from e
|
||||
|
||||
async def publish_to_project(self, event: Event) -> int:
|
||||
"""
|
||||
Publish an event to the project's channel.
|
||||
|
||||
Convenience method that publishes to the project channel based on
|
||||
the event's project_id.
|
||||
|
||||
Args:
|
||||
event: The Event to publish (must have project_id set)
|
||||
|
||||
Returns:
|
||||
Number of subscribers that received the message
|
||||
"""
|
||||
channel = self.get_project_channel(event.project_id)
|
||||
return await self.publish(channel, event)
|
||||
|
||||
async def publish_multi(self, channels: list[str], event: Event) -> dict[str, int]:
|
||||
"""
|
||||
Publish an event to multiple channels.
|
||||
|
||||
Args:
|
||||
channels: List of channel names to publish to
|
||||
event: The Event to publish
|
||||
|
||||
Returns:
|
||||
Dictionary mapping channel names to subscriber counts
|
||||
"""
|
||||
results = {}
|
||||
for channel in channels:
|
||||
try:
|
||||
results[channel] = await self.publish(channel, event)
|
||||
except EventBusPublishError as e:
|
||||
logger.warning(f"Failed to publish to {channel}: {e}")
|
||||
results[channel] = 0
|
||||
return results
|
||||
|
||||
async def subscribe(
|
||||
self, channels: list[str], *, max_wait: float | None = None
|
||||
) -> AsyncIterator[Event]:
|
||||
"""
|
||||
Subscribe to one or more channels and yield events.
|
||||
|
||||
This is an async generator that yields Event objects as they arrive.
|
||||
Use max_wait to limit how long to wait for messages.
|
||||
|
||||
Args:
|
||||
channels: List of channel names to subscribe to
|
||||
max_wait: Optional maximum wait time in seconds for each message.
|
||||
If None, waits indefinitely.
|
||||
|
||||
Yields:
|
||||
Event objects received from subscribed channels
|
||||
|
||||
Raises:
|
||||
EventBusConnectionError: If not connected to Redis
|
||||
EventBusSubscriptionError: If subscription fails
|
||||
|
||||
Example:
|
||||
async for event in event_bus.subscribe(["project:123"], max_wait=30):
|
||||
print(f"Received: {event.type}")
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise EventBusConnectionError("EventBus not connected")
|
||||
|
||||
# Create a new pubsub for this subscription
|
||||
subscription_pubsub = self.redis_client.pubsub()
|
||||
|
||||
try:
|
||||
await subscription_pubsub.subscribe(*channels)
|
||||
logger.info(f"Subscribed to channels: {channels}")
|
||||
except redis.RedisError as e:
|
||||
logger.error(f"Failed to subscribe to channels: {e}", exc_info=True)
|
||||
await subscription_pubsub.close()
|
||||
raise EventBusSubscriptionError(f"Failed to subscribe: {e}") from e
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
if max_wait is not None:
|
||||
async with asyncio.timeout(max_wait):
|
||||
message = await subscription_pubsub.get_message(
|
||||
ignore_subscribe_messages=True, timeout=1.0
|
||||
)
|
||||
else:
|
||||
message = await subscription_pubsub.get_message(
|
||||
ignore_subscribe_messages=True, timeout=1.0
|
||||
)
|
||||
except TimeoutError:
|
||||
# Timeout reached, stop iteration
|
||||
return
|
||||
|
||||
if message is None:
|
||||
continue
|
||||
|
||||
if message["type"] == "message":
|
||||
try:
|
||||
event = self._deserialize_event(message["data"])
|
||||
yield event
|
||||
except ValidationError as e:
|
||||
logger.warning(
|
||||
f"Invalid event data received: {e}",
|
||||
extra={"channel": message.get("channel")},
|
||||
)
|
||||
continue
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(
|
||||
f"Failed to decode event JSON: {e}",
|
||||
extra={"channel": message.get("channel")},
|
||||
)
|
||||
continue
|
||||
finally:
|
||||
try:
|
||||
await subscription_pubsub.unsubscribe(*channels)
|
||||
await subscription_pubsub.close()
|
||||
logger.debug(f"Unsubscribed from channels: {channels}")
|
||||
except redis.RedisError as e:
|
||||
logger.warning(f"Error unsubscribing from channels: {e}")
|
||||
|
||||
async def subscribe_sse(
|
||||
self,
|
||||
project_id: str | UUID,
|
||||
last_event_id: str | None = None,
|
||||
keepalive_interval: int = 30,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Subscribe to events for a project in SSE format.
|
||||
|
||||
This is an async generator that yields SSE-formatted event strings.
|
||||
It includes keepalive messages at the specified interval.
|
||||
|
||||
Args:
|
||||
project_id: The project to subscribe to
|
||||
last_event_id: Optional last received event ID for reconnection
|
||||
keepalive_interval: Seconds between keepalive messages (default 30)
|
||||
|
||||
Yields:
|
||||
SSE-formatted event strings (ready to send to client)
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise EventBusConnectionError("EventBus not connected")
|
||||
|
||||
project_id_str = str(project_id)
|
||||
channel = self.get_project_channel(project_id_str)
|
||||
|
||||
subscription_pubsub = self.redis_client.pubsub()
|
||||
await subscription_pubsub.subscribe(channel)
|
||||
|
||||
logger.info(
|
||||
f"Subscribed to SSE events for project {project_id_str} "
|
||||
f"(last_event_id={last_event_id})"
|
||||
)
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
# Wait for messages with a timeout for keepalive
|
||||
message = await asyncio.wait_for(
|
||||
subscription_pubsub.get_message(ignore_subscribe_messages=True),
|
||||
timeout=keepalive_interval,
|
||||
)
|
||||
|
||||
if message is not None and message["type"] == "message":
|
||||
event_data = message["data"]
|
||||
|
||||
# If reconnecting, check if we should skip this event
|
||||
if last_event_id:
|
||||
try:
|
||||
event_dict = json.loads(event_data)
|
||||
if event_dict.get("id") == last_event_id:
|
||||
# Found the last event, start yielding from next
|
||||
last_event_id = None
|
||||
continue
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
yield event_data
|
||||
|
||||
except TimeoutError:
|
||||
# Send keepalive comment
|
||||
yield "" # Empty string signals keepalive
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"SSE subscription cancelled for project {project_id_str}")
|
||||
raise
|
||||
finally:
|
||||
await subscription_pubsub.unsubscribe(channel)
|
||||
await subscription_pubsub.close()
|
||||
logger.info(f"Unsubscribed SSE from project {project_id_str}")
|
||||
|
||||
async def subscribe_with_callback(
|
||||
self,
|
||||
channels: list[str],
|
||||
callback: Any, # Callable[[Event], Awaitable[None]]
|
||||
stop_event: asyncio.Event | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Subscribe to channels and process events with a callback.
|
||||
|
||||
This method runs until stop_event is set or an unrecoverable error occurs.
|
||||
|
||||
Args:
|
||||
channels: List of channel names to subscribe to
|
||||
callback: Async function to call for each event
|
||||
stop_event: Optional asyncio.Event to signal stop
|
||||
|
||||
Example:
|
||||
async def handle_event(event: Event):
|
||||
print(f"Handling: {event.type}")
|
||||
|
||||
stop = asyncio.Event()
|
||||
asyncio.create_task(
|
||||
event_bus.subscribe_with_callback(["project:123"], handle_event, stop)
|
||||
)
|
||||
# Later...
|
||||
stop.set()
|
||||
"""
|
||||
if stop_event is None:
|
||||
stop_event = asyncio.Event()
|
||||
|
||||
try:
|
||||
async for event in self.subscribe(channels):
|
||||
if stop_event.is_set():
|
||||
break
|
||||
try:
|
||||
await callback(event)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in event callback: {e}", exc_info=True)
|
||||
except EventBusSubscriptionError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in subscription loop: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
# Singleton instance for application-wide use
|
||||
_event_bus: EventBus | None = None
|
||||
|
||||
|
||||
def get_event_bus() -> EventBus:
|
||||
"""
|
||||
Get the singleton EventBus instance.
|
||||
|
||||
Creates a new instance if one doesn't exist. Note that you still need
|
||||
to call connect() before using the EventBus.
|
||||
|
||||
Returns:
|
||||
The singleton EventBus instance
|
||||
"""
|
||||
global _event_bus
|
||||
if _event_bus is None:
|
||||
_event_bus = EventBus()
|
||||
return _event_bus
|
||||
|
||||
|
||||
async def get_connected_event_bus() -> EventBus:
|
||||
"""
|
||||
Get a connected EventBus instance.
|
||||
|
||||
Ensures the EventBus is connected before returning. For use in
|
||||
FastAPI dependency injection.
|
||||
|
||||
Returns:
|
||||
A connected EventBus instance
|
||||
|
||||
Raises:
|
||||
EventBusConnectionError: If connection fails
|
||||
"""
|
||||
event_bus = get_event_bus()
|
||||
if not event_bus.is_connected:
|
||||
await event_bus.connect()
|
||||
return event_bus
|
||||
|
||||
|
||||
async def close_event_bus() -> None:
|
||||
"""
|
||||
Close the global EventBus instance.
|
||||
|
||||
Should be called during application shutdown.
|
||||
"""
|
||||
global _event_bus
|
||||
if _event_bus is not None:
|
||||
await _event_bus.disconnect()
|
||||
_event_bus = None
|
||||
@@ -25,19 +25,15 @@ from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
import jwt
|
||||
from jwt.exceptions import ExpiredSignatureError, InvalidTokenError
|
||||
from jose import jwt
|
||||
from sqlalchemy import and_, delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.config import settings
|
||||
from app.models.oauth_authorization_code import OAuthAuthorizationCode
|
||||
from app.models.oauth_client import OAuthClient
|
||||
from app.models.oauth_provider_token import OAuthConsent, OAuthProviderRefreshToken
|
||||
from app.models.user import User
|
||||
from app.repositories.oauth_authorization_code import oauth_authorization_code_repo
|
||||
from app.repositories.oauth_client import oauth_client_repo
|
||||
from app.repositories.oauth_consent import oauth_consent_repo
|
||||
from app.repositories.oauth_provider_token import oauth_provider_token_repo
|
||||
from app.repositories.user import user_repo
|
||||
from app.schemas.oauth import OAuthClientCreate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -139,7 +135,7 @@ def verify_pkce(code_verifier: str, code_challenge: str, method: str) -> bool:
|
||||
if method != "S256":
|
||||
# SECURITY: Reject any method other than S256
|
||||
# 'plain' method provides no security against code interception attacks
|
||||
logger.warning("PKCE verification rejected for unsupported method: %s", method)
|
||||
logger.warning(f"PKCE verification rejected for unsupported method: {method}")
|
||||
return False
|
||||
|
||||
# SHA-256 hash, then base64url encode (RFC 7636 Section 4.2)
|
||||
@@ -165,7 +161,15 @@ def join_scope(scopes: list[str]) -> str:
|
||||
|
||||
async def get_client(db: AsyncSession, client_id: str) -> OAuthClient | None:
|
||||
"""Get OAuth client by client_id."""
|
||||
return await oauth_client_repo.get_by_client_id(db, client_id=client_id)
|
||||
result = await db.execute(
|
||||
select(OAuthClient).where(
|
||||
and_(
|
||||
OAuthClient.client_id == client_id,
|
||||
OAuthClient.is_active == True, # noqa: E712
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
async def validate_client(
|
||||
@@ -200,19 +204,21 @@ async def validate_client(
|
||||
if not client.client_secret_hash:
|
||||
raise InvalidClientError("Client not configured with secret")
|
||||
|
||||
# SECURITY: Verify secret using bcrypt
|
||||
# SECURITY: Verify secret using bcrypt (not SHA-256)
|
||||
# Supports both bcrypt and legacy SHA-256 hashes for migration
|
||||
from app.core.auth import verify_password
|
||||
|
||||
stored_hash = str(client.client_secret_hash)
|
||||
|
||||
if not stored_hash.startswith("$2"):
|
||||
raise InvalidClientError(
|
||||
"Client secret uses deprecated hash format. "
|
||||
"Please regenerate your client credentials."
|
||||
)
|
||||
|
||||
if not verify_password(client_secret, stored_hash):
|
||||
raise InvalidClientError("Invalid client secret")
|
||||
if stored_hash.startswith("$2"):
|
||||
# New bcrypt format
|
||||
if not verify_password(client_secret, stored_hash):
|
||||
raise InvalidClientError("Invalid client secret")
|
||||
else:
|
||||
# Legacy SHA-256 format
|
||||
computed_hash = hashlib.sha256(client_secret.encode()).hexdigest()
|
||||
if not secrets.compare_digest(computed_hash, stored_hash):
|
||||
raise InvalidClientError("Invalid client secret")
|
||||
|
||||
return client
|
||||
|
||||
@@ -257,9 +263,7 @@ def validate_scopes(client: OAuthClient, requested_scopes: list[str]) -> list[st
|
||||
# Warn if some scopes were filtered out
|
||||
invalid = requested - allowed
|
||||
if invalid:
|
||||
logger.warning(
|
||||
"Client %s requested invalid scopes: %s", client.client_id, invalid
|
||||
)
|
||||
logger.warning(f"Client {client.client_id} requested invalid scopes: {invalid}")
|
||||
|
||||
return list(valid)
|
||||
|
||||
@@ -307,24 +311,25 @@ async def create_authorization_code(
|
||||
minutes=AUTHORIZATION_CODE_EXPIRY_MINUTES
|
||||
)
|
||||
|
||||
await oauth_authorization_code_repo.create_code(
|
||||
db,
|
||||
auth_code = OAuthAuthorizationCode(
|
||||
code=code,
|
||||
client_id=client.client_id,
|
||||
user_id=user.id,
|
||||
redirect_uri=redirect_uri,
|
||||
scope=scope,
|
||||
expires_at=expires_at,
|
||||
code_challenge=code_challenge,
|
||||
code_challenge_method=code_challenge_method,
|
||||
state=state,
|
||||
nonce=nonce,
|
||||
expires_at=expires_at,
|
||||
used=False,
|
||||
)
|
||||
|
||||
db.add(auth_code)
|
||||
await db.commit()
|
||||
|
||||
logger.info(
|
||||
"Created authorization code for user %s and client %s",
|
||||
user.id,
|
||||
client.client_id,
|
||||
f"Created authorization code for user {user.id} and client {client.client_id}"
|
||||
)
|
||||
return code
|
||||
|
||||
@@ -361,20 +366,35 @@ async def exchange_authorization_code(
|
||||
"""
|
||||
# Atomically mark code as used and fetch it (prevents race condition)
|
||||
# RFC 6749 Section 4.1.2: Authorization codes MUST be single-use
|
||||
updated_id = await oauth_authorization_code_repo.consume_code_atomically(
|
||||
db, code=code
|
||||
from sqlalchemy import update
|
||||
|
||||
# First, atomically mark the code as used and get affected count
|
||||
update_stmt = (
|
||||
update(OAuthAuthorizationCode)
|
||||
.where(
|
||||
and_(
|
||||
OAuthAuthorizationCode.code == code,
|
||||
OAuthAuthorizationCode.used == False, # noqa: E712
|
||||
)
|
||||
)
|
||||
.values(used=True)
|
||||
.returning(OAuthAuthorizationCode.id)
|
||||
)
|
||||
result = await db.execute(update_stmt)
|
||||
updated_id = result.scalar_one_or_none()
|
||||
|
||||
if not updated_id:
|
||||
# Either code doesn't exist or was already used
|
||||
# Check if it exists to provide appropriate error
|
||||
existing_code = await oauth_authorization_code_repo.get_by_code(db, code=code)
|
||||
check_result = await db.execute(
|
||||
select(OAuthAuthorizationCode).where(OAuthAuthorizationCode.code == code)
|
||||
)
|
||||
existing_code = check_result.scalar_one_or_none()
|
||||
|
||||
if existing_code and existing_code.used:
|
||||
# Code reuse is a security incident - revoke all tokens for this grant
|
||||
logger.warning(
|
||||
"Authorization code reuse detected for client %s",
|
||||
existing_code.client_id,
|
||||
f"Authorization code reuse detected for client {existing_code.client_id}"
|
||||
)
|
||||
await revoke_tokens_for_user_client(
|
||||
db, UUID(str(existing_code.user_id)), str(existing_code.client_id)
|
||||
@@ -384,9 +404,11 @@ async def exchange_authorization_code(
|
||||
raise InvalidGrantError("Invalid authorization code")
|
||||
|
||||
# Now fetch the full auth code record
|
||||
auth_code = await oauth_authorization_code_repo.get_by_id(db, code_id=updated_id)
|
||||
if auth_code is None:
|
||||
raise InvalidGrantError("Authorization code not found after consumption")
|
||||
auth_code_result = await db.execute(
|
||||
select(OAuthAuthorizationCode).where(OAuthAuthorizationCode.id == updated_id)
|
||||
)
|
||||
auth_code = auth_code_result.scalar_one()
|
||||
await db.commit()
|
||||
|
||||
if auth_code.is_expired:
|
||||
raise InvalidGrantError("Authorization code has expired")
|
||||
@@ -430,7 +452,8 @@ async def exchange_authorization_code(
|
||||
raise InvalidGrantError("PKCE required for public clients")
|
||||
|
||||
# Get user
|
||||
user = await user_repo.get(db, id=str(auth_code.user_id))
|
||||
user_result = await db.execute(select(User).where(User.id == auth_code.user_id))
|
||||
user = user_result.scalar_one_or_none()
|
||||
if not user or not user.is_active:
|
||||
raise InvalidGrantError("User not found or inactive")
|
||||
|
||||
@@ -520,8 +543,7 @@ async def create_tokens(
|
||||
refresh_token_hash = hash_token(refresh_token)
|
||||
|
||||
# Store refresh token in database
|
||||
await oauth_provider_token_repo.create_token(
|
||||
db,
|
||||
refresh_token_record = OAuthProviderRefreshToken(
|
||||
token_hash=refresh_token_hash,
|
||||
jti=jti,
|
||||
client_id=client.client_id,
|
||||
@@ -531,8 +553,10 @@ async def create_tokens(
|
||||
device_info=device_info,
|
||||
ip_address=ip_address,
|
||||
)
|
||||
db.add(refresh_token_record)
|
||||
await db.commit()
|
||||
|
||||
logger.info("Issued tokens for user %s to client %s", user.id, client.client_id)
|
||||
logger.info(f"Issued tokens for user {user.id} to client {client.client_id}")
|
||||
|
||||
return {
|
||||
"access_token": access_token,
|
||||
@@ -575,9 +599,12 @@ async def refresh_tokens(
|
||||
"""
|
||||
# Find refresh token
|
||||
token_hash = hash_token(refresh_token)
|
||||
token_record = await oauth_provider_token_repo.get_by_token_hash(
|
||||
db, token_hash=token_hash
|
||||
result = await db.execute(
|
||||
select(OAuthProviderRefreshToken).where(
|
||||
OAuthProviderRefreshToken.token_hash == token_hash
|
||||
)
|
||||
)
|
||||
token_record: OAuthProviderRefreshToken | None = result.scalar_one_or_none()
|
||||
|
||||
if not token_record:
|
||||
raise InvalidGrantError("Invalid refresh token")
|
||||
@@ -585,7 +612,7 @@ async def refresh_tokens(
|
||||
if token_record.revoked:
|
||||
# Token reuse after revocation - security incident
|
||||
logger.warning(
|
||||
"Revoked refresh token reuse detected for client %s", token_record.client_id
|
||||
f"Revoked refresh token reuse detected for client {token_record.client_id}"
|
||||
)
|
||||
raise InvalidGrantError("Refresh token has been revoked")
|
||||
|
||||
@@ -604,7 +631,8 @@ async def refresh_tokens(
|
||||
)
|
||||
|
||||
# Get user
|
||||
user = await user_repo.get(db, id=str(token_record.user_id))
|
||||
user_result = await db.execute(select(User).where(User.id == token_record.user_id))
|
||||
user = user_result.scalar_one_or_none()
|
||||
if not user or not user.is_active:
|
||||
raise InvalidGrantError("User not found or inactive")
|
||||
|
||||
@@ -620,7 +648,9 @@ async def refresh_tokens(
|
||||
final_scope = token_scope
|
||||
|
||||
# Revoke old refresh token (token rotation)
|
||||
await oauth_provider_token_repo.revoke(db, token=token_record)
|
||||
token_record.revoked = True # type: ignore[assignment]
|
||||
token_record.last_used_at = datetime.now(UTC) # type: ignore[assignment]
|
||||
await db.commit()
|
||||
|
||||
# Issue new tokens
|
||||
device = str(token_record.device_info) if token_record.device_info else None
|
||||
@@ -667,22 +697,28 @@ async def revoke_token(
|
||||
# Try as refresh token first (more likely)
|
||||
if token_type_hint != "access_token":
|
||||
token_hash = hash_token(token)
|
||||
refresh_record = await oauth_provider_token_repo.get_by_token_hash(
|
||||
db, token_hash=token_hash
|
||||
result = await db.execute(
|
||||
select(OAuthProviderRefreshToken).where(
|
||||
OAuthProviderRefreshToken.token_hash == token_hash
|
||||
)
|
||||
)
|
||||
refresh_record = result.scalar_one_or_none()
|
||||
|
||||
if refresh_record:
|
||||
# Validate client if provided
|
||||
if client_id and refresh_record.client_id != client_id:
|
||||
raise InvalidClientError("Token was not issued to this client")
|
||||
|
||||
await oauth_provider_token_repo.revoke(db, token=refresh_record)
|
||||
logger.info("Revoked refresh token %s...", refresh_record.jti[:8])
|
||||
refresh_record.revoked = True # type: ignore[assignment]
|
||||
await db.commit()
|
||||
logger.info(f"Revoked refresh token {refresh_record.jti[:8]}...")
|
||||
return True
|
||||
|
||||
# Try as access token (JWT)
|
||||
if token_type_hint != "refresh_token":
|
||||
try:
|
||||
from jose.exceptions import JWTError
|
||||
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.SECRET_KEY,
|
||||
@@ -695,18 +731,22 @@ async def revoke_token(
|
||||
jti = payload.get("jti")
|
||||
if jti:
|
||||
# Find and revoke the associated refresh token
|
||||
refresh_record = await oauth_provider_token_repo.get_by_jti(db, jti=jti)
|
||||
result = await db.execute(
|
||||
select(OAuthProviderRefreshToken).where(
|
||||
OAuthProviderRefreshToken.jti == jti
|
||||
)
|
||||
)
|
||||
refresh_record = result.scalar_one_or_none()
|
||||
if refresh_record:
|
||||
if client_id and refresh_record.client_id != client_id:
|
||||
raise InvalidClientError("Token was not issued to this client")
|
||||
await oauth_provider_token_repo.revoke(db, token=refresh_record)
|
||||
refresh_record.revoked = True # type: ignore[assignment]
|
||||
await db.commit()
|
||||
logger.info(
|
||||
"Revoked refresh token via access token JTI %s...", jti[:8]
|
||||
f"Revoked refresh token via access token JTI {jti[:8]}..."
|
||||
)
|
||||
return True
|
||||
except InvalidTokenError:
|
||||
pass
|
||||
except Exception: # noqa: S110 - Intentional: invalid JWT not an error
|
||||
except (JWTError, Exception): # noqa: S110 - Intentional: invalid JWT not an error
|
||||
pass
|
||||
|
||||
return False
|
||||
@@ -730,13 +770,26 @@ async def revoke_tokens_for_user_client(
|
||||
Returns:
|
||||
Number of tokens revoked
|
||||
"""
|
||||
count = await oauth_provider_token_repo.revoke_all_for_user_client(
|
||||
db, user_id=user_id, client_id=client_id
|
||||
result = await db.execute(
|
||||
select(OAuthProviderRefreshToken).where(
|
||||
and_(
|
||||
OAuthProviderRefreshToken.user_id == user_id,
|
||||
OAuthProviderRefreshToken.client_id == client_id,
|
||||
OAuthProviderRefreshToken.revoked == False, # noqa: E712
|
||||
)
|
||||
)
|
||||
)
|
||||
tokens = result.scalars().all()
|
||||
|
||||
count = 0
|
||||
for token in tokens:
|
||||
token.revoked = True # type: ignore[assignment]
|
||||
count += 1
|
||||
|
||||
if count > 0:
|
||||
await db.commit()
|
||||
logger.warning(
|
||||
"Revoked %s tokens for user %s and client %s", count, user_id, client_id
|
||||
f"Revoked {count} tokens for user {user_id} and client {client_id}"
|
||||
)
|
||||
|
||||
return count
|
||||
@@ -755,10 +808,24 @@ async def revoke_all_user_tokens(db: AsyncSession, user_id: UUID) -> int:
|
||||
Returns:
|
||||
Number of tokens revoked
|
||||
"""
|
||||
count = await oauth_provider_token_repo.revoke_all_for_user(db, user_id=user_id)
|
||||
result = await db.execute(
|
||||
select(OAuthProviderRefreshToken).where(
|
||||
and_(
|
||||
OAuthProviderRefreshToken.user_id == user_id,
|
||||
OAuthProviderRefreshToken.revoked == False, # noqa: E712
|
||||
)
|
||||
)
|
||||
)
|
||||
tokens = result.scalars().all()
|
||||
|
||||
count = 0
|
||||
for token in tokens:
|
||||
token.revoked = True # type: ignore[assignment]
|
||||
count += 1
|
||||
|
||||
if count > 0:
|
||||
logger.info("Revoked %s OAuth provider tokens for user %s", count, user_id)
|
||||
await db.commit()
|
||||
logger.info(f"Revoked {count} OAuth provider tokens for user {user_id}")
|
||||
|
||||
return count
|
||||
|
||||
@@ -797,6 +864,8 @@ async def introspect_token(
|
||||
# Try as access token (JWT) first
|
||||
if token_type_hint != "refresh_token":
|
||||
try:
|
||||
from jose.exceptions import ExpiredSignatureError, JWTError
|
||||
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.SECRET_KEY,
|
||||
@@ -809,7 +878,12 @@ async def introspect_token(
|
||||
# Check if associated refresh token is revoked
|
||||
jti = payload.get("jti")
|
||||
if jti:
|
||||
refresh_record = await oauth_provider_token_repo.get_by_jti(db, jti=jti)
|
||||
result = await db.execute(
|
||||
select(OAuthProviderRefreshToken).where(
|
||||
OAuthProviderRefreshToken.jti == jti
|
||||
)
|
||||
)
|
||||
refresh_record = result.scalar_one_or_none()
|
||||
if refresh_record and refresh_record.revoked:
|
||||
return {"active": False}
|
||||
|
||||
@@ -827,17 +901,18 @@ async def introspect_token(
|
||||
}
|
||||
except ExpiredSignatureError:
|
||||
return {"active": False}
|
||||
except InvalidTokenError:
|
||||
pass
|
||||
except Exception: # noqa: S110 - Intentional: invalid JWT falls through to refresh token check
|
||||
except (JWTError, Exception): # noqa: S110 - Intentional: invalid JWT falls through to refresh token check
|
||||
pass
|
||||
|
||||
# Try as refresh token
|
||||
if token_type_hint != "access_token":
|
||||
token_hash = hash_token(token)
|
||||
refresh_record = await oauth_provider_token_repo.get_by_token_hash(
|
||||
db, token_hash=token_hash
|
||||
result = await db.execute(
|
||||
select(OAuthProviderRefreshToken).where(
|
||||
OAuthProviderRefreshToken.token_hash == token_hash
|
||||
)
|
||||
)
|
||||
refresh_record = result.scalar_one_or_none()
|
||||
|
||||
if refresh_record and refresh_record.is_valid:
|
||||
return {
|
||||
@@ -862,11 +937,17 @@ async def get_consent(
|
||||
db: AsyncSession,
|
||||
user_id: UUID,
|
||||
client_id: str,
|
||||
):
|
||||
) -> OAuthConsent | None:
|
||||
"""Get existing consent record for user-client pair."""
|
||||
return await oauth_consent_repo.get_consent(
|
||||
db, user_id=user_id, client_id=client_id
|
||||
result = await db.execute(
|
||||
select(OAuthConsent).where(
|
||||
and_(
|
||||
OAuthConsent.user_id == user_id,
|
||||
OAuthConsent.client_id == client_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
async def check_consent(
|
||||
@@ -891,15 +972,31 @@ async def grant_consent(
|
||||
user_id: UUID,
|
||||
client_id: str,
|
||||
scopes: list[str],
|
||||
):
|
||||
) -> OAuthConsent:
|
||||
"""
|
||||
Grant or update consent for a user-client pair.
|
||||
|
||||
If consent already exists, updates the granted scopes.
|
||||
"""
|
||||
return await oauth_consent_repo.grant_consent(
|
||||
db, user_id=user_id, client_id=client_id, scopes=scopes
|
||||
)
|
||||
consent = await get_consent(db, user_id, client_id)
|
||||
|
||||
if consent:
|
||||
# Merge scopes
|
||||
granted = str(consent.granted_scopes) if consent.granted_scopes else ""
|
||||
existing = set(parse_scope(granted))
|
||||
new_scopes = existing | set(scopes)
|
||||
consent.granted_scopes = join_scope(list(new_scopes)) # type: ignore[assignment]
|
||||
else:
|
||||
consent = OAuthConsent(
|
||||
user_id=user_id,
|
||||
client_id=client_id,
|
||||
granted_scopes=join_scope(scopes),
|
||||
)
|
||||
db.add(consent)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(consent)
|
||||
return consent
|
||||
|
||||
|
||||
async def revoke_consent(
|
||||
@@ -912,13 +1009,21 @@ async def revoke_consent(
|
||||
|
||||
Returns True if consent was found and revoked.
|
||||
"""
|
||||
# Revoke all tokens first
|
||||
# Delete consent record
|
||||
result = await db.execute(
|
||||
delete(OAuthConsent).where(
|
||||
and_(
|
||||
OAuthConsent.user_id == user_id,
|
||||
OAuthConsent.client_id == client_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Revoke all tokens
|
||||
await revoke_tokens_for_user_client(db, user_id, client_id)
|
||||
|
||||
# Delete consent record
|
||||
return await oauth_consent_repo.revoke_consent(
|
||||
db, user_id=user_id, client_id=client_id
|
||||
)
|
||||
await db.commit()
|
||||
return result.rowcount > 0 # type: ignore[attr-defined]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
@@ -926,26 +1031,6 @@ async def revoke_consent(
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def register_client(db: AsyncSession, client_data: OAuthClientCreate) -> tuple:
|
||||
"""Create a new OAuth client. Returns (client, secret)."""
|
||||
return await oauth_client_repo.create_client(db, obj_in=client_data)
|
||||
|
||||
|
||||
async def list_clients(db: AsyncSession) -> list:
|
||||
"""List all registered OAuth clients."""
|
||||
return await oauth_client_repo.get_all_clients(db)
|
||||
|
||||
|
||||
async def delete_client_by_id(db: AsyncSession, client_id: str) -> None:
|
||||
"""Delete an OAuth client by client_id."""
|
||||
await oauth_client_repo.delete_client(db, client_id=client_id)
|
||||
|
||||
|
||||
async def list_user_consents(db: AsyncSession, user_id: UUID) -> list[dict]:
|
||||
"""Get all OAuth consents for a user with client details."""
|
||||
return await oauth_consent_repo.get_user_consents_with_clients(db, user_id=user_id)
|
||||
|
||||
|
||||
async def cleanup_expired_codes(db: AsyncSession) -> int:
|
||||
"""
|
||||
Delete expired authorization codes.
|
||||
@@ -955,7 +1040,13 @@ async def cleanup_expired_codes(db: AsyncSession) -> int:
|
||||
Returns:
|
||||
Number of codes deleted
|
||||
"""
|
||||
return await oauth_authorization_code_repo.cleanup_expired(db)
|
||||
result = await db.execute(
|
||||
delete(OAuthAuthorizationCode).where(
|
||||
OAuthAuthorizationCode.expires_at < datetime.now(UTC)
|
||||
)
|
||||
)
|
||||
await db.commit()
|
||||
return result.rowcount # type: ignore[attr-defined]
|
||||
|
||||
|
||||
async def cleanup_expired_tokens(db: AsyncSession) -> int:
|
||||
@@ -967,4 +1058,12 @@ async def cleanup_expired_tokens(db: AsyncSession) -> int:
|
||||
Returns:
|
||||
Number of tokens deleted
|
||||
"""
|
||||
return await oauth_provider_token_repo.cleanup_expired(db, cutoff_days=7)
|
||||
# Delete tokens that are both expired AND revoked (or just very old)
|
||||
cutoff = datetime.now(UTC) - timedelta(days=7)
|
||||
result = await db.execute(
|
||||
delete(OAuthProviderRefreshToken).where(
|
||||
OAuthProviderRefreshToken.expires_at < cutoff
|
||||
)
|
||||
)
|
||||
await db.commit()
|
||||
return result.rowcount # type: ignore[attr-defined]
|
||||
|
||||
@@ -19,15 +19,14 @@ from typing import TypedDict, cast
|
||||
from uuid import UUID
|
||||
|
||||
from authlib.integrations.httpx_client import AsyncOAuth2Client
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.auth import create_access_token, create_refresh_token
|
||||
from app.core.config import settings
|
||||
from app.core.exceptions import AuthenticationError
|
||||
from app.crud import oauth_account, oauth_state
|
||||
from app.models.user import User
|
||||
from app.repositories.oauth_account import oauth_account_repo as oauth_account
|
||||
from app.repositories.oauth_state import oauth_state_repo as oauth_state
|
||||
from app.repositories.user import user_repo
|
||||
from app.schemas.oauth import (
|
||||
OAuthAccountCreate,
|
||||
OAuthCallbackResponse,
|
||||
@@ -39,22 +38,19 @@ from app.schemas.oauth import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _OAuthProviderConfigRequired(TypedDict):
|
||||
class OAuthProviderConfig(TypedDict, total=False):
|
||||
"""Type definition for OAuth provider configuration."""
|
||||
|
||||
name: str
|
||||
icon: str
|
||||
authorize_url: str
|
||||
token_url: str
|
||||
userinfo_url: str
|
||||
email_url: str # Optional, GitHub-only
|
||||
scopes: list[str]
|
||||
supports_pkce: bool
|
||||
|
||||
|
||||
class OAuthProviderConfig(_OAuthProviderConfigRequired, total=False):
|
||||
"""Type definition for OAuth provider configuration."""
|
||||
|
||||
email_url: str # Optional, GitHub-only
|
||||
|
||||
|
||||
# Provider configurations
|
||||
OAUTH_PROVIDERS: dict[str, OAuthProviderConfig] = {
|
||||
"google": {
|
||||
@@ -219,7 +215,7 @@ class OAuthService:
|
||||
**auth_params,
|
||||
)
|
||||
|
||||
logger.info("OAuth authorization URL created for %s", provider)
|
||||
logger.info(f"OAuth authorization URL created for {provider}")
|
||||
return url, state
|
||||
|
||||
@staticmethod
|
||||
@@ -254,9 +250,8 @@ class OAuthService:
|
||||
# This prevents authorization code injection attacks (RFC 6749 Section 10.6)
|
||||
if state_record.redirect_uri != redirect_uri:
|
||||
logger.warning(
|
||||
"OAuth redirect_uri mismatch: expected %s, got %s",
|
||||
state_record.redirect_uri,
|
||||
redirect_uri,
|
||||
f"OAuth redirect_uri mismatch: expected {state_record.redirect_uri}, "
|
||||
f"got {redirect_uri}"
|
||||
)
|
||||
raise AuthenticationError("Redirect URI mismatch")
|
||||
|
||||
@@ -300,7 +295,7 @@ class OAuthService:
|
||||
except AuthenticationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("OAuth token exchange failed: %s", e)
|
||||
logger.error(f"OAuth token exchange failed: {e!s}")
|
||||
raise AuthenticationError("Failed to exchange authorization code")
|
||||
|
||||
# Get user info from provider
|
||||
@@ -313,7 +308,7 @@ class OAuthService:
|
||||
client, provider, config, access_token
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get user info: %s", e)
|
||||
logger.error(f"Failed to get user info: {e!s}")
|
||||
raise AuthenticationError(
|
||||
"Failed to get user information from provider"
|
||||
)
|
||||
@@ -348,17 +343,18 @@ class OAuthService:
|
||||
await oauth_account.update_tokens(
|
||||
db,
|
||||
account=existing_oauth,
|
||||
access_token=token.get("access_token"),
|
||||
refresh_token=token.get("refresh_token"),
|
||||
token_expires_at=datetime.now(UTC)
|
||||
access_token_encrypted=token.get("access_token"), refresh_token_encrypted=token.get("refresh_token"), token_expires_at=datetime.now(UTC)
|
||||
+ timedelta(seconds=token.get("expires_in", 3600)),
|
||||
)
|
||||
|
||||
logger.info("OAuth login successful for %s via %s", user.email, provider)
|
||||
logger.info(f"OAuth login successful for {user.email} via {provider}")
|
||||
|
||||
elif state_record.user_id:
|
||||
# Account linking flow (user is already logged in)
|
||||
user = await user_repo.get(db, id=str(state_record.user_id))
|
||||
result = await db.execute(
|
||||
select(User).where(User.id == state_record.user_id)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user:
|
||||
raise AuthenticationError("User not found for account linking")
|
||||
@@ -379,23 +375,24 @@ class OAuthService:
|
||||
provider=provider,
|
||||
provider_user_id=provider_user_id,
|
||||
provider_email=provider_email,
|
||||
access_token=token.get("access_token"),
|
||||
refresh_token=token.get("refresh_token"),
|
||||
token_expires_at=datetime.now(UTC)
|
||||
access_token_encrypted=token.get("access_token"), refresh_token_encrypted=token.get("refresh_token"), token_expires_at=datetime.now(UTC)
|
||||
+ timedelta(seconds=token.get("expires_in", 3600))
|
||||
if token.get("expires_in")
|
||||
else None,
|
||||
)
|
||||
await oauth_account.create_account(db, obj_in=oauth_create)
|
||||
|
||||
logger.info("OAuth account linked: %s -> %s", provider, user.email)
|
||||
logger.info(f"OAuth account linked: {provider} -> {user.email}")
|
||||
|
||||
else:
|
||||
# New OAuth login - check for existing user by email
|
||||
user = None
|
||||
|
||||
if provider_email and settings.OAUTH_AUTO_LINK_BY_EMAIL:
|
||||
user = await user_repo.get_by_email(db, email=provider_email)
|
||||
result = await db.execute(
|
||||
select(User).where(User.email == provider_email)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if user:
|
||||
# Auto-link to existing user
|
||||
@@ -410,9 +407,7 @@ class OAuthService:
|
||||
if existing_provider:
|
||||
# This shouldn't happen if we got here, but safety check
|
||||
logger.warning(
|
||||
"OAuth account already linked (race condition?): %s -> %s",
|
||||
provider,
|
||||
user.email,
|
||||
f"OAuth account already linked (race condition?): {provider} -> {user.email}"
|
||||
)
|
||||
else:
|
||||
# Create OAuth account link
|
||||
@@ -421,8 +416,8 @@ class OAuthService:
|
||||
provider=provider,
|
||||
provider_user_id=provider_user_id,
|
||||
provider_email=provider_email,
|
||||
access_token=token.get("access_token"),
|
||||
refresh_token=token.get("refresh_token"),
|
||||
access_token_encrypted=token.get("access_token"),
|
||||
refresh_token_encrypted=token.get("refresh_token"),
|
||||
token_expires_at=datetime.now(UTC)
|
||||
+ timedelta(seconds=token.get("expires_in", 3600))
|
||||
if token.get("expires_in")
|
||||
@@ -430,9 +425,7 @@ class OAuthService:
|
||||
)
|
||||
await oauth_account.create_account(db, obj_in=oauth_create)
|
||||
|
||||
logger.info(
|
||||
"OAuth auto-linked by email: %s -> %s", provider, user.email
|
||||
)
|
||||
logger.info(f"OAuth auto-linked by email: {provider} -> {user.email}")
|
||||
|
||||
else:
|
||||
# Create new user
|
||||
@@ -452,7 +445,7 @@ class OAuthService:
|
||||
)
|
||||
is_new_user = True
|
||||
|
||||
logger.info("New user created via OAuth: %s (%s)", user.email, provider)
|
||||
logger.info(f"New user created via OAuth: {user.email} ({provider})")
|
||||
|
||||
# Generate JWT tokens
|
||||
claims = {
|
||||
@@ -493,7 +486,7 @@ class OAuthService:
|
||||
# GitHub requires separate request for email
|
||||
if provider == "github" and not user_info.get("email"):
|
||||
email_resp = await client.get(
|
||||
config["email_url"], # pyright: ignore[reportTypedDictNotRequiredAccess]
|
||||
config["email_url"],
|
||||
headers=headers,
|
||||
)
|
||||
email_resp.raise_for_status()
|
||||
@@ -537,9 +530,8 @@ class OAuthService:
|
||||
AuthenticationError: If verification fails
|
||||
"""
|
||||
import httpx
|
||||
import jwt as pyjwt
|
||||
from jwt.algorithms import RSAAlgorithm
|
||||
from jwt.exceptions import InvalidTokenError
|
||||
from jose import jwt as jose_jwt
|
||||
from jose.exceptions import JWTError
|
||||
|
||||
try:
|
||||
# Fetch Google's public keys (JWKS)
|
||||
@@ -553,27 +545,24 @@ class OAuthService:
|
||||
jwks = jwks_response.json()
|
||||
|
||||
# Get the key ID from the token header
|
||||
unverified_header = pyjwt.get_unverified_header(id_token)
|
||||
unverified_header = jose_jwt.get_unverified_header(id_token)
|
||||
kid = unverified_header.get("kid")
|
||||
if not kid:
|
||||
raise AuthenticationError("ID token missing key ID (kid)")
|
||||
|
||||
# Find the matching public key
|
||||
jwk_data = None
|
||||
public_key = None
|
||||
for key in jwks.get("keys", []):
|
||||
if key.get("kid") == kid:
|
||||
jwk_data = key
|
||||
public_key = key
|
||||
break
|
||||
|
||||
if not jwk_data:
|
||||
if not public_key:
|
||||
raise AuthenticationError("ID token signed with unknown key")
|
||||
|
||||
# Convert JWK to a public key object for PyJWT
|
||||
public_key = RSAAlgorithm.from_jwk(jwk_data)
|
||||
|
||||
# Verify the token signature and decode claims
|
||||
# PyJWT will verify signature against the RSA public key
|
||||
payload = pyjwt.decode(
|
||||
# jose library will verify signature against the JWK
|
||||
payload = jose_jwt.decode(
|
||||
id_token,
|
||||
public_key,
|
||||
algorithms=["RS256"], # Google uses RS256
|
||||
@@ -592,24 +581,23 @@ class OAuthService:
|
||||
token_nonce = payload.get("nonce")
|
||||
if token_nonce != expected_nonce:
|
||||
logger.warning(
|
||||
"OAuth ID token nonce mismatch: expected %s, got %s",
|
||||
expected_nonce,
|
||||
token_nonce,
|
||||
f"OAuth ID token nonce mismatch: expected {expected_nonce}, "
|
||||
f"got {token_nonce}"
|
||||
)
|
||||
raise AuthenticationError("Invalid ID token nonce")
|
||||
|
||||
logger.debug("Google ID token verified successfully")
|
||||
return payload
|
||||
|
||||
except InvalidTokenError as e:
|
||||
logger.warning("Google ID token verification failed: %s", e)
|
||||
except JWTError as e:
|
||||
logger.warning(f"Google ID token verification failed: {e}")
|
||||
raise AuthenticationError("Invalid ID token signature")
|
||||
except httpx.HTTPError as e:
|
||||
logger.error("Failed to fetch Google JWKS: %s", e)
|
||||
logger.error(f"Failed to fetch Google JWKS: {e}")
|
||||
# If we can't verify the ID token, fail closed for security
|
||||
raise AuthenticationError("Failed to verify ID token")
|
||||
except Exception as e:
|
||||
logger.error("Unexpected error verifying Google ID token: %s", e)
|
||||
logger.error(f"Unexpected error verifying Google ID token: {e}")
|
||||
raise AuthenticationError("ID token verification error")
|
||||
|
||||
@staticmethod
|
||||
@@ -656,15 +644,14 @@ class OAuthService:
|
||||
provider=provider,
|
||||
provider_user_id=provider_user_id,
|
||||
provider_email=email,
|
||||
access_token=token.get("access_token"),
|
||||
refresh_token=token.get("refresh_token"),
|
||||
token_expires_at=datetime.now(UTC)
|
||||
access_token_encrypted=token.get("access_token"), refresh_token_encrypted=token.get("refresh_token"), token_expires_at=datetime.now(UTC)
|
||||
+ timedelta(seconds=token.get("expires_in", 3600))
|
||||
if token.get("expires_in")
|
||||
else None,
|
||||
)
|
||||
await oauth_account.create_account(db, obj_in=oauth_create)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
|
||||
return user
|
||||
@@ -711,23 +698,9 @@ class OAuthService:
|
||||
if not deleted:
|
||||
raise AuthenticationError(f"No {provider} account found to unlink")
|
||||
|
||||
logger.info("OAuth provider unlinked: %s from %s", provider, user.email)
|
||||
logger.info(f"OAuth provider unlinked: {provider} from {user.email}")
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
async def get_user_accounts(db: AsyncSession, *, user_id: UUID) -> list:
|
||||
"""Get all OAuth accounts linked to a user."""
|
||||
return await oauth_account.get_user_accounts(db, user_id=user_id)
|
||||
|
||||
@staticmethod
|
||||
async def get_user_account_by_provider(
|
||||
db: AsyncSession, *, user_id: UUID, provider: str
|
||||
):
|
||||
"""Get a specific OAuth account for a user and provider."""
|
||||
return await oauth_account.get_user_account_by_provider(
|
||||
db, user_id=user_id, provider=provider
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def cleanup_expired_states(db: AsyncSession) -> int:
|
||||
"""
|
||||
|
||||
@@ -1,155 +0,0 @@
|
||||
# app/services/organization_service.py
|
||||
"""Service layer for organization operations — delegates to OrganizationRepository."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.exceptions import NotFoundError
|
||||
from app.models.organization import Organization
|
||||
from app.models.user_organization import OrganizationRole, UserOrganization
|
||||
from app.repositories.organization import OrganizationRepository, organization_repo
|
||||
from app.schemas.organizations import OrganizationCreate, OrganizationUpdate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OrganizationService:
|
||||
"""Service for organization management operations."""
|
||||
|
||||
def __init__(
|
||||
self, organization_repository: OrganizationRepository | None = None
|
||||
) -> None:
|
||||
self._repo = organization_repository or organization_repo
|
||||
|
||||
async def get_organization(self, db: AsyncSession, org_id: str) -> Organization:
|
||||
"""Get organization by ID, raising NotFoundError if not found."""
|
||||
org = await self._repo.get(db, id=org_id)
|
||||
if not org:
|
||||
raise NotFoundError(f"Organization {org_id} not found")
|
||||
return org
|
||||
|
||||
async def create_organization(
|
||||
self, db: AsyncSession, *, obj_in: OrganizationCreate
|
||||
) -> Organization:
|
||||
"""Create a new organization."""
|
||||
return await self._repo.create(db, obj_in=obj_in)
|
||||
|
||||
async def update_organization(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
org: Organization,
|
||||
obj_in: OrganizationUpdate | dict[str, Any],
|
||||
) -> Organization:
|
||||
"""Update an existing organization."""
|
||||
return await self._repo.update(db, db_obj=org, obj_in=obj_in)
|
||||
|
||||
async def remove_organization(self, db: AsyncSession, org_id: str) -> None:
|
||||
"""Permanently delete an organization by ID."""
|
||||
await self._repo.remove(db, id=org_id)
|
||||
|
||||
async def get_member_count(self, db: AsyncSession, *, organization_id: UUID) -> int:
|
||||
"""Get number of active members in an organization."""
|
||||
return await self._repo.get_member_count(db, organization_id=organization_id)
|
||||
|
||||
async def get_multi_with_member_counts(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
is_active: bool | None = None,
|
||||
search: str | None = None,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""List organizations with member counts and pagination."""
|
||||
return await self._repo.get_multi_with_member_counts(
|
||||
db, skip=skip, limit=limit, is_active=is_active, search=search
|
||||
)
|
||||
|
||||
async def get_user_organizations_with_details(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: UUID,
|
||||
is_active: bool | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get all organizations a user belongs to, with membership details."""
|
||||
return await self._repo.get_user_organizations_with_details(
|
||||
db, user_id=user_id, is_active=is_active
|
||||
)
|
||||
|
||||
async def get_organization_members(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
organization_id: UUID,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
is_active: bool | None = True,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""Get members of an organization with pagination."""
|
||||
return await self._repo.get_organization_members(
|
||||
db,
|
||||
organization_id=organization_id,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
is_active=is_active,
|
||||
)
|
||||
|
||||
async def add_member(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
organization_id: UUID,
|
||||
user_id: UUID,
|
||||
role: OrganizationRole = OrganizationRole.MEMBER,
|
||||
) -> UserOrganization:
|
||||
"""Add a user to an organization."""
|
||||
return await self._repo.add_user(
|
||||
db, organization_id=organization_id, user_id=user_id, role=role
|
||||
)
|
||||
|
||||
async def remove_member(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
organization_id: UUID,
|
||||
user_id: UUID,
|
||||
) -> bool:
|
||||
"""Remove a user from an organization. Returns True if found and removed."""
|
||||
return await self._repo.remove_user(
|
||||
db, organization_id=organization_id, user_id=user_id
|
||||
)
|
||||
|
||||
async def get_user_role_in_org(
|
||||
self, db: AsyncSession, *, user_id: UUID, organization_id: UUID
|
||||
) -> OrganizationRole | None:
|
||||
"""Get the role of a user in an organization."""
|
||||
return await self._repo.get_user_role_in_org(
|
||||
db, user_id=user_id, organization_id=organization_id
|
||||
)
|
||||
|
||||
async def get_org_distribution(
|
||||
self, db: AsyncSession, *, limit: int = 6
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Return top organizations by member count for admin dashboard."""
|
||||
from sqlalchemy import func, select
|
||||
|
||||
result = await db.execute(
|
||||
select(
|
||||
Organization.name,
|
||||
func.count(UserOrganization.user_id).label("count"),
|
||||
)
|
||||
.join(UserOrganization, Organization.id == UserOrganization.organization_id)
|
||||
.group_by(Organization.name)
|
||||
.order_by(func.count(UserOrganization.user_id).desc())
|
||||
.limit(limit)
|
||||
)
|
||||
return [{"name": row.name, "value": row.count} for row in result.all()]
|
||||
|
||||
|
||||
# Default singleton
|
||||
organization_service = OrganizationService()
|
||||
@@ -8,7 +8,7 @@ import logging
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from app.core.database import SessionLocal
|
||||
from app.repositories.session import session_repo as session_repo
|
||||
from app.crud.session import session as session_crud
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -32,15 +32,15 @@ async def cleanup_expired_sessions(keep_days: int = 30) -> int:
|
||||
|
||||
async with SessionLocal() as db:
|
||||
try:
|
||||
# Use repository method to cleanup
|
||||
count = await session_repo.cleanup_expired(db, keep_days=keep_days)
|
||||
# Use CRUD method to cleanup
|
||||
count = await session_crud.cleanup_expired(db, keep_days=keep_days)
|
||||
|
||||
logger.info("Session cleanup complete: %s sessions deleted", count)
|
||||
logger.info(f"Session cleanup complete: {count} sessions deleted")
|
||||
|
||||
return count
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error during session cleanup: %s", e)
|
||||
logger.error(f"Error during session cleanup: {e!s}", exc_info=True)
|
||||
return 0
|
||||
|
||||
|
||||
@@ -79,10 +79,10 @@ async def get_session_statistics() -> dict:
|
||||
"expired": expired_sessions,
|
||||
}
|
||||
|
||||
logger.info("Session statistics: %s", stats)
|
||||
logger.info(f"Session statistics: {stats}")
|
||||
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error getting session statistics: %s", e)
|
||||
logger.error(f"Error getting session statistics: {e!s}", exc_info=True)
|
||||
return {}
|
||||
|
||||
@@ -1,97 +0,0 @@
|
||||
# app/services/session_service.py
|
||||
"""Service layer for session operations — delegates to SessionRepository."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.user_session import UserSession
|
||||
from app.repositories.session import SessionRepository, session_repo
|
||||
from app.schemas.sessions import SessionCreate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SessionService:
|
||||
"""Service for user session management operations."""
|
||||
|
||||
def __init__(self, session_repository: SessionRepository | None = None) -> None:
|
||||
self._repo = session_repository or session_repo
|
||||
|
||||
async def create_session(
|
||||
self, db: AsyncSession, *, obj_in: SessionCreate
|
||||
) -> UserSession:
|
||||
"""Create a new session record."""
|
||||
return await self._repo.create_session(db, obj_in=obj_in)
|
||||
|
||||
async def get_session(
|
||||
self, db: AsyncSession, session_id: str
|
||||
) -> UserSession | None:
|
||||
"""Get session by ID."""
|
||||
return await self._repo.get(db, id=session_id)
|
||||
|
||||
async def get_user_sessions(
|
||||
self, db: AsyncSession, *, user_id: str, active_only: bool = True
|
||||
) -> list[UserSession]:
|
||||
"""Get all sessions for a user."""
|
||||
return await self._repo.get_user_sessions(
|
||||
db, user_id=user_id, active_only=active_only
|
||||
)
|
||||
|
||||
async def get_active_by_jti(
|
||||
self, db: AsyncSession, *, jti: str
|
||||
) -> UserSession | None:
|
||||
"""Get active session by refresh token JTI."""
|
||||
return await self._repo.get_active_by_jti(db, jti=jti)
|
||||
|
||||
async def get_by_jti(self, db: AsyncSession, *, jti: str) -> UserSession | None:
|
||||
"""Get session by refresh token JTI (active or inactive)."""
|
||||
return await self._repo.get_by_jti(db, jti=jti)
|
||||
|
||||
async def deactivate(
|
||||
self, db: AsyncSession, *, session_id: str
|
||||
) -> UserSession | None:
|
||||
"""Deactivate a session (logout from device)."""
|
||||
return await self._repo.deactivate(db, session_id=session_id)
|
||||
|
||||
async def deactivate_all_user_sessions(
|
||||
self, db: AsyncSession, *, user_id: str
|
||||
) -> int:
|
||||
"""Deactivate all sessions for a user. Returns count deactivated."""
|
||||
return await self._repo.deactivate_all_user_sessions(db, user_id=user_id)
|
||||
|
||||
async def update_refresh_token(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
session: UserSession,
|
||||
new_jti: str,
|
||||
new_expires_at: datetime,
|
||||
) -> UserSession:
|
||||
"""Update session with a rotated refresh token."""
|
||||
return await self._repo.update_refresh_token(
|
||||
db, session=session, new_jti=new_jti, new_expires_at=new_expires_at
|
||||
)
|
||||
|
||||
async def cleanup_expired_for_user(self, db: AsyncSession, *, user_id: str) -> int:
|
||||
"""Remove expired sessions for a user. Returns count removed."""
|
||||
return await self._repo.cleanup_expired_for_user(db, user_id=user_id)
|
||||
|
||||
async def get_all_sessions(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
active_only: bool = True,
|
||||
with_user: bool = True,
|
||||
) -> tuple[list[UserSession], int]:
|
||||
"""Get all sessions with pagination (admin only)."""
|
||||
return await self._repo.get_all_sessions(
|
||||
db, skip=skip, limit=limit, active_only=active_only, with_user=with_user
|
||||
)
|
||||
|
||||
|
||||
# Default singleton
|
||||
session_service = SessionService()
|
||||
@@ -1,120 +0,0 @@
|
||||
# app/services/user_service.py
|
||||
"""Service layer for user operations — delegates to UserRepository."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.exceptions import NotFoundError
|
||||
from app.models.user import User
|
||||
from app.repositories.user import UserRepository, user_repo
|
||||
from app.schemas.users import UserCreate, UserUpdate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UserService:
|
||||
"""Service for user management operations."""
|
||||
|
||||
def __init__(self, user_repository: UserRepository | None = None) -> None:
|
||||
self._repo = user_repository or user_repo
|
||||
|
||||
async def get_user(self, db: AsyncSession, user_id: str) -> User:
|
||||
"""Get user by ID, raising NotFoundError if not found."""
|
||||
user = await self._repo.get(db, id=user_id)
|
||||
if not user:
|
||||
raise NotFoundError(f"User {user_id} not found")
|
||||
return user
|
||||
|
||||
async def get_by_email(self, db: AsyncSession, email: str) -> User | None:
|
||||
"""Get user by email address."""
|
||||
return await self._repo.get_by_email(db, email=email)
|
||||
|
||||
async def create_user(self, db: AsyncSession, user_data: UserCreate) -> User:
|
||||
"""Create a new user."""
|
||||
return await self._repo.create(db, obj_in=user_data)
|
||||
|
||||
async def update_user(
|
||||
self, db: AsyncSession, *, user: User, obj_in: UserUpdate | dict[str, Any]
|
||||
) -> User:
|
||||
"""Update an existing user."""
|
||||
return await self._repo.update(db, db_obj=user, obj_in=obj_in)
|
||||
|
||||
async def soft_delete_user(self, db: AsyncSession, user_id: str) -> None:
|
||||
"""Soft-delete a user by ID."""
|
||||
await self._repo.soft_delete(db, id=user_id)
|
||||
|
||||
async def list_users(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
sort_by: str | None = None,
|
||||
sort_order: str = "asc",
|
||||
filters: dict[str, Any] | None = None,
|
||||
search: str | None = None,
|
||||
) -> tuple[list[User], int]:
|
||||
"""List users with pagination, sorting, filtering, and search."""
|
||||
return await self._repo.get_multi_with_total(
|
||||
db,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
sort_by=sort_by,
|
||||
sort_order=sort_order,
|
||||
filters=filters,
|
||||
search=search,
|
||||
)
|
||||
|
||||
async def bulk_update_status(
|
||||
self, db: AsyncSession, *, user_ids: list[UUID], is_active: bool
|
||||
) -> int:
|
||||
"""Bulk update active status for multiple users. Returns count updated."""
|
||||
return await self._repo.bulk_update_status(
|
||||
db, user_ids=user_ids, is_active=is_active
|
||||
)
|
||||
|
||||
async def bulk_soft_delete(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_ids: list[UUID],
|
||||
exclude_user_id: UUID | None = None,
|
||||
) -> int:
|
||||
"""Bulk soft-delete multiple users. Returns count deleted."""
|
||||
return await self._repo.bulk_soft_delete(
|
||||
db, user_ids=user_ids, exclude_user_id=exclude_user_id
|
||||
)
|
||||
|
||||
async def get_stats(self, db: AsyncSession) -> dict[str, Any]:
|
||||
"""Return user stats needed for the admin dashboard."""
|
||||
from sqlalchemy import func, select
|
||||
|
||||
total_users = (
|
||||
await db.execute(select(func.count()).select_from(User))
|
||||
).scalar() or 0
|
||||
active_count = (
|
||||
await db.execute(
|
||||
select(func.count()).select_from(User).where(User.is_active)
|
||||
)
|
||||
).scalar() or 0
|
||||
inactive_count = (
|
||||
await db.execute(
|
||||
select(func.count()).select_from(User).where(User.is_active.is_(False))
|
||||
)
|
||||
).scalar() or 0
|
||||
all_users = list(
|
||||
(await db.execute(select(User).order_by(User.created_at))).scalars().all()
|
||||
)
|
||||
return {
|
||||
"total_users": total_users,
|
||||
"active_count": active_count,
|
||||
"inactive_count": inactive_count,
|
||||
"all_users": all_users,
|
||||
}
|
||||
|
||||
|
||||
# Default singleton
|
||||
user_service = UserService()
|
||||
23
backend/app/tasks/__init__.py
Normal file
23
backend/app/tasks/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
# app/tasks/__init__.py
|
||||
"""
|
||||
Celery background tasks for Syndarix.
|
||||
|
||||
This package contains all Celery tasks organized by domain:
|
||||
|
||||
Modules:
|
||||
agent: Agent execution tasks (run_agent_step, spawn_agent, terminate_agent)
|
||||
git: Git operation tasks (clone, commit, branch, push, PR)
|
||||
sync: Issue synchronization tasks (incremental/full sync, webhooks)
|
||||
workflow: Workflow state management tasks
|
||||
cost: Cost tracking and budget monitoring tasks
|
||||
"""
|
||||
|
||||
from app.tasks import agent, cost, git, sync, workflow
|
||||
|
||||
__all__ = [
|
||||
"agent",
|
||||
"cost",
|
||||
"git",
|
||||
"sync",
|
||||
"workflow",
|
||||
]
|
||||
150
backend/app/tasks/agent.py
Normal file
150
backend/app/tasks/agent.py
Normal file
@@ -0,0 +1,150 @@
|
||||
# app/tasks/agent.py
|
||||
"""
|
||||
Agent execution tasks for Syndarix.
|
||||
|
||||
These tasks handle the lifecycle of AI agent instances:
|
||||
- Spawning new agent instances from agent types
|
||||
- Executing agent steps (LLM calls, tool execution)
|
||||
- Terminating agent instances
|
||||
|
||||
Tasks are routed to the 'agent' queue for dedicated processing.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.celery_app import celery_app
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.tasks.agent.run_agent_step")
|
||||
def run_agent_step(
|
||||
self,
|
||||
agent_instance_id: str,
|
||||
context: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Execute a single step of an agent's workflow.
|
||||
|
||||
This task performs one iteration of the agent loop:
|
||||
1. Load agent instance state
|
||||
2. Call LLM with context and available tools
|
||||
3. Execute tool calls if any
|
||||
4. Update agent state
|
||||
5. Return result for next step or completion
|
||||
|
||||
Args:
|
||||
agent_instance_id: UUID of the agent instance
|
||||
context: Current execution context including:
|
||||
- messages: Conversation history
|
||||
- tools: Available tool definitions
|
||||
- state: Agent state data
|
||||
- metadata: Project/task metadata
|
||||
|
||||
Returns:
|
||||
dict with status and agent_instance_id
|
||||
"""
|
||||
logger.info(
|
||||
f"Running agent step for instance {agent_instance_id} with context keys: {list(context.keys())}"
|
||||
)
|
||||
|
||||
# TODO: Implement actual agent step execution
|
||||
# This will involve:
|
||||
# 1. Loading agent instance from database
|
||||
# 2. Calling LLM provider (via litellm or anthropic SDK)
|
||||
# 3. Processing tool calls through MCP servers
|
||||
# 4. Updating agent state in database
|
||||
# 5. Scheduling next step if needed
|
||||
|
||||
return {
|
||||
"status": "pending",
|
||||
"agent_instance_id": agent_instance_id,
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.tasks.agent.spawn_agent")
|
||||
def spawn_agent(
|
||||
self,
|
||||
agent_type_id: str,
|
||||
project_id: str,
|
||||
initial_context: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Spawn a new agent instance from an agent type.
|
||||
|
||||
This task creates a new agent instance:
|
||||
1. Load agent type configuration (model, expertise, personality)
|
||||
2. Create agent instance record in database
|
||||
3. Initialize agent state with project context
|
||||
4. Start first agent step
|
||||
|
||||
Args:
|
||||
agent_type_id: UUID of the agent type template
|
||||
project_id: UUID of the project this agent will work on
|
||||
initial_context: Starting context including:
|
||||
- goal: High-level objective
|
||||
- constraints: Any limitations or requirements
|
||||
- assigned_issues: Issues to work on
|
||||
- autonomy_level: FULL_CONTROL, MILESTONE, or AUTONOMOUS
|
||||
|
||||
Returns:
|
||||
dict with status, agent_type_id, and project_id
|
||||
"""
|
||||
logger.info(
|
||||
f"Spawning agent of type {agent_type_id} for project {project_id}"
|
||||
)
|
||||
|
||||
# TODO: Implement agent spawning
|
||||
# This will involve:
|
||||
# 1. Loading agent type from database
|
||||
# 2. Creating agent instance record
|
||||
# 3. Setting up MCP tool access
|
||||
# 4. Initializing agent state
|
||||
# 5. Kicking off first step
|
||||
|
||||
return {
|
||||
"status": "spawned",
|
||||
"agent_type_id": agent_type_id,
|
||||
"project_id": project_id,
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.tasks.agent.terminate_agent")
|
||||
def terminate_agent(
|
||||
self,
|
||||
agent_instance_id: str,
|
||||
reason: str,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Terminate an agent instance.
|
||||
|
||||
This task gracefully shuts down an agent:
|
||||
1. Mark agent instance as terminated
|
||||
2. Save final state for audit
|
||||
3. Release any held resources
|
||||
4. Notify relevant subscribers
|
||||
|
||||
Args:
|
||||
agent_instance_id: UUID of the agent instance
|
||||
reason: Reason for termination (completion, error, manual, budget)
|
||||
|
||||
Returns:
|
||||
dict with status and agent_instance_id
|
||||
"""
|
||||
logger.info(
|
||||
f"Terminating agent instance {agent_instance_id} with reason: {reason}"
|
||||
)
|
||||
|
||||
# TODO: Implement agent termination
|
||||
# This will involve:
|
||||
# 1. Loading agent instance
|
||||
# 2. Updating status to terminated
|
||||
# 3. Saving termination reason
|
||||
# 4. Cleaning up any pending tasks
|
||||
# 5. Sending termination event
|
||||
|
||||
return {
|
||||
"status": "terminated",
|
||||
"agent_instance_id": agent_instance_id,
|
||||
}
|
||||
201
backend/app/tasks/cost.py
Normal file
201
backend/app/tasks/cost.py
Normal file
@@ -0,0 +1,201 @@
|
||||
# app/tasks/cost.py
|
||||
"""
|
||||
Cost tracking and budget management tasks for Syndarix.
|
||||
|
||||
These tasks implement multi-layered cost tracking per ADR-012:
|
||||
- Per-agent token usage tracking
|
||||
- Project budget monitoring
|
||||
- Daily cost aggregation
|
||||
- Budget threshold alerts
|
||||
- Cost reporting
|
||||
|
||||
Costs are tracked in real-time in Redis for speed,
|
||||
then aggregated to PostgreSQL for durability.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.celery_app import celery_app
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.tasks.cost.aggregate_daily_costs")
|
||||
def aggregate_daily_costs(self) -> dict[str, Any]:
|
||||
"""
|
||||
Aggregate daily costs from Redis to PostgreSQL.
|
||||
|
||||
This periodic task (runs daily):
|
||||
1. Read accumulated costs from Redis
|
||||
2. Aggregate by project, agent, and model
|
||||
3. Store in PostgreSQL cost_records table
|
||||
4. Clear Redis counters for new day
|
||||
|
||||
Returns:
|
||||
dict with status
|
||||
"""
|
||||
logger.info("Starting daily cost aggregation")
|
||||
|
||||
# TODO: Implement cost aggregation
|
||||
# This will involve:
|
||||
# 1. Fetching cost data from Redis
|
||||
# 2. Grouping by project_id, agent_id, model
|
||||
# 3. Inserting into PostgreSQL cost tables
|
||||
# 4. Resetting Redis counters
|
||||
|
||||
return {
|
||||
"status": "pending",
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.tasks.cost.check_budget_thresholds")
|
||||
def check_budget_thresholds(
|
||||
self,
|
||||
project_id: str,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Check if a project has exceeded budget thresholds.
|
||||
|
||||
This task checks budget limits:
|
||||
1. Get current spend from Redis counters
|
||||
2. Compare against project budget limits
|
||||
3. Send alerts if thresholds exceeded
|
||||
4. Pause agents if hard limit reached
|
||||
|
||||
Args:
|
||||
project_id: UUID of the project
|
||||
|
||||
Returns:
|
||||
dict with status and project_id
|
||||
"""
|
||||
logger.info(f"Checking budget thresholds for project {project_id}")
|
||||
|
||||
# TODO: Implement budget checking
|
||||
# This will involve:
|
||||
# 1. Loading project budget configuration
|
||||
# 2. Getting current spend from Redis
|
||||
# 3. Comparing against soft/hard limits
|
||||
# 4. Sending alerts or pausing agents
|
||||
|
||||
return {
|
||||
"status": "pending",
|
||||
"project_id": project_id,
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.tasks.cost.record_llm_usage")
|
||||
def record_llm_usage(
|
||||
self,
|
||||
agent_id: str,
|
||||
project_id: str,
|
||||
model: str,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
cost_usd: float,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Record LLM usage from an agent call.
|
||||
|
||||
This task tracks each LLM API call:
|
||||
1. Increment Redis counters for real-time tracking
|
||||
2. Store raw usage event for audit
|
||||
3. Trigger budget check if threshold approaching
|
||||
|
||||
Args:
|
||||
agent_id: UUID of the agent instance
|
||||
project_id: UUID of the project
|
||||
model: Model identifier (e.g., claude-opus-4-5-20251101)
|
||||
prompt_tokens: Number of input tokens
|
||||
completion_tokens: Number of output tokens
|
||||
cost_usd: Calculated cost in USD
|
||||
|
||||
Returns:
|
||||
dict with status, agent_id, project_id, and cost_usd
|
||||
"""
|
||||
logger.debug(
|
||||
f"Recording LLM usage for model {model}: "
|
||||
f"{prompt_tokens} prompt + {completion_tokens} completion tokens = ${cost_usd}"
|
||||
)
|
||||
|
||||
# TODO: Implement usage recording
|
||||
# This will involve:
|
||||
# 1. Incrementing Redis counters
|
||||
# 2. Storing usage event
|
||||
# 3. Checking if near budget threshold
|
||||
|
||||
return {
|
||||
"status": "pending",
|
||||
"agent_id": agent_id,
|
||||
"project_id": project_id,
|
||||
"cost_usd": cost_usd,
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.tasks.cost.generate_cost_report")
|
||||
def generate_cost_report(
|
||||
self,
|
||||
project_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Generate a cost report for a project.
|
||||
|
||||
This task creates a detailed cost breakdown:
|
||||
1. Query cost records for date range
|
||||
2. Group by agent, model, and day
|
||||
3. Calculate totals and trends
|
||||
4. Format report for display
|
||||
|
||||
Args:
|
||||
project_id: UUID of the project
|
||||
start_date: Report start date (YYYY-MM-DD)
|
||||
end_date: Report end date (YYYY-MM-DD)
|
||||
|
||||
Returns:
|
||||
dict with status, project_id, and date range
|
||||
"""
|
||||
logger.info(
|
||||
f"Generating cost report for project {project_id} from {start_date} to {end_date}"
|
||||
)
|
||||
|
||||
# TODO: Implement report generation
|
||||
# This will involve:
|
||||
# 1. Querying PostgreSQL for cost records
|
||||
# 2. Aggregating by various dimensions
|
||||
# 3. Calculating totals and averages
|
||||
# 4. Formatting report data
|
||||
|
||||
return {
|
||||
"status": "pending",
|
||||
"project_id": project_id,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.tasks.cost.reset_daily_budget_counters")
|
||||
def reset_daily_budget_counters(self) -> dict[str, Any]:
|
||||
"""
|
||||
Reset daily budget counters in Redis.
|
||||
|
||||
This periodic task (runs daily at midnight UTC):
|
||||
1. Archive current day's counters
|
||||
2. Reset all daily budget counters
|
||||
3. Prepare for new day's tracking
|
||||
|
||||
Returns:
|
||||
dict with status
|
||||
"""
|
||||
logger.info("Resetting daily budget counters")
|
||||
|
||||
# TODO: Implement counter reset
|
||||
# This will involve:
|
||||
# 1. Getting all daily counter keys from Redis
|
||||
# 2. Archiving current values
|
||||
# 3. Resetting counters to zero
|
||||
|
||||
return {
|
||||
"status": "pending",
|
||||
}
|
||||
225
backend/app/tasks/git.py
Normal file
225
backend/app/tasks/git.py
Normal file
@@ -0,0 +1,225 @@
|
||||
# app/tasks/git.py
|
||||
"""
|
||||
Git operation tasks for Syndarix.
|
||||
|
||||
These tasks handle Git operations for projects:
|
||||
- Cloning repositories
|
||||
- Creating branches
|
||||
- Committing changes
|
||||
- Pushing to remotes
|
||||
- Creating pull requests
|
||||
|
||||
Tasks are routed to the 'git' queue for dedicated processing.
|
||||
All operations are scoped by project_id for multi-tenancy.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.celery_app import celery_app
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.tasks.git.clone_repository")
|
||||
def clone_repository(
|
||||
self,
|
||||
project_id: str,
|
||||
repo_url: str,
|
||||
branch: str = "main",
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Clone a repository for a project.
|
||||
|
||||
This task clones a Git repository to the project workspace:
|
||||
1. Prepare workspace directory
|
||||
2. Clone repository with credentials
|
||||
3. Checkout specified branch
|
||||
4. Update project metadata
|
||||
|
||||
Args:
|
||||
project_id: UUID of the project
|
||||
repo_url: Git repository URL (HTTPS or SSH)
|
||||
branch: Branch to checkout (default: main)
|
||||
|
||||
Returns:
|
||||
dict with status and project_id
|
||||
"""
|
||||
logger.info(
|
||||
f"Cloning repository {repo_url} for project {project_id} on branch {branch}"
|
||||
)
|
||||
|
||||
# TODO: Implement repository cloning
|
||||
# This will involve:
|
||||
# 1. Getting project credentials from secrets store
|
||||
# 2. Creating workspace directory
|
||||
# 3. Running git clone with proper auth
|
||||
# 4. Checking out the target branch
|
||||
# 5. Updating project record with clone status
|
||||
|
||||
return {
|
||||
"status": "pending",
|
||||
"project_id": project_id,
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.tasks.git.commit_changes")
|
||||
def commit_changes(
|
||||
self,
|
||||
project_id: str,
|
||||
message: str,
|
||||
files: list[str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Commit changes in a project repository.
|
||||
|
||||
This task creates a Git commit:
|
||||
1. Stage specified files (or all if None)
|
||||
2. Create commit with message
|
||||
3. Update commit history record
|
||||
|
||||
Args:
|
||||
project_id: UUID of the project
|
||||
message: Commit message (follows conventional commits)
|
||||
files: List of files to stage, or None for all staged
|
||||
|
||||
Returns:
|
||||
dict with status and project_id
|
||||
"""
|
||||
logger.info(
|
||||
f"Committing changes for project {project_id}: {message}"
|
||||
)
|
||||
|
||||
# TODO: Implement commit operation
|
||||
# This will involve:
|
||||
# 1. Loading project workspace path
|
||||
# 2. Running git add for specified files
|
||||
# 3. Running git commit with message
|
||||
# 4. Recording commit hash in database
|
||||
|
||||
return {
|
||||
"status": "pending",
|
||||
"project_id": project_id,
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.tasks.git.create_branch")
|
||||
def create_branch(
|
||||
self,
|
||||
project_id: str,
|
||||
branch_name: str,
|
||||
from_ref: str = "HEAD",
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Create a new branch in a project repository.
|
||||
|
||||
This task creates a Git branch:
|
||||
1. Checkout from reference
|
||||
2. Create new branch
|
||||
3. Update branch tracking
|
||||
|
||||
Args:
|
||||
project_id: UUID of the project
|
||||
branch_name: Name of the new branch (e.g., feature/123-description)
|
||||
from_ref: Reference to branch from (default: HEAD)
|
||||
|
||||
Returns:
|
||||
dict with status and project_id
|
||||
"""
|
||||
logger.info(
|
||||
f"Creating branch {branch_name} from {from_ref} for project {project_id}"
|
||||
)
|
||||
|
||||
# TODO: Implement branch creation
|
||||
# This will involve:
|
||||
# 1. Loading project workspace
|
||||
# 2. Running git checkout -b from_ref
|
||||
# 3. Recording branch in database
|
||||
|
||||
return {
|
||||
"status": "pending",
|
||||
"project_id": project_id,
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.tasks.git.create_pull_request")
|
||||
def create_pull_request(
|
||||
self,
|
||||
project_id: str,
|
||||
title: str,
|
||||
body: str,
|
||||
head_branch: str,
|
||||
base_branch: str = "main",
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Create a pull request for a project.
|
||||
|
||||
This task creates a PR on the external Git provider:
|
||||
1. Push branch if needed
|
||||
2. Create PR via API (Gitea, GitHub, GitLab)
|
||||
3. Store PR reference
|
||||
|
||||
Args:
|
||||
project_id: UUID of the project
|
||||
title: PR title
|
||||
body: PR description (markdown)
|
||||
head_branch: Branch with changes
|
||||
base_branch: Target branch (default: main)
|
||||
|
||||
Returns:
|
||||
dict with status and project_id
|
||||
"""
|
||||
logger.info(
|
||||
f"Creating PR '{title}' from {head_branch} to {base_branch} for project {project_id}"
|
||||
)
|
||||
|
||||
# TODO: Implement PR creation
|
||||
# This will involve:
|
||||
# 1. Loading project and Git provider config
|
||||
# 2. Ensuring head_branch is pushed
|
||||
# 3. Calling provider API to create PR
|
||||
# 4. Storing PR URL and number
|
||||
|
||||
return {
|
||||
"status": "pending",
|
||||
"project_id": project_id,
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.tasks.git.push_changes")
|
||||
def push_changes(
|
||||
self,
|
||||
project_id: str,
|
||||
branch: str,
|
||||
force: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Push changes to remote repository.
|
||||
|
||||
This task pushes commits to the remote:
|
||||
1. Verify authentication
|
||||
2. Push branch to remote
|
||||
3. Handle push failures
|
||||
|
||||
Args:
|
||||
project_id: UUID of the project
|
||||
branch: Branch to push
|
||||
force: Whether to force push (use with caution)
|
||||
|
||||
Returns:
|
||||
dict with status and project_id
|
||||
"""
|
||||
logger.info(
|
||||
f"Pushing branch {branch} for project {project_id} (force={force})"
|
||||
)
|
||||
|
||||
# TODO: Implement push operation
|
||||
# This will involve:
|
||||
# 1. Loading project credentials
|
||||
# 2. Running git push (with --force if specified)
|
||||
# 3. Handling authentication and conflicts
|
||||
|
||||
return {
|
||||
"status": "pending",
|
||||
"project_id": project_id,
|
||||
}
|
||||
198
backend/app/tasks/sync.py
Normal file
198
backend/app/tasks/sync.py
Normal file
@@ -0,0 +1,198 @@
|
||||
# app/tasks/sync.py
|
||||
"""
|
||||
Issue synchronization tasks for Syndarix.
|
||||
|
||||
These tasks handle bidirectional issue synchronization:
|
||||
- Incremental sync (polling for recent changes)
|
||||
- Full reconciliation (daily comprehensive sync)
|
||||
- Webhook event processing
|
||||
- Pushing local changes to external trackers
|
||||
|
||||
Tasks are routed to the 'sync' queue for dedicated processing.
|
||||
Per ADR-011, sync follows a master/replica model with configurable direction.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.celery_app import celery_app
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.tasks.sync.sync_issues_incremental")
|
||||
def sync_issues_incremental(self) -> dict[str, Any]:
|
||||
"""
|
||||
Perform incremental issue synchronization across all projects.
|
||||
|
||||
This periodic task (runs every 5 minutes):
|
||||
1. Query each project's external tracker for recent changes
|
||||
2. Compare with local issue cache
|
||||
3. Apply updates to local database
|
||||
4. Handle conflicts based on sync direction config
|
||||
|
||||
Returns:
|
||||
dict with status and type
|
||||
"""
|
||||
logger.info("Starting incremental issue sync across all projects")
|
||||
|
||||
# TODO: Implement incremental sync
|
||||
# This will involve:
|
||||
# 1. Loading all active projects with sync enabled
|
||||
# 2. For each project, querying external tracker since last_sync_at
|
||||
# 3. Upserting issues into local database
|
||||
# 4. Updating last_sync_at timestamp
|
||||
|
||||
return {
|
||||
"status": "pending",
|
||||
"type": "incremental",
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.tasks.sync.sync_issues_full")
|
||||
def sync_issues_full(self) -> dict[str, Any]:
|
||||
"""
|
||||
Perform full issue reconciliation across all projects.
|
||||
|
||||
This periodic task (runs daily):
|
||||
1. Fetch all issues from external trackers
|
||||
2. Compare with local database
|
||||
3. Handle orphaned issues
|
||||
4. Resolve any drift between systems
|
||||
|
||||
Returns:
|
||||
dict with status and type
|
||||
"""
|
||||
logger.info("Starting full issue reconciliation across all projects")
|
||||
|
||||
# TODO: Implement full sync
|
||||
# This will involve:
|
||||
# 1. Loading all active projects
|
||||
# 2. Fetching complete issue lists from external trackers
|
||||
# 3. Comparing with local database
|
||||
# 4. Handling deletes and orphans
|
||||
# 5. Resolving conflicts based on sync config
|
||||
|
||||
return {
|
||||
"status": "pending",
|
||||
"type": "full",
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.tasks.sync.process_webhook_event")
|
||||
def process_webhook_event(
|
||||
self,
|
||||
provider: str,
|
||||
event_type: str,
|
||||
payload: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Process a webhook event from an external Git provider.
|
||||
|
||||
This task handles real-time updates from:
|
||||
- Gitea: issue.created, issue.updated, pull_request.*, etc.
|
||||
- GitHub: issues, pull_request, push, etc.
|
||||
- GitLab: issue events, merge request events, etc.
|
||||
|
||||
Args:
|
||||
provider: Git provider name (gitea, github, gitlab)
|
||||
event_type: Event type from provider
|
||||
payload: Raw webhook payload
|
||||
|
||||
Returns:
|
||||
dict with status, provider, and event_type
|
||||
"""
|
||||
logger.info(f"Processing webhook event from {provider}: {event_type}")
|
||||
|
||||
# TODO: Implement webhook processing
|
||||
# This will involve:
|
||||
# 1. Validating webhook signature
|
||||
# 2. Parsing provider-specific payload
|
||||
# 3. Mapping to internal event format
|
||||
# 4. Updating local database
|
||||
# 5. Triggering any dependent workflows
|
||||
|
||||
return {
|
||||
"status": "pending",
|
||||
"provider": provider,
|
||||
"event_type": event_type,
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.tasks.sync.sync_project_issues")
|
||||
def sync_project_issues(
|
||||
self,
|
||||
project_id: str,
|
||||
full: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Synchronize issues for a specific project.
|
||||
|
||||
This task can be triggered manually or by webhooks:
|
||||
1. Connect to project's external tracker
|
||||
2. Fetch issues (incremental or full)
|
||||
3. Update local database
|
||||
|
||||
Args:
|
||||
project_id: UUID of the project
|
||||
full: Whether to do full sync or incremental
|
||||
|
||||
Returns:
|
||||
dict with status and project_id
|
||||
"""
|
||||
logger.info(
|
||||
f"Syncing issues for project {project_id} (full={full})"
|
||||
)
|
||||
|
||||
# TODO: Implement project-specific sync
|
||||
# This will involve:
|
||||
# 1. Loading project configuration
|
||||
# 2. Connecting to external tracker
|
||||
# 3. Fetching issues based on full flag
|
||||
# 4. Upserting to database
|
||||
|
||||
return {
|
||||
"status": "pending",
|
||||
"project_id": project_id,
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.tasks.sync.push_issue_to_external")
|
||||
def push_issue_to_external(
|
||||
self,
|
||||
project_id: str,
|
||||
issue_id: str,
|
||||
operation: str,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Push a local issue change to the external tracker.
|
||||
|
||||
This task handles outbound sync when Syndarix is the master:
|
||||
- create: Create new issue in external tracker
|
||||
- update: Update existing issue
|
||||
- close: Close issue in external tracker
|
||||
|
||||
Args:
|
||||
project_id: UUID of the project
|
||||
issue_id: UUID of the local issue
|
||||
operation: Operation type (create, update, close)
|
||||
|
||||
Returns:
|
||||
dict with status, issue_id, and operation
|
||||
"""
|
||||
logger.info(
|
||||
f"Pushing {operation} for issue {issue_id} in project {project_id}"
|
||||
)
|
||||
|
||||
# TODO: Implement outbound sync
|
||||
# This will involve:
|
||||
# 1. Loading issue and project config
|
||||
# 2. Mapping to external tracker format
|
||||
# 3. Calling provider API
|
||||
# 4. Updating external_id mapping
|
||||
|
||||
return {
|
||||
"status": "pending",
|
||||
"issue_id": issue_id,
|
||||
"operation": operation,
|
||||
}
|
||||
213
backend/app/tasks/workflow.py
Normal file
213
backend/app/tasks/workflow.py
Normal file
@@ -0,0 +1,213 @@
|
||||
# app/tasks/workflow.py
|
||||
"""
|
||||
Workflow state management tasks for Syndarix.
|
||||
|
||||
These tasks manage workflow execution and state transitions:
|
||||
- Sprint workflows (planning -> implementation -> review -> done)
|
||||
- Story workflows (todo -> in_progress -> review -> done)
|
||||
- Approval checkpoints for autonomy levels
|
||||
- Stale workflow recovery
|
||||
|
||||
Per ADR-007 and ADR-010, workflow state is durable in PostgreSQL
|
||||
with defined state transitions.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.celery_app import celery_app
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.tasks.workflow.recover_stale_workflows")
|
||||
def recover_stale_workflows(self) -> dict[str, Any]:
|
||||
"""
|
||||
Recover workflows that have become stale.
|
||||
|
||||
This periodic task (runs every 5 minutes):
|
||||
1. Find workflows stuck in intermediate states
|
||||
2. Check for timed-out agent operations
|
||||
3. Retry or escalate based on configuration
|
||||
4. Notify relevant users if needed
|
||||
|
||||
Returns:
|
||||
dict with status and recovered count
|
||||
"""
|
||||
logger.info("Checking for stale workflows to recover")
|
||||
|
||||
# TODO: Implement stale workflow recovery
|
||||
# This will involve:
|
||||
# 1. Querying for workflows with last_updated > threshold
|
||||
# 2. Checking if associated agents are still running
|
||||
# 3. Retrying or resetting stuck workflows
|
||||
# 4. Sending notifications for manual intervention
|
||||
|
||||
return {
|
||||
"status": "pending",
|
||||
"recovered": 0,
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.tasks.workflow.execute_workflow_step")
|
||||
def execute_workflow_step(
|
||||
self,
|
||||
workflow_id: str,
|
||||
transition: str,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Execute a state transition for a workflow.
|
||||
|
||||
This task applies a transition to a workflow:
|
||||
1. Validate transition is allowed from current state
|
||||
2. Execute any pre-transition hooks
|
||||
3. Update workflow state
|
||||
4. Execute any post-transition hooks
|
||||
5. Trigger follow-up tasks
|
||||
|
||||
Args:
|
||||
workflow_id: UUID of the workflow
|
||||
transition: Transition to execute (start, approve, reject, etc.)
|
||||
|
||||
Returns:
|
||||
dict with status, workflow_id, and transition
|
||||
"""
|
||||
logger.info(
|
||||
f"Executing transition '{transition}' for workflow {workflow_id}"
|
||||
)
|
||||
|
||||
# TODO: Implement workflow transition
|
||||
# This will involve:
|
||||
# 1. Loading workflow from database
|
||||
# 2. Validating transition from current state
|
||||
# 3. Running pre-transition hooks
|
||||
# 4. Updating state in database
|
||||
# 5. Running post-transition hooks
|
||||
# 6. Scheduling follow-up tasks
|
||||
|
||||
return {
|
||||
"status": "pending",
|
||||
"workflow_id": workflow_id,
|
||||
"transition": transition,
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.tasks.workflow.handle_approval_response")
|
||||
def handle_approval_response(
|
||||
self,
|
||||
workflow_id: str,
|
||||
approved: bool,
|
||||
comment: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Handle a user approval response for a workflow checkpoint.
|
||||
|
||||
This task processes approval decisions:
|
||||
1. Record approval decision with timestamp
|
||||
2. Update workflow state accordingly
|
||||
3. Resume or halt workflow execution
|
||||
4. Notify relevant parties
|
||||
|
||||
Args:
|
||||
workflow_id: UUID of the workflow
|
||||
approved: Whether the checkpoint was approved
|
||||
comment: Optional comment from approver
|
||||
|
||||
Returns:
|
||||
dict with status, workflow_id, and approved flag
|
||||
"""
|
||||
logger.info(
|
||||
f"Handling approval response for workflow {workflow_id}: approved={approved}"
|
||||
)
|
||||
|
||||
# TODO: Implement approval handling
|
||||
# This will involve:
|
||||
# 1. Loading workflow and approval checkpoint
|
||||
# 2. Recording decision with user and timestamp
|
||||
# 3. Transitioning workflow state
|
||||
# 4. Resuming or stopping execution
|
||||
# 5. Sending notifications
|
||||
|
||||
return {
|
||||
"status": "pending",
|
||||
"workflow_id": workflow_id,
|
||||
"approved": approved,
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.tasks.workflow.start_sprint_workflow")
|
||||
def start_sprint_workflow(
|
||||
self,
|
||||
project_id: str,
|
||||
sprint_id: str,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Start a new sprint workflow.
|
||||
|
||||
This task initializes sprint execution:
|
||||
1. Create sprint workflow record
|
||||
2. Set up sprint planning phase
|
||||
3. Spawn Product Owner agent for planning
|
||||
4. Begin story assignment
|
||||
|
||||
Args:
|
||||
project_id: UUID of the project
|
||||
sprint_id: UUID of the sprint
|
||||
|
||||
Returns:
|
||||
dict with status and sprint_id
|
||||
"""
|
||||
logger.info(
|
||||
f"Starting sprint workflow for sprint {sprint_id} in project {project_id}"
|
||||
)
|
||||
|
||||
# TODO: Implement sprint workflow initialization
|
||||
# This will involve:
|
||||
# 1. Creating workflow record for sprint
|
||||
# 2. Setting initial state to PLANNING
|
||||
# 3. Spawning PO agent for sprint planning
|
||||
# 4. Setting up monitoring and checkpoints
|
||||
|
||||
return {
|
||||
"status": "pending",
|
||||
"sprint_id": sprint_id,
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.tasks.workflow.start_story_workflow")
|
||||
def start_story_workflow(
|
||||
self,
|
||||
project_id: str,
|
||||
story_id: str,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Start a new story workflow.
|
||||
|
||||
This task initializes story execution:
|
||||
1. Create story workflow record
|
||||
2. Spawn appropriate developer agent
|
||||
3. Set up implementation tracking
|
||||
4. Configure approval checkpoints based on autonomy level
|
||||
|
||||
Args:
|
||||
project_id: UUID of the project
|
||||
story_id: UUID of the story/issue
|
||||
|
||||
Returns:
|
||||
dict with status and story_id
|
||||
"""
|
||||
logger.info(
|
||||
f"Starting story workflow for story {story_id} in project {project_id}"
|
||||
)
|
||||
|
||||
# TODO: Implement story workflow initialization
|
||||
# This will involve:
|
||||
# 1. Creating workflow record for story
|
||||
# 2. Determining appropriate agent type
|
||||
# 3. Spawning developer agent
|
||||
# 4. Setting up checkpoints based on autonomy level
|
||||
|
||||
return {
|
||||
"status": "pending",
|
||||
"story_id": story_id,
|
||||
}
|
||||
@@ -65,10 +65,10 @@ async def setup_async_test_db():
|
||||
async with test_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
AsyncTestingSessionLocal = sessionmaker( # pyright: ignore[reportCallIssue]
|
||||
AsyncTestingSessionLocal = sessionmaker(
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
bind=test_engine, # pyright: ignore[reportArgumentType]
|
||||
bind=test_engine,
|
||||
expire_on_commit=False,
|
||||
class_=AsyncSession,
|
||||
)
|
||||
|
||||
@@ -79,13 +79,12 @@ This FastAPI backend application follows a **clean layered architecture** patter
|
||||
|
||||
### Authentication & Security
|
||||
|
||||
- **PyJWT**: JWT token generation and validation
|
||||
- Cryptographic signing (HS256, RS256)
|
||||
- **python-jose**: JWT token generation and validation
|
||||
- Cryptographic signing
|
||||
- Token expiration handling
|
||||
- Claims validation
|
||||
- JWK support for Google ID token verification
|
||||
|
||||
- **bcrypt**: Password hashing
|
||||
- **passlib + bcrypt**: Password hashing
|
||||
- Industry-standard bcrypt algorithm
|
||||
- Configurable cost factor
|
||||
- Salt generation
|
||||
@@ -118,8 +117,7 @@ backend/
|
||||
│ ├── api/ # API layer
|
||||
│ │ ├── dependencies/ # Dependency injection
|
||||
│ │ │ ├── auth.py # Authentication dependencies
|
||||
│ │ │ ├── permissions.py # Authorization dependencies
|
||||
│ │ │ └── services.py # Service singleton injection
|
||||
│ │ │ └── permissions.py # Authorization dependencies
|
||||
│ │ ├── routes/ # API endpoints
|
||||
│ │ │ ├── auth.py # Authentication routes
|
||||
│ │ │ ├── users.py # User management routes
|
||||
@@ -133,14 +131,13 @@ backend/
|
||||
│ │ ├── config.py # Application configuration
|
||||
│ │ ├── database.py # Database connection
|
||||
│ │ ├── exceptions.py # Custom exception classes
|
||||
│ │ ├── repository_exceptions.py # Repository-level exception hierarchy
|
||||
│ │ └── middleware.py # Custom middleware
|
||||
│ │
|
||||
│ ├── repositories/ # Data access layer
|
||||
│ │ ├── base.py # Generic repository base class
|
||||
│ │ ├── user.py # User repository
|
||||
│ │ ├── session.py # Session repository
|
||||
│ │ └── organization.py # Organization repository
|
||||
│ ├── crud/ # Database operations
|
||||
│ │ ├── base.py # Generic CRUD base class
|
||||
│ │ ├── user.py # User CRUD operations
|
||||
│ │ ├── session.py # Session CRUD operations
|
||||
│ │ └── organization.py # Organization CRUD
|
||||
│ │
|
||||
│ ├── models/ # SQLAlchemy models
|
||||
│ │ ├── base.py # Base model with mixins
|
||||
@@ -156,11 +153,8 @@ backend/
|
||||
│ │ ├── sessions.py # Session schemas
|
||||
│ │ └── organizations.py # Organization schemas
|
||||
│ │
|
||||
│ ├── services/ # Business logic layer
|
||||
│ ├── services/ # Business logic
|
||||
│ │ ├── auth_service.py # Authentication service
|
||||
│ │ ├── user_service.py # User management service
|
||||
│ │ ├── session_service.py # Session management service
|
||||
│ │ ├── organization_service.py # Organization service
|
||||
│ │ ├── email_service.py # Email service
|
||||
│ │ └── session_cleanup.py # Background cleanup
|
||||
│ │
|
||||
@@ -174,25 +168,20 @@ backend/
|
||||
│
|
||||
├── tests/ # Test suite
|
||||
│ ├── api/ # Integration tests
|
||||
│ ├── repositories/ # Repository unit tests
|
||||
│ ├── services/ # Service unit tests
|
||||
│ ├── crud/ # CRUD tests
|
||||
│ ├── models/ # Model tests
|
||||
│ ├── services/ # Service tests
|
||||
│ └── conftest.py # Test configuration
|
||||
│
|
||||
├── docs/ # Documentation
|
||||
│ ├── ARCHITECTURE.md # This file
|
||||
│ ├── CODING_STANDARDS.md # Coding standards
|
||||
│ ├── COMMON_PITFALLS.md # Common mistakes to avoid
|
||||
│ ├── E2E_TESTING.md # E2E testing guide
|
||||
│ └── FEATURE_EXAMPLE.md # Feature implementation guide
|
||||
│
|
||||
├── pyproject.toml # Dependencies, tool configs (Ruff, pytest, coverage, Pyright)
|
||||
├── uv.lock # Locked dependency versions (commit to git)
|
||||
├── Makefile # Development commands (quality, security, testing)
|
||||
├── .pre-commit-config.yaml # Pre-commit hook configuration
|
||||
├── .secrets.baseline # detect-secrets baseline (known false positives)
|
||||
├── alembic.ini # Alembic configuration
|
||||
└── migrate.py # Migration helper script
|
||||
├── requirements.txt # Python dependencies
|
||||
├── pytest.ini # Pytest configuration
|
||||
├── .coveragerc # Coverage configuration
|
||||
└── alembic.ini # Alembic configuration
|
||||
```
|
||||
|
||||
## Layered Architecture
|
||||
@@ -225,11 +214,11 @@ The application follows a strict 5-layer architecture:
|
||||
└──────────────────────────┬──────────────────────────────────┘
|
||||
│ calls
|
||||
┌──────────────────────────▼──────────────────────────────────┐
|
||||
│ Repository Layer (repositories/) │
|
||||
│ CRUD Layer (crud/) │
|
||||
│ - Database operations │
|
||||
│ - Query building │
|
||||
│ - Custom repository exceptions │
|
||||
│ - No business logic │
|
||||
│ - Transaction management │
|
||||
│ - Error handling │
|
||||
└──────────────────────────┬──────────────────────────────────┘
|
||||
│ uses
|
||||
┌──────────────────────────▼──────────────────────────────────┐
|
||||
@@ -273,7 +262,7 @@ async def get_current_user_info(
|
||||
|
||||
**Rules**:
|
||||
- Should NOT contain business logic
|
||||
- Should NOT directly call repositories (use services injected via `dependencies/services.py`)
|
||||
- Should NOT directly perform database operations (use CRUD or services)
|
||||
- Must validate all input via Pydantic schemas
|
||||
- Must specify response models
|
||||
- Should apply appropriate rate limits
|
||||
@@ -290,9 +279,9 @@ async def get_current_user_info(
|
||||
|
||||
**Example**:
|
||||
```python
|
||||
async def get_current_user(
|
||||
def get_current_user(
|
||||
token: str = Depends(oauth2_scheme),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: Session = Depends(get_db)
|
||||
) -> User:
|
||||
"""
|
||||
Extract and validate user from JWT token.
|
||||
@@ -306,7 +295,7 @@ async def get_current_user(
|
||||
except Exception:
|
||||
raise AuthenticationError("Invalid authentication credentials")
|
||||
|
||||
user = await user_repo.get(db, id=user_id)
|
||||
user = user_crud.get(db, id=user_id)
|
||||
if not user:
|
||||
raise AuthenticationError("User not found")
|
||||
|
||||
@@ -324,7 +313,7 @@ async def get_current_user(
|
||||
**Responsibility**: Implement complex business logic
|
||||
|
||||
**Key Functions**:
|
||||
- Orchestrate multiple repository operations
|
||||
- Orchestrate multiple CRUD operations
|
||||
- Implement business rules
|
||||
- Handle external service integration
|
||||
- Coordinate transactions
|
||||
@@ -334,9 +323,9 @@ async def get_current_user(
|
||||
class AuthService:
|
||||
"""Authentication service with business logic."""
|
||||
|
||||
async def login(
|
||||
def login(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
db: Session,
|
||||
email: str,
|
||||
password: str,
|
||||
request: Request
|
||||
@@ -350,8 +339,8 @@ class AuthService:
|
||||
3. Generate tokens
|
||||
4. Return tokens and user info
|
||||
"""
|
||||
# Validate credentials via repository
|
||||
user = await user_repo.get_by_email(db, email=email)
|
||||
# Validate credentials
|
||||
user = user_crud.get_by_email(db, email=email)
|
||||
if not user or not verify_password(password, user.hashed_password):
|
||||
raise AuthenticationError("Invalid credentials")
|
||||
|
||||
@@ -361,10 +350,11 @@ class AuthService:
|
||||
# Extract device info
|
||||
device_info = extract_device_info(request)
|
||||
|
||||
# Create session via repository
|
||||
session = await session_repo.create(
|
||||
# Create session
|
||||
session = session_crud.create_session(
|
||||
db,
|
||||
obj_in=SessionCreate(user_id=user.id, **device_info)
|
||||
user_id=user.id,
|
||||
device_info=device_info
|
||||
)
|
||||
|
||||
# Generate tokens
|
||||
@@ -383,60 +373,75 @@ class AuthService:
|
||||
|
||||
**Rules**:
|
||||
- Contains business logic, not just data operations
|
||||
- Can call multiple repository operations
|
||||
- Can call multiple CRUD operations
|
||||
- Should handle complex workflows
|
||||
- Must maintain data consistency
|
||||
- Should use transactions when needed
|
||||
|
||||
#### 4. Repository Layer (`app/repositories/`)
|
||||
#### 4. CRUD Layer (`app/crud/`)
|
||||
|
||||
**Responsibility**: Database operations and queries — no business logic
|
||||
**Responsibility**: Database operations and queries
|
||||
|
||||
**Key Functions**:
|
||||
- Create, read, update, delete operations
|
||||
- Build database queries
|
||||
- Raise custom repository exceptions (`DuplicateEntryError`, `IntegrityConstraintError`)
|
||||
- Handle database errors
|
||||
- Manage soft deletes
|
||||
- Implement pagination and filtering
|
||||
|
||||
**Example**:
|
||||
```python
|
||||
class SessionRepository(RepositoryBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
"""Repository for user sessions — database operations only."""
|
||||
class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
"""CRUD operations for user sessions."""
|
||||
|
||||
async def get_by_jti(self, db: AsyncSession, *, jti: str) -> UserSession | None:
|
||||
def get_by_jti(self, db: Session, jti: UUID) -> Optional[UserSession]:
|
||||
"""Get session by refresh token JTI."""
|
||||
result = await db.execute(
|
||||
select(UserSession).where(UserSession.refresh_token_jti == jti)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
try:
|
||||
return (
|
||||
db.query(UserSession)
|
||||
.filter(UserSession.refresh_token_jti == jti)
|
||||
.first()
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting session by JTI: {str(e)}")
|
||||
return None
|
||||
|
||||
async def deactivate(self, db: AsyncSession, *, session_id: UUID) -> bool:
|
||||
def get_active_by_jti(
|
||||
self,
|
||||
db: Session,
|
||||
jti: UUID
|
||||
) -> Optional[UserSession]:
|
||||
"""Get active session by refresh token JTI."""
|
||||
session = self.get_by_jti(db, jti=jti)
|
||||
if session and session.is_active and not session.is_expired:
|
||||
return session
|
||||
return None
|
||||
|
||||
def deactivate(self, db: Session, session_id: UUID) -> bool:
|
||||
"""Deactivate a session (logout)."""
|
||||
try:
|
||||
session = await self.get(db, id=session_id)
|
||||
session = self.get(db, id=session_id)
|
||||
if not session:
|
||||
return False
|
||||
|
||||
session.is_active = False
|
||||
await db.commit()
|
||||
db.commit()
|
||||
logger.info(f"Session {session_id} deactivated")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
db.rollback()
|
||||
logger.error(f"Error deactivating session: {str(e)}")
|
||||
return False
|
||||
```
|
||||
|
||||
**Rules**:
|
||||
- Should NOT contain business logic
|
||||
- Must raise custom repository exceptions (not raw `ValueError`/`IntegrityError`)
|
||||
- Must use async SQLAlchemy 2.0 `select()` API (never `db.query()`)
|
||||
- Must handle database exceptions
|
||||
- Must use parameterized queries (SQLAlchemy does this)
|
||||
- Should log all database errors
|
||||
- Must rollback on errors
|
||||
- Should use soft deletes when possible
|
||||
- **Never imported directly by routes** — always called through services
|
||||
|
||||
#### 5. Data Layer (`app/models/` + `app/schemas/`)
|
||||
|
||||
@@ -541,23 +546,51 @@ SessionLocal = sessionmaker(
|
||||
#### Dependency Injection Pattern
|
||||
|
||||
```python
|
||||
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
def get_db() -> Generator[Session, None, None]:
|
||||
"""
|
||||
Async database session dependency for FastAPI routes.
|
||||
Database session dependency for FastAPI routes.
|
||||
|
||||
The session is passed to service methods; commit/rollback is
|
||||
managed inside service or repository methods.
|
||||
Automatically commits on success, rolls back on error.
|
||||
"""
|
||||
async with AsyncSessionLocal() as db:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# Usage in routes — always through a service, never direct repository
|
||||
# Usage in routes
|
||||
@router.get("/users")
|
||||
async def list_users(
|
||||
user_service: UserService = Depends(get_user_service),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
return await user_service.get_users(db)
|
||||
def list_users(db: Session = Depends(get_db)):
|
||||
return user_crud.get_multi(db)
|
||||
```
|
||||
|
||||
#### Context Manager Pattern
|
||||
|
||||
```python
|
||||
@contextmanager
|
||||
def transaction_scope() -> Generator[Session, None, None]:
|
||||
"""
|
||||
Context manager for database transactions.
|
||||
|
||||
Use for complex operations requiring multiple steps.
|
||||
Automatically commits on success, rolls back on error.
|
||||
"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
db.commit()
|
||||
except Exception:
|
||||
db.rollback()
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# Usage in services
|
||||
def complex_operation():
|
||||
with transaction_scope() as db:
|
||||
user = user_crud.create(db, obj_in=user_data)
|
||||
session = session_crud.create(db, session_data)
|
||||
return user, session
|
||||
```
|
||||
|
||||
### Model Mixins
|
||||
@@ -749,15 +782,22 @@ def get_profile(
|
||||
|
||||
```python
|
||||
@router.delete("/sessions/{session_id}")
|
||||
async def revoke_session(
|
||||
def revoke_session(
|
||||
session_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
session_service: SessionService = Depends(get_session_service),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Users can only revoke their own sessions."""
|
||||
# SessionService verifies ownership and raises NotFoundError / AuthorizationError
|
||||
await session_service.revoke_session(db, session_id=session_id, user_id=current_user.id)
|
||||
session = session_crud.get(db, id=session_id)
|
||||
|
||||
if not session:
|
||||
raise NotFoundError("Session not found")
|
||||
|
||||
# Check ownership
|
||||
if session.user_id != current_user.id:
|
||||
raise AuthorizationError("You can only revoke your own sessions")
|
||||
|
||||
session_crud.deactivate(db, session_id=session_id)
|
||||
return MessageResponse(success=True, message="Session revoked")
|
||||
```
|
||||
|
||||
@@ -1021,27 +1061,23 @@ from app.services.session_cleanup import cleanup_expired_sessions
|
||||
|
||||
scheduler = AsyncIOScheduler()
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Application lifespan context manager."""
|
||||
# Startup
|
||||
if os.getenv("IS_TEST", "False") != "True":
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""Start background jobs on application startup."""
|
||||
if not settings.IS_TEST: # Don't run in tests
|
||||
scheduler.add_job(
|
||||
cleanup_expired_sessions,
|
||||
"cron",
|
||||
hour=2, # Run at 2 AM daily
|
||||
id="cleanup_expired_sessions",
|
||||
replace_existing=True,
|
||||
id="cleanup_expired_sessions"
|
||||
)
|
||||
scheduler.start()
|
||||
logger.info("Background jobs started")
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
if os.getenv("IS_TEST", "False") != "True":
|
||||
scheduler.shutdown()
|
||||
await close_async_db() # Dispose database engine connections
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
"""Stop background jobs on application shutdown."""
|
||||
scheduler.shutdown()
|
||||
```
|
||||
|
||||
### Job Implementation
|
||||
@@ -1056,8 +1092,8 @@ async def cleanup_expired_sessions():
|
||||
Runs daily at 2 AM. Removes sessions expired for more than 30 days.
|
||||
"""
|
||||
try:
|
||||
async with AsyncSessionLocal() as db:
|
||||
count = await session_repo.cleanup_expired(db, keep_days=30)
|
||||
with transaction_scope() as db:
|
||||
count = session_crud.cleanup_expired(db, keep_days=30)
|
||||
logger.info(f"Cleaned up {count} expired sessions")
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up sessions: {str(e)}", exc_info=True)
|
||||
@@ -1074,7 +1110,7 @@ async def cleanup_expired_sessions():
|
||||
│Integration │ ← API endpoint tests
|
||||
│ Tests │
|
||||
├─────────────┤
|
||||
│ Unit │ ← repositories, services, utilities
|
||||
│ Unit │ ← CRUD, services, utilities
|
||||
│ Tests │
|
||||
└─────────────┘
|
||||
```
|
||||
@@ -1169,8 +1205,6 @@ app.add_middleware(
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
> 📖 For the full benchmarking guide (how to run, read results, write new benchmarks, and manage baselines), see **[BENCHMARKS.md](BENCHMARKS.md)**.
|
||||
|
||||
### Database Connection Pooling
|
||||
|
||||
- Pool size: 20 connections
|
||||
|
||||
@@ -1,311 +0,0 @@
|
||||
# Performance Benchmarks Guide
|
||||
|
||||
Automated performance benchmarking infrastructure using **pytest-benchmark** to detect latency regressions in critical API endpoints.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Why Benchmark?](#why-benchmark)
|
||||
- [Quick Start](#quick-start)
|
||||
- [How It Works](#how-it-works)
|
||||
- [Understanding Results](#understanding-results)
|
||||
- [Test Organization](#test-organization)
|
||||
- [Writing Benchmark Tests](#writing-benchmark-tests)
|
||||
- [Baseline Management](#baseline-management)
|
||||
- [CI/CD Integration](#cicd-integration)
|
||||
- [Troubleshooting](#troubleshooting)
|
||||
|
||||
---
|
||||
|
||||
## Why Benchmark?
|
||||
|
||||
Performance regressions are silent bugs — they don't break tests or cause errors, but they degrade the user experience over time. Common causes include:
|
||||
|
||||
- **Unintended N+1 queries** after adding a relationship
|
||||
- **Heavier serialization** after adding new fields to a response model
|
||||
- **Middleware overhead** from new security headers or logging
|
||||
- **Dependency upgrades** that introduce slower code paths
|
||||
|
||||
Without automated benchmarks, these regressions go unnoticed until users complain. Performance benchmarks serve as an **early warning system** — they measure endpoint latency on every run and flag significant deviations from an established baseline.
|
||||
|
||||
### What benchmarks give you
|
||||
|
||||
| Benefit | Description |
|
||||
|---------|-------------|
|
||||
| **Regression detection** | Automatically flags when an endpoint becomes significantly slower |
|
||||
| **Baseline tracking** | Stores known-good performance numbers for comparison |
|
||||
| **Confidence in refactors** | Verify that code changes don't degrade response times |
|
||||
| **Visibility** | Makes performance a first-class, measurable quality attribute |
|
||||
|
||||
---
|
||||
|
||||
## Quick Start
|
||||
|
||||
```bash
|
||||
# Run benchmarks (no comparison, just see current numbers)
|
||||
make benchmark
|
||||
|
||||
# Save current results as the baseline
|
||||
make benchmark-save
|
||||
|
||||
# Run benchmarks and compare against the saved baseline
|
||||
make benchmark-check
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## How It Works
|
||||
|
||||
The benchmarking system has three layers:
|
||||
|
||||
### 1. pytest-benchmark integration
|
||||
|
||||
[pytest-benchmark](https://pytest-benchmark.readthedocs.io/) is a pytest plugin that provides a `benchmark` fixture. It handles:
|
||||
|
||||
- **Calibration**: Automatically determines how many iterations to run for statistical significance
|
||||
- **Timing**: Uses `time.perf_counter` for high-resolution measurements
|
||||
- **Statistics**: Computes min, max, mean, median, standard deviation, IQR, and outlier detection
|
||||
- **Comparison**: Compares current results against saved baselines and flags regressions
|
||||
|
||||
### 2. Benchmark types
|
||||
|
||||
The test suite includes two categories of performance tests:
|
||||
|
||||
| Type | How it works | Examples |
|
||||
|------|-------------|----------|
|
||||
| **pytest-benchmark tests** | Uses the `benchmark` fixture for precise, multi-round timing | `test_health_endpoint_performance`, `test_openapi_schema_performance`, `test_password_hashing_performance`, `test_password_verification_performance`, `test_access_token_creation_performance`, `test_refresh_token_creation_performance`, `test_token_decode_performance` |
|
||||
| **Manual latency tests** | Uses `time.perf_counter` with explicit thresholds (for async endpoints that pytest-benchmark doesn't support natively) | `test_login_latency`, `test_get_current_user_latency`, `test_register_latency`, `test_token_refresh_latency`, `test_sessions_list_latency`, `test_user_profile_update_latency` |
|
||||
|
||||
### 3. Regression detection
|
||||
|
||||
When running `make benchmark-check`, the system:
|
||||
|
||||
1. Runs all benchmark tests
|
||||
2. Compares results against the saved baseline (`.benchmarks/` directory)
|
||||
3. **Fails the build** if any test's mean time exceeds **200%** of the baseline (i.e., 3× slower)
|
||||
|
||||
The `200%` threshold in `--benchmark-compare-fail=mean:200%` means "fail if the mean increased by more than 200% relative to the baseline." This is deliberately generous to avoid false positives from normal run-to-run variance while still catching real regressions.
|
||||
|
||||
---
|
||||
|
||||
## Understanding Results
|
||||
|
||||
A typical benchmark output looks like this:
|
||||
|
||||
```
|
||||
--------------------------------------------------------------------------------------- benchmark: 2 tests --------------------------------------------------------------------------------------
|
||||
Name (time in ms) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations
|
||||
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
test_health_endpoint_performance 0.9841 (1.0) 1.5513 (1.0) 1.1390 (1.0) 0.1098 (1.0) 1.1151 (1.0) 0.1672 (1.0) 39;2 877.9666 (1.0) 133 1
|
||||
test_openapi_schema_performance 1.6523 (1.68) 2.0892 (1.35) 1.7843 (1.57) 0.1553 (1.41) 1.7200 (1.54) 0.1727 (1.03) 2;0 560.4471 (0.64) 10 1
|
||||
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
```
|
||||
|
||||
### Column reference
|
||||
|
||||
| Column | Meaning |
|
||||
|--------|---------|
|
||||
| **Min** | Fastest single execution |
|
||||
| **Max** | Slowest single execution |
|
||||
| **Mean** | Average across all rounds — the primary metric for regression detection |
|
||||
| **StdDev** | How much results vary between rounds (lower = more stable) |
|
||||
| **Median** | Middle value, less sensitive to outliers than mean |
|
||||
| **IQR** | Interquartile range — spread of the middle 50% of results |
|
||||
| **Outliers** | Format `A;B` — A = within 1 StdDev, B = within 1.5 IQR from quartiles |
|
||||
| **OPS** | Operations per second (`1 / Mean`) |
|
||||
| **Rounds** | How many times the test was executed (auto-calibrated) |
|
||||
| **Iterations** | Iterations per round (usually 1 for ms-scale tests) |
|
||||
|
||||
### The ratio numbers `(1.0)`, `(1.68)`, etc.
|
||||
|
||||
These show how each test compares **to the best result in that column**. The fastest test is always `(1.0)`, and others show their relative factor. For example, `(1.68)` means "1.68× slower than the fastest."
|
||||
|
||||
### Color coding
|
||||
|
||||
- **Green**: The fastest (best) value in each column
|
||||
- **Red**: The slowest (worst) value in each column
|
||||
|
||||
This is a **relative ranking within the current run** — red does NOT mean the test failed or that performance is bad. It simply highlights which endpoint is the slower one in the group.
|
||||
|
||||
### What's "normal"?
|
||||
|
||||
For this project's current endpoints:
|
||||
|
||||
| Test | Expected range | Why |
|
||||
|------|---------------|-----|
|
||||
| `GET /health` | ~1–1.5ms | Minimal logic, mocked DB check |
|
||||
| `GET /api/v1/openapi.json` | ~1.5–2.5ms | Serializes entire API schema |
|
||||
| `get_password_hash` | ~200ms | CPU-bound bcrypt hashing |
|
||||
| `verify_password` | ~200ms | CPU-bound bcrypt verification |
|
||||
| `create_access_token` | ~17–20µs | JWT encoding with HMAC-SHA256 |
|
||||
| `create_refresh_token` | ~17–20µs | JWT encoding with HMAC-SHA256 |
|
||||
| `decode_token` | ~20–25µs | JWT decoding and claim validation |
|
||||
| `POST /api/v1/auth/login` | < 500ms threshold | Includes bcrypt password verification |
|
||||
| `POST /api/v1/auth/register` | < 500ms threshold | Includes bcrypt password hashing |
|
||||
| `POST /api/v1/auth/refresh` | < 200ms threshold | Token rotation + DB session update |
|
||||
| `GET /api/v1/users/me` | < 200ms threshold | DB lookup + token validation |
|
||||
| `GET /api/v1/sessions/me` | < 200ms threshold | Session list query + token validation |
|
||||
| `PATCH /api/v1/users/me` | < 200ms threshold | DB update + token validation |
|
||||
|
||||
---
|
||||
|
||||
## Test Organization
|
||||
|
||||
```
|
||||
backend/tests/
|
||||
├── benchmarks/
|
||||
│ └── test_endpoint_performance.py # All performance benchmark tests
|
||||
│
|
||||
backend/.benchmarks/ # Saved baselines (auto-generated)
|
||||
└── Linux-CPython-3.12-64bit/
|
||||
└── 0001_baseline.json # Platform-specific baseline file
|
||||
```
|
||||
|
||||
### Test markers
|
||||
|
||||
All benchmark tests use the `@pytest.mark.benchmark` marker. The `--benchmark-only` flag ensures that only tests using the `benchmark` fixture are executed during benchmark runs, while manual latency tests (async) are skipped.
|
||||
|
||||
---
|
||||
|
||||
## Writing Benchmark Tests
|
||||
|
||||
### Stateless endpoint (using pytest-benchmark fixture)
|
||||
|
||||
```python
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
def test_my_endpoint_performance(sync_client, benchmark):
|
||||
"""Benchmark: GET /my-endpoint should respond within acceptable latency."""
|
||||
result = benchmark(sync_client.get, "/my-endpoint")
|
||||
assert result.status_code == 200
|
||||
```
|
||||
|
||||
The `benchmark` fixture handles all timing, calibration, and statistics automatically. Just pass it the callable and arguments.
|
||||
|
||||
### Async / DB-dependent endpoint (manual timing)
|
||||
|
||||
For async endpoints that require database access, use manual timing with an explicit threshold:
|
||||
|
||||
```python
|
||||
import time
|
||||
import pytest
|
||||
|
||||
MAX_RESPONSE_MS = 300
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_my_async_endpoint_latency(client, setup_fixture):
|
||||
"""Performance: endpoint must respond under threshold."""
|
||||
iterations = 5
|
||||
total_ms = 0.0
|
||||
|
||||
for _ in range(iterations):
|
||||
start = time.perf_counter()
|
||||
response = await client.get("/api/v1/my-endpoint")
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
total_ms += elapsed_ms
|
||||
assert response.status_code == 200
|
||||
|
||||
mean_ms = total_ms / iterations
|
||||
assert mean_ms < MAX_RESPONSE_MS, (
|
||||
f"Latency regression: {mean_ms:.1f}ms exceeds {MAX_RESPONSE_MS}ms threshold"
|
||||
)
|
||||
```
|
||||
|
||||
### Guidelines for new benchmarks
|
||||
|
||||
1. **Benchmark critical paths** — endpoints users hit frequently or where latency matters most
|
||||
2. **Mock external dependencies** for stateless tests to isolate endpoint overhead
|
||||
3. **Set generous thresholds** for manual tests — account for CI variability
|
||||
4. **Keep benchmarks fast** — they run on every check, so avoid heavy setup
|
||||
|
||||
---
|
||||
|
||||
## Baseline Management
|
||||
|
||||
### Saving a baseline
|
||||
|
||||
```bash
|
||||
make benchmark-save
|
||||
```
|
||||
|
||||
This runs all benchmarks and saves results to `.benchmarks/<platform>/0001_baseline.json`. The baseline captures:
|
||||
- Mean, min, max, median, stddev for each test
|
||||
- Machine info (CPU, OS, Python version)
|
||||
- Timestamp
|
||||
|
||||
### Comparing against baseline
|
||||
|
||||
```bash
|
||||
make benchmark-check
|
||||
```
|
||||
|
||||
If no baseline exists, this command automatically creates one and prints a warning. On subsequent runs, it compares current results against the saved baseline.
|
||||
|
||||
### When to update the baseline
|
||||
|
||||
- **After intentional performance changes** (e.g., you optimized an endpoint — save the new, faster baseline)
|
||||
- **After infrastructure changes** (e.g., new CI runner, different hardware)
|
||||
- **After adding new benchmark tests** (the new tests need a baseline entry)
|
||||
|
||||
```bash
|
||||
# Update the baseline after intentional changes
|
||||
make benchmark-save
|
||||
```
|
||||
|
||||
### Version control
|
||||
|
||||
The `.benchmarks/` directory can be committed to version control so that CI pipelines can compare against a known-good baseline. However, since benchmark results are machine-specific, you may prefer to generate baselines in CI rather than committing local results.
|
||||
|
||||
---
|
||||
|
||||
## CI/CD Integration
|
||||
|
||||
Add benchmark checking to your CI pipeline to catch regressions on every PR:
|
||||
|
||||
```yaml
|
||||
# Example GitHub Actions step
|
||||
- name: Performance regression check
|
||||
run: |
|
||||
cd backend
|
||||
make benchmark-save # Create baseline from main branch
|
||||
# ... apply PR changes ...
|
||||
make benchmark-check # Compare PR against baseline
|
||||
```
|
||||
|
||||
A more robust approach:
|
||||
1. Save the baseline on the `main` branch after each merge
|
||||
2. On PR branches, run `make benchmark-check` against the `main` baseline
|
||||
3. The pipeline fails if any endpoint regresses beyond the 200% threshold
|
||||
|
||||
---
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "No benchmark baseline found" warning
|
||||
|
||||
```
|
||||
⚠️ No benchmark baseline found. Run 'make benchmark-save' first to create one.
|
||||
```
|
||||
|
||||
This means no baseline file exists yet. The command will auto-create one. Future runs of `make benchmark-check` will compare against it.
|
||||
|
||||
### Machine info mismatch warning
|
||||
|
||||
```
|
||||
WARNING: benchmark machine_info is different
|
||||
```
|
||||
|
||||
This is expected when comparing baselines generated on a different machine or OS. The comparison still works, but absolute numbers may differ. Re-save the baseline on the current machine if needed.
|
||||
|
||||
### High variance (large StdDev)
|
||||
|
||||
If StdDev is high relative to the Mean, results may be unreliable. Common causes:
|
||||
- System under load during benchmark run
|
||||
- Garbage collection interference
|
||||
- Thermal throttling
|
||||
|
||||
Try running benchmarks on an idle system or increasing `min_rounds` in `pyproject.toml`.
|
||||
|
||||
### Only 7 of 13 tests run
|
||||
|
||||
The async tests (`test_login_latency`, `test_get_current_user_latency`, `test_register_latency`, `test_token_refresh_latency`, `test_sessions_list_latency`, `test_user_profile_update_latency`) are skipped during `--benchmark-only` runs because they don't use the `benchmark` fixture. They run as part of the normal test suite (`make test`) with manual threshold assertions.
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user