forked from cardosofelipe/pragma-stack
Compare commits
133 Commits
5c35702caf
...
dev
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4ad3d20cf2 | ||
|
|
8623eb56f5 | ||
|
|
3cb6c8d13b | ||
|
|
8e16e2645e | ||
|
|
82c3a6ba47 | ||
|
|
b6c38cac88 | ||
|
|
51404216ae | ||
|
|
3f23bc3db3 | ||
|
|
a0ec5fa2cc | ||
|
|
f262d08be2 | ||
|
|
b3f371e0a3 | ||
|
|
93cc37224c | ||
|
|
5717bffd63 | ||
|
|
9339ea30a1 | ||
|
|
79cb6bfd7b | ||
|
|
45025bb2f1 | ||
|
|
3c6b14d2bf | ||
|
|
6b21a6fadd | ||
|
|
600657adc4 | ||
|
|
c9d0d079b3 | ||
|
|
4c8f81368c | ||
|
|
efbe91ce14 | ||
|
|
5d646779c9 | ||
|
|
5a4d93df26 | ||
|
|
7ef217be39 | ||
|
|
20159c5865 | ||
|
|
f9a72fcb34 | ||
|
|
fcb0a5f86a | ||
|
|
92782bcb05 | ||
|
|
1dcf99ee38 | ||
|
|
70009676a3 | ||
|
|
192237e69b | ||
|
|
3edce9cd26 | ||
|
|
35aea2d73a | ||
|
|
d0f32d04f7 | ||
|
|
da85a8aba8 | ||
|
|
f8bd1011e9 | ||
|
|
f057c2f0b6 | ||
|
|
33ec889fc4 | ||
|
|
74b8c65741 | ||
|
|
b232298c61 | ||
|
|
cf6291ac8e | ||
|
|
e3fe0439fd | ||
|
|
57680c3772 | ||
|
|
997cfaa03a | ||
|
|
6954774e36 | ||
|
|
30e5c68304 | ||
|
|
0b24d4c6cc | ||
|
|
1670e05e0d | ||
|
|
999b7ac03f | ||
|
|
48ecb40f18 | ||
|
|
b818f17418 | ||
|
|
e946787a61 | ||
|
|
3554efe66a | ||
|
|
bd988f76b0 | ||
|
|
4974233169 | ||
|
|
c9d8c0835c | ||
|
|
085a748929 | ||
|
|
4b149b8a52 | ||
|
|
ad0c06851d | ||
|
|
49359b1416 | ||
|
|
911d950c15 | ||
|
|
b2a3ac60e0 | ||
|
|
dea092e1bb | ||
|
|
4154dd5268 | ||
|
|
db12937495 | ||
|
|
81e1456631 | ||
|
|
58e78d8700 | ||
|
|
5e80139afa | ||
|
|
60ebeaa582 | ||
|
|
758052dcff | ||
|
|
1628eacf2b | ||
|
|
2bea057fb1 | ||
|
|
9e54f16e56 | ||
|
|
96e6400bd8 | ||
|
|
6c7b72f130 | ||
|
|
027ebfc332 | ||
|
|
c2466ab401 | ||
|
|
7828d35e06 | ||
|
|
6b07e62f00 | ||
|
|
0d2005ddcb | ||
|
|
dfa75e682e | ||
|
|
22ecb5e989 | ||
|
|
2ab69f8561 | ||
|
|
95342cc94d | ||
|
|
f6194b3e19 | ||
|
|
6bb376a336 | ||
|
|
cd7a9ccbdf | ||
|
|
953af52d0e | ||
|
|
e6e98d4ed1 | ||
|
|
ca5f5e3383 | ||
|
|
d0fc7f37ff | ||
|
|
18d717e996 | ||
|
|
f482559e15 | ||
|
|
6e8b0b022a | ||
|
|
746fb7b181 | ||
|
|
caf283bed2 | ||
|
|
520c06175e | ||
|
|
065e43c5a9 | ||
|
|
c8b88dadc3 | ||
|
|
015f2de6c6 | ||
|
|
f36bfb3781 | ||
|
|
ef659cd72d | ||
|
|
728edd1453 | ||
|
|
498c0a0e94 | ||
|
|
e5975fa5d0 | ||
|
|
731a188a76 | ||
|
|
fe2104822e | ||
|
|
664415111a | ||
|
|
acd18ff694 | ||
|
|
da5affd613 | ||
|
|
a79d923dc1 | ||
|
|
c72f6aa2f9 | ||
|
|
4f24cebf11 | ||
|
|
e0739a786c | ||
|
|
64576da7dc | ||
|
|
4a55bd63a3 | ||
|
|
a78b903f5a | ||
|
|
c7b2c82700 | ||
|
|
50b865b23b | ||
|
|
6f5dd58b54 | ||
|
|
0ceee8545e | ||
|
|
62aea06e0d | ||
|
|
24f1cc637e | ||
|
|
8b6cca5d4d | ||
|
|
c9700f760e | ||
|
|
6f509e71ce | ||
|
|
f5a86953c6 | ||
|
|
246d2a6752 | ||
|
|
36ab7069cf | ||
|
|
a4c91cb8c3 | ||
|
|
a7ba0f9bd8 | ||
|
|
f3fb4ecbeb |
61
.githooks/pre-commit
Executable file
61
.githooks/pre-commit
Executable file
@@ -0,0 +1,61 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Pre-commit hook to enforce validation before commits on protected branches
|
||||||
|
# Install: git config core.hooksPath .githooks
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
# Get the current branch name
|
||||||
|
BRANCH=$(git rev-parse --abbrev-ref HEAD)
|
||||||
|
|
||||||
|
# Protected branches that require validation
|
||||||
|
PROTECTED_BRANCHES="main dev"
|
||||||
|
|
||||||
|
# Check if we're on a protected branch
|
||||||
|
is_protected() {
|
||||||
|
for branch in $PROTECTED_BRANCHES; do
|
||||||
|
if [ "$BRANCH" = "$branch" ]; then
|
||||||
|
return 0
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
if is_protected; then
|
||||||
|
echo "🔒 Committing to protected branch '$BRANCH' - running validation..."
|
||||||
|
|
||||||
|
# Check if we have backend changes
|
||||||
|
if git diff --cached --name-only | grep -q "^backend/"; then
|
||||||
|
echo "📦 Backend changes detected - running make validate..."
|
||||||
|
cd backend
|
||||||
|
if ! make validate; then
|
||||||
|
echo ""
|
||||||
|
echo "❌ Backend validation failed!"
|
||||||
|
echo " Please fix the issues and try again."
|
||||||
|
echo " Run 'cd backend && make validate' to see errors."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
cd ..
|
||||||
|
echo "✅ Backend validation passed!"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Check if we have frontend changes
|
||||||
|
if git diff --cached --name-only | grep -q "^frontend/"; then
|
||||||
|
echo "🎨 Frontend changes detected - running npm run validate..."
|
||||||
|
cd frontend
|
||||||
|
if ! npm run validate 2>/dev/null; then
|
||||||
|
echo ""
|
||||||
|
echo "❌ Frontend validation failed!"
|
||||||
|
echo " Please fix the issues and try again."
|
||||||
|
echo " Run 'cd frontend && npm run validate' to see errors."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
cd ..
|
||||||
|
echo "✅ Frontend validation passed!"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "🎉 All validations passed! Proceeding with commit..."
|
||||||
|
else
|
||||||
|
echo "📝 Committing to feature branch '$BRANCH' - skipping validation (run manually if needed)"
|
||||||
|
fi
|
||||||
|
|
||||||
|
exit 0
|
||||||
408
CLAUDE.md
408
CLAUDE.md
@@ -9,9 +9,11 @@ Claude Code context for **Syndarix** - AI-Powered Software Consulting Agency.
|
|||||||
## Syndarix Project Context
|
## Syndarix Project Context
|
||||||
|
|
||||||
### Vision
|
### Vision
|
||||||
|
|
||||||
Syndarix is an autonomous platform that orchestrates specialized AI agents to deliver complete software solutions with minimal human intervention. It acts as a virtual consulting agency with AI agents playing roles like Product Owner, Architect, Engineers, QA, etc.
|
Syndarix is an autonomous platform that orchestrates specialized AI agents to deliver complete software solutions with minimal human intervention. It acts as a virtual consulting agency with AI agents playing roles like Product Owner, Architect, Engineers, QA, etc.
|
||||||
|
|
||||||
### Repository
|
### Repository
|
||||||
|
|
||||||
- **URL:** https://gitea.pragmazest.com/cardosofelipe/syndarix
|
- **URL:** https://gitea.pragmazest.com/cardosofelipe/syndarix
|
||||||
- **Issue Tracker:** Gitea Issues (primary)
|
- **Issue Tracker:** Gitea Issues (primary)
|
||||||
- **CI/CD:** Gitea Actions
|
- **CI/CD:** Gitea Actions
|
||||||
@@ -43,9 +45,11 @@ search_knowledge(project_id="proj-123", query="auth flow")
|
|||||||
create_issue(project_id="proj-123", title="Add login")
|
create_issue(project_id="proj-123", title="Add login")
|
||||||
```
|
```
|
||||||
|
|
||||||
### Syndarix-Specific Directories
|
### Directory Structure
|
||||||
|
|
||||||
```
|
```
|
||||||
docs/
|
docs/
|
||||||
|
├── development/ # Workflow and coding standards
|
||||||
├── requirements/ # Requirements documents
|
├── requirements/ # Requirements documents
|
||||||
├── architecture/ # Architecture documentation
|
├── architecture/ # Architecture documentation
|
||||||
├── adrs/ # Architecture Decision Records
|
├── adrs/ # Architecture Decision Records
|
||||||
@@ -53,97 +57,127 @@ docs/
|
|||||||
```
|
```
|
||||||
|
|
||||||
### Current Phase
|
### Current Phase
|
||||||
|
|
||||||
**Backlog Population** - Creating detailed issues for Phase 0-1 implementation.
|
**Backlog Population** - Creating detailed issues for Phase 0-1 implementation.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Development Workflow & Standards
|
## Development Standards
|
||||||
|
|
||||||
**CRITICAL: These rules are mandatory for all development work.**
|
**CRITICAL: These rules are mandatory. See linked docs for full details.**
|
||||||
|
|
||||||
### 1. Issue-Driven Development
|
### Quick Reference
|
||||||
|
|
||||||
**Every piece of work MUST have an issue in the Gitea tracker first.**
|
| Topic | Documentation |
|
||||||
|
|-------|---------------|
|
||||||
|
| **Workflow & Branching** | [docs/development/WORKFLOW.md](./docs/development/WORKFLOW.md) |
|
||||||
|
| **Coding Standards** | [docs/development/CODING_STANDARDS.md](./docs/development/CODING_STANDARDS.md) |
|
||||||
|
| **Design System** | [frontend/docs/design-system/](./frontend/docs/design-system/) |
|
||||||
|
| **Backend E2E Testing** | [backend/docs/E2E_TESTING.md](./backend/docs/E2E_TESTING.md) |
|
||||||
|
| **Demo Mode** | [frontend/docs/DEMO_MODE.md](./frontend/docs/DEMO_MODE.md) |
|
||||||
|
|
||||||
- Issue tracker: https://gitea.pragmazest.com/cardosofelipe/syndarix/issues
|
### Essential Rules Summary
|
||||||
- Create detailed, well-scoped issues before starting work
|
|
||||||
- Structure issues to enable parallel work by multiple agents
|
|
||||||
- Reference issues in commits and PRs
|
|
||||||
|
|
||||||
### 2. Git Hygiene
|
1. **Issue-Driven Development**: Every piece of work MUST have an issue first
|
||||||
|
2. **Branch per Feature**: `feature/<issue-number>-<description>`, single branch for design+implementation
|
||||||
|
3. **Testing Required**: All code must be tested, aim for >90% coverage
|
||||||
|
4. **Code Review**: Must pass multi-agent review before merge
|
||||||
|
5. **No Direct Commits**: Never commit directly to `main` or `dev`
|
||||||
|
6. **Stack Verification**: ALWAYS run the full stack before considering work done (see below)
|
||||||
|
|
||||||
**Branch naming convention:** `feature/123-description`
|
### CRITICAL: Stack Verification Before Merge
|
||||||
|
|
||||||
- Every issue gets its own feature branch
|
**This is NON-NEGOTIABLE. A feature with 100% test coverage that crashes on startup is WORTHLESS.**
|
||||||
- No direct commits to `main` or `dev`
|
|
||||||
- Keep branches focused and small
|
|
||||||
- Delete branches after merge
|
|
||||||
|
|
||||||
**Workflow:**
|
Before considering ANY issue complete:
|
||||||
```
|
|
||||||
main (production-ready)
|
```bash
|
||||||
└── dev (integration branch)
|
# 1. Start the dev stack
|
||||||
└── feature/123-description (issue branch)
|
make dev
|
||||||
|
|
||||||
|
# 2. Wait for backend to be healthy, check logs
|
||||||
|
docker compose -f docker-compose.dev.yml logs backend --tail=100
|
||||||
|
|
||||||
|
# 3. Start frontend
|
||||||
|
cd frontend && npm run dev
|
||||||
|
|
||||||
|
# 4. Verify both are running without errors
|
||||||
```
|
```
|
||||||
|
|
||||||
### 3. Testing Requirements
|
**The issue is NOT done if:**
|
||||||
|
- Backend crashes on startup (import errors, missing dependencies)
|
||||||
|
- Frontend fails to compile or render
|
||||||
|
- Health checks fail
|
||||||
|
- Any error appears in logs
|
||||||
|
|
||||||
**All code must be tested. No exceptions.**
|
**Why this matters:**
|
||||||
|
- Tests run in isolation and may pass despite broken imports
|
||||||
|
- Docker builds cache layers and may hide dependency issues
|
||||||
|
- A single `ModuleNotFoundError` renders all test coverage meaningless
|
||||||
|
|
||||||
- **TDD preferred**: Write tests first when possible
|
### Common Commands
|
||||||
- **Test after**: If not TDD, write tests immediately after testable code
|
|
||||||
- **Coverage types**: Unit, integration, functional, E2E as appropriate
|
|
||||||
- **Minimum coverage**: Aim for >90% on new code
|
|
||||||
|
|
||||||
### 4. Code Review Process
|
```bash
|
||||||
|
# Backend
|
||||||
|
IS_TEST=True uv run pytest # Run tests
|
||||||
|
uv run ruff check src/ # Lint
|
||||||
|
uv run mypy src/ # Type check
|
||||||
|
python migrate.py auto "message" # Database migration
|
||||||
|
|
||||||
**Before merging any feature branch, code must pass multi-agent review:**
|
# Frontend
|
||||||
|
npm test # Unit tests
|
||||||
| Check | Description |
|
npm run lint # Lint
|
||||||
|-------|-------------|
|
npm run type-check # Type check
|
||||||
| Bug hunting | Logic errors, edge cases, race conditions |
|
npm run generate:api # Regenerate API client
|
||||||
| Linting | `ruff check` passes with no errors |
|
```
|
||||||
| Typing | `mypy` passes with no errors |
|
|
||||||
| Formatting | Code follows style guidelines |
|
|
||||||
| Performance | No obvious bottlenecks or N+1 queries |
|
|
||||||
| Security | No vulnerabilities (OWASP top 10) |
|
|
||||||
| Architecture | Follows established patterns and ADRs |
|
|
||||||
|
|
||||||
**Issue is NOT done until review passes with flying colors.**
|
|
||||||
|
|
||||||
### 5. QA Before Main
|
|
||||||
|
|
||||||
**Before merging `dev` into `main`:**
|
|
||||||
|
|
||||||
- Full test suite passes
|
|
||||||
- Manual QA verification
|
|
||||||
- Performance baseline check
|
|
||||||
- Security scan
|
|
||||||
- Code must be clean, functional, bug-free, well-architected, and secure
|
|
||||||
|
|
||||||
### 6. Implementation Plan Updates
|
|
||||||
|
|
||||||
- Keep `docs/architecture/IMPLEMENTATION_ROADMAP.md` updated
|
|
||||||
- Mark completed items as work progresses
|
|
||||||
- Add new items discovered during implementation
|
|
||||||
|
|
||||||
### 7. UI/UX Design Approval
|
|
||||||
|
|
||||||
**Frontend tasks involving UI/UX require design approval:**
|
|
||||||
|
|
||||||
1. **Design Issue**: Create issue with `design` label
|
|
||||||
2. **Prototype**: Build interactive React prototype (navigable demo)
|
|
||||||
3. **Review**: User inspects and provides feedback
|
|
||||||
4. **Approval**: User approves before implementation begins
|
|
||||||
5. **Implementation**: Follow approved design, respecting design system
|
|
||||||
|
|
||||||
**Design constraints:**
|
|
||||||
- Prototypes: Best effort to match design system (not required)
|
|
||||||
- Production code: MUST follow `frontend/docs/design-system/` strictly
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
### Key Extensions to Add (from PragmaStack base)
|
## Claude Code-Specific Guidance
|
||||||
|
|
||||||
|
### Critical User Preferences
|
||||||
|
|
||||||
|
**File Operations:**
|
||||||
|
- ALWAYS use Read/Write/Edit tools instead of `cat >> file << EOF`
|
||||||
|
- Never use heredoc - it triggers manual approval dialogs
|
||||||
|
|
||||||
|
**Work Style:**
|
||||||
|
- User prefers autonomous operation without frequent interruptions
|
||||||
|
- Ask for batch permissions upfront for long work sessions
|
||||||
|
- Work independently, document decisions clearly
|
||||||
|
- Only use emojis if the user explicitly requests it
|
||||||
|
|
||||||
|
### Critical Pattern: Auth Store DI
|
||||||
|
|
||||||
|
**ALWAYS use `useAuth()` from `AuthContext`, NEVER import `useAuthStore` directly!**
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// ❌ WRONG
|
||||||
|
import { useAuthStore } from '@/lib/stores/authStore';
|
||||||
|
|
||||||
|
// ✅ CORRECT
|
||||||
|
import { useAuth } from '@/lib/auth/AuthContext';
|
||||||
|
```
|
||||||
|
|
||||||
|
See [CODING_STANDARDS.md](./docs/development/CODING_STANDARDS.md#auth-store-dependency-injection) for details.
|
||||||
|
|
||||||
|
### Tool Usage Preferences
|
||||||
|
|
||||||
|
**Prefer specialized tools over bash:**
|
||||||
|
- Use Read/Write/Edit tools for file operations
|
||||||
|
- Use Task tool with `subagent_type=Explore` for codebase exploration
|
||||||
|
- Use Grep tool for code search, not bash `grep`
|
||||||
|
|
||||||
|
**Parallel tool calls for:**
|
||||||
|
- Independent git commands
|
||||||
|
- Reading multiple unrelated files
|
||||||
|
- Running multiple test suites
|
||||||
|
- Independent validation steps
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Key Extensions (from PragmaStack base)
|
||||||
|
|
||||||
- Celery + Redis for agent job queue
|
- Celery + Redis for agent job queue
|
||||||
- WebSocket/SSE for real-time updates
|
- WebSocket/SSE for real-time updates
|
||||||
- pgvector for RAG knowledge base
|
- pgvector for RAG knowledge base
|
||||||
@@ -151,244 +185,20 @@ main (production-ready)
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## PragmaStack Development Guidelines
|
|
||||||
|
|
||||||
*The following guidelines are inherited from PragmaStack and remain applicable.*
|
|
||||||
|
|
||||||
## Claude Code-Specific Guidance
|
|
||||||
|
|
||||||
### Critical User Preferences
|
|
||||||
|
|
||||||
#### File Operations - NEVER Use Heredoc/Cat Append
|
|
||||||
**ALWAYS use Read/Write/Edit tools instead of `cat >> file << EOF` commands.**
|
|
||||||
|
|
||||||
This triggers manual approval dialogs and disrupts workflow.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# WRONG ❌
|
|
||||||
cat >> file.txt << EOF
|
|
||||||
content
|
|
||||||
EOF
|
|
||||||
|
|
||||||
# CORRECT ✅ - Use Read, then Write tools
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Work Style
|
|
||||||
- User prefers autonomous operation without frequent interruptions
|
|
||||||
- Ask for batch permissions upfront for long work sessions
|
|
||||||
- Work independently, document decisions clearly
|
|
||||||
- Only use emojis if the user explicitly requests it
|
|
||||||
|
|
||||||
### When Working with This Stack
|
|
||||||
|
|
||||||
**Dependency Management:**
|
|
||||||
- Backend uses **uv** (modern Python package manager), not pip
|
|
||||||
- Always use `uv run` prefix: `IS_TEST=True uv run pytest`
|
|
||||||
- Or use Makefile commands: `make test`, `make install-dev`
|
|
||||||
- Add dependencies: `uv add <package>` or `uv add --dev <package>`
|
|
||||||
|
|
||||||
**Database Migrations:**
|
|
||||||
- Use the `migrate.py` helper script, not Alembic directly
|
|
||||||
- Generate + apply: `python migrate.py auto "message"`
|
|
||||||
- Never commit migrations without testing them first
|
|
||||||
- Check current state: `python migrate.py current`
|
|
||||||
|
|
||||||
**Frontend API Client Generation:**
|
|
||||||
- Run `npm run generate:api` after backend schema changes
|
|
||||||
- Client is auto-generated from OpenAPI spec
|
|
||||||
- Located in `frontend/src/lib/api/generated/`
|
|
||||||
- NEVER manually edit generated files
|
|
||||||
|
|
||||||
**Testing Commands:**
|
|
||||||
- Backend unit/integration: `IS_TEST=True uv run pytest` (always prefix with `IS_TEST=True`)
|
|
||||||
- Backend E2E (requires Docker): `make test-e2e`
|
|
||||||
- Frontend unit: `npm test`
|
|
||||||
- Frontend E2E: `npm run test:e2e`
|
|
||||||
- Use `make test` or `make test-cov` in backend for convenience
|
|
||||||
|
|
||||||
**Backend E2E Testing (requires Docker):**
|
|
||||||
- Install deps: `make install-e2e`
|
|
||||||
- Run all E2E tests: `make test-e2e`
|
|
||||||
- Run schema tests only: `make test-e2e-schema`
|
|
||||||
- Run all tests: `make test-all` (unit + E2E)
|
|
||||||
- Uses Testcontainers (real PostgreSQL) + Schemathesis (OpenAPI contract testing)
|
|
||||||
- Markers: `@pytest.mark.e2e`, `@pytest.mark.postgres`, `@pytest.mark.schemathesis`
|
|
||||||
- See: `backend/docs/E2E_TESTING.md` for complete guide
|
|
||||||
|
|
||||||
### 🔴 CRITICAL: Auth Store Dependency Injection Pattern
|
|
||||||
|
|
||||||
**ALWAYS use `useAuth()` from `AuthContext`, NEVER import `useAuthStore` directly!**
|
|
||||||
|
|
||||||
```typescript
|
|
||||||
// ❌ WRONG - Bypasses dependency injection
|
|
||||||
import { useAuthStore } from '@/lib/stores/authStore';
|
|
||||||
const { user, isAuthenticated } = useAuthStore();
|
|
||||||
|
|
||||||
// ✅ CORRECT - Uses dependency injection
|
|
||||||
import { useAuth } from '@/lib/auth/AuthContext';
|
|
||||||
const { user, isAuthenticated } = useAuth();
|
|
||||||
```
|
|
||||||
|
|
||||||
**Why This Matters:**
|
|
||||||
- E2E tests inject mock stores via `window.__TEST_AUTH_STORE__`
|
|
||||||
- Unit tests inject via `<AuthProvider store={mockStore}>`
|
|
||||||
- Direct `useAuthStore` imports bypass this injection → **tests fail**
|
|
||||||
- ESLint will catch violations (added Nov 2025)
|
|
||||||
|
|
||||||
**Exceptions:**
|
|
||||||
1. `AuthContext.tsx` - DI boundary, legitimately needs real store
|
|
||||||
2. `client.ts` - Non-React context, uses dynamic import + `__TEST_AUTH_STORE__` check
|
|
||||||
|
|
||||||
### E2E Test Best Practices
|
|
||||||
|
|
||||||
When writing or fixing Playwright tests:
|
|
||||||
|
|
||||||
**Navigation Pattern:**
|
|
||||||
```typescript
|
|
||||||
// ✅ CORRECT - Use Promise.all for Next.js Link clicks
|
|
||||||
await Promise.all([
|
|
||||||
page.waitForURL('/target', { timeout: 10000 }),
|
|
||||||
link.click()
|
|
||||||
]);
|
|
||||||
```
|
|
||||||
|
|
||||||
**Selectors:**
|
|
||||||
- Use ID-based selectors for validation errors: `#email-error`
|
|
||||||
- Error IDs use dashes not underscores: `#new-password-error`
|
|
||||||
- Target `.border-destructive[role="alert"]` to avoid Next.js route announcer conflicts
|
|
||||||
- Avoid generic `[role="alert"]` which matches multiple elements
|
|
||||||
|
|
||||||
**URL Assertions:**
|
|
||||||
```typescript
|
|
||||||
// ✅ Use regex to handle query params
|
|
||||||
await expect(page).toHaveURL(/\/auth\/login/);
|
|
||||||
|
|
||||||
// ❌ Don't use exact strings (fails with query params)
|
|
||||||
await expect(page).toHaveURL('/auth/login');
|
|
||||||
```
|
|
||||||
|
|
||||||
**Configuration:**
|
|
||||||
- Uses 12 workers in non-CI mode (`playwright.config.ts`)
|
|
||||||
- Reduces to 2 workers in CI for stability
|
|
||||||
- Tests are designed to be non-flaky with proper waits
|
|
||||||
|
|
||||||
### Important Implementation Details
|
|
||||||
|
|
||||||
**Authentication Testing:**
|
|
||||||
- Backend fixtures in `tests/conftest.py`:
|
|
||||||
- `async_test_db`: Fresh SQLite per test
|
|
||||||
- `async_test_user` / `async_test_superuser`: Pre-created users
|
|
||||||
- `user_token` / `superuser_token`: Access tokens for API calls
|
|
||||||
- Always use `@pytest.mark.asyncio` for async tests
|
|
||||||
- Use `@pytest_asyncio.fixture` for async fixtures
|
|
||||||
|
|
||||||
**Database Testing:**
|
|
||||||
```python
|
|
||||||
# Mock database exceptions correctly
|
|
||||||
from unittest.mock import patch, AsyncMock
|
|
||||||
|
|
||||||
async def mock_commit():
|
|
||||||
raise OperationalError("Connection lost", {}, Exception())
|
|
||||||
|
|
||||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
|
||||||
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
|
|
||||||
with pytest.raises(OperationalError):
|
|
||||||
await crud_method(session, obj_in=data)
|
|
||||||
mock_rollback.assert_called_once()
|
|
||||||
```
|
|
||||||
|
|
||||||
**Frontend Component Development:**
|
|
||||||
- Follow design system docs in `frontend/docs/design-system/`
|
|
||||||
- Read `08-ai-guidelines.md` for AI code generation rules
|
|
||||||
- Use parent-controlled spacing (see `04-spacing-philosophy.md`)
|
|
||||||
- WCAG AA compliance required (see `07-accessibility.md`)
|
|
||||||
|
|
||||||
**Security Considerations:**
|
|
||||||
- Backend has comprehensive security tests (JWT attacks, session hijacking)
|
|
||||||
- Never skip security headers in production
|
|
||||||
- Rate limiting is configured in route decorators: `@limiter.limit("10/minute")`
|
|
||||||
- Session revocation is database-backed, not just JWT expiry
|
|
||||||
|
|
||||||
### Common Workflows Guidance
|
|
||||||
|
|
||||||
**When Adding a New Feature:**
|
|
||||||
1. Start with backend schema and CRUD
|
|
||||||
2. Implement API route with proper authorization
|
|
||||||
3. Write backend tests (aim for >90% coverage)
|
|
||||||
4. Generate frontend API client: `npm run generate:api`
|
|
||||||
5. Implement frontend components
|
|
||||||
6. Write frontend unit tests
|
|
||||||
7. Add E2E tests for critical flows
|
|
||||||
8. Update relevant documentation
|
|
||||||
|
|
||||||
**When Fixing Tests:**
|
|
||||||
- Backend: Check test database isolation and async fixture usage
|
|
||||||
- Frontend unit: Verify mocking of `useAuth()` not `useAuthStore`
|
|
||||||
- E2E: Use `Promise.all()` pattern and regex URL assertions
|
|
||||||
|
|
||||||
**When Debugging:**
|
|
||||||
- Backend: Check `IS_TEST=True` environment variable is set
|
|
||||||
- Frontend: Run `npm run type-check` first
|
|
||||||
- E2E: Use `npm run test:e2e:debug` for step-by-step debugging
|
|
||||||
- Check logs: Backend has detailed error logging
|
|
||||||
|
|
||||||
**Demo Mode (Frontend-Only Showcase):**
|
|
||||||
- Enable: `echo "NEXT_PUBLIC_DEMO_MODE=true" > frontend/.env.local`
|
|
||||||
- Uses MSW (Mock Service Worker) to intercept API calls in browser
|
|
||||||
- Zero backend required - perfect for Vercel deployments
|
|
||||||
- **Fully Automated**: MSW handlers auto-generated from OpenAPI spec
|
|
||||||
- Run `npm run generate:api` → updates both API client AND MSW handlers
|
|
||||||
- No manual synchronization needed!
|
|
||||||
- Demo credentials (any password ≥8 chars works):
|
|
||||||
- User: `demo@example.com` / `DemoPass123`
|
|
||||||
- Admin: `admin@example.com` / `AdminPass123`
|
|
||||||
- **Safe**: MSW never runs during tests (Jest or Playwright)
|
|
||||||
- **Coverage**: Mock files excluded from linting and coverage
|
|
||||||
- **Documentation**: `frontend/docs/DEMO_MODE.md` for complete guide
|
|
||||||
|
|
||||||
### Tool Usage Preferences
|
|
||||||
|
|
||||||
**Prefer specialized tools over bash:**
|
|
||||||
- Use Read/Write/Edit tools for file operations
|
|
||||||
- Never use `cat`, `echo >`, or heredoc for file manipulation
|
|
||||||
- Use Task tool with `subagent_type=Explore` for codebase exploration
|
|
||||||
- Use Grep tool for code search, not bash `grep`
|
|
||||||
|
|
||||||
**When to use parallel tool calls:**
|
|
||||||
- Independent git commands: `git status`, `git diff`, `git log`
|
|
||||||
- Reading multiple unrelated files
|
|
||||||
- Running multiple test suites simultaneously
|
|
||||||
- Independent validation steps
|
|
||||||
|
|
||||||
## Custom Skills
|
|
||||||
|
|
||||||
No Claude Code Skills installed yet. To create one, invoke the built-in "skill-creator" skill.
|
|
||||||
|
|
||||||
**Potential skill ideas for this project:**
|
|
||||||
- API endpoint generator workflow (schema → CRUD → route → tests → frontend client)
|
|
||||||
- Component generator with design system compliance
|
|
||||||
- Database migration troubleshooting helper
|
|
||||||
- Test coverage analyzer and improvement suggester
|
|
||||||
- E2E test generator for new features
|
|
||||||
|
|
||||||
## Additional Resources
|
## Additional Resources
|
||||||
|
|
||||||
**Comprehensive Documentation:**
|
**Documentation:**
|
||||||
- [AGENTS.md](./AGENTS.md) - Framework-agnostic AI assistant context
|
- [AGENTS.md](./AGENTS.md) - Framework-agnostic AI assistant context
|
||||||
- [README.md](./README.md) - User-facing project overview
|
- [README.md](./README.md) - User-facing project overview
|
||||||
- `backend/docs/` - Backend architecture, coding standards, common pitfalls
|
- [docs/development/](./docs/development/) - Development workflow and standards
|
||||||
- `frontend/docs/design-system/` - Complete design system guide
|
- [backend/docs/](./backend/docs/) - Backend architecture and guides
|
||||||
|
- [frontend/docs/design-system/](./frontend/docs/design-system/) - Complete design system
|
||||||
|
|
||||||
**API Documentation (when running):**
|
**API Documentation (when running):**
|
||||||
- Swagger UI: http://localhost:8000/docs
|
- Swagger UI: http://localhost:8000/docs
|
||||||
- ReDoc: http://localhost:8000/redoc
|
- ReDoc: http://localhost:8000/redoc
|
||||||
- OpenAPI JSON: http://localhost:8000/api/v1/openapi.json
|
- OpenAPI JSON: http://localhost:8000/api/v1/openapi.json
|
||||||
|
|
||||||
**Testing Documentation:**
|
|
||||||
- Backend tests: `backend/tests/` (97% coverage)
|
|
||||||
- Frontend E2E: `frontend/e2e/README.md`
|
|
||||||
- Design system: `frontend/docs/design-system/08-ai-guidelines.md`
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
**For project architecture, development commands, and general context, see [AGENTS.md](./AGENTS.md).**
|
**For project architecture, development commands, and general context, see [AGENTS.md](./AGENTS.md).**
|
||||||
|
|||||||
110
Makefile
110
Makefile
@@ -1,18 +1,34 @@
|
|||||||
.PHONY: help dev dev-full prod down logs logs-dev clean clean-slate drop-db reset-db push-images deploy
|
.PHONY: help dev dev-full prod down logs logs-dev clean clean-slate drop-db reset-db push-images deploy
|
||||||
|
.PHONY: test test-backend test-mcp test-frontend test-all test-cov test-integration validate validate-all format-all
|
||||||
|
|
||||||
VERSION ?= latest
|
VERSION ?= latest
|
||||||
REGISTRY ?= ghcr.io/cardosofelipe/pragma-stack
|
REGISTRY ?= ghcr.io/cardosofelipe/pragma-stack
|
||||||
|
|
||||||
# Default target
|
# Default target
|
||||||
help:
|
help:
|
||||||
@echo "FastAPI + Next.js Full-Stack Template"
|
@echo "Syndarix - AI-Powered Software Consulting Agency"
|
||||||
@echo ""
|
@echo ""
|
||||||
@echo "Development:"
|
@echo "Development:"
|
||||||
@echo " make dev - Start backend + db (frontend runs separately)"
|
@echo " make dev - Start backend + db + MCP servers (frontend runs separately)"
|
||||||
@echo " make dev-full - Start all services including frontend"
|
@echo " make dev-full - Start all services including frontend"
|
||||||
@echo " make down - Stop all services"
|
@echo " make down - Stop all services"
|
||||||
@echo " make logs-dev - Follow dev container logs"
|
@echo " make logs-dev - Follow dev container logs"
|
||||||
@echo ""
|
@echo ""
|
||||||
|
@echo "Testing:"
|
||||||
|
@echo " make test - Run all tests (backend + MCP servers)"
|
||||||
|
@echo " make test-backend - Run backend tests only"
|
||||||
|
@echo " make test-mcp - Run MCP server tests only"
|
||||||
|
@echo " make test-frontend - Run frontend tests only"
|
||||||
|
@echo " make test-cov - Run all tests with coverage reports"
|
||||||
|
@echo " make test-integration - Run MCP integration tests (requires running stack)"
|
||||||
|
@echo ""
|
||||||
|
@echo "Formatting:"
|
||||||
|
@echo " make format-all - Format code in backend + MCP servers + frontend"
|
||||||
|
@echo ""
|
||||||
|
@echo "Validation:"
|
||||||
|
@echo " make validate - Validate backend + MCP servers (lint, type-check, test)"
|
||||||
|
@echo " make validate-all - Validate everything including frontend"
|
||||||
|
@echo ""
|
||||||
@echo "Database:"
|
@echo "Database:"
|
||||||
@echo " make drop-db - Drop and recreate empty database"
|
@echo " make drop-db - Drop and recreate empty database"
|
||||||
@echo " make reset-db - Drop database and apply all migrations"
|
@echo " make reset-db - Drop database and apply all migrations"
|
||||||
@@ -29,6 +45,8 @@ help:
|
|||||||
@echo ""
|
@echo ""
|
||||||
@echo "Subdirectory commands:"
|
@echo "Subdirectory commands:"
|
||||||
@echo " cd backend && make help - Backend-specific commands"
|
@echo " cd backend && make help - Backend-specific commands"
|
||||||
|
@echo " cd mcp-servers/llm-gateway && make - LLM Gateway commands"
|
||||||
|
@echo " cd mcp-servers/knowledge-base && make - Knowledge Base commands"
|
||||||
@echo " cd frontend && npm run - Frontend-specific commands"
|
@echo " cd frontend && npm run - Frontend-specific commands"
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
@@ -99,3 +117,91 @@ clean:
|
|||||||
# WARNING! THIS REMOVES CONTAINERS AND VOLUMES AS WELL - DO NOT USE THIS UNLESS YOU WANT TO START OVER WITH DATA AND ALL
|
# WARNING! THIS REMOVES CONTAINERS AND VOLUMES AS WELL - DO NOT USE THIS UNLESS YOU WANT TO START OVER WITH DATA AND ALL
|
||||||
clean-slate:
|
clean-slate:
|
||||||
docker compose -f docker-compose.dev.yml down -v --remove-orphans
|
docker compose -f docker-compose.dev.yml down -v --remove-orphans
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Testing
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
test: test-backend test-mcp
|
||||||
|
@echo ""
|
||||||
|
@echo "All tests passed!"
|
||||||
|
|
||||||
|
test-backend:
|
||||||
|
@echo "Running backend tests..."
|
||||||
|
@cd backend && IS_TEST=True uv run pytest tests/ -v
|
||||||
|
|
||||||
|
test-mcp:
|
||||||
|
@echo "Running MCP server tests..."
|
||||||
|
@echo ""
|
||||||
|
@echo "=== LLM Gateway ==="
|
||||||
|
@cd mcp-servers/llm-gateway && uv run pytest tests/ -v
|
||||||
|
@echo ""
|
||||||
|
@echo "=== Knowledge Base ==="
|
||||||
|
@cd mcp-servers/knowledge-base && uv run pytest tests/ -v
|
||||||
|
|
||||||
|
test-frontend:
|
||||||
|
@echo "Running frontend tests..."
|
||||||
|
@cd frontend && npm test
|
||||||
|
|
||||||
|
test-all: test test-frontend
|
||||||
|
@echo ""
|
||||||
|
@echo "All tests (backend + MCP + frontend) passed!"
|
||||||
|
|
||||||
|
test-cov:
|
||||||
|
@echo "Running all tests with coverage..."
|
||||||
|
@echo ""
|
||||||
|
@echo "=== Backend Coverage ==="
|
||||||
|
@cd backend && IS_TEST=True uv run pytest tests/ -v --cov=app --cov-report=term-missing
|
||||||
|
@echo ""
|
||||||
|
@echo "=== LLM Gateway Coverage ==="
|
||||||
|
@cd mcp-servers/llm-gateway && uv run pytest tests/ -v --cov=. --cov-report=term-missing
|
||||||
|
@echo ""
|
||||||
|
@echo "=== Knowledge Base Coverage ==="
|
||||||
|
@cd mcp-servers/knowledge-base && uv run pytest tests/ -v --cov=. --cov-report=term-missing
|
||||||
|
|
||||||
|
test-integration:
|
||||||
|
@echo "Running MCP integration tests..."
|
||||||
|
@echo "Note: Requires running stack (make dev first)"
|
||||||
|
@cd backend && RUN_INTEGRATION_TESTS=true IS_TEST=True uv run pytest tests/integration/ -v
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Formatting
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
format-all:
|
||||||
|
@echo "Formatting backend..."
|
||||||
|
@cd backend && make format
|
||||||
|
@echo ""
|
||||||
|
@echo "Formatting LLM Gateway..."
|
||||||
|
@cd mcp-servers/llm-gateway && make format
|
||||||
|
@echo ""
|
||||||
|
@echo "Formatting Knowledge Base..."
|
||||||
|
@cd mcp-servers/knowledge-base && make format
|
||||||
|
@echo ""
|
||||||
|
@echo "Formatting frontend..."
|
||||||
|
@cd frontend && npm run format
|
||||||
|
@echo ""
|
||||||
|
@echo "All code formatted!"
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Validation (lint + type-check + test)
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
validate:
|
||||||
|
@echo "Validating backend..."
|
||||||
|
@cd backend && make validate
|
||||||
|
@echo ""
|
||||||
|
@echo "Validating LLM Gateway..."
|
||||||
|
@cd mcp-servers/llm-gateway && make validate
|
||||||
|
@echo ""
|
||||||
|
@echo "Validating Knowledge Base..."
|
||||||
|
@cd mcp-servers/knowledge-base && make validate
|
||||||
|
@echo ""
|
||||||
|
@echo "All validations passed!"
|
||||||
|
|
||||||
|
validate-all: validate
|
||||||
|
@echo ""
|
||||||
|
@echo "Validating frontend..."
|
||||||
|
@cd frontend && npm run validate
|
||||||
|
@echo ""
|
||||||
|
@echo "Full validation passed!"
|
||||||
|
|||||||
@@ -7,7 +7,10 @@ ENV PYTHONDONTWRITEBYTECODE=1 \
|
|||||||
PYTHONPATH=/app \
|
PYTHONPATH=/app \
|
||||||
UV_COMPILE_BYTECODE=1 \
|
UV_COMPILE_BYTECODE=1 \
|
||||||
UV_LINK_MODE=copy \
|
UV_LINK_MODE=copy \
|
||||||
UV_NO_CACHE=1
|
UV_NO_CACHE=1 \
|
||||||
|
UV_PROJECT_ENVIRONMENT=/opt/venv \
|
||||||
|
VIRTUAL_ENV=/opt/venv \
|
||||||
|
PATH="/opt/venv/bin:$PATH"
|
||||||
|
|
||||||
# Install system dependencies and uv
|
# Install system dependencies and uv
|
||||||
RUN apt-get update && \
|
RUN apt-get update && \
|
||||||
@@ -20,7 +23,7 @@ RUN apt-get update && \
|
|||||||
# Copy dependency files
|
# Copy dependency files
|
||||||
COPY pyproject.toml uv.lock ./
|
COPY pyproject.toml uv.lock ./
|
||||||
|
|
||||||
# Install dependencies using uv (development mode with dev dependencies)
|
# Install dependencies using uv into /opt/venv (outside /app to survive bind mounts)
|
||||||
RUN uv sync --extra dev --frozen
|
RUN uv sync --extra dev --frozen
|
||||||
|
|
||||||
# Copy application code
|
# Copy application code
|
||||||
@@ -45,7 +48,10 @@ ENV PYTHONDONTWRITEBYTECODE=1 \
|
|||||||
PYTHONPATH=/app \
|
PYTHONPATH=/app \
|
||||||
UV_COMPILE_BYTECODE=1 \
|
UV_COMPILE_BYTECODE=1 \
|
||||||
UV_LINK_MODE=copy \
|
UV_LINK_MODE=copy \
|
||||||
UV_NO_CACHE=1
|
UV_NO_CACHE=1 \
|
||||||
|
UV_PROJECT_ENVIRONMENT=/opt/venv \
|
||||||
|
VIRTUAL_ENV=/opt/venv \
|
||||||
|
PATH="/opt/venv/bin:$PATH"
|
||||||
|
|
||||||
# Install system dependencies and uv
|
# Install system dependencies and uv
|
||||||
RUN apt-get update && \
|
RUN apt-get update && \
|
||||||
@@ -58,7 +64,7 @@ RUN apt-get update && \
|
|||||||
# Copy dependency files
|
# Copy dependency files
|
||||||
COPY pyproject.toml uv.lock ./
|
COPY pyproject.toml uv.lock ./
|
||||||
|
|
||||||
# Install only production dependencies using uv (no dev dependencies)
|
# Install only production dependencies using uv into /opt/venv
|
||||||
RUN uv sync --frozen --no-dev
|
RUN uv sync --frozen --no-dev
|
||||||
|
|
||||||
# Copy application code
|
# Copy application code
|
||||||
@@ -67,7 +73,7 @@ COPY entrypoint.sh /usr/local/bin/
|
|||||||
RUN chmod +x /usr/local/bin/entrypoint.sh
|
RUN chmod +x /usr/local/bin/entrypoint.sh
|
||||||
|
|
||||||
# Set ownership to non-root user
|
# Set ownership to non-root user
|
||||||
RUN chown -R appuser:appuser /app
|
RUN chown -R appuser:appuser /app /opt/venv
|
||||||
|
|
||||||
# Switch to non-root user
|
# Switch to non-root user
|
||||||
USER appuser
|
USER appuser
|
||||||
@@ -77,4 +83,4 @@ HEALTHCHECK --interval=30s --timeout=10s --start-period=40s --retries=3 \
|
|||||||
CMD curl -f http://localhost:8000/health || exit 1
|
CMD curl -f http://localhost:8000/health || exit 1
|
||||||
|
|
||||||
ENTRYPOINT ["/usr/local/bin/entrypoint.sh"]
|
ENTRYPOINT ["/usr/local/bin/entrypoint.sh"]
|
||||||
CMD ["uv", "run", "uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
.PHONY: help lint lint-fix format format-check type-check test test-cov validate clean install-dev sync check-docker install-e2e test-e2e test-e2e-schema test-all
|
.PHONY: help lint lint-fix format format-check type-check test test-cov validate clean install-dev sync check-docker install-e2e test-e2e test-e2e-schema test-all test-integration
|
||||||
|
|
||||||
# Default target
|
# Default target
|
||||||
help:
|
help:
|
||||||
@@ -22,6 +22,7 @@ help:
|
|||||||
@echo " make test-cov - Run pytest with coverage report"
|
@echo " make test-cov - Run pytest with coverage report"
|
||||||
@echo " make test-e2e - Run E2E tests (PostgreSQL, requires Docker)"
|
@echo " make test-e2e - Run E2E tests (PostgreSQL, requires Docker)"
|
||||||
@echo " make test-e2e-schema - Run Schemathesis API schema tests"
|
@echo " make test-e2e-schema - Run Schemathesis API schema tests"
|
||||||
|
@echo " make test-integration - Run MCP integration tests (requires running stack)"
|
||||||
@echo " make test-all - Run all tests (unit + E2E)"
|
@echo " make test-all - Run all tests (unit + E2E)"
|
||||||
@echo " make check-docker - Check if Docker is available"
|
@echo " make check-docker - Check if Docker is available"
|
||||||
@echo ""
|
@echo ""
|
||||||
@@ -79,9 +80,18 @@ test:
|
|||||||
|
|
||||||
test-cov:
|
test-cov:
|
||||||
@echo "🧪 Running tests with coverage..."
|
@echo "🧪 Running tests with coverage..."
|
||||||
@IS_TEST=True PYTHONPATH=. uv run pytest --cov=app --cov-report=term-missing --cov-report=html -n 16
|
@IS_TEST=True PYTHONPATH=. uv run pytest --cov=app --cov-report=term-missing --cov-report=html -n 20
|
||||||
@echo "📊 Coverage report generated in htmlcov/index.html"
|
@echo "📊 Coverage report generated in htmlcov/index.html"
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Integration Testing (requires running stack: make dev)
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
test-integration:
|
||||||
|
@echo "🧪 Running MCP integration tests..."
|
||||||
|
@echo "Note: Requires running stack (make dev from project root)"
|
||||||
|
@RUN_INTEGRATION_TESTS=true IS_TEST=True PYTHONPATH=. uv run pytest tests/integration/ -v
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
# E2E Testing (requires Docker)
|
# E2E Testing (requires Docker)
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ def include_object(object, name, type_, reflected, compare_to):
|
|||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
# Interpret the config file for Python logging.
|
# Interpret the config file for Python logging.
|
||||||
# This line sets up loggers basically.
|
# This line sets up loggers basically.
|
||||||
if config.config_file_name is not None:
|
if config.config_file_name is not None:
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ Revises:
|
|||||||
Create Date: 2025-11-27 09:08:09.464506
|
Create Date: 2025-11-27 09:08:09.464506
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
@@ -12,7 +13,7 @@ from alembic import op
|
|||||||
from sqlalchemy.dialects import postgresql
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
revision: str = '0001'
|
revision: str = "0001"
|
||||||
down_revision: str | None = None
|
down_revision: str | None = None
|
||||||
branch_labels: str | Sequence[str] | None = None
|
branch_labels: str | Sequence[str] | None = None
|
||||||
depends_on: str | Sequence[str] | None = None
|
depends_on: str | Sequence[str] | None = None
|
||||||
@@ -20,243 +21,426 @@ depends_on: str | Sequence[str] | None = None
|
|||||||
|
|
||||||
def upgrade() -> None:
|
def upgrade() -> None:
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
op.create_table('oauth_states',
|
op.create_table(
|
||||||
sa.Column('state', sa.String(length=255), nullable=False),
|
"oauth_states",
|
||||||
sa.Column('code_verifier', sa.String(length=128), nullable=True),
|
sa.Column("state", sa.String(length=255), nullable=False),
|
||||||
sa.Column('nonce', sa.String(length=255), nullable=True),
|
sa.Column("code_verifier", sa.String(length=128), nullable=True),
|
||||||
sa.Column('provider', sa.String(length=50), nullable=False),
|
sa.Column("nonce", sa.String(length=255), nullable=True),
|
||||||
sa.Column('redirect_uri', sa.String(length=500), nullable=True),
|
sa.Column("provider", sa.String(length=50), nullable=False),
|
||||||
sa.Column('user_id', sa.UUID(), nullable=True),
|
sa.Column("redirect_uri", sa.String(length=500), nullable=True),
|
||||||
sa.Column('expires_at', sa.DateTime(timezone=True), nullable=False),
|
sa.Column("user_id", sa.UUID(), nullable=True),
|
||||||
sa.Column('id', sa.UUID(), nullable=False),
|
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
sa.Column("id", sa.UUID(), nullable=False),
|
||||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
sa.PrimaryKeyConstraint('id')
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
)
|
)
|
||||||
op.create_index(op.f('ix_oauth_states_state'), 'oauth_states', ['state'], unique=True)
|
op.create_index(
|
||||||
op.create_table('organizations',
|
op.f("ix_oauth_states_state"), "oauth_states", ["state"], unique=True
|
||||||
sa.Column('name', sa.String(length=255), nullable=False),
|
|
||||||
sa.Column('slug', sa.String(length=255), nullable=False),
|
|
||||||
sa.Column('description', sa.Text(), nullable=True),
|
|
||||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
|
||||||
sa.Column('settings', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
|
||||||
sa.Column('id', sa.UUID(), nullable=False),
|
|
||||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
|
||||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
|
||||||
sa.PrimaryKeyConstraint('id')
|
|
||||||
)
|
)
|
||||||
op.create_index(op.f('ix_organizations_is_active'), 'organizations', ['is_active'], unique=False)
|
op.create_table(
|
||||||
op.create_index(op.f('ix_organizations_name'), 'organizations', ['name'], unique=False)
|
"organizations",
|
||||||
op.create_index('ix_organizations_name_active', 'organizations', ['name', 'is_active'], unique=False)
|
sa.Column("name", sa.String(length=255), nullable=False),
|
||||||
op.create_index(op.f('ix_organizations_slug'), 'organizations', ['slug'], unique=True)
|
sa.Column("slug", sa.String(length=255), nullable=False),
|
||||||
op.create_index('ix_organizations_slug_active', 'organizations', ['slug', 'is_active'], unique=False)
|
sa.Column("description", sa.Text(), nullable=True),
|
||||||
op.create_table('users',
|
sa.Column("is_active", sa.Boolean(), nullable=False),
|
||||||
sa.Column('email', sa.String(length=255), nullable=False),
|
sa.Column("settings", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||||
sa.Column('password_hash', sa.String(length=255), nullable=True),
|
sa.Column("id", sa.UUID(), nullable=False),
|
||||||
sa.Column('first_name', sa.String(length=100), nullable=False),
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
sa.Column('last_name', sa.String(length=100), nullable=True),
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
sa.Column('phone_number', sa.String(length=20), nullable=True),
|
sa.PrimaryKeyConstraint("id"),
|
||||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
|
||||||
sa.Column('is_superuser', sa.Boolean(), nullable=False),
|
|
||||||
sa.Column('preferences', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
|
||||||
sa.Column('locale', sa.String(length=10), nullable=True),
|
|
||||||
sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True),
|
|
||||||
sa.Column('id', sa.UUID(), nullable=False),
|
|
||||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
|
||||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
|
||||||
sa.PrimaryKeyConstraint('id')
|
|
||||||
)
|
)
|
||||||
op.create_index(op.f('ix_users_deleted_at'), 'users', ['deleted_at'], unique=False)
|
op.create_index(
|
||||||
op.create_index(op.f('ix_users_email'), 'users', ['email'], unique=True)
|
op.f("ix_organizations_is_active"), "organizations", ["is_active"], unique=False
|
||||||
op.create_index(op.f('ix_users_is_active'), 'users', ['is_active'], unique=False)
|
|
||||||
op.create_index(op.f('ix_users_is_superuser'), 'users', ['is_superuser'], unique=False)
|
|
||||||
op.create_index(op.f('ix_users_locale'), 'users', ['locale'], unique=False)
|
|
||||||
op.create_table('oauth_accounts',
|
|
||||||
sa.Column('user_id', sa.UUID(), nullable=False),
|
|
||||||
sa.Column('provider', sa.String(length=50), nullable=False),
|
|
||||||
sa.Column('provider_user_id', sa.String(length=255), nullable=False),
|
|
||||||
sa.Column('provider_email', sa.String(length=255), nullable=True),
|
|
||||||
sa.Column('access_token_encrypted', sa.String(length=2048), nullable=True),
|
|
||||||
sa.Column('refresh_token_encrypted', sa.String(length=2048), nullable=True),
|
|
||||||
sa.Column('token_expires_at', sa.DateTime(timezone=True), nullable=True),
|
|
||||||
sa.Column('id', sa.UUID(), nullable=False),
|
|
||||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
|
||||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
|
||||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
|
|
||||||
sa.PrimaryKeyConstraint('id'),
|
|
||||||
sa.UniqueConstraint('provider', 'provider_user_id', name='uq_oauth_provider_user')
|
|
||||||
)
|
)
|
||||||
op.create_index(op.f('ix_oauth_accounts_provider'), 'oauth_accounts', ['provider'], unique=False)
|
op.create_index(
|
||||||
op.create_index(op.f('ix_oauth_accounts_provider_email'), 'oauth_accounts', ['provider_email'], unique=False)
|
op.f("ix_organizations_name"), "organizations", ["name"], unique=False
|
||||||
op.create_index(op.f('ix_oauth_accounts_user_id'), 'oauth_accounts', ['user_id'], unique=False)
|
|
||||||
op.create_index('ix_oauth_accounts_user_provider', 'oauth_accounts', ['user_id', 'provider'], unique=False)
|
|
||||||
op.create_table('oauth_clients',
|
|
||||||
sa.Column('client_id', sa.String(length=64), nullable=False),
|
|
||||||
sa.Column('client_secret_hash', sa.String(length=255), nullable=True),
|
|
||||||
sa.Column('client_name', sa.String(length=255), nullable=False),
|
|
||||||
sa.Column('client_description', sa.String(length=1000), nullable=True),
|
|
||||||
sa.Column('client_type', sa.String(length=20), nullable=False),
|
|
||||||
sa.Column('redirect_uris', postgresql.JSONB(astext_type=sa.Text()), nullable=False),
|
|
||||||
sa.Column('allowed_scopes', postgresql.JSONB(astext_type=sa.Text()), nullable=False),
|
|
||||||
sa.Column('access_token_lifetime', sa.String(length=10), nullable=False),
|
|
||||||
sa.Column('refresh_token_lifetime', sa.String(length=10), nullable=False),
|
|
||||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
|
||||||
sa.Column('owner_user_id', sa.UUID(), nullable=True),
|
|
||||||
sa.Column('mcp_server_url', sa.String(length=2048), nullable=True),
|
|
||||||
sa.Column('id', sa.UUID(), nullable=False),
|
|
||||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
|
||||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
|
||||||
sa.ForeignKeyConstraint(['owner_user_id'], ['users.id'], ondelete='SET NULL'),
|
|
||||||
sa.PrimaryKeyConstraint('id')
|
|
||||||
)
|
)
|
||||||
op.create_index(op.f('ix_oauth_clients_client_id'), 'oauth_clients', ['client_id'], unique=True)
|
op.create_index(
|
||||||
op.create_index(op.f('ix_oauth_clients_is_active'), 'oauth_clients', ['is_active'], unique=False)
|
"ix_organizations_name_active",
|
||||||
op.create_table('user_organizations',
|
"organizations",
|
||||||
sa.Column('user_id', sa.UUID(), nullable=False),
|
["name", "is_active"],
|
||||||
sa.Column('organization_id', sa.UUID(), nullable=False),
|
unique=False,
|
||||||
sa.Column('role', sa.Enum('OWNER', 'ADMIN', 'MEMBER', 'GUEST', name='organizationrole'), nullable=False),
|
|
||||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
|
||||||
sa.Column('custom_permissions', sa.String(length=500), nullable=True),
|
|
||||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
|
||||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
|
||||||
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ondelete='CASCADE'),
|
|
||||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
|
|
||||||
sa.PrimaryKeyConstraint('user_id', 'organization_id')
|
|
||||||
)
|
)
|
||||||
op.create_index('ix_user_org_org_active', 'user_organizations', ['organization_id', 'is_active'], unique=False)
|
op.create_index(
|
||||||
op.create_index('ix_user_org_role', 'user_organizations', ['role'], unique=False)
|
op.f("ix_organizations_slug"), "organizations", ["slug"], unique=True
|
||||||
op.create_index('ix_user_org_user_active', 'user_organizations', ['user_id', 'is_active'], unique=False)
|
|
||||||
op.create_index(op.f('ix_user_organizations_is_active'), 'user_organizations', ['is_active'], unique=False)
|
|
||||||
op.create_table('user_sessions',
|
|
||||||
sa.Column('user_id', sa.UUID(), nullable=False),
|
|
||||||
sa.Column('refresh_token_jti', sa.String(length=255), nullable=False),
|
|
||||||
sa.Column('device_name', sa.String(length=255), nullable=True),
|
|
||||||
sa.Column('device_id', sa.String(length=255), nullable=True),
|
|
||||||
sa.Column('ip_address', sa.String(length=45), nullable=True),
|
|
||||||
sa.Column('user_agent', sa.String(length=500), nullable=True),
|
|
||||||
sa.Column('last_used_at', sa.DateTime(timezone=True), nullable=False),
|
|
||||||
sa.Column('expires_at', sa.DateTime(timezone=True), nullable=False),
|
|
||||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
|
||||||
sa.Column('location_city', sa.String(length=100), nullable=True),
|
|
||||||
sa.Column('location_country', sa.String(length=100), nullable=True),
|
|
||||||
sa.Column('id', sa.UUID(), nullable=False),
|
|
||||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
|
||||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
|
||||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
|
|
||||||
sa.PrimaryKeyConstraint('id')
|
|
||||||
)
|
)
|
||||||
op.create_index(op.f('ix_user_sessions_is_active'), 'user_sessions', ['is_active'], unique=False)
|
op.create_index(
|
||||||
op.create_index('ix_user_sessions_jti_active', 'user_sessions', ['refresh_token_jti', 'is_active'], unique=False)
|
"ix_organizations_slug_active",
|
||||||
op.create_index(op.f('ix_user_sessions_refresh_token_jti'), 'user_sessions', ['refresh_token_jti'], unique=True)
|
"organizations",
|
||||||
op.create_index('ix_user_sessions_user_active', 'user_sessions', ['user_id', 'is_active'], unique=False)
|
["slug", "is_active"],
|
||||||
op.create_index(op.f('ix_user_sessions_user_id'), 'user_sessions', ['user_id'], unique=False)
|
unique=False,
|
||||||
op.create_table('oauth_authorization_codes',
|
|
||||||
sa.Column('code', sa.String(length=128), nullable=False),
|
|
||||||
sa.Column('client_id', sa.String(length=64), nullable=False),
|
|
||||||
sa.Column('user_id', sa.UUID(), nullable=False),
|
|
||||||
sa.Column('redirect_uri', sa.String(length=2048), nullable=False),
|
|
||||||
sa.Column('scope', sa.String(length=1000), nullable=False),
|
|
||||||
sa.Column('code_challenge', sa.String(length=128), nullable=True),
|
|
||||||
sa.Column('code_challenge_method', sa.String(length=10), nullable=True),
|
|
||||||
sa.Column('state', sa.String(length=256), nullable=True),
|
|
||||||
sa.Column('nonce', sa.String(length=256), nullable=True),
|
|
||||||
sa.Column('expires_at', sa.DateTime(timezone=True), nullable=False),
|
|
||||||
sa.Column('used', sa.Boolean(), nullable=False),
|
|
||||||
sa.Column('id', sa.UUID(), nullable=False),
|
|
||||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
|
||||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
|
||||||
sa.ForeignKeyConstraint(['client_id'], ['oauth_clients.client_id'], ondelete='CASCADE'),
|
|
||||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
|
|
||||||
sa.PrimaryKeyConstraint('id')
|
|
||||||
)
|
)
|
||||||
op.create_index('ix_oauth_authorization_codes_client_user', 'oauth_authorization_codes', ['client_id', 'user_id'], unique=False)
|
op.create_table(
|
||||||
op.create_index(op.f('ix_oauth_authorization_codes_code'), 'oauth_authorization_codes', ['code'], unique=True)
|
"users",
|
||||||
op.create_index('ix_oauth_authorization_codes_expires_at', 'oauth_authorization_codes', ['expires_at'], unique=False)
|
sa.Column("email", sa.String(length=255), nullable=False),
|
||||||
op.create_table('oauth_consents',
|
sa.Column("password_hash", sa.String(length=255), nullable=True),
|
||||||
sa.Column('user_id', sa.UUID(), nullable=False),
|
sa.Column("first_name", sa.String(length=100), nullable=False),
|
||||||
sa.Column('client_id', sa.String(length=64), nullable=False),
|
sa.Column("last_name", sa.String(length=100), nullable=True),
|
||||||
sa.Column('granted_scopes', sa.String(length=1000), nullable=False),
|
sa.Column("phone_number", sa.String(length=20), nullable=True),
|
||||||
sa.Column('id', sa.UUID(), nullable=False),
|
sa.Column("is_active", sa.Boolean(), nullable=False),
|
||||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
sa.Column("is_superuser", sa.Boolean(), nullable=False),
|
||||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
sa.Column(
|
||||||
sa.ForeignKeyConstraint(['client_id'], ['oauth_clients.client_id'], ondelete='CASCADE'),
|
"preferences", postgresql.JSONB(astext_type=sa.Text()), nullable=True
|
||||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
|
),
|
||||||
sa.PrimaryKeyConstraint('id')
|
sa.Column("locale", sa.String(length=10), nullable=True),
|
||||||
|
sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("id", sa.UUID(), nullable=False),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
)
|
)
|
||||||
op.create_index('ix_oauth_consents_user_client', 'oauth_consents', ['user_id', 'client_id'], unique=True)
|
op.create_index(op.f("ix_users_deleted_at"), "users", ["deleted_at"], unique=False)
|
||||||
op.create_table('oauth_provider_refresh_tokens',
|
op.create_index(op.f("ix_users_email"), "users", ["email"], unique=True)
|
||||||
sa.Column('token_hash', sa.String(length=64), nullable=False),
|
op.create_index(op.f("ix_users_is_active"), "users", ["is_active"], unique=False)
|
||||||
sa.Column('jti', sa.String(length=64), nullable=False),
|
op.create_index(
|
||||||
sa.Column('client_id', sa.String(length=64), nullable=False),
|
op.f("ix_users_is_superuser"), "users", ["is_superuser"], unique=False
|
||||||
sa.Column('user_id', sa.UUID(), nullable=False),
|
)
|
||||||
sa.Column('scope', sa.String(length=1000), nullable=False),
|
op.create_index(op.f("ix_users_locale"), "users", ["locale"], unique=False)
|
||||||
sa.Column('expires_at', sa.DateTime(timezone=True), nullable=False),
|
op.create_table(
|
||||||
sa.Column('revoked', sa.Boolean(), nullable=False),
|
"oauth_accounts",
|
||||||
sa.Column('last_used_at', sa.DateTime(timezone=True), nullable=True),
|
sa.Column("user_id", sa.UUID(), nullable=False),
|
||||||
sa.Column('device_info', sa.String(length=500), nullable=True),
|
sa.Column("provider", sa.String(length=50), nullable=False),
|
||||||
sa.Column('ip_address', sa.String(length=45), nullable=True),
|
sa.Column("provider_user_id", sa.String(length=255), nullable=False),
|
||||||
sa.Column('id', sa.UUID(), nullable=False),
|
sa.Column("provider_email", sa.String(length=255), nullable=True),
|
||||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
sa.Column("access_token_encrypted", sa.String(length=2048), nullable=True),
|
||||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
sa.Column("refresh_token_encrypted", sa.String(length=2048), nullable=True),
|
||||||
sa.ForeignKeyConstraint(['client_id'], ['oauth_clients.client_id'], ondelete='CASCADE'),
|
sa.Column("token_expires_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
|
sa.Column("id", sa.UUID(), nullable=False),
|
||||||
sa.PrimaryKeyConstraint('id')
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.UniqueConstraint(
|
||||||
|
"provider", "provider_user_id", name="uq_oauth_provider_user"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_oauth_accounts_provider"), "oauth_accounts", ["provider"], unique=False
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_oauth_accounts_provider_email"),
|
||||||
|
"oauth_accounts",
|
||||||
|
["provider_email"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_oauth_accounts_user_id"), "oauth_accounts", ["user_id"], unique=False
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_oauth_accounts_user_provider",
|
||||||
|
"oauth_accounts",
|
||||||
|
["user_id", "provider"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"oauth_clients",
|
||||||
|
sa.Column("client_id", sa.String(length=64), nullable=False),
|
||||||
|
sa.Column("client_secret_hash", sa.String(length=255), nullable=True),
|
||||||
|
sa.Column("client_name", sa.String(length=255), nullable=False),
|
||||||
|
sa.Column("client_description", sa.String(length=1000), nullable=True),
|
||||||
|
sa.Column("client_type", sa.String(length=20), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"redirect_uris", postgresql.JSONB(astext_type=sa.Text()), nullable=False
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"allowed_scopes", postgresql.JSONB(astext_type=sa.Text()), nullable=False
|
||||||
|
),
|
||||||
|
sa.Column("access_token_lifetime", sa.String(length=10), nullable=False),
|
||||||
|
sa.Column("refresh_token_lifetime", sa.String(length=10), nullable=False),
|
||||||
|
sa.Column("is_active", sa.Boolean(), nullable=False),
|
||||||
|
sa.Column("owner_user_id", sa.UUID(), nullable=True),
|
||||||
|
sa.Column("mcp_server_url", sa.String(length=2048), nullable=True),
|
||||||
|
sa.Column("id", sa.UUID(), nullable=False),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(["owner_user_id"], ["users.id"], ondelete="SET NULL"),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_oauth_clients_client_id"), "oauth_clients", ["client_id"], unique=True
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_oauth_clients_is_active"), "oauth_clients", ["is_active"], unique=False
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"user_organizations",
|
||||||
|
sa.Column("user_id", sa.UUID(), nullable=False),
|
||||||
|
sa.Column("organization_id", sa.UUID(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"role",
|
||||||
|
sa.Enum("OWNER", "ADMIN", "MEMBER", "GUEST", name="organizationrole"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column("is_active", sa.Boolean(), nullable=False),
|
||||||
|
sa.Column("custom_permissions", sa.String(length=500), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["organization_id"], ["organizations.id"], ondelete="CASCADE"
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
sa.PrimaryKeyConstraint("user_id", "organization_id"),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_user_org_org_active",
|
||||||
|
"user_organizations",
|
||||||
|
["organization_id", "is_active"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_index("ix_user_org_role", "user_organizations", ["role"], unique=False)
|
||||||
|
op.create_index(
|
||||||
|
"ix_user_org_user_active",
|
||||||
|
"user_organizations",
|
||||||
|
["user_id", "is_active"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_user_organizations_is_active"),
|
||||||
|
"user_organizations",
|
||||||
|
["is_active"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"user_sessions",
|
||||||
|
sa.Column("user_id", sa.UUID(), nullable=False),
|
||||||
|
sa.Column("refresh_token_jti", sa.String(length=255), nullable=False),
|
||||||
|
sa.Column("device_name", sa.String(length=255), nullable=True),
|
||||||
|
sa.Column("device_id", sa.String(length=255), nullable=True),
|
||||||
|
sa.Column("ip_address", sa.String(length=45), nullable=True),
|
||||||
|
sa.Column("user_agent", sa.String(length=500), nullable=True),
|
||||||
|
sa.Column("last_used_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("is_active", sa.Boolean(), nullable=False),
|
||||||
|
sa.Column("location_city", sa.String(length=100), nullable=True),
|
||||||
|
sa.Column("location_country", sa.String(length=100), nullable=True),
|
||||||
|
sa.Column("id", sa.UUID(), nullable=False),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_user_sessions_is_active"), "user_sessions", ["is_active"], unique=False
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_user_sessions_jti_active",
|
||||||
|
"user_sessions",
|
||||||
|
["refresh_token_jti", "is_active"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_user_sessions_refresh_token_jti"),
|
||||||
|
"user_sessions",
|
||||||
|
["refresh_token_jti"],
|
||||||
|
unique=True,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_user_sessions_user_active",
|
||||||
|
"user_sessions",
|
||||||
|
["user_id", "is_active"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_user_sessions_user_id"), "user_sessions", ["user_id"], unique=False
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"oauth_authorization_codes",
|
||||||
|
sa.Column("code", sa.String(length=128), nullable=False),
|
||||||
|
sa.Column("client_id", sa.String(length=64), nullable=False),
|
||||||
|
sa.Column("user_id", sa.UUID(), nullable=False),
|
||||||
|
sa.Column("redirect_uri", sa.String(length=2048), nullable=False),
|
||||||
|
sa.Column("scope", sa.String(length=1000), nullable=False),
|
||||||
|
sa.Column("code_challenge", sa.String(length=128), nullable=True),
|
||||||
|
sa.Column("code_challenge_method", sa.String(length=10), nullable=True),
|
||||||
|
sa.Column("state", sa.String(length=256), nullable=True),
|
||||||
|
sa.Column("nonce", sa.String(length=256), nullable=True),
|
||||||
|
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("used", sa.Boolean(), nullable=False),
|
||||||
|
sa.Column("id", sa.UUID(), nullable=False),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["client_id"], ["oauth_clients.client_id"], ondelete="CASCADE"
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_oauth_authorization_codes_client_user",
|
||||||
|
"oauth_authorization_codes",
|
||||||
|
["client_id", "user_id"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_oauth_authorization_codes_code"),
|
||||||
|
"oauth_authorization_codes",
|
||||||
|
["code"],
|
||||||
|
unique=True,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_oauth_authorization_codes_expires_at",
|
||||||
|
"oauth_authorization_codes",
|
||||||
|
["expires_at"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"oauth_consents",
|
||||||
|
sa.Column("user_id", sa.UUID(), nullable=False),
|
||||||
|
sa.Column("client_id", sa.String(length=64), nullable=False),
|
||||||
|
sa.Column("granted_scopes", sa.String(length=1000), nullable=False),
|
||||||
|
sa.Column("id", sa.UUID(), nullable=False),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["client_id"], ["oauth_clients.client_id"], ondelete="CASCADE"
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_oauth_consents_user_client",
|
||||||
|
"oauth_consents",
|
||||||
|
["user_id", "client_id"],
|
||||||
|
unique=True,
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"oauth_provider_refresh_tokens",
|
||||||
|
sa.Column("token_hash", sa.String(length=64), nullable=False),
|
||||||
|
sa.Column("jti", sa.String(length=64), nullable=False),
|
||||||
|
sa.Column("client_id", sa.String(length=64), nullable=False),
|
||||||
|
sa.Column("user_id", sa.UUID(), nullable=False),
|
||||||
|
sa.Column("scope", sa.String(length=1000), nullable=False),
|
||||||
|
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("revoked", sa.Boolean(), nullable=False),
|
||||||
|
sa.Column("last_used_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("device_info", sa.String(length=500), nullable=True),
|
||||||
|
sa.Column("ip_address", sa.String(length=45), nullable=True),
|
||||||
|
sa.Column("id", sa.UUID(), nullable=False),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["client_id"], ["oauth_clients.client_id"], ondelete="CASCADE"
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_oauth_provider_refresh_tokens_client_user",
|
||||||
|
"oauth_provider_refresh_tokens",
|
||||||
|
["client_id", "user_id"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_oauth_provider_refresh_tokens_expires_at",
|
||||||
|
"oauth_provider_refresh_tokens",
|
||||||
|
["expires_at"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_oauth_provider_refresh_tokens_jti"),
|
||||||
|
"oauth_provider_refresh_tokens",
|
||||||
|
["jti"],
|
||||||
|
unique=True,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_oauth_provider_refresh_tokens_revoked"),
|
||||||
|
"oauth_provider_refresh_tokens",
|
||||||
|
["revoked"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_oauth_provider_refresh_tokens_token_hash"),
|
||||||
|
"oauth_provider_refresh_tokens",
|
||||||
|
["token_hash"],
|
||||||
|
unique=True,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_oauth_provider_refresh_tokens_user_revoked",
|
||||||
|
"oauth_provider_refresh_tokens",
|
||||||
|
["user_id", "revoked"],
|
||||||
|
unique=False,
|
||||||
)
|
)
|
||||||
op.create_index('ix_oauth_provider_refresh_tokens_client_user', 'oauth_provider_refresh_tokens', ['client_id', 'user_id'], unique=False)
|
|
||||||
op.create_index('ix_oauth_provider_refresh_tokens_expires_at', 'oauth_provider_refresh_tokens', ['expires_at'], unique=False)
|
|
||||||
op.create_index(op.f('ix_oauth_provider_refresh_tokens_jti'), 'oauth_provider_refresh_tokens', ['jti'], unique=True)
|
|
||||||
op.create_index(op.f('ix_oauth_provider_refresh_tokens_revoked'), 'oauth_provider_refresh_tokens', ['revoked'], unique=False)
|
|
||||||
op.create_index(op.f('ix_oauth_provider_refresh_tokens_token_hash'), 'oauth_provider_refresh_tokens', ['token_hash'], unique=True)
|
|
||||||
op.create_index('ix_oauth_provider_refresh_tokens_user_revoked', 'oauth_provider_refresh_tokens', ['user_id', 'revoked'], unique=False)
|
|
||||||
# ### end Alembic commands ###
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
def downgrade() -> None:
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
op.drop_index('ix_oauth_provider_refresh_tokens_user_revoked', table_name='oauth_provider_refresh_tokens')
|
op.drop_index(
|
||||||
op.drop_index(op.f('ix_oauth_provider_refresh_tokens_token_hash'), table_name='oauth_provider_refresh_tokens')
|
"ix_oauth_provider_refresh_tokens_user_revoked",
|
||||||
op.drop_index(op.f('ix_oauth_provider_refresh_tokens_revoked'), table_name='oauth_provider_refresh_tokens')
|
table_name="oauth_provider_refresh_tokens",
|
||||||
op.drop_index(op.f('ix_oauth_provider_refresh_tokens_jti'), table_name='oauth_provider_refresh_tokens')
|
)
|
||||||
op.drop_index('ix_oauth_provider_refresh_tokens_expires_at', table_name='oauth_provider_refresh_tokens')
|
op.drop_index(
|
||||||
op.drop_index('ix_oauth_provider_refresh_tokens_client_user', table_name='oauth_provider_refresh_tokens')
|
op.f("ix_oauth_provider_refresh_tokens_token_hash"),
|
||||||
op.drop_table('oauth_provider_refresh_tokens')
|
table_name="oauth_provider_refresh_tokens",
|
||||||
op.drop_index('ix_oauth_consents_user_client', table_name='oauth_consents')
|
)
|
||||||
op.drop_table('oauth_consents')
|
op.drop_index(
|
||||||
op.drop_index('ix_oauth_authorization_codes_expires_at', table_name='oauth_authorization_codes')
|
op.f("ix_oauth_provider_refresh_tokens_revoked"),
|
||||||
op.drop_index(op.f('ix_oauth_authorization_codes_code'), table_name='oauth_authorization_codes')
|
table_name="oauth_provider_refresh_tokens",
|
||||||
op.drop_index('ix_oauth_authorization_codes_client_user', table_name='oauth_authorization_codes')
|
)
|
||||||
op.drop_table('oauth_authorization_codes')
|
op.drop_index(
|
||||||
op.drop_index(op.f('ix_user_sessions_user_id'), table_name='user_sessions')
|
op.f("ix_oauth_provider_refresh_tokens_jti"),
|
||||||
op.drop_index('ix_user_sessions_user_active', table_name='user_sessions')
|
table_name="oauth_provider_refresh_tokens",
|
||||||
op.drop_index(op.f('ix_user_sessions_refresh_token_jti'), table_name='user_sessions')
|
)
|
||||||
op.drop_index('ix_user_sessions_jti_active', table_name='user_sessions')
|
op.drop_index(
|
||||||
op.drop_index(op.f('ix_user_sessions_is_active'), table_name='user_sessions')
|
"ix_oauth_provider_refresh_tokens_expires_at",
|
||||||
op.drop_table('user_sessions')
|
table_name="oauth_provider_refresh_tokens",
|
||||||
op.drop_index(op.f('ix_user_organizations_is_active'), table_name='user_organizations')
|
)
|
||||||
op.drop_index('ix_user_org_user_active', table_name='user_organizations')
|
op.drop_index(
|
||||||
op.drop_index('ix_user_org_role', table_name='user_organizations')
|
"ix_oauth_provider_refresh_tokens_client_user",
|
||||||
op.drop_index('ix_user_org_org_active', table_name='user_organizations')
|
table_name="oauth_provider_refresh_tokens",
|
||||||
op.drop_table('user_organizations')
|
)
|
||||||
op.drop_index(op.f('ix_oauth_clients_is_active'), table_name='oauth_clients')
|
op.drop_table("oauth_provider_refresh_tokens")
|
||||||
op.drop_index(op.f('ix_oauth_clients_client_id'), table_name='oauth_clients')
|
op.drop_index("ix_oauth_consents_user_client", table_name="oauth_consents")
|
||||||
op.drop_table('oauth_clients')
|
op.drop_table("oauth_consents")
|
||||||
op.drop_index('ix_oauth_accounts_user_provider', table_name='oauth_accounts')
|
op.drop_index(
|
||||||
op.drop_index(op.f('ix_oauth_accounts_user_id'), table_name='oauth_accounts')
|
"ix_oauth_authorization_codes_expires_at",
|
||||||
op.drop_index(op.f('ix_oauth_accounts_provider_email'), table_name='oauth_accounts')
|
table_name="oauth_authorization_codes",
|
||||||
op.drop_index(op.f('ix_oauth_accounts_provider'), table_name='oauth_accounts')
|
)
|
||||||
op.drop_table('oauth_accounts')
|
op.drop_index(
|
||||||
op.drop_index(op.f('ix_users_locale'), table_name='users')
|
op.f("ix_oauth_authorization_codes_code"),
|
||||||
op.drop_index(op.f('ix_users_is_superuser'), table_name='users')
|
table_name="oauth_authorization_codes",
|
||||||
op.drop_index(op.f('ix_users_is_active'), table_name='users')
|
)
|
||||||
op.drop_index(op.f('ix_users_email'), table_name='users')
|
op.drop_index(
|
||||||
op.drop_index(op.f('ix_users_deleted_at'), table_name='users')
|
"ix_oauth_authorization_codes_client_user",
|
||||||
op.drop_table('users')
|
table_name="oauth_authorization_codes",
|
||||||
op.drop_index('ix_organizations_slug_active', table_name='organizations')
|
)
|
||||||
op.drop_index(op.f('ix_organizations_slug'), table_name='organizations')
|
op.drop_table("oauth_authorization_codes")
|
||||||
op.drop_index('ix_organizations_name_active', table_name='organizations')
|
op.drop_index(op.f("ix_user_sessions_user_id"), table_name="user_sessions")
|
||||||
op.drop_index(op.f('ix_organizations_name'), table_name='organizations')
|
op.drop_index("ix_user_sessions_user_active", table_name="user_sessions")
|
||||||
op.drop_index(op.f('ix_organizations_is_active'), table_name='organizations')
|
op.drop_index(
|
||||||
op.drop_table('organizations')
|
op.f("ix_user_sessions_refresh_token_jti"), table_name="user_sessions"
|
||||||
op.drop_index(op.f('ix_oauth_states_state'), table_name='oauth_states')
|
)
|
||||||
op.drop_table('oauth_states')
|
op.drop_index("ix_user_sessions_jti_active", table_name="user_sessions")
|
||||||
|
op.drop_index(op.f("ix_user_sessions_is_active"), table_name="user_sessions")
|
||||||
|
op.drop_table("user_sessions")
|
||||||
|
op.drop_index(
|
||||||
|
op.f("ix_user_organizations_is_active"), table_name="user_organizations"
|
||||||
|
)
|
||||||
|
op.drop_index("ix_user_org_user_active", table_name="user_organizations")
|
||||||
|
op.drop_index("ix_user_org_role", table_name="user_organizations")
|
||||||
|
op.drop_index("ix_user_org_org_active", table_name="user_organizations")
|
||||||
|
op.drop_table("user_organizations")
|
||||||
|
op.drop_index(op.f("ix_oauth_clients_is_active"), table_name="oauth_clients")
|
||||||
|
op.drop_index(op.f("ix_oauth_clients_client_id"), table_name="oauth_clients")
|
||||||
|
op.drop_table("oauth_clients")
|
||||||
|
op.drop_index("ix_oauth_accounts_user_provider", table_name="oauth_accounts")
|
||||||
|
op.drop_index(op.f("ix_oauth_accounts_user_id"), table_name="oauth_accounts")
|
||||||
|
op.drop_index(op.f("ix_oauth_accounts_provider_email"), table_name="oauth_accounts")
|
||||||
|
op.drop_index(op.f("ix_oauth_accounts_provider"), table_name="oauth_accounts")
|
||||||
|
op.drop_table("oauth_accounts")
|
||||||
|
op.drop_index(op.f("ix_users_locale"), table_name="users")
|
||||||
|
op.drop_index(op.f("ix_users_is_superuser"), table_name="users")
|
||||||
|
op.drop_index(op.f("ix_users_is_active"), table_name="users")
|
||||||
|
op.drop_index(op.f("ix_users_email"), table_name="users")
|
||||||
|
op.drop_index(op.f("ix_users_deleted_at"), table_name="users")
|
||||||
|
op.drop_table("users")
|
||||||
|
op.drop_index("ix_organizations_slug_active", table_name="organizations")
|
||||||
|
op.drop_index(op.f("ix_organizations_slug"), table_name="organizations")
|
||||||
|
op.drop_index("ix_organizations_name_active", table_name="organizations")
|
||||||
|
op.drop_index(op.f("ix_organizations_name"), table_name="organizations")
|
||||||
|
op.drop_index(op.f("ix_organizations_is_active"), table_name="organizations")
|
||||||
|
op.drop_table("organizations")
|
||||||
|
op.drop_index(op.f("ix_oauth_states_state"), table_name="oauth_states")
|
||||||
|
op.drop_table("oauth_states")
|
||||||
# ### end Alembic commands ###
|
# ### end Alembic commands ###
|
||||||
|
|||||||
@@ -114,8 +114,13 @@ def upgrade() -> None:
|
|||||||
|
|
||||||
def downgrade() -> None:
|
def downgrade() -> None:
|
||||||
# Drop indexes in reverse order
|
# Drop indexes in reverse order
|
||||||
op.drop_index("ix_perf_oauth_auth_codes_expires", table_name="oauth_authorization_codes")
|
op.drop_index(
|
||||||
op.drop_index("ix_perf_oauth_refresh_tokens_expires", table_name="oauth_provider_refresh_tokens")
|
"ix_perf_oauth_auth_codes_expires", table_name="oauth_authorization_codes"
|
||||||
|
)
|
||||||
|
op.drop_index(
|
||||||
|
"ix_perf_oauth_refresh_tokens_expires",
|
||||||
|
table_name="oauth_provider_refresh_tokens",
|
||||||
|
)
|
||||||
op.drop_index("ix_perf_user_sessions_expires", table_name="user_sessions")
|
op.drop_index("ix_perf_user_sessions_expires", table_name="user_sessions")
|
||||||
op.drop_index("ix_perf_organizations_slug_lower", table_name="organizations")
|
op.drop_index("ix_perf_organizations_slug_lower", table_name="organizations")
|
||||||
op.drop_index("ix_perf_users_active", table_name="users")
|
op.drop_index("ix_perf_users_active", table_name="users")
|
||||||
|
|||||||
@@ -28,82 +28,9 @@ depends_on: str | Sequence[str] | None = None
|
|||||||
def upgrade() -> None:
|
def upgrade() -> None:
|
||||||
"""Create Syndarix domain tables."""
|
"""Create Syndarix domain tables."""
|
||||||
|
|
||||||
# =========================================================================
|
|
||||||
# Create ENUM types
|
|
||||||
# =========================================================================
|
|
||||||
op.execute(
|
|
||||||
"""
|
|
||||||
CREATE TYPE autonomy_level AS ENUM (
|
|
||||||
'full_control', 'milestone', 'autonomous'
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
op.execute(
|
|
||||||
"""
|
|
||||||
CREATE TYPE project_status AS ENUM (
|
|
||||||
'active', 'paused', 'completed', 'archived'
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
op.execute(
|
|
||||||
"""
|
|
||||||
CREATE TYPE project_complexity AS ENUM (
|
|
||||||
'script', 'simple', 'medium', 'complex'
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
op.execute(
|
|
||||||
"""
|
|
||||||
CREATE TYPE client_mode AS ENUM (
|
|
||||||
'technical', 'auto'
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
op.execute(
|
|
||||||
"""
|
|
||||||
CREATE TYPE agent_status AS ENUM (
|
|
||||||
'idle', 'working', 'waiting', 'paused', 'terminated'
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
op.execute(
|
|
||||||
"""
|
|
||||||
CREATE TYPE issue_type AS ENUM (
|
|
||||||
'epic', 'story', 'task', 'bug'
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
op.execute(
|
|
||||||
"""
|
|
||||||
CREATE TYPE issue_status AS ENUM (
|
|
||||||
'open', 'in_progress', 'in_review', 'blocked', 'closed'
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
op.execute(
|
|
||||||
"""
|
|
||||||
CREATE TYPE issue_priority AS ENUM (
|
|
||||||
'low', 'medium', 'high', 'critical'
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
op.execute(
|
|
||||||
"""
|
|
||||||
CREATE TYPE sync_status AS ENUM (
|
|
||||||
'synced', 'pending', 'conflict', 'error'
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
op.execute(
|
|
||||||
"""
|
|
||||||
CREATE TYPE sprint_status AS ENUM (
|
|
||||||
'planned', 'active', 'in_review', 'completed', 'cancelled'
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
# Create projects table
|
# Create projects table
|
||||||
|
# Note: ENUM types are created automatically by sa.Enum() during table creation
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
op.create_table(
|
op.create_table(
|
||||||
"projects",
|
"projects",
|
||||||
@@ -118,7 +45,6 @@ def upgrade() -> None:
|
|||||||
"milestone",
|
"milestone",
|
||||||
"autonomous",
|
"autonomous",
|
||||||
name="autonomy_level",
|
name="autonomy_level",
|
||||||
create_type=False,
|
|
||||||
),
|
),
|
||||||
nullable=False,
|
nullable=False,
|
||||||
server_default="milestone",
|
server_default="milestone",
|
||||||
@@ -131,7 +57,6 @@ def upgrade() -> None:
|
|||||||
"completed",
|
"completed",
|
||||||
"archived",
|
"archived",
|
||||||
name="project_status",
|
name="project_status",
|
||||||
create_type=False,
|
|
||||||
),
|
),
|
||||||
nullable=False,
|
nullable=False,
|
||||||
server_default="active",
|
server_default="active",
|
||||||
@@ -144,14 +69,13 @@ def upgrade() -> None:
|
|||||||
"medium",
|
"medium",
|
||||||
"complex",
|
"complex",
|
||||||
name="project_complexity",
|
name="project_complexity",
|
||||||
create_type=False,
|
|
||||||
),
|
),
|
||||||
nullable=False,
|
nullable=False,
|
||||||
server_default="medium",
|
server_default="medium",
|
||||||
),
|
),
|
||||||
sa.Column(
|
sa.Column(
|
||||||
"client_mode",
|
"client_mode",
|
||||||
sa.Enum("technical", "auto", name="client_mode", create_type=False),
|
sa.Enum("technical", "auto", name="client_mode"),
|
||||||
nullable=False,
|
nullable=False,
|
||||||
server_default="auto",
|
server_default="auto",
|
||||||
),
|
),
|
||||||
@@ -285,7 +209,6 @@ def upgrade() -> None:
|
|||||||
"paused",
|
"paused",
|
||||||
"terminated",
|
"terminated",
|
||||||
name="agent_status",
|
name="agent_status",
|
||||||
create_type=False,
|
|
||||||
),
|
),
|
||||||
nullable=False,
|
nullable=False,
|
||||||
server_default="idle",
|
server_default="idle",
|
||||||
@@ -384,7 +307,6 @@ def upgrade() -> None:
|
|||||||
"completed",
|
"completed",
|
||||||
"cancelled",
|
"cancelled",
|
||||||
name="sprint_status",
|
name="sprint_status",
|
||||||
create_type=False,
|
|
||||||
),
|
),
|
||||||
nullable=False,
|
nullable=False,
|
||||||
server_default="planned",
|
server_default="planned",
|
||||||
@@ -435,7 +357,6 @@ def upgrade() -> None:
|
|||||||
"task",
|
"task",
|
||||||
"bug",
|
"bug",
|
||||||
name="issue_type",
|
name="issue_type",
|
||||||
create_type=False,
|
|
||||||
),
|
),
|
||||||
nullable=False,
|
nullable=False,
|
||||||
server_default="task",
|
server_default="task",
|
||||||
@@ -455,7 +376,6 @@ def upgrade() -> None:
|
|||||||
"blocked",
|
"blocked",
|
||||||
"closed",
|
"closed",
|
||||||
name="issue_status",
|
name="issue_status",
|
||||||
create_type=False,
|
|
||||||
),
|
),
|
||||||
nullable=False,
|
nullable=False,
|
||||||
server_default="open",
|
server_default="open",
|
||||||
@@ -468,7 +388,6 @@ def upgrade() -> None:
|
|||||||
"high",
|
"high",
|
||||||
"critical",
|
"critical",
|
||||||
name="issue_priority",
|
name="issue_priority",
|
||||||
create_type=False,
|
|
||||||
),
|
),
|
||||||
nullable=False,
|
nullable=False,
|
||||||
server_default="medium",
|
server_default="medium",
|
||||||
@@ -502,7 +421,6 @@ def upgrade() -> None:
|
|||||||
"conflict",
|
"conflict",
|
||||||
"error",
|
"error",
|
||||||
name="sync_status",
|
name="sync_status",
|
||||||
create_type=False,
|
|
||||||
),
|
),
|
||||||
nullable=False,
|
nullable=False,
|
||||||
server_default="synced",
|
server_default="synced",
|
||||||
@@ -525,9 +443,7 @@ def upgrade() -> None:
|
|||||||
),
|
),
|
||||||
sa.PrimaryKeyConstraint("id"),
|
sa.PrimaryKeyConstraint("id"),
|
||||||
sa.ForeignKeyConstraint(["project_id"], ["projects.id"], ondelete="CASCADE"),
|
sa.ForeignKeyConstraint(["project_id"], ["projects.id"], ondelete="CASCADE"),
|
||||||
sa.ForeignKeyConstraint(
|
sa.ForeignKeyConstraint(["parent_id"], ["issues.id"], ondelete="CASCADE"),
|
||||||
["parent_id"], ["issues.id"], ondelete="CASCADE"
|
|
||||||
),
|
|
||||||
sa.ForeignKeyConstraint(["sprint_id"], ["sprints.id"], ondelete="SET NULL"),
|
sa.ForeignKeyConstraint(["sprint_id"], ["sprints.id"], ondelete="SET NULL"),
|
||||||
sa.ForeignKeyConstraint(
|
sa.ForeignKeyConstraint(
|
||||||
["assigned_agent_id"], ["agent_instances.id"], ondelete="SET NULL"
|
["assigned_agent_id"], ["agent_instances.id"], ondelete="SET NULL"
|
||||||
@@ -544,7 +460,9 @@ def upgrade() -> None:
|
|||||||
op.create_index("ix_issues_human_assignee", "issues", ["human_assignee"])
|
op.create_index("ix_issues_human_assignee", "issues", ["human_assignee"])
|
||||||
op.create_index("ix_issues_sprint_id", "issues", ["sprint_id"])
|
op.create_index("ix_issues_sprint_id", "issues", ["sprint_id"])
|
||||||
op.create_index("ix_issues_due_date", "issues", ["due_date"])
|
op.create_index("ix_issues_due_date", "issues", ["due_date"])
|
||||||
op.create_index("ix_issues_external_tracker_type", "issues", ["external_tracker_type"])
|
op.create_index(
|
||||||
|
"ix_issues_external_tracker_type", "issues", ["external_tracker_type"]
|
||||||
|
)
|
||||||
op.create_index("ix_issues_sync_status", "issues", ["sync_status"])
|
op.create_index("ix_issues_sync_status", "issues", ["sync_status"])
|
||||||
op.create_index("ix_issues_closed_at", "issues", ["closed_at"])
|
op.create_index("ix_issues_closed_at", "issues", ["closed_at"])
|
||||||
# Composite indexes
|
# Composite indexes
|
||||||
@@ -552,7 +470,9 @@ def upgrade() -> None:
|
|||||||
op.create_index("ix_issues_project_priority", "issues", ["project_id", "priority"])
|
op.create_index("ix_issues_project_priority", "issues", ["project_id", "priority"])
|
||||||
op.create_index("ix_issues_project_sprint", "issues", ["project_id", "sprint_id"])
|
op.create_index("ix_issues_project_sprint", "issues", ["project_id", "sprint_id"])
|
||||||
op.create_index("ix_issues_project_type", "issues", ["project_id", "type"])
|
op.create_index("ix_issues_project_type", "issues", ["project_id", "type"])
|
||||||
op.create_index("ix_issues_project_agent", "issues", ["project_id", "assigned_agent_id"])
|
op.create_index(
|
||||||
|
"ix_issues_project_agent", "issues", ["project_id", "assigned_agent_id"]
|
||||||
|
)
|
||||||
op.create_index(
|
op.create_index(
|
||||||
"ix_issues_project_status_priority",
|
"ix_issues_project_status_priority",
|
||||||
"issues",
|
"issues",
|
||||||
|
|||||||
512
backend/app/alembic/versions/0005_add_memory_system_tables.py
Normal file
512
backend/app/alembic/versions/0005_add_memory_system_tables.py
Normal file
@@ -0,0 +1,512 @@
|
|||||||
|
"""Add Agent Memory System tables
|
||||||
|
|
||||||
|
Revision ID: 0005
|
||||||
|
Revises: 0004
|
||||||
|
Create Date: 2025-01-05
|
||||||
|
|
||||||
|
This migration creates the Agent Memory System tables:
|
||||||
|
- working_memory: Key-value storage with TTL for active sessions
|
||||||
|
- episodes: Experiential memories from task executions
|
||||||
|
- facts: Semantic knowledge triples with confidence scores
|
||||||
|
- procedures: Learned skills and procedures
|
||||||
|
- memory_consolidation_log: Tracks consolidation jobs
|
||||||
|
|
||||||
|
See Issue #88: Database Schema & Storage Layer
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "0005"
|
||||||
|
down_revision: str | None = "0004"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Create Agent Memory System tables."""
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Create ENUM types for memory system
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
# Scope type enum
|
||||||
|
scope_type_enum = postgresql.ENUM(
|
||||||
|
"global",
|
||||||
|
"project",
|
||||||
|
"agent_type",
|
||||||
|
"agent_instance",
|
||||||
|
"session",
|
||||||
|
name="scope_type",
|
||||||
|
create_type=False,
|
||||||
|
)
|
||||||
|
scope_type_enum.create(op.get_bind(), checkfirst=True)
|
||||||
|
|
||||||
|
# Episode outcome enum
|
||||||
|
episode_outcome_enum = postgresql.ENUM(
|
||||||
|
"success",
|
||||||
|
"failure",
|
||||||
|
"partial",
|
||||||
|
name="episode_outcome",
|
||||||
|
create_type=False,
|
||||||
|
)
|
||||||
|
episode_outcome_enum.create(op.get_bind(), checkfirst=True)
|
||||||
|
|
||||||
|
# Consolidation type enum
|
||||||
|
consolidation_type_enum = postgresql.ENUM(
|
||||||
|
"working_to_episodic",
|
||||||
|
"episodic_to_semantic",
|
||||||
|
"episodic_to_procedural",
|
||||||
|
"pruning",
|
||||||
|
name="consolidation_type",
|
||||||
|
create_type=False,
|
||||||
|
)
|
||||||
|
consolidation_type_enum.create(op.get_bind(), checkfirst=True)
|
||||||
|
|
||||||
|
# Consolidation status enum
|
||||||
|
consolidation_status_enum = postgresql.ENUM(
|
||||||
|
"pending",
|
||||||
|
"running",
|
||||||
|
"completed",
|
||||||
|
"failed",
|
||||||
|
name="consolidation_status",
|
||||||
|
create_type=False,
|
||||||
|
)
|
||||||
|
consolidation_status_enum.create(op.get_bind(), checkfirst=True)
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Create working_memory table
|
||||||
|
# Key-value storage with TTL for active sessions
|
||||||
|
# =========================================================================
|
||||||
|
op.create_table(
|
||||||
|
"working_memory",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"scope_type",
|
||||||
|
scope_type_enum,
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column("scope_id", sa.String(255), nullable=False),
|
||||||
|
sa.Column("key", sa.String(255), nullable=False),
|
||||||
|
sa.Column("value", postgresql.JSONB(astext_type=sa.Text()), nullable=False),
|
||||||
|
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"created_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Working memory indexes
|
||||||
|
op.create_index(
|
||||||
|
"ix_working_memory_scope_type",
|
||||||
|
"working_memory",
|
||||||
|
["scope_type"],
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_working_memory_scope_id",
|
||||||
|
"working_memory",
|
||||||
|
["scope_id"],
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_working_memory_scope_key",
|
||||||
|
"working_memory",
|
||||||
|
["scope_type", "scope_id", "key"],
|
||||||
|
unique=True,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_working_memory_expires",
|
||||||
|
"working_memory",
|
||||||
|
["expires_at"],
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_working_memory_scope_list",
|
||||||
|
"working_memory",
|
||||||
|
["scope_type", "scope_id"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Create episodes table
|
||||||
|
# Experiential memories from task executions
|
||||||
|
# =========================================================================
|
||||||
|
op.create_table(
|
||||||
|
"episodes",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||||
|
sa.Column("project_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||||
|
sa.Column("agent_instance_id", postgresql.UUID(as_uuid=True), nullable=True),
|
||||||
|
sa.Column("agent_type_id", postgresql.UUID(as_uuid=True), nullable=True),
|
||||||
|
sa.Column("session_id", sa.String(255), nullable=False),
|
||||||
|
sa.Column("task_type", sa.String(100), nullable=False),
|
||||||
|
sa.Column("task_description", sa.Text(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"actions",
|
||||||
|
postgresql.JSONB(astext_type=sa.Text()),
|
||||||
|
nullable=False,
|
||||||
|
server_default="[]",
|
||||||
|
),
|
||||||
|
sa.Column("context_summary", sa.Text(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"outcome",
|
||||||
|
episode_outcome_enum,
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column("outcome_details", sa.Text(), nullable=True),
|
||||||
|
sa.Column("duration_seconds", sa.Float(), nullable=False, server_default="0.0"),
|
||||||
|
sa.Column("tokens_used", sa.BigInteger(), nullable=False, server_default="0"),
|
||||||
|
sa.Column(
|
||||||
|
"lessons_learned",
|
||||||
|
postgresql.JSONB(astext_type=sa.Text()),
|
||||||
|
nullable=False,
|
||||||
|
server_default="[]",
|
||||||
|
),
|
||||||
|
sa.Column("importance_score", sa.Float(), nullable=False, server_default="0.5"),
|
||||||
|
# Vector embedding - using TEXT as fallback, will be VECTOR(1536) when pgvector is available
|
||||||
|
sa.Column("embedding", sa.Text(), nullable=True),
|
||||||
|
sa.Column("occurred_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"created_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["project_id"],
|
||||||
|
["projects.id"],
|
||||||
|
name="fk_episodes_project",
|
||||||
|
ondelete="CASCADE",
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["agent_instance_id"],
|
||||||
|
["agent_instances.id"],
|
||||||
|
name="fk_episodes_agent_instance",
|
||||||
|
ondelete="SET NULL",
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["agent_type_id"],
|
||||||
|
["agent_types.id"],
|
||||||
|
name="fk_episodes_agent_type",
|
||||||
|
ondelete="SET NULL",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Episode indexes
|
||||||
|
op.create_index("ix_episodes_project_id", "episodes", ["project_id"])
|
||||||
|
op.create_index("ix_episodes_agent_instance_id", "episodes", ["agent_instance_id"])
|
||||||
|
op.create_index("ix_episodes_agent_type_id", "episodes", ["agent_type_id"])
|
||||||
|
op.create_index("ix_episodes_session_id", "episodes", ["session_id"])
|
||||||
|
op.create_index("ix_episodes_task_type", "episodes", ["task_type"])
|
||||||
|
op.create_index("ix_episodes_outcome", "episodes", ["outcome"])
|
||||||
|
op.create_index("ix_episodes_importance_score", "episodes", ["importance_score"])
|
||||||
|
op.create_index("ix_episodes_occurred_at", "episodes", ["occurred_at"])
|
||||||
|
op.create_index("ix_episodes_project_task", "episodes", ["project_id", "task_type"])
|
||||||
|
op.create_index(
|
||||||
|
"ix_episodes_project_outcome", "episodes", ["project_id", "outcome"]
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_episodes_agent_task", "episodes", ["agent_instance_id", "task_type"]
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_episodes_project_time", "episodes", ["project_id", "occurred_at"]
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_episodes_importance_time",
|
||||||
|
"episodes",
|
||||||
|
["importance_score", "occurred_at"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Create facts table
|
||||||
|
# Semantic knowledge triples with confidence scores
|
||||||
|
# =========================================================================
|
||||||
|
op.create_table(
|
||||||
|
"facts",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"project_id", postgresql.UUID(as_uuid=True), nullable=True
|
||||||
|
), # NULL for global facts
|
||||||
|
sa.Column("subject", sa.String(500), nullable=False),
|
||||||
|
sa.Column("predicate", sa.String(255), nullable=False),
|
||||||
|
sa.Column("object", sa.Text(), nullable=False),
|
||||||
|
sa.Column("confidence", sa.Float(), nullable=False, server_default="0.8"),
|
||||||
|
# Source episode IDs stored as JSON array of UUID strings for cross-db compatibility
|
||||||
|
sa.Column(
|
||||||
|
"source_episode_ids",
|
||||||
|
postgresql.JSONB(astext_type=sa.Text()),
|
||||||
|
nullable=False,
|
||||||
|
server_default="[]",
|
||||||
|
),
|
||||||
|
sa.Column("first_learned", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("last_reinforced", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"reinforcement_count", sa.Integer(), nullable=False, server_default="1"
|
||||||
|
),
|
||||||
|
# Vector embedding
|
||||||
|
sa.Column("embedding", sa.Text(), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"created_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["project_id"],
|
||||||
|
["projects.id"],
|
||||||
|
name="fk_facts_project",
|
||||||
|
ondelete="CASCADE",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fact indexes
|
||||||
|
op.create_index("ix_facts_project_id", "facts", ["project_id"])
|
||||||
|
op.create_index("ix_facts_subject", "facts", ["subject"])
|
||||||
|
op.create_index("ix_facts_predicate", "facts", ["predicate"])
|
||||||
|
op.create_index("ix_facts_confidence", "facts", ["confidence"])
|
||||||
|
op.create_index("ix_facts_subject_predicate", "facts", ["subject", "predicate"])
|
||||||
|
op.create_index("ix_facts_project_subject", "facts", ["project_id", "subject"])
|
||||||
|
op.create_index(
|
||||||
|
"ix_facts_confidence_time", "facts", ["confidence", "last_reinforced"]
|
||||||
|
)
|
||||||
|
# Unique constraint for triples within project scope
|
||||||
|
op.create_index(
|
||||||
|
"ix_facts_unique_triple",
|
||||||
|
"facts",
|
||||||
|
["project_id", "subject", "predicate", "object"],
|
||||||
|
unique=True,
|
||||||
|
postgresql_where=sa.text("project_id IS NOT NULL"),
|
||||||
|
)
|
||||||
|
# Unique constraint for global facts (project_id IS NULL)
|
||||||
|
op.create_index(
|
||||||
|
"ix_facts_unique_triple_global",
|
||||||
|
"facts",
|
||||||
|
["subject", "predicate", "object"],
|
||||||
|
unique=True,
|
||||||
|
postgresql_where=sa.text("project_id IS NULL"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Create procedures table
|
||||||
|
# Learned skills and procedures
|
||||||
|
# =========================================================================
|
||||||
|
op.create_table(
|
||||||
|
"procedures",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||||
|
sa.Column("project_id", postgresql.UUID(as_uuid=True), nullable=True),
|
||||||
|
sa.Column("agent_type_id", postgresql.UUID(as_uuid=True), nullable=True),
|
||||||
|
sa.Column("name", sa.String(255), nullable=False),
|
||||||
|
sa.Column("trigger_pattern", sa.Text(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"steps",
|
||||||
|
postgresql.JSONB(astext_type=sa.Text()),
|
||||||
|
nullable=False,
|
||||||
|
server_default="[]",
|
||||||
|
),
|
||||||
|
sa.Column("success_count", sa.Integer(), nullable=False, server_default="0"),
|
||||||
|
sa.Column("failure_count", sa.Integer(), nullable=False, server_default="0"),
|
||||||
|
sa.Column("last_used", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
# Vector embedding
|
||||||
|
sa.Column("embedding", sa.Text(), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"created_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["project_id"],
|
||||||
|
["projects.id"],
|
||||||
|
name="fk_procedures_project",
|
||||||
|
ondelete="CASCADE",
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["agent_type_id"],
|
||||||
|
["agent_types.id"],
|
||||||
|
name="fk_procedures_agent_type",
|
||||||
|
ondelete="SET NULL",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Procedure indexes
|
||||||
|
op.create_index("ix_procedures_project_id", "procedures", ["project_id"])
|
||||||
|
op.create_index("ix_procedures_agent_type_id", "procedures", ["agent_type_id"])
|
||||||
|
op.create_index("ix_procedures_name", "procedures", ["name"])
|
||||||
|
op.create_index("ix_procedures_last_used", "procedures", ["last_used"])
|
||||||
|
op.create_index(
|
||||||
|
"ix_procedures_unique_name",
|
||||||
|
"procedures",
|
||||||
|
["project_id", "agent_type_id", "name"],
|
||||||
|
unique=True,
|
||||||
|
)
|
||||||
|
op.create_index("ix_procedures_project_name", "procedures", ["project_id", "name"])
|
||||||
|
# Note: agent_type_id already indexed via ix_procedures_agent_type_id (line 354)
|
||||||
|
op.create_index(
|
||||||
|
"ix_procedures_success_rate",
|
||||||
|
"procedures",
|
||||||
|
["success_count", "failure_count"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Add check constraints for data integrity
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
# Episode constraints
|
||||||
|
op.create_check_constraint(
|
||||||
|
"ck_episodes_importance_range",
|
||||||
|
"episodes",
|
||||||
|
"importance_score >= 0.0 AND importance_score <= 1.0",
|
||||||
|
)
|
||||||
|
op.create_check_constraint(
|
||||||
|
"ck_episodes_duration_positive",
|
||||||
|
"episodes",
|
||||||
|
"duration_seconds >= 0.0",
|
||||||
|
)
|
||||||
|
op.create_check_constraint(
|
||||||
|
"ck_episodes_tokens_positive",
|
||||||
|
"episodes",
|
||||||
|
"tokens_used >= 0",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fact constraints
|
||||||
|
op.create_check_constraint(
|
||||||
|
"ck_facts_confidence_range",
|
||||||
|
"facts",
|
||||||
|
"confidence >= 0.0 AND confidence <= 1.0",
|
||||||
|
)
|
||||||
|
op.create_check_constraint(
|
||||||
|
"ck_facts_reinforcement_positive",
|
||||||
|
"facts",
|
||||||
|
"reinforcement_count >= 1",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Procedure constraints
|
||||||
|
op.create_check_constraint(
|
||||||
|
"ck_procedures_success_positive",
|
||||||
|
"procedures",
|
||||||
|
"success_count >= 0",
|
||||||
|
)
|
||||||
|
op.create_check_constraint(
|
||||||
|
"ck_procedures_failure_positive",
|
||||||
|
"procedures",
|
||||||
|
"failure_count >= 0",
|
||||||
|
)
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Create memory_consolidation_log table
|
||||||
|
# Tracks consolidation jobs
|
||||||
|
# =========================================================================
|
||||||
|
op.create_table(
|
||||||
|
"memory_consolidation_log",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"consolidation_type",
|
||||||
|
consolidation_type_enum,
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column("source_count", sa.Integer(), nullable=False, server_default="0"),
|
||||||
|
sa.Column("result_count", sa.Integer(), nullable=False, server_default="0"),
|
||||||
|
sa.Column("started_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"status",
|
||||||
|
consolidation_status_enum,
|
||||||
|
nullable=False,
|
||||||
|
server_default="pending",
|
||||||
|
),
|
||||||
|
sa.Column("error", sa.Text(), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"created_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Consolidation log indexes
|
||||||
|
op.create_index(
|
||||||
|
"ix_consolidation_type",
|
||||||
|
"memory_consolidation_log",
|
||||||
|
["consolidation_type"],
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_consolidation_status",
|
||||||
|
"memory_consolidation_log",
|
||||||
|
["status"],
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_consolidation_type_status",
|
||||||
|
"memory_consolidation_log",
|
||||||
|
["consolidation_type", "status"],
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_consolidation_started",
|
||||||
|
"memory_consolidation_log",
|
||||||
|
["started_at"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Drop Agent Memory System tables."""
|
||||||
|
|
||||||
|
# Drop check constraints first
|
||||||
|
op.drop_constraint("ck_procedures_failure_positive", "procedures", type_="check")
|
||||||
|
op.drop_constraint("ck_procedures_success_positive", "procedures", type_="check")
|
||||||
|
op.drop_constraint("ck_facts_reinforcement_positive", "facts", type_="check")
|
||||||
|
op.drop_constraint("ck_facts_confidence_range", "facts", type_="check")
|
||||||
|
op.drop_constraint("ck_episodes_tokens_positive", "episodes", type_="check")
|
||||||
|
op.drop_constraint("ck_episodes_duration_positive", "episodes", type_="check")
|
||||||
|
op.drop_constraint("ck_episodes_importance_range", "episodes", type_="check")
|
||||||
|
|
||||||
|
# Drop unique indexes for global facts
|
||||||
|
op.drop_index("ix_facts_unique_triple_global", "facts")
|
||||||
|
|
||||||
|
# Drop tables in reverse order (dependencies first)
|
||||||
|
op.drop_table("memory_consolidation_log")
|
||||||
|
op.drop_table("procedures")
|
||||||
|
op.drop_table("facts")
|
||||||
|
op.drop_table("episodes")
|
||||||
|
op.drop_table("working_memory")
|
||||||
|
|
||||||
|
# Drop ENUM types
|
||||||
|
op.execute("DROP TYPE IF EXISTS consolidation_status")
|
||||||
|
op.execute("DROP TYPE IF EXISTS consolidation_type")
|
||||||
|
op.execute("DROP TYPE IF EXISTS episode_outcome")
|
||||||
|
op.execute("DROP TYPE IF EXISTS scope_type")
|
||||||
52
backend/app/alembic/versions/0006_add_abandoned_outcome.py
Normal file
52
backend/app/alembic/versions/0006_add_abandoned_outcome.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
"""Add ABANDONED to episode_outcome enum
|
||||||
|
|
||||||
|
Revision ID: 0006
|
||||||
|
Revises: 0005
|
||||||
|
Create Date: 2025-01-06
|
||||||
|
|
||||||
|
This migration adds the 'abandoned' value to the episode_outcome enum type.
|
||||||
|
This allows episodes to track when a task was abandoned (not completed,
|
||||||
|
but not necessarily a failure either - e.g., user cancelled, session timeout).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "0006"
|
||||||
|
down_revision: str | None = "0005"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Add 'abandoned' value to episode_outcome enum."""
|
||||||
|
# PostgreSQL ALTER TYPE ADD VALUE is safe and non-blocking
|
||||||
|
op.execute("ALTER TYPE episode_outcome ADD VALUE IF NOT EXISTS 'abandoned'")
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Remove 'abandoned' from episode_outcome enum.
|
||||||
|
|
||||||
|
Note: PostgreSQL doesn't support removing values from enums directly.
|
||||||
|
This downgrade converts any 'abandoned' episodes to 'failure' and
|
||||||
|
recreates the enum without 'abandoned'.
|
||||||
|
"""
|
||||||
|
# Convert any abandoned episodes to failure first
|
||||||
|
op.execute("""
|
||||||
|
UPDATE episodes
|
||||||
|
SET outcome = 'failure'
|
||||||
|
WHERE outcome = 'abandoned'
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Recreate the enum without abandoned
|
||||||
|
# This is complex in PostgreSQL - requires creating new type, updating columns, dropping old
|
||||||
|
op.execute("ALTER TYPE episode_outcome RENAME TO episode_outcome_old")
|
||||||
|
op.execute("CREATE TYPE episode_outcome AS ENUM ('success', 'failure', 'partial')")
|
||||||
|
op.execute("""
|
||||||
|
ALTER TABLE episodes
|
||||||
|
ALTER COLUMN outcome TYPE episode_outcome
|
||||||
|
USING outcome::text::episode_outcome
|
||||||
|
""")
|
||||||
|
op.execute("DROP TYPE episode_outcome_old")
|
||||||
@@ -0,0 +1,90 @@
|
|||||||
|
"""Add category and display fields to agent_types table
|
||||||
|
|
||||||
|
Revision ID: 0007
|
||||||
|
Revises: 0006
|
||||||
|
Create Date: 2026-01-06
|
||||||
|
|
||||||
|
This migration adds:
|
||||||
|
- category: String(50) for grouping agents by role type
|
||||||
|
- icon: String(50) for Lucide icon identifier
|
||||||
|
- color: String(7) for hex color code
|
||||||
|
- sort_order: Integer for display ordering within categories
|
||||||
|
- typical_tasks: JSONB list of tasks this agent excels at
|
||||||
|
- collaboration_hints: JSONB list of agent slugs that work well together
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "0007"
|
||||||
|
down_revision: str | None = "0006"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Add category and display fields to agent_types table."""
|
||||||
|
# Add new columns
|
||||||
|
op.add_column(
|
||||||
|
"agent_types",
|
||||||
|
sa.Column("category", sa.String(length=50), nullable=True),
|
||||||
|
)
|
||||||
|
op.add_column(
|
||||||
|
"agent_types",
|
||||||
|
sa.Column("icon", sa.String(length=50), nullable=True, server_default="bot"),
|
||||||
|
)
|
||||||
|
op.add_column(
|
||||||
|
"agent_types",
|
||||||
|
sa.Column(
|
||||||
|
"color", sa.String(length=7), nullable=True, server_default="#3B82F6"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
op.add_column(
|
||||||
|
"agent_types",
|
||||||
|
sa.Column("sort_order", sa.Integer(), nullable=False, server_default="0"),
|
||||||
|
)
|
||||||
|
op.add_column(
|
||||||
|
"agent_types",
|
||||||
|
sa.Column(
|
||||||
|
"typical_tasks",
|
||||||
|
postgresql.JSONB(astext_type=sa.Text()),
|
||||||
|
nullable=False,
|
||||||
|
server_default="[]",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
op.add_column(
|
||||||
|
"agent_types",
|
||||||
|
sa.Column(
|
||||||
|
"collaboration_hints",
|
||||||
|
postgresql.JSONB(astext_type=sa.Text()),
|
||||||
|
nullable=False,
|
||||||
|
server_default="[]",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add indexes for category and sort_order
|
||||||
|
op.create_index("ix_agent_types_category", "agent_types", ["category"])
|
||||||
|
op.create_index("ix_agent_types_sort_order", "agent_types", ["sort_order"])
|
||||||
|
op.create_index(
|
||||||
|
"ix_agent_types_category_sort", "agent_types", ["category", "sort_order"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Remove category and display fields from agent_types table."""
|
||||||
|
# Drop indexes
|
||||||
|
op.drop_index("ix_agent_types_category_sort", table_name="agent_types")
|
||||||
|
op.drop_index("ix_agent_types_sort_order", table_name="agent_types")
|
||||||
|
op.drop_index("ix_agent_types_category", table_name="agent_types")
|
||||||
|
|
||||||
|
# Drop columns
|
||||||
|
op.drop_column("agent_types", "collaboration_hints")
|
||||||
|
op.drop_column("agent_types", "typical_tasks")
|
||||||
|
op.drop_column("agent_types", "sort_order")
|
||||||
|
op.drop_column("agent_types", "color")
|
||||||
|
op.drop_column("agent_types", "icon")
|
||||||
|
op.drop_column("agent_types", "category")
|
||||||
@@ -5,8 +5,10 @@ from app.api.routes import (
|
|||||||
agent_types,
|
agent_types,
|
||||||
agents,
|
agents,
|
||||||
auth,
|
auth,
|
||||||
|
context,
|
||||||
events,
|
events,
|
||||||
issues,
|
issues,
|
||||||
|
mcp,
|
||||||
oauth,
|
oauth,
|
||||||
oauth_provider,
|
oauth_provider,
|
||||||
organizations,
|
organizations,
|
||||||
@@ -31,10 +33,14 @@ api_router.include_router(
|
|||||||
# SSE events router - no prefix, routes define full paths
|
# SSE events router - no prefix, routes define full paths
|
||||||
api_router.include_router(events.router, tags=["Events"])
|
api_router.include_router(events.router, tags=["Events"])
|
||||||
|
|
||||||
|
# MCP (Model Context Protocol) router
|
||||||
|
api_router.include_router(mcp.router, prefix="/mcp", tags=["MCP"])
|
||||||
|
|
||||||
|
# Context Management Engine router
|
||||||
|
api_router.include_router(context.router, prefix="/context", tags=["Context"])
|
||||||
|
|
||||||
# Syndarix domain routers
|
# Syndarix domain routers
|
||||||
api_router.include_router(
|
api_router.include_router(projects.router, prefix="/projects", tags=["Projects"])
|
||||||
projects.router, prefix="/projects", tags=["Projects"]
|
|
||||||
)
|
|
||||||
api_router.include_router(
|
api_router.include_router(
|
||||||
agent_types.router, prefix="/agent-types", tags=["Agent Types"]
|
agent_types.router, prefix="/agent-types", tags=["Agent Types"]
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -81,6 +81,13 @@ def _build_agent_type_response(
|
|||||||
mcp_servers=agent_type.mcp_servers,
|
mcp_servers=agent_type.mcp_servers,
|
||||||
tool_permissions=agent_type.tool_permissions,
|
tool_permissions=agent_type.tool_permissions,
|
||||||
is_active=agent_type.is_active,
|
is_active=agent_type.is_active,
|
||||||
|
# Category and display fields
|
||||||
|
category=agent_type.category,
|
||||||
|
icon=agent_type.icon,
|
||||||
|
color=agent_type.color,
|
||||||
|
sort_order=agent_type.sort_order,
|
||||||
|
typical_tasks=agent_type.typical_tasks or [],
|
||||||
|
collaboration_hints=agent_type.collaboration_hints or [],
|
||||||
created_at=agent_type.created_at,
|
created_at=agent_type.created_at,
|
||||||
updated_at=agent_type.updated_at,
|
updated_at=agent_type.updated_at,
|
||||||
instance_count=instance_count,
|
instance_count=instance_count,
|
||||||
@@ -300,6 +307,7 @@ async def list_agent_types(
|
|||||||
request: Request,
|
request: Request,
|
||||||
pagination: PaginationParams = Depends(),
|
pagination: PaginationParams = Depends(),
|
||||||
is_active: bool = Query(True, description="Filter by active status"),
|
is_active: bool = Query(True, description="Filter by active status"),
|
||||||
|
category: str | None = Query(None, description="Filter by category"),
|
||||||
search: str | None = Query(None, description="Search by name, slug, description"),
|
search: str | None = Query(None, description="Search by name, slug, description"),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
@@ -314,6 +322,7 @@ async def list_agent_types(
|
|||||||
request: FastAPI request object
|
request: FastAPI request object
|
||||||
pagination: Pagination parameters (page, limit)
|
pagination: Pagination parameters (page, limit)
|
||||||
is_active: Filter by active status (default: True)
|
is_active: Filter by active status (default: True)
|
||||||
|
category: Filter by category (e.g., "development", "design")
|
||||||
search: Optional search term for name, slug, description
|
search: Optional search term for name, slug, description
|
||||||
current_user: Authenticated user
|
current_user: Authenticated user
|
||||||
db: Database session
|
db: Database session
|
||||||
@@ -328,6 +337,7 @@ async def list_agent_types(
|
|||||||
skip=pagination.offset,
|
skip=pagination.offset,
|
||||||
limit=pagination.limit,
|
limit=pagination.limit,
|
||||||
is_active=is_active,
|
is_active=is_active,
|
||||||
|
category=category,
|
||||||
search=search,
|
search=search,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -354,6 +364,51 @@ async def list_agent_types(
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/grouped",
|
||||||
|
response_model=dict[str, list[AgentTypeResponse]],
|
||||||
|
summary="List Agent Types Grouped by Category",
|
||||||
|
description="Get all agent types organized by category",
|
||||||
|
operation_id="list_agent_types_grouped",
|
||||||
|
)
|
||||||
|
@limiter.limit(f"{60 * RATE_MULTIPLIER}/minute")
|
||||||
|
async def list_agent_types_grouped(
|
||||||
|
request: Request,
|
||||||
|
is_active: bool = Query(True, description="Filter by active status"),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Get agent types grouped by category.
|
||||||
|
|
||||||
|
Returns a dictionary where keys are category names and values
|
||||||
|
are lists of agent types, sorted by sort_order within each category.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: FastAPI request object
|
||||||
|
is_active: Filter by active status (default: True)
|
||||||
|
current_user: Authenticated user
|
||||||
|
db: Database session
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping category to list of agent types
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
grouped = await agent_type_crud.get_grouped_by_category(db, is_active=is_active)
|
||||||
|
|
||||||
|
# Transform to response objects
|
||||||
|
result: dict[str, list[AgentTypeResponse]] = {}
|
||||||
|
for category, types in grouped.items():
|
||||||
|
result[category] = [
|
||||||
|
_build_agent_type_response(t, instance_count=0) for t in types
|
||||||
|
]
|
||||||
|
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting grouped agent types: {e!s}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/{agent_type_id}",
|
"/{agent_type_id}",
|
||||||
response_model=AgentTypeResponse,
|
response_model=AgentTypeResponse,
|
||||||
|
|||||||
@@ -57,8 +57,18 @@ RATE_MULTIPLIER = 100 if IS_TEST else 1
|
|||||||
# Valid status transitions for agent lifecycle management
|
# Valid status transitions for agent lifecycle management
|
||||||
VALID_STATUS_TRANSITIONS: dict[AgentStatus, set[AgentStatus]] = {
|
VALID_STATUS_TRANSITIONS: dict[AgentStatus, set[AgentStatus]] = {
|
||||||
AgentStatus.IDLE: {AgentStatus.WORKING, AgentStatus.PAUSED, AgentStatus.TERMINATED},
|
AgentStatus.IDLE: {AgentStatus.WORKING, AgentStatus.PAUSED, AgentStatus.TERMINATED},
|
||||||
AgentStatus.WORKING: {AgentStatus.IDLE, AgentStatus.WAITING, AgentStatus.PAUSED, AgentStatus.TERMINATED},
|
AgentStatus.WORKING: {
|
||||||
AgentStatus.WAITING: {AgentStatus.IDLE, AgentStatus.WORKING, AgentStatus.PAUSED, AgentStatus.TERMINATED},
|
AgentStatus.IDLE,
|
||||||
|
AgentStatus.WAITING,
|
||||||
|
AgentStatus.PAUSED,
|
||||||
|
AgentStatus.TERMINATED,
|
||||||
|
},
|
||||||
|
AgentStatus.WAITING: {
|
||||||
|
AgentStatus.IDLE,
|
||||||
|
AgentStatus.WORKING,
|
||||||
|
AgentStatus.PAUSED,
|
||||||
|
AgentStatus.TERMINATED,
|
||||||
|
},
|
||||||
AgentStatus.PAUSED: {AgentStatus.IDLE, AgentStatus.TERMINATED},
|
AgentStatus.PAUSED: {AgentStatus.IDLE, AgentStatus.TERMINATED},
|
||||||
AgentStatus.TERMINATED: set(), # Terminal state, no transitions allowed
|
AgentStatus.TERMINATED: set(), # Terminal state, no transitions allowed
|
||||||
}
|
}
|
||||||
@@ -870,9 +880,7 @@ async def terminate_agent(
|
|||||||
agent_name = agent.name
|
agent_name = agent.name
|
||||||
|
|
||||||
# Terminate the agent
|
# Terminate the agent
|
||||||
terminated_agent = await agent_instance_crud.terminate(
|
terminated_agent = await agent_instance_crud.terminate(db, instance_id=agent_id)
|
||||||
db, instance_id=agent_id
|
|
||||||
)
|
|
||||||
|
|
||||||
if not terminated_agent:
|
if not terminated_agent:
|
||||||
raise NotFoundError(
|
raise NotFoundError(
|
||||||
@@ -881,8 +889,7 @@ async def terminate_agent(
|
|||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"User {current_user.email} terminated agent {agent_name} "
|
f"User {current_user.email} terminated agent {agent_name} (id={agent_id})"
|
||||||
f"(id={agent_id})"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return MessageResponse(
|
return MessageResponse(
|
||||||
|
|||||||
411
backend/app/api/routes/context.py
Normal file
411
backend/app/api/routes/context.py
Normal file
@@ -0,0 +1,411 @@
|
|||||||
|
"""
|
||||||
|
Context Management API Endpoints.
|
||||||
|
|
||||||
|
Provides REST endpoints for context assembly and optimization
|
||||||
|
for LLM requests using the ContextEngine.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Annotated, Any
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from app.api.dependencies.permissions import require_superuser
|
||||||
|
from app.models.user import User
|
||||||
|
from app.services.context import (
|
||||||
|
AssemblyTimeoutError,
|
||||||
|
BudgetExceededError,
|
||||||
|
ContextEngine,
|
||||||
|
ContextSettings,
|
||||||
|
create_context_engine,
|
||||||
|
get_context_settings,
|
||||||
|
)
|
||||||
|
from app.services.mcp import MCPClientManager, get_mcp_client
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Singleton Engine Management
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
_context_engine: ContextEngine | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_or_create_engine(
|
||||||
|
mcp: MCPClientManager,
|
||||||
|
settings: ContextSettings | None = None,
|
||||||
|
) -> ContextEngine:
|
||||||
|
"""Get or create the singleton ContextEngine."""
|
||||||
|
global _context_engine
|
||||||
|
if _context_engine is None:
|
||||||
|
_context_engine = create_context_engine(
|
||||||
|
mcp_manager=mcp,
|
||||||
|
redis=None, # Optional: add Redis caching later
|
||||||
|
settings=settings or get_context_settings(),
|
||||||
|
)
|
||||||
|
logger.info("ContextEngine initialized")
|
||||||
|
else:
|
||||||
|
# Ensure MCP manager is up to date
|
||||||
|
_context_engine.set_mcp_manager(mcp)
|
||||||
|
return _context_engine
|
||||||
|
|
||||||
|
|
||||||
|
async def get_context_engine(
|
||||||
|
mcp: MCPClientManager = Depends(get_mcp_client),
|
||||||
|
) -> ContextEngine:
|
||||||
|
"""FastAPI dependency to get the ContextEngine."""
|
||||||
|
return _get_or_create_engine(mcp)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Request/Response Schemas
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationTurn(BaseModel):
|
||||||
|
"""A single conversation turn."""
|
||||||
|
|
||||||
|
role: str = Field(..., description="Role: 'user' or 'assistant'")
|
||||||
|
content: str = Field(..., description="Message content")
|
||||||
|
|
||||||
|
|
||||||
|
class ToolResult(BaseModel):
|
||||||
|
"""A tool execution result."""
|
||||||
|
|
||||||
|
tool_name: str = Field(..., description="Name of the tool")
|
||||||
|
content: str | dict[str, Any] = Field(..., description="Tool result content")
|
||||||
|
status: str = Field(default="success", description="Execution status")
|
||||||
|
|
||||||
|
|
||||||
|
class AssembleContextRequest(BaseModel):
|
||||||
|
"""Request to assemble context for an LLM request."""
|
||||||
|
|
||||||
|
project_id: str = Field(..., description="Project identifier")
|
||||||
|
agent_id: str = Field(..., description="Agent identifier")
|
||||||
|
query: str = Field(..., description="User's query or current request")
|
||||||
|
model: str = Field(
|
||||||
|
default="claude-3-sonnet",
|
||||||
|
description="Target model name",
|
||||||
|
)
|
||||||
|
max_tokens: int | None = Field(
|
||||||
|
None,
|
||||||
|
description="Maximum context tokens (uses model default if None)",
|
||||||
|
)
|
||||||
|
system_prompt: str | None = Field(
|
||||||
|
None,
|
||||||
|
description="System prompt/instructions",
|
||||||
|
)
|
||||||
|
task_description: str | None = Field(
|
||||||
|
None,
|
||||||
|
description="Current task description",
|
||||||
|
)
|
||||||
|
knowledge_query: str | None = Field(
|
||||||
|
None,
|
||||||
|
description="Query for knowledge base search",
|
||||||
|
)
|
||||||
|
knowledge_limit: int = Field(
|
||||||
|
default=10,
|
||||||
|
ge=1,
|
||||||
|
le=50,
|
||||||
|
description="Max number of knowledge results",
|
||||||
|
)
|
||||||
|
conversation_history: list[ConversationTurn] | None = Field(
|
||||||
|
None,
|
||||||
|
description="Previous conversation turns",
|
||||||
|
)
|
||||||
|
tool_results: list[ToolResult] | None = Field(
|
||||||
|
None,
|
||||||
|
description="Tool execution results to include",
|
||||||
|
)
|
||||||
|
compress: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Whether to apply compression",
|
||||||
|
)
|
||||||
|
use_cache: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Whether to use caching",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AssembledContextResponse(BaseModel):
|
||||||
|
"""Response containing assembled context."""
|
||||||
|
|
||||||
|
content: str = Field(..., description="Assembled context content")
|
||||||
|
total_tokens: int = Field(..., description="Total token count")
|
||||||
|
context_count: int = Field(..., description="Number of context items included")
|
||||||
|
compressed: bool = Field(..., description="Whether compression was applied")
|
||||||
|
budget_used_percent: float = Field(
|
||||||
|
...,
|
||||||
|
description="Percentage of token budget used",
|
||||||
|
)
|
||||||
|
metadata: dict[str, Any] = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
description="Additional metadata",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TokenCountRequest(BaseModel):
|
||||||
|
"""Request to count tokens in content."""
|
||||||
|
|
||||||
|
content: str = Field(..., description="Content to count tokens in")
|
||||||
|
model: str | None = Field(
|
||||||
|
None,
|
||||||
|
description="Model for model-specific tokenization",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TokenCountResponse(BaseModel):
|
||||||
|
"""Response containing token count."""
|
||||||
|
|
||||||
|
token_count: int = Field(..., description="Number of tokens")
|
||||||
|
model: str | None = Field(None, description="Model used for counting")
|
||||||
|
|
||||||
|
|
||||||
|
class BudgetInfoResponse(BaseModel):
|
||||||
|
"""Response containing budget information for a model."""
|
||||||
|
|
||||||
|
model: str = Field(..., description="Model name")
|
||||||
|
total_tokens: int = Field(..., description="Total token budget")
|
||||||
|
system_tokens: int = Field(..., description="Tokens reserved for system")
|
||||||
|
knowledge_tokens: int = Field(..., description="Tokens for knowledge")
|
||||||
|
conversation_tokens: int = Field(..., description="Tokens for conversation")
|
||||||
|
tool_tokens: int = Field(..., description="Tokens for tool results")
|
||||||
|
response_reserve: int = Field(..., description="Tokens reserved for response")
|
||||||
|
|
||||||
|
|
||||||
|
class ContextEngineStatsResponse(BaseModel):
|
||||||
|
"""Response containing engine statistics."""
|
||||||
|
|
||||||
|
cache: dict[str, Any] = Field(..., description="Cache statistics")
|
||||||
|
settings: dict[str, Any] = Field(..., description="Current settings")
|
||||||
|
|
||||||
|
|
||||||
|
class HealthResponse(BaseModel):
|
||||||
|
"""Health check response."""
|
||||||
|
|
||||||
|
status: str = Field(..., description="Health status")
|
||||||
|
mcp_connected: bool = Field(..., description="Whether MCP is connected")
|
||||||
|
cache_enabled: bool = Field(..., description="Whether caching is enabled")
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Endpoints
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/health",
|
||||||
|
response_model=HealthResponse,
|
||||||
|
summary="Context Engine Health",
|
||||||
|
description="Check health status of the context engine.",
|
||||||
|
)
|
||||||
|
async def health_check(
|
||||||
|
engine: ContextEngine = Depends(get_context_engine),
|
||||||
|
) -> HealthResponse:
|
||||||
|
"""Check context engine health."""
|
||||||
|
stats = await engine.get_stats()
|
||||||
|
return HealthResponse(
|
||||||
|
status="healthy",
|
||||||
|
mcp_connected=engine._mcp is not None,
|
||||||
|
cache_enabled=stats.get("settings", {}).get("cache_enabled", False),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/assemble",
|
||||||
|
response_model=AssembledContextResponse,
|
||||||
|
summary="Assemble Context",
|
||||||
|
description="Assemble optimized context for an LLM request.",
|
||||||
|
)
|
||||||
|
async def assemble_context(
|
||||||
|
request: AssembleContextRequest,
|
||||||
|
current_user: User = Depends(require_superuser),
|
||||||
|
engine: ContextEngine = Depends(get_context_engine),
|
||||||
|
) -> AssembledContextResponse:
|
||||||
|
"""
|
||||||
|
Assemble optimized context for an LLM request.
|
||||||
|
|
||||||
|
This endpoint gathers context from various sources, scores and ranks them,
|
||||||
|
compresses if needed, and formats for the target model.
|
||||||
|
"""
|
||||||
|
logger.info(
|
||||||
|
"Context assembly for project=%s agent=%s by user=%s",
|
||||||
|
request.project_id,
|
||||||
|
request.agent_id,
|
||||||
|
current_user.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert conversation history to dict format
|
||||||
|
conversation_history = None
|
||||||
|
if request.conversation_history:
|
||||||
|
conversation_history = [
|
||||||
|
{"role": turn.role, "content": turn.content}
|
||||||
|
for turn in request.conversation_history
|
||||||
|
]
|
||||||
|
|
||||||
|
# Convert tool results to dict format
|
||||||
|
tool_results = None
|
||||||
|
if request.tool_results:
|
||||||
|
tool_results = [
|
||||||
|
{
|
||||||
|
"tool_name": tr.tool_name,
|
||||||
|
"content": tr.content,
|
||||||
|
"status": tr.status,
|
||||||
|
}
|
||||||
|
for tr in request.tool_results
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await engine.assemble_context(
|
||||||
|
project_id=request.project_id,
|
||||||
|
agent_id=request.agent_id,
|
||||||
|
query=request.query,
|
||||||
|
model=request.model,
|
||||||
|
max_tokens=request.max_tokens,
|
||||||
|
system_prompt=request.system_prompt,
|
||||||
|
task_description=request.task_description,
|
||||||
|
knowledge_query=request.knowledge_query,
|
||||||
|
knowledge_limit=request.knowledge_limit,
|
||||||
|
conversation_history=conversation_history,
|
||||||
|
tool_results=tool_results,
|
||||||
|
compress=request.compress,
|
||||||
|
use_cache=request.use_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate budget usage percentage
|
||||||
|
budget = await engine.get_budget_for_model(request.model, request.max_tokens)
|
||||||
|
budget_used_percent = (result.total_tokens / budget.total) * 100
|
||||||
|
|
||||||
|
# Check if compression was applied (from metadata if available)
|
||||||
|
was_compressed = result.metadata.get("compressed_contexts", 0) > 0
|
||||||
|
|
||||||
|
return AssembledContextResponse(
|
||||||
|
content=result.content,
|
||||||
|
total_tokens=result.total_tokens,
|
||||||
|
context_count=result.context_count,
|
||||||
|
compressed=was_compressed,
|
||||||
|
budget_used_percent=round(budget_used_percent, 2),
|
||||||
|
metadata={
|
||||||
|
"model": request.model,
|
||||||
|
"query": request.query,
|
||||||
|
"knowledge_included": bool(request.knowledge_query),
|
||||||
|
"conversation_turns": len(request.conversation_history or []),
|
||||||
|
"excluded_count": result.excluded_count,
|
||||||
|
"assembly_time_ms": result.assembly_time_ms,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
except AssemblyTimeoutError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_504_GATEWAY_TIMEOUT,
|
||||||
|
detail=f"Context assembly timed out: {e}",
|
||||||
|
) from e
|
||||||
|
except BudgetExceededError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
|
||||||
|
detail=f"Token budget exceeded: {e}",
|
||||||
|
) from e
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Context assembly failed")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Context assembly failed: {e}",
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/count-tokens",
|
||||||
|
response_model=TokenCountResponse,
|
||||||
|
summary="Count Tokens",
|
||||||
|
description="Count tokens in content using the LLM Gateway.",
|
||||||
|
)
|
||||||
|
async def count_tokens(
|
||||||
|
request: TokenCountRequest,
|
||||||
|
engine: ContextEngine = Depends(get_context_engine),
|
||||||
|
) -> TokenCountResponse:
|
||||||
|
"""Count tokens in content."""
|
||||||
|
try:
|
||||||
|
count = await engine.count_tokens(
|
||||||
|
content=request.content,
|
||||||
|
model=request.model,
|
||||||
|
)
|
||||||
|
return TokenCountResponse(
|
||||||
|
token_count=count,
|
||||||
|
model=request.model,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Token counting failed: {e}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Token counting failed: {e}",
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/budget/{model}",
|
||||||
|
response_model=BudgetInfoResponse,
|
||||||
|
summary="Get Token Budget",
|
||||||
|
description="Get token budget allocation for a specific model.",
|
||||||
|
)
|
||||||
|
async def get_budget(
|
||||||
|
model: str,
|
||||||
|
max_tokens: Annotated[int | None, Query(description="Custom max tokens")] = None,
|
||||||
|
engine: ContextEngine = Depends(get_context_engine),
|
||||||
|
) -> BudgetInfoResponse:
|
||||||
|
"""Get token budget information for a model."""
|
||||||
|
budget = await engine.get_budget_for_model(model, max_tokens)
|
||||||
|
return BudgetInfoResponse(
|
||||||
|
model=model,
|
||||||
|
total_tokens=budget.total,
|
||||||
|
system_tokens=budget.system,
|
||||||
|
knowledge_tokens=budget.knowledge,
|
||||||
|
conversation_tokens=budget.conversation,
|
||||||
|
tool_tokens=budget.tools,
|
||||||
|
response_reserve=budget.response_reserve,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/stats",
|
||||||
|
response_model=ContextEngineStatsResponse,
|
||||||
|
summary="Engine Statistics",
|
||||||
|
description="Get context engine statistics and configuration.",
|
||||||
|
)
|
||||||
|
async def get_stats(
|
||||||
|
current_user: User = Depends(require_superuser),
|
||||||
|
engine: ContextEngine = Depends(get_context_engine),
|
||||||
|
) -> ContextEngineStatsResponse:
|
||||||
|
"""Get engine statistics."""
|
||||||
|
stats = await engine.get_stats()
|
||||||
|
return ContextEngineStatsResponse(
|
||||||
|
cache=stats.get("cache", {}),
|
||||||
|
settings=stats.get("settings", {}),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/cache/invalidate",
|
||||||
|
status_code=status.HTTP_204_NO_CONTENT,
|
||||||
|
summary="Invalidate Cache (Admin Only)",
|
||||||
|
description="Invalidate context cache entries.",
|
||||||
|
)
|
||||||
|
async def invalidate_cache(
|
||||||
|
project_id: Annotated[
|
||||||
|
str | None, Query(description="Project to invalidate")
|
||||||
|
] = None,
|
||||||
|
pattern: Annotated[str | None, Query(description="Pattern to match")] = None,
|
||||||
|
current_user: User = Depends(require_superuser),
|
||||||
|
engine: ContextEngine = Depends(get_context_engine),
|
||||||
|
) -> None:
|
||||||
|
"""Invalidate cache entries."""
|
||||||
|
logger.info(
|
||||||
|
"Cache invalidation by user %s: project=%s pattern=%s",
|
||||||
|
current_user.id,
|
||||||
|
project_id,
|
||||||
|
pattern,
|
||||||
|
)
|
||||||
|
await engine.invalidate_cache(project_id=project_id, pattern=pattern)
|
||||||
@@ -199,7 +199,9 @@ async def stream_project_events(
|
|||||||
project_id: UUID,
|
project_id: UUID,
|
||||||
db: "AsyncSession" = Depends(get_db),
|
db: "AsyncSession" = Depends(get_db),
|
||||||
event_bus: EventBus = Depends(get_event_bus),
|
event_bus: EventBus = Depends(get_event_bus),
|
||||||
token: str | None = Query(None, description="Auth token (for EventSource compatibility)"),
|
token: str | None = Query(
|
||||||
|
None, description="Auth token (for EventSource compatibility)"
|
||||||
|
),
|
||||||
authorization: str | None = Header(None, alias="Authorization"),
|
authorization: str | None = Header(None, alias="Authorization"),
|
||||||
last_event_id: str | None = Header(None, alias="Last-Event-ID"),
|
last_event_id: str | None = Header(None, alias="Last-Event-ID"),
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -278,9 +278,7 @@ async def list_issues(
|
|||||||
assigned_agent_id: UUID | None = Query(
|
assigned_agent_id: UUID | None = Query(
|
||||||
None, description="Filter by assigned agent ID"
|
None, description="Filter by assigned agent ID"
|
||||||
),
|
),
|
||||||
sync_status: SyncStatus | None = Query(
|
sync_status: SyncStatus | None = Query(None, description="Filter by sync status"),
|
||||||
None, description="Filter by sync status"
|
|
||||||
),
|
|
||||||
search: str | None = Query(
|
search: str | None = Query(
|
||||||
None, min_length=1, max_length=100, description="Search in title and body"
|
None, min_length=1, max_length=100, description="Search in title and body"
|
||||||
),
|
),
|
||||||
@@ -783,9 +781,7 @@ async def assign_issue(
|
|||||||
updated_issue = await issue_crud.assign_to_agent(
|
updated_issue = await issue_crud.assign_to_agent(
|
||||||
db, issue_id=issue_id, agent_id=None
|
db, issue_id=issue_id, agent_id=None
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(f"User {current_user.email} unassigned issue {issue_id}")
|
||||||
f"User {current_user.email} unassigned issue {issue_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not updated_issue:
|
if not updated_issue:
|
||||||
raise NotFoundError(
|
raise NotFoundError(
|
||||||
|
|||||||
446
backend/app/api/routes/mcp.py
Normal file
446
backend/app/api/routes/mcp.py
Normal file
@@ -0,0 +1,446 @@
|
|||||||
|
"""
|
||||||
|
MCP (Model Context Protocol) API Endpoints
|
||||||
|
|
||||||
|
Provides REST endpoints for managing MCP server connections
|
||||||
|
and executing tool calls.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from typing import Annotated, Any
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Path, status
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from app.api.dependencies.permissions import require_superuser
|
||||||
|
from app.models.user import User
|
||||||
|
from app.services.mcp import (
|
||||||
|
MCPCircuitOpenError,
|
||||||
|
MCPClientManager,
|
||||||
|
MCPConnectionError,
|
||||||
|
MCPError,
|
||||||
|
MCPServerNotFoundError,
|
||||||
|
MCPTimeoutError,
|
||||||
|
MCPToolError,
|
||||||
|
MCPToolNotFoundError,
|
||||||
|
get_mcp_client,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
# Server name validation pattern: alphanumeric, hyphens, underscores, 1-64 chars
|
||||||
|
SERVER_NAME_PATTERN = re.compile(r"^[a-zA-Z0-9_-]{1,64}$")
|
||||||
|
|
||||||
|
# Type alias for validated server name path parameter
|
||||||
|
ServerNamePath = Annotated[
|
||||||
|
str,
|
||||||
|
Path(
|
||||||
|
description="MCP server name",
|
||||||
|
min_length=1,
|
||||||
|
max_length=64,
|
||||||
|
pattern=r"^[a-zA-Z0-9_-]+$",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Request/Response Schemas
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class ServerInfo(BaseModel):
|
||||||
|
"""Information about an MCP server."""
|
||||||
|
|
||||||
|
name: str = Field(..., description="Server name")
|
||||||
|
url: str = Field(..., description="Server URL")
|
||||||
|
enabled: bool = Field(..., description="Whether server is enabled")
|
||||||
|
timeout: int = Field(..., description="Request timeout in seconds")
|
||||||
|
transport: str = Field(..., description="Transport type (http, stdio, sse)")
|
||||||
|
description: str | None = Field(None, description="Server description")
|
||||||
|
|
||||||
|
|
||||||
|
class ServerListResponse(BaseModel):
|
||||||
|
"""Response containing list of MCP servers."""
|
||||||
|
|
||||||
|
servers: list[ServerInfo]
|
||||||
|
total: int
|
||||||
|
|
||||||
|
|
||||||
|
class ToolInfoResponse(BaseModel):
|
||||||
|
"""Information about an MCP tool."""
|
||||||
|
|
||||||
|
name: str = Field(..., description="Tool name")
|
||||||
|
description: str | None = Field(None, description="Tool description")
|
||||||
|
server_name: str | None = Field(None, description="Server providing the tool")
|
||||||
|
input_schema: dict[str, Any] | None = Field(
|
||||||
|
None, description="JSON schema for input"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ToolListResponse(BaseModel):
|
||||||
|
"""Response containing list of tools."""
|
||||||
|
|
||||||
|
tools: list[ToolInfoResponse]
|
||||||
|
total: int
|
||||||
|
|
||||||
|
|
||||||
|
class ServerHealthStatus(BaseModel):
|
||||||
|
"""Health status for a server."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
healthy: bool
|
||||||
|
state: str
|
||||||
|
url: str
|
||||||
|
error: str | None = None
|
||||||
|
tools_count: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
class HealthCheckResponse(BaseModel):
|
||||||
|
"""Response containing health status of all servers."""
|
||||||
|
|
||||||
|
servers: dict[str, ServerHealthStatus]
|
||||||
|
healthy_count: int
|
||||||
|
unhealthy_count: int
|
||||||
|
total: int
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCallRequest(BaseModel):
|
||||||
|
"""Request to execute a tool."""
|
||||||
|
|
||||||
|
server: str = Field(..., description="MCP server name")
|
||||||
|
tool: str = Field(..., description="Tool name to execute")
|
||||||
|
arguments: dict[str, Any] = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
description="Tool arguments",
|
||||||
|
)
|
||||||
|
timeout: float | None = Field(
|
||||||
|
None,
|
||||||
|
description="Optional timeout override in seconds",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCallResponse(BaseModel):
|
||||||
|
"""Response from tool execution."""
|
||||||
|
|
||||||
|
success: bool
|
||||||
|
data: Any | None = None
|
||||||
|
error: str | None = None
|
||||||
|
error_code: str | None = None
|
||||||
|
tool_name: str | None = None
|
||||||
|
server_name: str | None = None
|
||||||
|
execution_time_ms: float = 0.0
|
||||||
|
request_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class CircuitBreakerStatus(BaseModel):
|
||||||
|
"""Status of a circuit breaker."""
|
||||||
|
|
||||||
|
server_name: str
|
||||||
|
state: str
|
||||||
|
failure_count: int
|
||||||
|
|
||||||
|
|
||||||
|
class CircuitBreakerListResponse(BaseModel):
|
||||||
|
"""Response containing circuit breaker statuses."""
|
||||||
|
|
||||||
|
circuit_breakers: list[CircuitBreakerStatus]
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Endpoints
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/servers",
|
||||||
|
response_model=ServerListResponse,
|
||||||
|
summary="List MCP Servers",
|
||||||
|
description="Get list of all registered MCP servers with their configurations.",
|
||||||
|
)
|
||||||
|
async def list_servers(
|
||||||
|
mcp: MCPClientManager = Depends(get_mcp_client),
|
||||||
|
) -> ServerListResponse:
|
||||||
|
"""List all registered MCP servers."""
|
||||||
|
servers = []
|
||||||
|
|
||||||
|
for name in mcp.list_servers():
|
||||||
|
try:
|
||||||
|
config = mcp.get_server_config(name)
|
||||||
|
servers.append(
|
||||||
|
ServerInfo(
|
||||||
|
name=name,
|
||||||
|
url=config.url,
|
||||||
|
enabled=config.enabled,
|
||||||
|
timeout=config.timeout,
|
||||||
|
transport=config.transport.value,
|
||||||
|
description=config.description,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except MCPServerNotFoundError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
return ServerListResponse(
|
||||||
|
servers=servers,
|
||||||
|
total=len(servers),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/servers/{server_name}/tools",
|
||||||
|
response_model=ToolListResponse,
|
||||||
|
summary="List Server Tools",
|
||||||
|
description="Get list of tools available on a specific MCP server.",
|
||||||
|
)
|
||||||
|
async def list_server_tools(
|
||||||
|
server_name: ServerNamePath,
|
||||||
|
mcp: MCPClientManager = Depends(get_mcp_client),
|
||||||
|
) -> ToolListResponse:
|
||||||
|
"""List all tools available on a specific server."""
|
||||||
|
try:
|
||||||
|
tools = await mcp.list_tools(server_name)
|
||||||
|
return ToolListResponse(
|
||||||
|
tools=[
|
||||||
|
ToolInfoResponse(
|
||||||
|
name=t.name,
|
||||||
|
description=t.description,
|
||||||
|
server_name=t.server_name,
|
||||||
|
input_schema=t.input_schema,
|
||||||
|
)
|
||||||
|
for t in tools
|
||||||
|
],
|
||||||
|
total=len(tools),
|
||||||
|
)
|
||||||
|
except MCPServerNotFoundError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Server not found: {server_name}",
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/tools",
|
||||||
|
response_model=ToolListResponse,
|
||||||
|
summary="List All Tools",
|
||||||
|
description="Get list of all tools from all MCP servers.",
|
||||||
|
)
|
||||||
|
async def list_all_tools(
|
||||||
|
mcp: MCPClientManager = Depends(get_mcp_client),
|
||||||
|
) -> ToolListResponse:
|
||||||
|
"""List all tools from all servers."""
|
||||||
|
tools = await mcp.list_all_tools()
|
||||||
|
return ToolListResponse(
|
||||||
|
tools=[
|
||||||
|
ToolInfoResponse(
|
||||||
|
name=t.name,
|
||||||
|
description=t.description,
|
||||||
|
server_name=t.server_name,
|
||||||
|
input_schema=t.input_schema,
|
||||||
|
)
|
||||||
|
for t in tools
|
||||||
|
],
|
||||||
|
total=len(tools),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/health",
|
||||||
|
response_model=HealthCheckResponse,
|
||||||
|
summary="Health Check",
|
||||||
|
description="Check health status of all MCP servers.",
|
||||||
|
)
|
||||||
|
async def health_check(
|
||||||
|
mcp: MCPClientManager = Depends(get_mcp_client),
|
||||||
|
) -> HealthCheckResponse:
|
||||||
|
"""Perform health check on all MCP servers."""
|
||||||
|
health_results = await mcp.health_check()
|
||||||
|
|
||||||
|
servers = {
|
||||||
|
name: ServerHealthStatus(
|
||||||
|
name=status.name,
|
||||||
|
healthy=status.healthy,
|
||||||
|
state=status.state,
|
||||||
|
url=status.url,
|
||||||
|
error=status.error,
|
||||||
|
tools_count=status.tools_count,
|
||||||
|
)
|
||||||
|
for name, status in health_results.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
healthy_count = sum(1 for s in servers.values() if s.healthy)
|
||||||
|
unhealthy_count = len(servers) - healthy_count
|
||||||
|
|
||||||
|
return HealthCheckResponse(
|
||||||
|
servers=servers,
|
||||||
|
healthy_count=healthy_count,
|
||||||
|
unhealthy_count=unhealthy_count,
|
||||||
|
total=len(servers),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/call",
|
||||||
|
response_model=ToolCallResponse,
|
||||||
|
summary="Execute Tool (Admin Only)",
|
||||||
|
description="Execute a tool on an MCP server. Requires superuser privileges.",
|
||||||
|
)
|
||||||
|
async def call_tool(
|
||||||
|
request: ToolCallRequest,
|
||||||
|
current_user: User = Depends(require_superuser),
|
||||||
|
mcp: MCPClientManager = Depends(get_mcp_client),
|
||||||
|
) -> ToolCallResponse:
|
||||||
|
"""
|
||||||
|
Execute a tool on an MCP server.
|
||||||
|
|
||||||
|
This endpoint is restricted to superusers for direct tool execution.
|
||||||
|
Normal tool execution should go through agent workflows.
|
||||||
|
"""
|
||||||
|
logger.info(
|
||||||
|
"Tool call by user %s: %s.%s",
|
||||||
|
current_user.id,
|
||||||
|
request.server,
|
||||||
|
request.tool,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await mcp.call_tool(
|
||||||
|
server=request.server,
|
||||||
|
tool=request.tool,
|
||||||
|
args=request.arguments,
|
||||||
|
timeout=request.timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ToolCallResponse(
|
||||||
|
success=result.success,
|
||||||
|
data=result.data,
|
||||||
|
error=result.error,
|
||||||
|
error_code=result.error_code,
|
||||||
|
tool_name=result.tool_name,
|
||||||
|
server_name=result.server_name,
|
||||||
|
execution_time_ms=result.execution_time_ms,
|
||||||
|
request_id=result.request_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
except MCPCircuitOpenError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
detail=f"Server temporarily unavailable: {e.server_name}",
|
||||||
|
) from e
|
||||||
|
except MCPToolNotFoundError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Tool not found: {e.tool_name}",
|
||||||
|
) from e
|
||||||
|
except MCPServerNotFoundError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Server not found: {e.server_name}",
|
||||||
|
) from e
|
||||||
|
except MCPTimeoutError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_504_GATEWAY_TIMEOUT,
|
||||||
|
detail=str(e),
|
||||||
|
) from e
|
||||||
|
except MCPConnectionError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||||
|
detail=str(e),
|
||||||
|
) from e
|
||||||
|
except MCPToolError as e:
|
||||||
|
# Tool errors are returned in the response, not as HTTP errors
|
||||||
|
return ToolCallResponse(
|
||||||
|
success=False,
|
||||||
|
error=str(e),
|
||||||
|
error_code=e.error_code,
|
||||||
|
tool_name=e.tool_name,
|
||||||
|
server_name=e.server_name,
|
||||||
|
)
|
||||||
|
except MCPError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=str(e),
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/circuit-breakers",
|
||||||
|
response_model=CircuitBreakerListResponse,
|
||||||
|
summary="List Circuit Breakers",
|
||||||
|
description="Get status of all circuit breakers.",
|
||||||
|
)
|
||||||
|
async def list_circuit_breakers(
|
||||||
|
mcp: MCPClientManager = Depends(get_mcp_client),
|
||||||
|
) -> CircuitBreakerListResponse:
|
||||||
|
"""Get status of all circuit breakers."""
|
||||||
|
status_dict = mcp.get_circuit_breaker_status()
|
||||||
|
|
||||||
|
return CircuitBreakerListResponse(
|
||||||
|
circuit_breakers=[
|
||||||
|
CircuitBreakerStatus(
|
||||||
|
server_name=name,
|
||||||
|
state=info.get("state", "unknown"),
|
||||||
|
failure_count=info.get("failure_count", 0),
|
||||||
|
)
|
||||||
|
for name, info in status_dict.items()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/circuit-breakers/{server_name}/reset",
|
||||||
|
status_code=status.HTTP_204_NO_CONTENT,
|
||||||
|
summary="Reset Circuit Breaker (Admin Only)",
|
||||||
|
description="Manually reset a circuit breaker for a server.",
|
||||||
|
)
|
||||||
|
async def reset_circuit_breaker(
|
||||||
|
server_name: ServerNamePath,
|
||||||
|
current_user: User = Depends(require_superuser),
|
||||||
|
mcp: MCPClientManager = Depends(get_mcp_client),
|
||||||
|
) -> None:
|
||||||
|
"""Manually reset a circuit breaker."""
|
||||||
|
logger.info(
|
||||||
|
"Circuit breaker reset by user %s for server %s",
|
||||||
|
current_user.id,
|
||||||
|
server_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
success = await mcp.reset_circuit_breaker(server_name)
|
||||||
|
if not success:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"No circuit breaker found for server: {server_name}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/servers/{server_name}/reconnect",
|
||||||
|
status_code=status.HTTP_204_NO_CONTENT,
|
||||||
|
summary="Reconnect to Server (Admin Only)",
|
||||||
|
description="Force reconnection to an MCP server.",
|
||||||
|
)
|
||||||
|
async def reconnect_server(
|
||||||
|
server_name: ServerNamePath,
|
||||||
|
current_user: User = Depends(require_superuser),
|
||||||
|
mcp: MCPClientManager = Depends(get_mcp_client),
|
||||||
|
) -> None:
|
||||||
|
"""Force reconnection to an MCP server."""
|
||||||
|
logger.info(
|
||||||
|
"Reconnect requested by user %s for server %s",
|
||||||
|
current_user.id,
|
||||||
|
server_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await mcp.disconnect(server_name)
|
||||||
|
await mcp.connect(server_name)
|
||||||
|
except MCPServerNotFoundError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Server not found: {server_name}",
|
||||||
|
) from e
|
||||||
|
except MCPConnectionError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||||
|
detail=f"Failed to reconnect: {e}",
|
||||||
|
) from e
|
||||||
@@ -197,10 +197,10 @@ async def list_projects(
|
|||||||
status_filter: ProjectStatus | None = Query(
|
status_filter: ProjectStatus | None = Query(
|
||||||
None, alias="status", description="Filter by project status"
|
None, alias="status", description="Filter by project status"
|
||||||
),
|
),
|
||||||
search: str | None = Query(None, description="Search by name, slug, or description"),
|
search: str | None = Query(
|
||||||
all_projects: bool = Query(
|
None, description="Search by name, slug, or description"
|
||||||
False, description="Show all projects (superuser only)"
|
|
||||||
),
|
),
|
||||||
|
all_projects: bool = Query(False, description="Show all projects (superuser only)"),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
) -> Any:
|
) -> Any:
|
||||||
@@ -212,7 +212,9 @@ async def list_projects(
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Determine owner filter based on user role and request
|
# Determine owner filter based on user role and request
|
||||||
owner_id = None if (current_user.is_superuser and all_projects) else current_user.id
|
owner_id = (
|
||||||
|
None if (current_user.is_superuser and all_projects) else current_user.id
|
||||||
|
)
|
||||||
|
|
||||||
projects_data, total = await project_crud.get_multi_with_counts(
|
projects_data, total = await project_crud.get_multi_with_counts(
|
||||||
db,
|
db,
|
||||||
@@ -379,13 +381,15 @@ async def update_project(
|
|||||||
_check_project_ownership(project, current_user)
|
_check_project_ownership(project, current_user)
|
||||||
|
|
||||||
# Update the project
|
# Update the project
|
||||||
updated_project = await project_crud.update(db, db_obj=project, obj_in=project_in)
|
updated_project = await project_crud.update(
|
||||||
logger.info(
|
db, db_obj=project, obj_in=project_in
|
||||||
f"User {current_user.email} updated project {updated_project.slug}"
|
|
||||||
)
|
)
|
||||||
|
logger.info(f"User {current_user.email} updated project {updated_project.slug}")
|
||||||
|
|
||||||
# Get updated project with counts
|
# Get updated project with counts
|
||||||
project_data = await project_crud.get_with_counts(db, project_id=updated_project.id)
|
project_data = await project_crud.get_with_counts(
|
||||||
|
db, project_id=updated_project.id
|
||||||
|
)
|
||||||
|
|
||||||
if not project_data:
|
if not project_data:
|
||||||
# This shouldn't happen, but handle gracefully
|
# This shouldn't happen, but handle gracefully
|
||||||
@@ -551,7 +555,9 @@ async def pause_project(
|
|||||||
logger.info(f"User {current_user.email} paused project {project.slug}")
|
logger.info(f"User {current_user.email} paused project {project.slug}")
|
||||||
|
|
||||||
# Get project with counts
|
# Get project with counts
|
||||||
project_data = await project_crud.get_with_counts(db, project_id=updated_project.id)
|
project_data = await project_crud.get_with_counts(
|
||||||
|
db, project_id=updated_project.id
|
||||||
|
)
|
||||||
|
|
||||||
if not project_data:
|
if not project_data:
|
||||||
raise NotFoundError(
|
raise NotFoundError(
|
||||||
@@ -634,7 +640,9 @@ async def resume_project(
|
|||||||
logger.info(f"User {current_user.email} resumed project {project.slug}")
|
logger.info(f"User {current_user.email} resumed project {project.slug}")
|
||||||
|
|
||||||
# Get project with counts
|
# Get project with counts
|
||||||
project_data = await project_crud.get_with_counts(db, project_id=updated_project.id)
|
project_data = await project_crud.get_with_counts(
|
||||||
|
db, project_id=updated_project.id
|
||||||
|
)
|
||||||
|
|
||||||
if not project_data:
|
if not project_data:
|
||||||
raise NotFoundError(
|
raise NotFoundError(
|
||||||
|
|||||||
@@ -320,7 +320,9 @@ async def list_sprints(
|
|||||||
return PaginatedResponse(data=sprint_responses, pagination=pagination_meta)
|
return PaginatedResponse(data=sprint_responses, pagination=pagination_meta)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error listing sprints for project {project_id}: {e!s}", exc_info=True)
|
logger.error(
|
||||||
|
f"Error listing sprints for project {project_id}: {e!s}", exc_info=True
|
||||||
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
@@ -564,7 +566,9 @@ async def update_sprint(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Update the sprint
|
# Update the sprint
|
||||||
updated_sprint = await sprint_crud.update(db, db_obj=sprint, obj_in=sprint_update)
|
updated_sprint = await sprint_crud.update(
|
||||||
|
db, db_obj=sprint, obj_in=sprint_update
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"User {current_user.id} updated sprint {sprint_id} in project {project_id}"
|
f"User {current_user.id} updated sprint {sprint_id} in project {project_id}"
|
||||||
@@ -1123,7 +1127,9 @@ async def remove_issue_from_sprint(
|
|||||||
request: Request,
|
request: Request,
|
||||||
project_id: UUID,
|
project_id: UUID,
|
||||||
sprint_id: UUID,
|
sprint_id: UUID,
|
||||||
issue_id: UUID = Query(..., description="ID of the issue to remove from the sprint"),
|
issue_id: UUID = Query(
|
||||||
|
..., description="ID of the issue to remove from the sprint"
|
||||||
|
),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
) -> Any:
|
) -> Any:
|
||||||
|
|||||||
@@ -1,366 +0,0 @@
|
|||||||
{
|
|
||||||
"organizations": [
|
|
||||||
{
|
|
||||||
"name": "Acme Corp",
|
|
||||||
"slug": "acme-corp",
|
|
||||||
"description": "A leading provider of coyote-catching equipment."
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Globex Corporation",
|
|
||||||
"slug": "globex",
|
|
||||||
"description": "We own the East Coast."
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Soylent Corp",
|
|
||||||
"slug": "soylent",
|
|
||||||
"description": "Making food for the future."
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Initech",
|
|
||||||
"slug": "initech",
|
|
||||||
"description": "Software for the soul."
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Umbrella Corporation",
|
|
||||||
"slug": "umbrella",
|
|
||||||
"description": "Our business is life itself."
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Massive Dynamic",
|
|
||||||
"slug": "massive-dynamic",
|
|
||||||
"description": "What don't we do?"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"users": [
|
|
||||||
{
|
|
||||||
"email": "demo@example.com",
|
|
||||||
"password": "DemoPass1234!",
|
|
||||||
"first_name": "Demo",
|
|
||||||
"last_name": "User",
|
|
||||||
"is_superuser": false,
|
|
||||||
"organization_slug": "acme-corp",
|
|
||||||
"role": "member",
|
|
||||||
"is_active": true
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"email": "alice@acme.com",
|
|
||||||
"password": "Demo123!",
|
|
||||||
"first_name": "Alice",
|
|
||||||
"last_name": "Smith",
|
|
||||||
"is_superuser": false,
|
|
||||||
"organization_slug": "acme-corp",
|
|
||||||
"role": "admin",
|
|
||||||
"is_active": true
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"email": "bob@acme.com",
|
|
||||||
"password": "Demo123!",
|
|
||||||
"first_name": "Bob",
|
|
||||||
"last_name": "Jones",
|
|
||||||
"is_superuser": false,
|
|
||||||
"organization_slug": "acme-corp",
|
|
||||||
"role": "member",
|
|
||||||
"is_active": true
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"email": "charlie@acme.com",
|
|
||||||
"password": "Demo123!",
|
|
||||||
"first_name": "Charlie",
|
|
||||||
"last_name": "Brown",
|
|
||||||
"is_superuser": false,
|
|
||||||
"organization_slug": "acme-corp",
|
|
||||||
"role": "member",
|
|
||||||
"is_active": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"email": "diana@acme.com",
|
|
||||||
"password": "Demo123!",
|
|
||||||
"first_name": "Diana",
|
|
||||||
"last_name": "Prince",
|
|
||||||
"is_superuser": false,
|
|
||||||
"organization_slug": "acme-corp",
|
|
||||||
"role": "member",
|
|
||||||
"is_active": true
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"email": "carol@globex.com",
|
|
||||||
"password": "Demo123!",
|
|
||||||
"first_name": "Carol",
|
|
||||||
"last_name": "Williams",
|
|
||||||
"is_superuser": false,
|
|
||||||
"organization_slug": "globex",
|
|
||||||
"role": "owner",
|
|
||||||
"is_active": true
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"email": "dan@globex.com",
|
|
||||||
"password": "Demo123!",
|
|
||||||
"first_name": "Dan",
|
|
||||||
"last_name": "Miller",
|
|
||||||
"is_superuser": false,
|
|
||||||
"organization_slug": "globex",
|
|
||||||
"role": "member",
|
|
||||||
"is_active": true
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"email": "ellen@globex.com",
|
|
||||||
"password": "Demo123!",
|
|
||||||
"first_name": "Ellen",
|
|
||||||
"last_name": "Ripley",
|
|
||||||
"is_superuser": false,
|
|
||||||
"organization_slug": "globex",
|
|
||||||
"role": "member",
|
|
||||||
"is_active": true
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"email": "fred@globex.com",
|
|
||||||
"password": "Demo123!",
|
|
||||||
"first_name": "Fred",
|
|
||||||
"last_name": "Flintstone",
|
|
||||||
"is_superuser": false,
|
|
||||||
"organization_slug": "globex",
|
|
||||||
"role": "member",
|
|
||||||
"is_active": true
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"email": "dave@soylent.com",
|
|
||||||
"password": "Demo123!",
|
|
||||||
"first_name": "Dave",
|
|
||||||
"last_name": "Brown",
|
|
||||||
"is_superuser": false,
|
|
||||||
"organization_slug": "soylent",
|
|
||||||
"role": "member",
|
|
||||||
"is_active": true
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"email": "gina@soylent.com",
|
|
||||||
"password": "Demo123!",
|
|
||||||
"first_name": "Gina",
|
|
||||||
"last_name": "Torres",
|
|
||||||
"is_superuser": false,
|
|
||||||
"organization_slug": "soylent",
|
|
||||||
"role": "member",
|
|
||||||
"is_active": true
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"email": "harry@soylent.com",
|
|
||||||
"password": "Demo123!",
|
|
||||||
"first_name": "Harry",
|
|
||||||
"last_name": "Potter",
|
|
||||||
"is_superuser": false,
|
|
||||||
"organization_slug": "soylent",
|
|
||||||
"role": "admin",
|
|
||||||
"is_active": true
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"email": "eve@initech.com",
|
|
||||||
"password": "Demo123!",
|
|
||||||
"first_name": "Eve",
|
|
||||||
"last_name": "Davis",
|
|
||||||
"is_superuser": false,
|
|
||||||
"organization_slug": "initech",
|
|
||||||
"role": "admin",
|
|
||||||
"is_active": true
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"email": "iris@initech.com",
|
|
||||||
"password": "Demo123!",
|
|
||||||
"first_name": "Iris",
|
|
||||||
"last_name": "West",
|
|
||||||
"is_superuser": false,
|
|
||||||
"organization_slug": "initech",
|
|
||||||
"role": "member",
|
|
||||||
"is_active": true
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"email": "jack@initech.com",
|
|
||||||
"password": "Demo123!",
|
|
||||||
"first_name": "Jack",
|
|
||||||
"last_name": "Sparrow",
|
|
||||||
"is_superuser": false,
|
|
||||||
"organization_slug": "initech",
|
|
||||||
"role": "member",
|
|
||||||
"is_active": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"email": "frank@umbrella.com",
|
|
||||||
"password": "Demo123!",
|
|
||||||
"first_name": "Frank",
|
|
||||||
"last_name": "Miller",
|
|
||||||
"is_superuser": false,
|
|
||||||
"organization_slug": "umbrella",
|
|
||||||
"role": "member",
|
|
||||||
"is_active": true
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"email": "george@umbrella.com",
|
|
||||||
"password": "Demo123!",
|
|
||||||
"first_name": "George",
|
|
||||||
"last_name": "Costanza",
|
|
||||||
"is_superuser": false,
|
|
||||||
"organization_slug": "umbrella",
|
|
||||||
"role": "member",
|
|
||||||
"is_active": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"email": "kate@umbrella.com",
|
|
||||||
"password": "Demo123!",
|
|
||||||
"first_name": "Kate",
|
|
||||||
"last_name": "Bishop",
|
|
||||||
"is_superuser": false,
|
|
||||||
"organization_slug": "umbrella",
|
|
||||||
"role": "member",
|
|
||||||
"is_active": true
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"email": "leo@massive.com",
|
|
||||||
"password": "Demo123!",
|
|
||||||
"first_name": "Leo",
|
|
||||||
"last_name": "Messi",
|
|
||||||
"is_superuser": false,
|
|
||||||
"organization_slug": "massive-dynamic",
|
|
||||||
"role": "owner",
|
|
||||||
"is_active": true
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"email": "mary@massive.com",
|
|
||||||
"password": "Demo123!",
|
|
||||||
"first_name": "Mary",
|
|
||||||
"last_name": "Jane",
|
|
||||||
"is_superuser": false,
|
|
||||||
"organization_slug": "massive-dynamic",
|
|
||||||
"role": "member",
|
|
||||||
"is_active": true
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"email": "nathan@massive.com",
|
|
||||||
"password": "Demo123!",
|
|
||||||
"first_name": "Nathan",
|
|
||||||
"last_name": "Drake",
|
|
||||||
"is_superuser": false,
|
|
||||||
"organization_slug": "massive-dynamic",
|
|
||||||
"role": "member",
|
|
||||||
"is_active": true
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"email": "olivia@massive.com",
|
|
||||||
"password": "Demo123!",
|
|
||||||
"first_name": "Olivia",
|
|
||||||
"last_name": "Dunham",
|
|
||||||
"is_superuser": false,
|
|
||||||
"organization_slug": "massive-dynamic",
|
|
||||||
"role": "admin",
|
|
||||||
"is_active": true
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"email": "peter@massive.com",
|
|
||||||
"password": "Demo123!",
|
|
||||||
"first_name": "Peter",
|
|
||||||
"last_name": "Parker",
|
|
||||||
"is_superuser": false,
|
|
||||||
"organization_slug": "massive-dynamic",
|
|
||||||
"role": "member",
|
|
||||||
"is_active": true
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"email": "quinn@massive.com",
|
|
||||||
"password": "Demo123!",
|
|
||||||
"first_name": "Quinn",
|
|
||||||
"last_name": "Mallory",
|
|
||||||
"is_superuser": false,
|
|
||||||
"organization_slug": "massive-dynamic",
|
|
||||||
"role": "member",
|
|
||||||
"is_active": true
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"email": "grace@example.com",
|
|
||||||
"password": "Demo123!",
|
|
||||||
"first_name": "Grace",
|
|
||||||
"last_name": "Hopper",
|
|
||||||
"is_superuser": false,
|
|
||||||
"organization_slug": null,
|
|
||||||
"role": null,
|
|
||||||
"is_active": true
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"email": "heidi@example.com",
|
|
||||||
"password": "Demo123!",
|
|
||||||
"first_name": "Heidi",
|
|
||||||
"last_name": "Klum",
|
|
||||||
"is_superuser": false,
|
|
||||||
"organization_slug": null,
|
|
||||||
"role": null,
|
|
||||||
"is_active": true
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"email": "ivan@example.com",
|
|
||||||
"password": "Demo123!",
|
|
||||||
"first_name": "Ivan",
|
|
||||||
"last_name": "Drago",
|
|
||||||
"is_superuser": false,
|
|
||||||
"organization_slug": null,
|
|
||||||
"role": null,
|
|
||||||
"is_active": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"email": "rachel@example.com",
|
|
||||||
"password": "Demo123!",
|
|
||||||
"first_name": "Rachel",
|
|
||||||
"last_name": "Green",
|
|
||||||
"is_superuser": false,
|
|
||||||
"organization_slug": null,
|
|
||||||
"role": null,
|
|
||||||
"is_active": true
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"email": "sam@example.com",
|
|
||||||
"password": "Demo123!",
|
|
||||||
"first_name": "Sam",
|
|
||||||
"last_name": "Wilson",
|
|
||||||
"is_superuser": false,
|
|
||||||
"organization_slug": null,
|
|
||||||
"role": null,
|
|
||||||
"is_active": true
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"email": "tony@example.com",
|
|
||||||
"password": "Demo123!",
|
|
||||||
"first_name": "Tony",
|
|
||||||
"last_name": "Stark",
|
|
||||||
"is_superuser": false,
|
|
||||||
"organization_slug": null,
|
|
||||||
"role": null,
|
|
||||||
"is_active": true
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"email": "una@example.com",
|
|
||||||
"password": "Demo123!",
|
|
||||||
"first_name": "Una",
|
|
||||||
"last_name": "Chin-Riley",
|
|
||||||
"is_superuser": false,
|
|
||||||
"organization_slug": null,
|
|
||||||
"role": null,
|
|
||||||
"is_active": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"email": "victor@example.com",
|
|
||||||
"password": "Demo123!",
|
|
||||||
"first_name": "Victor",
|
|
||||||
"last_name": "Von Doom",
|
|
||||||
"is_superuser": false,
|
|
||||||
"organization_slug": null,
|
|
||||||
"role": null,
|
|
||||||
"is_active": true
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"email": "wanda@example.com",
|
|
||||||
"password": "Demo123!",
|
|
||||||
"first_name": "Wanda",
|
|
||||||
"last_name": "Maximoff",
|
|
||||||
"is_superuser": false,
|
|
||||||
"organization_slug": null,
|
|
||||||
"role": null,
|
|
||||||
"is_active": true
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
@@ -243,7 +243,9 @@ class RedisClient:
|
|||||||
try:
|
try:
|
||||||
client = await self._get_client()
|
client = await self._get_client()
|
||||||
result = await client.expire(key, ttl)
|
result = await client.expire(key, ttl)
|
||||||
logger.debug(f"Cache expire for key: {key} (TTL: {ttl}s, success: {result})")
|
logger.debug(
|
||||||
|
f"Cache expire for key: {key} (TTL: {ttl}s, success: {result})"
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
except (ConnectionError, TimeoutError) as e:
|
except (ConnectionError, TimeoutError) as e:
|
||||||
logger.error(f"Redis cache_expire failed for key '{key}': {e}")
|
logger.error(f"Redis cache_expire failed for key '{key}': {e}")
|
||||||
@@ -323,9 +325,7 @@ class RedisClient:
|
|||||||
return 0
|
return 0
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def subscribe(
|
async def subscribe(self, *channels: str) -> AsyncGenerator[PubSub, None]:
|
||||||
self, *channels: str
|
|
||||||
) -> AsyncGenerator[PubSub, None]:
|
|
||||||
"""
|
"""
|
||||||
Subscribe to one or more channels.
|
Subscribe to one or more channels.
|
||||||
|
|
||||||
@@ -353,9 +353,7 @@ class RedisClient:
|
|||||||
logger.debug(f"Unsubscribed from channels: {channels}")
|
logger.debug(f"Unsubscribed from channels: {channels}")
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def psubscribe(
|
async def psubscribe(self, *patterns: str) -> AsyncGenerator[PubSub, None]:
|
||||||
self, *patterns: str
|
|
||||||
) -> AsyncGenerator[PubSub, None]:
|
|
||||||
"""
|
"""
|
||||||
Subscribe to channels matching patterns.
|
Subscribe to channels matching patterns.
|
||||||
|
|
||||||
|
|||||||
@@ -20,7 +20,9 @@ from app.schemas.syndarix import AgentInstanceCreate, AgentInstanceUpdate
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class CRUDAgentInstance(CRUDBase[AgentInstance, AgentInstanceCreate, AgentInstanceUpdate]):
|
class CRUDAgentInstance(
|
||||||
|
CRUDBase[AgentInstance, AgentInstanceCreate, AgentInstanceUpdate]
|
||||||
|
):
|
||||||
"""Async CRUD operations for AgentInstance model."""
|
"""Async CRUD operations for AgentInstance model."""
|
||||||
|
|
||||||
async def create(
|
async def create(
|
||||||
@@ -91,8 +93,12 @@ class CRUDAgentInstance(CRUDBase[AgentInstance, AgentInstanceCreate, AgentInstan
|
|||||||
|
|
||||||
return {
|
return {
|
||||||
"instance": instance,
|
"instance": instance,
|
||||||
"agent_type_name": instance.agent_type.name if instance.agent_type else None,
|
"agent_type_name": instance.agent_type.name
|
||||||
"agent_type_slug": instance.agent_type.slug if instance.agent_type else None,
|
if instance.agent_type
|
||||||
|
else None,
|
||||||
|
"agent_type_slug": instance.agent_type.slug
|
||||||
|
if instance.agent_type
|
||||||
|
else None,
|
||||||
"project_name": instance.project.name if instance.project else None,
|
"project_name": instance.project.name if instance.project else None,
|
||||||
"project_slug": instance.project.slug if instance.project else None,
|
"project_slug": instance.project.slug if instance.project else None,
|
||||||
"assigned_issues_count": assigned_issues_count,
|
"assigned_issues_count": assigned_issues_count,
|
||||||
@@ -115,9 +121,7 @@ class CRUDAgentInstance(CRUDBase[AgentInstance, AgentInstanceCreate, AgentInstan
|
|||||||
) -> tuple[list[AgentInstance], int]:
|
) -> tuple[list[AgentInstance], int]:
|
||||||
"""Get agent instances for a specific project."""
|
"""Get agent instances for a specific project."""
|
||||||
try:
|
try:
|
||||||
query = select(AgentInstance).where(
|
query = select(AgentInstance).where(AgentInstance.project_id == project_id)
|
||||||
AgentInstance.project_id == project_id
|
|
||||||
)
|
|
||||||
|
|
||||||
if status is not None:
|
if status is not None:
|
||||||
query = query.where(AgentInstance.status == status)
|
query = query.where(AgentInstance.status == status)
|
||||||
|
|||||||
@@ -22,17 +22,13 @@ class CRUDAgentType(CRUDBase[AgentType, AgentTypeCreate, AgentTypeUpdate]):
|
|||||||
async def get_by_slug(self, db: AsyncSession, *, slug: str) -> AgentType | None:
|
async def get_by_slug(self, db: AsyncSession, *, slug: str) -> AgentType | None:
|
||||||
"""Get agent type by slug."""
|
"""Get agent type by slug."""
|
||||||
try:
|
try:
|
||||||
result = await db.execute(
|
result = await db.execute(select(AgentType).where(AgentType.slug == slug))
|
||||||
select(AgentType).where(AgentType.slug == slug)
|
|
||||||
)
|
|
||||||
return result.scalar_one_or_none()
|
return result.scalar_one_or_none()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting agent type by slug {slug}: {e!s}")
|
logger.error(f"Error getting agent type by slug {slug}: {e!s}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def create(
|
async def create(self, db: AsyncSession, *, obj_in: AgentTypeCreate) -> AgentType:
|
||||||
self, db: AsyncSession, *, obj_in: AgentTypeCreate
|
|
||||||
) -> AgentType:
|
|
||||||
"""Create a new agent type with error handling."""
|
"""Create a new agent type with error handling."""
|
||||||
try:
|
try:
|
||||||
db_obj = AgentType(
|
db_obj = AgentType(
|
||||||
@@ -47,6 +43,13 @@ class CRUDAgentType(CRUDBase[AgentType, AgentTypeCreate, AgentTypeUpdate]):
|
|||||||
mcp_servers=obj_in.mcp_servers,
|
mcp_servers=obj_in.mcp_servers,
|
||||||
tool_permissions=obj_in.tool_permissions,
|
tool_permissions=obj_in.tool_permissions,
|
||||||
is_active=obj_in.is_active,
|
is_active=obj_in.is_active,
|
||||||
|
# Category and display fields
|
||||||
|
category=obj_in.category.value if obj_in.category else None,
|
||||||
|
icon=obj_in.icon,
|
||||||
|
color=obj_in.color,
|
||||||
|
sort_order=obj_in.sort_order,
|
||||||
|
typical_tasks=obj_in.typical_tasks,
|
||||||
|
collaboration_hints=obj_in.collaboration_hints,
|
||||||
)
|
)
|
||||||
db.add(db_obj)
|
db.add(db_obj)
|
||||||
await db.commit()
|
await db.commit()
|
||||||
@@ -57,16 +60,12 @@ class CRUDAgentType(CRUDBase[AgentType, AgentTypeCreate, AgentTypeUpdate]):
|
|||||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||||
if "slug" in error_msg.lower():
|
if "slug" in error_msg.lower():
|
||||||
logger.warning(f"Duplicate slug attempted: {obj_in.slug}")
|
logger.warning(f"Duplicate slug attempted: {obj_in.slug}")
|
||||||
raise ValueError(
|
raise ValueError(f"Agent type with slug '{obj_in.slug}' already exists")
|
||||||
f"Agent type with slug '{obj_in.slug}' already exists"
|
|
||||||
)
|
|
||||||
logger.error(f"Integrity error creating agent type: {error_msg}")
|
logger.error(f"Integrity error creating agent type: {error_msg}")
|
||||||
raise ValueError(f"Database integrity error: {error_msg}")
|
raise ValueError(f"Database integrity error: {error_msg}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await db.rollback()
|
await db.rollback()
|
||||||
logger.error(
|
logger.error(f"Unexpected error creating agent type: {e!s}", exc_info=True)
|
||||||
f"Unexpected error creating agent type: {e!s}", exc_info=True
|
|
||||||
)
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def get_multi_with_filters(
|
async def get_multi_with_filters(
|
||||||
@@ -76,6 +75,7 @@ class CRUDAgentType(CRUDBase[AgentType, AgentTypeCreate, AgentTypeUpdate]):
|
|||||||
skip: int = 0,
|
skip: int = 0,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
is_active: bool | None = None,
|
is_active: bool | None = None,
|
||||||
|
category: str | None = None,
|
||||||
search: str | None = None,
|
search: str | None = None,
|
||||||
sort_by: str = "created_at",
|
sort_by: str = "created_at",
|
||||||
sort_order: str = "desc",
|
sort_order: str = "desc",
|
||||||
@@ -93,6 +93,9 @@ class CRUDAgentType(CRUDBase[AgentType, AgentTypeCreate, AgentTypeUpdate]):
|
|||||||
if is_active is not None:
|
if is_active is not None:
|
||||||
query = query.where(AgentType.is_active == is_active)
|
query = query.where(AgentType.is_active == is_active)
|
||||||
|
|
||||||
|
if category:
|
||||||
|
query = query.where(AgentType.category == category)
|
||||||
|
|
||||||
if search:
|
if search:
|
||||||
search_filter = or_(
|
search_filter = or_(
|
||||||
AgentType.name.ilike(f"%{search}%"),
|
AgentType.name.ilike(f"%{search}%"),
|
||||||
@@ -170,6 +173,7 @@ class CRUDAgentType(CRUDBase[AgentType, AgentTypeCreate, AgentTypeUpdate]):
|
|||||||
skip: int = 0,
|
skip: int = 0,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
is_active: bool | None = None,
|
is_active: bool | None = None,
|
||||||
|
category: str | None = None,
|
||||||
search: str | None = None,
|
search: str | None = None,
|
||||||
) -> tuple[list[dict[str, Any]], int]:
|
) -> tuple[list[dict[str, Any]], int]:
|
||||||
"""
|
"""
|
||||||
@@ -185,6 +189,7 @@ class CRUDAgentType(CRUDBase[AgentType, AgentTypeCreate, AgentTypeUpdate]):
|
|||||||
skip=skip,
|
skip=skip,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
is_active=is_active,
|
is_active=is_active,
|
||||||
|
category=category,
|
||||||
search=search,
|
search=search,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -215,9 +220,7 @@ class CRUDAgentType(CRUDBase[AgentType, AgentTypeCreate, AgentTypeUpdate]):
|
|||||||
|
|
||||||
return results, total
|
return results, total
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(f"Error getting agent types with counts: {e!s}", exc_info=True)
|
||||||
f"Error getting agent types with counts: {e!s}", exc_info=True
|
|
||||||
)
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def get_by_expertise(
|
async def get_by_expertise(
|
||||||
@@ -270,6 +273,44 @@ class CRUDAgentType(CRUDBase[AgentType, AgentTypeCreate, AgentTypeUpdate]):
|
|||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
async def get_grouped_by_category(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
is_active: bool = True,
|
||||||
|
) -> dict[str, list[AgentType]]:
|
||||||
|
"""
|
||||||
|
Get agent types grouped by category, sorted by sort_order within each group.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Database session
|
||||||
|
is_active: Filter by active status (default: True)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping category to list of agent types
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
query = (
|
||||||
|
select(AgentType)
|
||||||
|
.where(AgentType.is_active == is_active)
|
||||||
|
.order_by(AgentType.category, AgentType.sort_order, AgentType.name)
|
||||||
|
)
|
||||||
|
result = await db.execute(query)
|
||||||
|
agent_types = list(result.scalars().all())
|
||||||
|
|
||||||
|
# Group by category
|
||||||
|
grouped: dict[str, list[AgentType]] = {}
|
||||||
|
for at in agent_types:
|
||||||
|
cat: str = str(at.category) if at.category else "uncategorized"
|
||||||
|
if cat not in grouped:
|
||||||
|
grouped[cat] = []
|
||||||
|
grouped[cat].append(at)
|
||||||
|
|
||||||
|
return grouped
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting grouped agent types: {e!s}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
# Create a singleton instance for use across the application
|
# Create a singleton instance for use across the application
|
||||||
agent_type = CRUDAgentType(AgentType)
|
agent_type = CRUDAgentType(AgentType)
|
||||||
|
|||||||
@@ -75,7 +75,9 @@ class CRUDIssue(CRUDBase[Issue, IssueCreate, IssueUpdate]):
|
|||||||
.options(
|
.options(
|
||||||
joinedload(Issue.project),
|
joinedload(Issue.project),
|
||||||
joinedload(Issue.sprint),
|
joinedload(Issue.sprint),
|
||||||
joinedload(Issue.assigned_agent).joinedload(AgentInstance.agent_type),
|
joinedload(Issue.assigned_agent).joinedload(
|
||||||
|
AgentInstance.agent_type
|
||||||
|
),
|
||||||
)
|
)
|
||||||
.where(Issue.id == issue_id)
|
.where(Issue.id == issue_id)
|
||||||
)
|
)
|
||||||
@@ -449,9 +451,7 @@ class CRUDIssue(CRUDBase[Issue, IssueCreate, IssueUpdate]):
|
|||||||
from sqlalchemy import update
|
from sqlalchemy import update
|
||||||
|
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
update(Issue)
|
update(Issue).where(Issue.sprint_id == sprint_id).values(sprint_id=None)
|
||||||
.where(Issue.sprint_id == sprint_id)
|
|
||||||
.values(sprint_id=None)
|
|
||||||
)
|
)
|
||||||
await db.commit()
|
await db.commit()
|
||||||
return result.rowcount
|
return result.rowcount
|
||||||
|
|||||||
@@ -2,11 +2,10 @@
|
|||||||
"""Async CRUD operations for Project model using SQLAlchemy 2.0 patterns."""
|
"""Async CRUD operations for Project model using SQLAlchemy 2.0 patterns."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from datetime import UTC, datetime
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from datetime import UTC, datetime
|
|
||||||
|
|
||||||
from sqlalchemy import func, or_, select, update
|
from sqlalchemy import func, or_, select, update
|
||||||
from sqlalchemy.exc import IntegrityError
|
from sqlalchemy.exc import IntegrityError
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
@@ -234,9 +233,7 @@ class CRUDProject(CRUDBase[Project, ProjectCreate, ProjectUpdate]):
|
|||||||
Sprint.status == SprintStatus.ACTIVE,
|
Sprint.status == SprintStatus.ACTIVE,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
active_sprints = {
|
active_sprints = {row.project_id: row.name for row in active_sprints_result}
|
||||||
row.project_id: row.name for row in active_sprints_result
|
|
||||||
}
|
|
||||||
|
|
||||||
# Combine results
|
# Combine results
|
||||||
results = [
|
results = [
|
||||||
@@ -251,9 +248,7 @@ class CRUDProject(CRUDBase[Project, ProjectCreate, ProjectUpdate]):
|
|||||||
|
|
||||||
return results, total
|
return results, total
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(f"Error getting projects with counts: {e!s}", exc_info=True)
|
||||||
f"Error getting projects with counts: {e!s}", exc_info=True
|
|
||||||
)
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def get_projects_by_owner(
|
async def get_projects_by_owner(
|
||||||
@@ -293,9 +288,7 @@ class CRUDProject(CRUDBase[Project, ProjectCreate, ProjectUpdate]):
|
|||||||
- Unassigns issues from terminated agents
|
- Unassigns issues from terminated agents
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
result = await db.execute(
|
result = await db.execute(select(Project).where(Project.id == project_id))
|
||||||
select(Project).where(Project.id == project_id)
|
|
||||||
)
|
|
||||||
project = result.scalar_one_or_none()
|
project = result.scalar_one_or_none()
|
||||||
|
|
||||||
if not project:
|
if not project:
|
||||||
@@ -361,9 +354,7 @@ class CRUDProject(CRUDBase[Project, ProjectCreate, ProjectUpdate]):
|
|||||||
return project
|
return project
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await db.rollback()
|
await db.rollback()
|
||||||
logger.error(
|
logger.error(f"Error archiving project {project_id}: {e!s}", exc_info=True)
|
||||||
f"Error archiving project {project_id}: {e!s}", exc_info=True
|
|
||||||
)
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -193,9 +193,7 @@ class CRUDSprint(CRUDBase[Sprint, SprintCreate, SprintUpdate]):
|
|||||||
try:
|
try:
|
||||||
# Lock the sprint row to prevent concurrent modifications
|
# Lock the sprint row to prevent concurrent modifications
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(Sprint)
|
select(Sprint).where(Sprint.id == sprint_id).with_for_update()
|
||||||
.where(Sprint.id == sprint_id)
|
|
||||||
.with_for_update()
|
|
||||||
)
|
)
|
||||||
sprint = result.scalar_one_or_none()
|
sprint = result.scalar_one_or_none()
|
||||||
|
|
||||||
@@ -257,9 +255,7 @@ class CRUDSprint(CRUDBase[Sprint, SprintCreate, SprintUpdate]):
|
|||||||
try:
|
try:
|
||||||
# Lock the sprint row to prevent concurrent modifications
|
# Lock the sprint row to prevent concurrent modifications
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(Sprint)
|
select(Sprint).where(Sprint.id == sprint_id).with_for_update()
|
||||||
.where(Sprint.id == sprint_id)
|
|
||||||
.with_for_update()
|
|
||||||
)
|
)
|
||||||
sprint = result.scalar_one_or_none()
|
sprint = result.scalar_one_or_none()
|
||||||
|
|
||||||
@@ -308,9 +304,7 @@ class CRUDSprint(CRUDBase[Sprint, SprintCreate, SprintUpdate]):
|
|||||||
try:
|
try:
|
||||||
# Lock the sprint row to prevent concurrent modifications
|
# Lock the sprint row to prevent concurrent modifications
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(Sprint)
|
select(Sprint).where(Sprint.id == sprint_id).with_for_update()
|
||||||
.where(Sprint.id == sprint_id)
|
|
||||||
.with_for_update()
|
|
||||||
)
|
)
|
||||||
sprint = result.scalar_one_or_none()
|
sprint = result.scalar_one_or_none()
|
||||||
|
|
||||||
@@ -425,7 +419,8 @@ class CRUDSprint(CRUDBase[Sprint, SprintCreate, SprintUpdate]):
|
|||||||
{
|
{
|
||||||
"sprint": sprint,
|
"sprint": sprint,
|
||||||
**counts_map.get(
|
**counts_map.get(
|
||||||
sprint.id, {"issue_count": 0, "open_issues": 0, "completed_issues": 0}
|
sprint.id,
|
||||||
|
{"issue_count": 0, "open_issues": 0, "completed_issues": 0},
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
for sprint in sprints
|
for sprint in sprints
|
||||||
|
|||||||
@@ -3,27 +3,48 @@
|
|||||||
Async database initialization script.
|
Async database initialization script.
|
||||||
|
|
||||||
Creates the first superuser if configured and doesn't already exist.
|
Creates the first superuser if configured and doesn't already exist.
|
||||||
|
Seeds default agent types (production data) and demo data (when DEMO_MODE is enabled).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
from datetime import UTC, datetime, timedelta
|
from datetime import UTC, date, datetime, timedelta
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from sqlalchemy import select, text
|
from sqlalchemy import select, text
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.database import SessionLocal, engine
|
from app.core.database import SessionLocal, engine
|
||||||
|
from app.crud.syndarix.agent_type import agent_type as agent_type_crud
|
||||||
from app.crud.user import user as user_crud
|
from app.crud.user import user as user_crud
|
||||||
from app.models.organization import Organization
|
from app.models.organization import Organization
|
||||||
|
from app.models.syndarix import AgentInstance, AgentType, Issue, Project, Sprint
|
||||||
|
from app.models.syndarix.enums import (
|
||||||
|
AgentStatus,
|
||||||
|
AutonomyLevel,
|
||||||
|
ClientMode,
|
||||||
|
IssuePriority,
|
||||||
|
IssueStatus,
|
||||||
|
IssueType,
|
||||||
|
ProjectComplexity,
|
||||||
|
ProjectStatus,
|
||||||
|
SprintStatus,
|
||||||
|
)
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.models.user_organization import UserOrganization
|
from app.models.user_organization import UserOrganization
|
||||||
|
from app.schemas.syndarix import AgentTypeCreate
|
||||||
from app.schemas.users import UserCreate
|
from app.schemas.users import UserCreate
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Data file paths
|
||||||
|
DATA_DIR = Path(__file__).parent.parent / "data"
|
||||||
|
DEFAULT_AGENT_TYPES_PATH = DATA_DIR / "default_agent_types.json"
|
||||||
|
DEMO_DATA_PATH = DATA_DIR / "demo_data.json"
|
||||||
|
|
||||||
|
|
||||||
async def init_db() -> User | None:
|
async def init_db() -> User | None:
|
||||||
"""
|
"""
|
||||||
@@ -54,8 +75,7 @@ async def init_db() -> User | None:
|
|||||||
|
|
||||||
if existing_user:
|
if existing_user:
|
||||||
logger.info(f"Superuser already exists: {existing_user.email}")
|
logger.info(f"Superuser already exists: {existing_user.email}")
|
||||||
return existing_user
|
else:
|
||||||
|
|
||||||
# Create superuser if doesn't exist
|
# Create superuser if doesn't exist
|
||||||
user_in = UserCreate(
|
user_in = UserCreate(
|
||||||
email=superuser_email,
|
email=superuser_email,
|
||||||
@@ -65,17 +85,19 @@ async def init_db() -> User | None:
|
|||||||
is_superuser=True,
|
is_superuser=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
user = await user_crud.create(session, obj_in=user_in)
|
existing_user = await user_crud.create(session, obj_in=user_in)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
await session.refresh(user)
|
await session.refresh(existing_user)
|
||||||
|
logger.info(f"Created first superuser: {existing_user.email}")
|
||||||
|
|
||||||
logger.info(f"Created first superuser: {user.email}")
|
# ALWAYS load default agent types (production data)
|
||||||
|
await load_default_agent_types(session)
|
||||||
|
|
||||||
# Create demo data if in demo mode
|
# Only load demo data if in demo mode
|
||||||
if settings.DEMO_MODE:
|
if settings.DEMO_MODE:
|
||||||
await load_demo_data(session)
|
await load_demo_data(session)
|
||||||
|
|
||||||
return user
|
return existing_user
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
@@ -88,26 +110,96 @@ def _load_json_file(path: Path):
|
|||||||
return json.load(f)
|
return json.load(f)
|
||||||
|
|
||||||
|
|
||||||
async def load_demo_data(session):
|
async def load_default_agent_types(session: AsyncSession) -> None:
|
||||||
"""Load demo data from JSON file."""
|
"""
|
||||||
demo_data_path = Path(__file__).parent / "core" / "demo_data.json"
|
Load default agent types from JSON file.
|
||||||
if not demo_data_path.exists():
|
|
||||||
logger.warning(f"Demo data file not found: {demo_data_path}")
|
These are production defaults - created only if they don't exist, never overwritten.
|
||||||
|
This allows users to customize agent types without worrying about server restarts.
|
||||||
|
"""
|
||||||
|
if not DEFAULT_AGENT_TYPES_PATH.exists():
|
||||||
|
logger.warning(
|
||||||
|
f"Default agent types file not found: {DEFAULT_AGENT_TYPES_PATH}"
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Use asyncio.to_thread to avoid blocking the event loop
|
data = await asyncio.to_thread(_load_json_file, DEFAULT_AGENT_TYPES_PATH)
|
||||||
data = await asyncio.to_thread(_load_json_file, demo_data_path)
|
|
||||||
|
|
||||||
# Create Organizations
|
for agent_type_data in data:
|
||||||
org_map = {}
|
slug = agent_type_data["slug"]
|
||||||
for org_data in data.get("organizations", []):
|
|
||||||
# Check if org exists
|
# Check if agent type already exists
|
||||||
result = await session.execute(
|
existing = await agent_type_crud.get_by_slug(session, slug=slug)
|
||||||
text("SELECT * FROM organizations WHERE slug = :slug"),
|
|
||||||
{"slug": org_data["slug"]},
|
if existing:
|
||||||
|
logger.debug(f"Agent type already exists: {agent_type_data['name']}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Create the agent type
|
||||||
|
agent_type_in = AgentTypeCreate(
|
||||||
|
name=agent_type_data["name"],
|
||||||
|
slug=slug,
|
||||||
|
description=agent_type_data.get("description"),
|
||||||
|
expertise=agent_type_data.get("expertise", []),
|
||||||
|
personality_prompt=agent_type_data["personality_prompt"],
|
||||||
|
primary_model=agent_type_data["primary_model"],
|
||||||
|
fallback_models=agent_type_data.get("fallback_models", []),
|
||||||
|
model_params=agent_type_data.get("model_params", {}),
|
||||||
|
mcp_servers=agent_type_data.get("mcp_servers", []),
|
||||||
|
tool_permissions=agent_type_data.get("tool_permissions", {}),
|
||||||
|
is_active=agent_type_data.get("is_active", True),
|
||||||
|
# Category and display fields
|
||||||
|
category=agent_type_data.get("category"),
|
||||||
|
icon=agent_type_data.get("icon", "bot"),
|
||||||
|
color=agent_type_data.get("color", "#3B82F6"),
|
||||||
|
sort_order=agent_type_data.get("sort_order", 0),
|
||||||
|
typical_tasks=agent_type_data.get("typical_tasks", []),
|
||||||
|
collaboration_hints=agent_type_data.get("collaboration_hints", []),
|
||||||
)
|
)
|
||||||
existing_org = result.first()
|
|
||||||
|
await agent_type_crud.create(session, obj_in=agent_type_in)
|
||||||
|
logger.info(f"Created default agent type: {agent_type_data['name']}")
|
||||||
|
|
||||||
|
logger.info("Default agent types loaded successfully")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading default agent types: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
async def load_demo_data(session: AsyncSession) -> None:
|
||||||
|
"""
|
||||||
|
Load demo data from JSON file.
|
||||||
|
|
||||||
|
Only runs when DEMO_MODE is enabled. Creates demo organizations, users,
|
||||||
|
projects, sprints, agent instances, and issues.
|
||||||
|
"""
|
||||||
|
if not DEMO_DATA_PATH.exists():
|
||||||
|
logger.warning(f"Demo data file not found: {DEMO_DATA_PATH}")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = await asyncio.to_thread(_load_json_file, DEMO_DATA_PATH)
|
||||||
|
|
||||||
|
# Build lookup maps for FK resolution
|
||||||
|
org_map: dict[str, Organization] = {}
|
||||||
|
user_map: dict[str, User] = {}
|
||||||
|
project_map: dict[str, Project] = {}
|
||||||
|
sprint_map: dict[str, Sprint] = {} # key: "project_slug:sprint_number"
|
||||||
|
agent_type_map: dict[str, AgentType] = {}
|
||||||
|
agent_instance_map: dict[
|
||||||
|
str, AgentInstance
|
||||||
|
] = {} # key: "project_slug:agent_name"
|
||||||
|
|
||||||
|
# ========================
|
||||||
|
# 1. Create Organizations
|
||||||
|
# ========================
|
||||||
|
for org_data in data.get("organizations", []):
|
||||||
|
org_result = await session.execute(
|
||||||
|
select(Organization).where(Organization.slug == org_data["slug"])
|
||||||
|
)
|
||||||
|
existing_org = org_result.scalar_one_or_none()
|
||||||
|
|
||||||
if not existing_org:
|
if not existing_org:
|
||||||
org = Organization(
|
org = Organization(
|
||||||
@@ -117,29 +209,20 @@ async def load_demo_data(session):
|
|||||||
is_active=True,
|
is_active=True,
|
||||||
)
|
)
|
||||||
session.add(org)
|
session.add(org)
|
||||||
await session.flush() # Flush to get ID
|
await session.flush()
|
||||||
org_map[org.slug] = org
|
org_map[str(org.slug)] = org
|
||||||
logger.info(f"Created demo organization: {org.name}")
|
logger.info(f"Created demo organization: {org.name}")
|
||||||
else:
|
else:
|
||||||
# We can't easily get the ORM object from raw SQL result for map without querying again or mapping
|
org_map[str(existing_org.slug)] = existing_org
|
||||||
# So let's just query it properly if we need it for relationships
|
|
||||||
# But for simplicity in this script, let's just assume we created it or it exists.
|
|
||||||
# To properly map for users, we need the ID.
|
|
||||||
# Let's use a simpler approach: just try to create, if slug conflict, skip.
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Re-query all orgs to build map for users
|
# ========================
|
||||||
result = await session.execute(select(Organization))
|
# 2. Create Users
|
||||||
orgs = result.scalars().all()
|
# ========================
|
||||||
org_map = {org.slug: org for org in orgs}
|
|
||||||
|
|
||||||
# Create Users
|
|
||||||
for user_data in data.get("users", []):
|
for user_data in data.get("users", []):
|
||||||
existing_user = await user_crud.get_by_email(
|
existing_user = await user_crud.get_by_email(
|
||||||
session, email=user_data["email"]
|
session, email=user_data["email"]
|
||||||
)
|
)
|
||||||
if not existing_user:
|
if not existing_user:
|
||||||
# Create user
|
|
||||||
user_in = UserCreate(
|
user_in = UserCreate(
|
||||||
email=user_data["email"],
|
email=user_data["email"],
|
||||||
password=user_data["password"],
|
password=user_data["password"],
|
||||||
@@ -151,17 +234,13 @@ async def load_demo_data(session):
|
|||||||
user = await user_crud.create(session, obj_in=user_in)
|
user = await user_crud.create(session, obj_in=user_in)
|
||||||
|
|
||||||
# Randomize created_at for demo data (last 30 days)
|
# Randomize created_at for demo data (last 30 days)
|
||||||
# This makes the charts look more realistic
|
|
||||||
days_ago = random.randint(0, 30) # noqa: S311
|
days_ago = random.randint(0, 30) # noqa: S311
|
||||||
random_time = datetime.now(UTC) - timedelta(days=days_ago)
|
random_time = datetime.now(UTC) - timedelta(days=days_ago)
|
||||||
# Add some random hours/minutes variation
|
|
||||||
random_time = random_time.replace(
|
random_time = random_time.replace(
|
||||||
hour=random.randint(0, 23), # noqa: S311
|
hour=random.randint(0, 23), # noqa: S311
|
||||||
minute=random.randint(0, 59), # noqa: S311
|
minute=random.randint(0, 59), # noqa: S311
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update the timestamp and is_active directly in the database
|
|
||||||
# We do this to ensure the values are persisted correctly
|
|
||||||
await session.execute(
|
await session.execute(
|
||||||
text(
|
text(
|
||||||
"UPDATE users SET created_at = :created_at, is_active = :is_active WHERE id = :user_id"
|
"UPDATE users SET created_at = :created_at, is_active = :is_active WHERE id = :user_id"
|
||||||
@@ -174,7 +253,7 @@ async def load_demo_data(session):
|
|||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Created demo user: {user.email} (created {days_ago} days ago, active={user_data.get('is_active', True)})"
|
f"Created demo user: {user.email} (created {days_ago} days ago)"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add to organization if specified
|
# Add to organization if specified
|
||||||
@@ -182,19 +261,228 @@ async def load_demo_data(session):
|
|||||||
role = user_data.get("role")
|
role = user_data.get("role")
|
||||||
if org_slug and org_slug in org_map and role:
|
if org_slug and org_slug in org_map and role:
|
||||||
org = org_map[org_slug]
|
org = org_map[org_slug]
|
||||||
# Check if membership exists (it shouldn't for new user)
|
|
||||||
member = UserOrganization(
|
member = UserOrganization(
|
||||||
user_id=user.id, organization_id=org.id, role=role
|
user_id=user.id, organization_id=org.id, role=role
|
||||||
)
|
)
|
||||||
session.add(member)
|
session.add(member)
|
||||||
logger.info(f"Added {user.email} to {org.name} as {role}")
|
logger.info(f"Added {user.email} to {org.name} as {role}")
|
||||||
|
|
||||||
|
user_map[str(user.email)] = user
|
||||||
else:
|
else:
|
||||||
logger.info(f"Demo user already exists: {existing_user.email}")
|
user_map[str(existing_user.email)] = existing_user
|
||||||
|
logger.debug(f"Demo user already exists: {existing_user.email}")
|
||||||
|
|
||||||
|
await session.flush()
|
||||||
|
|
||||||
|
# Add admin user to map with special "__admin__" key
|
||||||
|
# This allows demo data to reference the admin user as owner
|
||||||
|
superuser_email = settings.FIRST_SUPERUSER_EMAIL or "admin@example.com"
|
||||||
|
admin_user = await user_crud.get_by_email(session, email=superuser_email)
|
||||||
|
if admin_user:
|
||||||
|
user_map["__admin__"] = admin_user
|
||||||
|
user_map[str(admin_user.email)] = admin_user
|
||||||
|
logger.debug(f"Added admin user to map: {admin_user.email}")
|
||||||
|
|
||||||
|
# ========================
|
||||||
|
# 3. Load Agent Types Map (for FK resolution)
|
||||||
|
# ========================
|
||||||
|
agent_types_result = await session.execute(select(AgentType))
|
||||||
|
for at in agent_types_result.scalars().all():
|
||||||
|
agent_type_map[str(at.slug)] = at
|
||||||
|
|
||||||
|
# ========================
|
||||||
|
# 4. Create Projects
|
||||||
|
# ========================
|
||||||
|
for project_data in data.get("projects", []):
|
||||||
|
project_result = await session.execute(
|
||||||
|
select(Project).where(Project.slug == project_data["slug"])
|
||||||
|
)
|
||||||
|
existing_project = project_result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not existing_project:
|
||||||
|
# Resolve owner email to user ID
|
||||||
|
owner_id = None
|
||||||
|
owner_email = project_data.get("owner_email")
|
||||||
|
if owner_email and owner_email in user_map:
|
||||||
|
owner_id = user_map[owner_email].id
|
||||||
|
|
||||||
|
project = Project(
|
||||||
|
name=project_data["name"],
|
||||||
|
slug=project_data["slug"],
|
||||||
|
description=project_data.get("description"),
|
||||||
|
owner_id=owner_id,
|
||||||
|
autonomy_level=AutonomyLevel(
|
||||||
|
project_data.get("autonomy_level", "milestone")
|
||||||
|
),
|
||||||
|
status=ProjectStatus(project_data.get("status", "active")),
|
||||||
|
complexity=ProjectComplexity(
|
||||||
|
project_data.get("complexity", "medium")
|
||||||
|
),
|
||||||
|
client_mode=ClientMode(project_data.get("client_mode", "auto")),
|
||||||
|
settings=project_data.get("settings", {}),
|
||||||
|
)
|
||||||
|
session.add(project)
|
||||||
|
await session.flush()
|
||||||
|
project_map[str(project.slug)] = project
|
||||||
|
logger.info(f"Created demo project: {project.name}")
|
||||||
|
else:
|
||||||
|
project_map[str(existing_project.slug)] = existing_project
|
||||||
|
logger.debug(f"Demo project already exists: {existing_project.name}")
|
||||||
|
|
||||||
|
# ========================
|
||||||
|
# 5. Create Sprints
|
||||||
|
# ========================
|
||||||
|
for sprint_data in data.get("sprints", []):
|
||||||
|
project_slug = sprint_data["project_slug"]
|
||||||
|
sprint_number = sprint_data["number"]
|
||||||
|
sprint_key = f"{project_slug}:{sprint_number}"
|
||||||
|
|
||||||
|
if project_slug not in project_map:
|
||||||
|
logger.warning(f"Project not found for sprint: {project_slug}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
sprint_project = project_map[project_slug]
|
||||||
|
|
||||||
|
# Check if sprint exists
|
||||||
|
sprint_result = await session.execute(
|
||||||
|
select(Sprint).where(
|
||||||
|
Sprint.project_id == sprint_project.id,
|
||||||
|
Sprint.number == sprint_number,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
existing_sprint = sprint_result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not existing_sprint:
|
||||||
|
sprint = Sprint(
|
||||||
|
project_id=sprint_project.id,
|
||||||
|
name=sprint_data["name"],
|
||||||
|
number=sprint_number,
|
||||||
|
goal=sprint_data.get("goal"),
|
||||||
|
start_date=date.fromisoformat(sprint_data["start_date"]),
|
||||||
|
end_date=date.fromisoformat(sprint_data["end_date"]),
|
||||||
|
status=SprintStatus(sprint_data.get("status", "planned")),
|
||||||
|
planned_points=sprint_data.get("planned_points"),
|
||||||
|
)
|
||||||
|
session.add(sprint)
|
||||||
|
await session.flush()
|
||||||
|
sprint_map[sprint_key] = sprint
|
||||||
|
logger.info(
|
||||||
|
f"Created demo sprint: {sprint.name} for {sprint_project.name}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sprint_map[sprint_key] = existing_sprint
|
||||||
|
logger.debug(f"Demo sprint already exists: {existing_sprint.name}")
|
||||||
|
|
||||||
|
# ========================
|
||||||
|
# 6. Create Agent Instances
|
||||||
|
# ========================
|
||||||
|
for agent_data in data.get("agent_instances", []):
|
||||||
|
project_slug = agent_data["project_slug"]
|
||||||
|
agent_type_slug = agent_data["agent_type_slug"]
|
||||||
|
agent_name = agent_data["name"]
|
||||||
|
agent_key = f"{project_slug}:{agent_name}"
|
||||||
|
|
||||||
|
if project_slug not in project_map:
|
||||||
|
logger.warning(f"Project not found for agent: {project_slug}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if agent_type_slug not in agent_type_map:
|
||||||
|
logger.warning(f"Agent type not found: {agent_type_slug}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
agent_project = project_map[project_slug]
|
||||||
|
agent_type = agent_type_map[agent_type_slug]
|
||||||
|
|
||||||
|
# Check if agent instance exists (by name within project)
|
||||||
|
agent_result = await session.execute(
|
||||||
|
select(AgentInstance).where(
|
||||||
|
AgentInstance.project_id == agent_project.id,
|
||||||
|
AgentInstance.name == agent_name,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
existing_agent = agent_result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not existing_agent:
|
||||||
|
agent_instance = AgentInstance(
|
||||||
|
project_id=agent_project.id,
|
||||||
|
agent_type_id=agent_type.id,
|
||||||
|
name=agent_name,
|
||||||
|
status=AgentStatus(agent_data.get("status", "idle")),
|
||||||
|
current_task=agent_data.get("current_task"),
|
||||||
|
)
|
||||||
|
session.add(agent_instance)
|
||||||
|
await session.flush()
|
||||||
|
agent_instance_map[agent_key] = agent_instance
|
||||||
|
logger.info(
|
||||||
|
f"Created demo agent: {agent_name} ({agent_type.name}) "
|
||||||
|
f"for {agent_project.name}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
agent_instance_map[agent_key] = existing_agent
|
||||||
|
logger.debug(f"Demo agent already exists: {existing_agent.name}")
|
||||||
|
|
||||||
|
# ========================
|
||||||
|
# 7. Create Issues
|
||||||
|
# ========================
|
||||||
|
for issue_data in data.get("issues", []):
|
||||||
|
project_slug = issue_data["project_slug"]
|
||||||
|
|
||||||
|
if project_slug not in project_map:
|
||||||
|
logger.warning(f"Project not found for issue: {project_slug}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
issue_project = project_map[project_slug]
|
||||||
|
|
||||||
|
# Check if issue exists (by title within project - simple heuristic)
|
||||||
|
issue_result = await session.execute(
|
||||||
|
select(Issue).where(
|
||||||
|
Issue.project_id == issue_project.id,
|
||||||
|
Issue.title == issue_data["title"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
existing_issue = issue_result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not existing_issue:
|
||||||
|
# Resolve sprint
|
||||||
|
sprint_id = None
|
||||||
|
sprint_number = issue_data.get("sprint_number")
|
||||||
|
if sprint_number:
|
||||||
|
sprint_key = f"{project_slug}:{sprint_number}"
|
||||||
|
if sprint_key in sprint_map:
|
||||||
|
sprint_id = sprint_map[sprint_key].id
|
||||||
|
|
||||||
|
# Resolve assigned agent
|
||||||
|
assigned_agent_id = None
|
||||||
|
assigned_agent_name = issue_data.get("assigned_agent_name")
|
||||||
|
if assigned_agent_name:
|
||||||
|
agent_key = f"{project_slug}:{assigned_agent_name}"
|
||||||
|
if agent_key in agent_instance_map:
|
||||||
|
assigned_agent_id = agent_instance_map[agent_key].id
|
||||||
|
|
||||||
|
issue = Issue(
|
||||||
|
project_id=issue_project.id,
|
||||||
|
sprint_id=sprint_id,
|
||||||
|
type=IssueType(issue_data.get("type", "task")),
|
||||||
|
title=issue_data["title"],
|
||||||
|
body=issue_data.get("body", ""),
|
||||||
|
status=IssueStatus(issue_data.get("status", "open")),
|
||||||
|
priority=IssuePriority(issue_data.get("priority", "medium")),
|
||||||
|
labels=issue_data.get("labels", []),
|
||||||
|
story_points=issue_data.get("story_points"),
|
||||||
|
assigned_agent_id=assigned_agent_id,
|
||||||
|
)
|
||||||
|
session.add(issue)
|
||||||
|
logger.info(f"Created demo issue: {issue.title[:50]}...")
|
||||||
|
else:
|
||||||
|
logger.debug(
|
||||||
|
f"Demo issue already exists: {existing_issue.title[:50]}..."
|
||||||
|
)
|
||||||
|
|
||||||
await session.commit()
|
await session.commit()
|
||||||
logger.info("Demo data loaded successfully")
|
logger.info("Demo data loaded successfully")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
await session.rollback()
|
||||||
logger.error(f"Error loading demo data: {e}")
|
logger.error(f"Error loading demo data: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@@ -210,12 +498,12 @@ async def main():
|
|||||||
try:
|
try:
|
||||||
user = await init_db()
|
user = await init_db()
|
||||||
if user:
|
if user:
|
||||||
print("✓ Database initialized successfully")
|
print("Database initialized successfully")
|
||||||
print(f"✓ Superuser: {user.email}")
|
print(f"Superuser: {user.email}")
|
||||||
else:
|
else:
|
||||||
print("✗ Failed to initialize database")
|
print("Failed to initialize database")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"✗ Error initializing database: {e}")
|
print(f"Error initializing database: {e}")
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
# Close the engine
|
# Close the engine
|
||||||
|
|||||||
@@ -8,6 +8,19 @@ from app.core.database import Base
|
|||||||
|
|
||||||
from .base import TimestampMixin, UUIDMixin
|
from .base import TimestampMixin, UUIDMixin
|
||||||
|
|
||||||
|
# Memory system models
|
||||||
|
from .memory import (
|
||||||
|
ConsolidationStatus,
|
||||||
|
ConsolidationType,
|
||||||
|
Episode,
|
||||||
|
EpisodeOutcome,
|
||||||
|
Fact,
|
||||||
|
MemoryConsolidationLog,
|
||||||
|
Procedure,
|
||||||
|
ScopeType,
|
||||||
|
WorkingMemory,
|
||||||
|
)
|
||||||
|
|
||||||
# OAuth models (client mode - authenticate via Google/GitHub)
|
# OAuth models (client mode - authenticate via Google/GitHub)
|
||||||
from .oauth_account import OAuthAccount
|
from .oauth_account import OAuthAccount
|
||||||
|
|
||||||
@@ -37,7 +50,14 @@ __all__ = [
|
|||||||
"AgentInstance",
|
"AgentInstance",
|
||||||
"AgentType",
|
"AgentType",
|
||||||
"Base",
|
"Base",
|
||||||
|
# Memory models
|
||||||
|
"ConsolidationStatus",
|
||||||
|
"ConsolidationType",
|
||||||
|
"Episode",
|
||||||
|
"EpisodeOutcome",
|
||||||
|
"Fact",
|
||||||
"Issue",
|
"Issue",
|
||||||
|
"MemoryConsolidationLog",
|
||||||
"OAuthAccount",
|
"OAuthAccount",
|
||||||
"OAuthAuthorizationCode",
|
"OAuthAuthorizationCode",
|
||||||
"OAuthClient",
|
"OAuthClient",
|
||||||
@@ -46,11 +66,14 @@ __all__ = [
|
|||||||
"OAuthState",
|
"OAuthState",
|
||||||
"Organization",
|
"Organization",
|
||||||
"OrganizationRole",
|
"OrganizationRole",
|
||||||
|
"Procedure",
|
||||||
"Project",
|
"Project",
|
||||||
|
"ScopeType",
|
||||||
"Sprint",
|
"Sprint",
|
||||||
"TimestampMixin",
|
"TimestampMixin",
|
||||||
"UUIDMixin",
|
"UUIDMixin",
|
||||||
"User",
|
"User",
|
||||||
"UserOrganization",
|
"UserOrganization",
|
||||||
"UserSession",
|
"UserSession",
|
||||||
|
"WorkingMemory",
|
||||||
]
|
]
|
||||||
|
|||||||
32
backend/app/models/memory/__init__.py
Normal file
32
backend/app/models/memory/__init__.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
# app/models/memory/__init__.py
|
||||||
|
"""
|
||||||
|
Memory System Database Models.
|
||||||
|
|
||||||
|
Provides SQLAlchemy models for the Agent Memory System:
|
||||||
|
- WorkingMemory: Key-value storage with TTL
|
||||||
|
- Episode: Experiential memories
|
||||||
|
- Fact: Semantic knowledge triples
|
||||||
|
- Procedure: Learned skills
|
||||||
|
- MemoryConsolidationLog: Consolidation job tracking
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .consolidation import MemoryConsolidationLog
|
||||||
|
from .enums import ConsolidationStatus, ConsolidationType, EpisodeOutcome, ScopeType
|
||||||
|
from .episode import Episode
|
||||||
|
from .fact import Fact
|
||||||
|
from .procedure import Procedure
|
||||||
|
from .working_memory import WorkingMemory
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# Enums
|
||||||
|
"ConsolidationStatus",
|
||||||
|
"ConsolidationType",
|
||||||
|
# Models
|
||||||
|
"Episode",
|
||||||
|
"EpisodeOutcome",
|
||||||
|
"Fact",
|
||||||
|
"MemoryConsolidationLog",
|
||||||
|
"Procedure",
|
||||||
|
"ScopeType",
|
||||||
|
"WorkingMemory",
|
||||||
|
]
|
||||||
72
backend/app/models/memory/consolidation.py
Normal file
72
backend/app/models/memory/consolidation.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
# app/models/memory/consolidation.py
|
||||||
|
"""
|
||||||
|
Memory Consolidation Log database model.
|
||||||
|
|
||||||
|
Tracks memory consolidation jobs that transfer knowledge
|
||||||
|
between memory tiers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from sqlalchemy import Column, DateTime, Enum, Index, Integer, Text
|
||||||
|
|
||||||
|
from app.models.base import Base, TimestampMixin, UUIDMixin
|
||||||
|
|
||||||
|
from .enums import ConsolidationStatus, ConsolidationType
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryConsolidationLog(Base, UUIDMixin, TimestampMixin):
|
||||||
|
"""
|
||||||
|
Memory consolidation job log.
|
||||||
|
|
||||||
|
Tracks consolidation operations:
|
||||||
|
- Working -> Episodic (session end)
|
||||||
|
- Episodic -> Semantic (fact extraction)
|
||||||
|
- Episodic -> Procedural (procedure learning)
|
||||||
|
- Pruning (removing low-value memories)
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "memory_consolidation_log"
|
||||||
|
|
||||||
|
# Consolidation type
|
||||||
|
consolidation_type: Column[ConsolidationType] = Column(
|
||||||
|
Enum(ConsolidationType),
|
||||||
|
nullable=False,
|
||||||
|
index=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Counts
|
||||||
|
source_count = Column(Integer, nullable=False, default=0)
|
||||||
|
result_count = Column(Integer, nullable=False, default=0)
|
||||||
|
|
||||||
|
# Timing
|
||||||
|
started_at = Column(DateTime(timezone=True), nullable=False)
|
||||||
|
completed_at = Column(DateTime(timezone=True), nullable=True)
|
||||||
|
|
||||||
|
# Status
|
||||||
|
status: Column[ConsolidationStatus] = Column(
|
||||||
|
Enum(ConsolidationStatus),
|
||||||
|
nullable=False,
|
||||||
|
default=ConsolidationStatus.PENDING,
|
||||||
|
index=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Error details if failed
|
||||||
|
error = Column(Text, nullable=True)
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
# Query patterns
|
||||||
|
Index("ix_consolidation_type_status", "consolidation_type", "status"),
|
||||||
|
Index("ix_consolidation_started", "started_at"),
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def duration_seconds(self) -> float | None:
|
||||||
|
"""Calculate duration of the consolidation job."""
|
||||||
|
if self.completed_at is None or self.started_at is None:
|
||||||
|
return None
|
||||||
|
return (self.completed_at - self.started_at).total_seconds()
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return (
|
||||||
|
f"<MemoryConsolidationLog {self.id} "
|
||||||
|
f"type={self.consolidation_type.value} status={self.status.value}>"
|
||||||
|
)
|
||||||
73
backend/app/models/memory/enums.py
Normal file
73
backend/app/models/memory/enums.py
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
# app/models/memory/enums.py
|
||||||
|
"""
|
||||||
|
Enums for Memory System database models.
|
||||||
|
|
||||||
|
These enums define the database-level constraints for memory types
|
||||||
|
and scoping levels.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from enum import Enum as PyEnum
|
||||||
|
|
||||||
|
|
||||||
|
class ScopeType(str, PyEnum):
|
||||||
|
"""
|
||||||
|
Memory scope levels matching the memory service types.
|
||||||
|
|
||||||
|
GLOBAL: System-wide memories accessible by all
|
||||||
|
PROJECT: Project-scoped memories
|
||||||
|
AGENT_TYPE: Type-specific memories (shared by instances of same type)
|
||||||
|
AGENT_INSTANCE: Instance-specific memories
|
||||||
|
SESSION: Session-scoped ephemeral memories
|
||||||
|
"""
|
||||||
|
|
||||||
|
GLOBAL = "global"
|
||||||
|
PROJECT = "project"
|
||||||
|
AGENT_TYPE = "agent_type"
|
||||||
|
AGENT_INSTANCE = "agent_instance"
|
||||||
|
SESSION = "session"
|
||||||
|
|
||||||
|
|
||||||
|
class EpisodeOutcome(str, PyEnum):
|
||||||
|
"""
|
||||||
|
Outcome of an episode (task execution).
|
||||||
|
|
||||||
|
SUCCESS: Task completed successfully
|
||||||
|
FAILURE: Task failed
|
||||||
|
PARTIAL: Task partially completed
|
||||||
|
"""
|
||||||
|
|
||||||
|
SUCCESS = "success"
|
||||||
|
FAILURE = "failure"
|
||||||
|
PARTIAL = "partial"
|
||||||
|
|
||||||
|
|
||||||
|
class ConsolidationType(str, PyEnum):
|
||||||
|
"""
|
||||||
|
Types of memory consolidation operations.
|
||||||
|
|
||||||
|
WORKING_TO_EPISODIC: Transfer session state to episodic
|
||||||
|
EPISODIC_TO_SEMANTIC: Extract facts from episodes
|
||||||
|
EPISODIC_TO_PROCEDURAL: Extract procedures from episodes
|
||||||
|
PRUNING: Remove low-value memories
|
||||||
|
"""
|
||||||
|
|
||||||
|
WORKING_TO_EPISODIC = "working_to_episodic"
|
||||||
|
EPISODIC_TO_SEMANTIC = "episodic_to_semantic"
|
||||||
|
EPISODIC_TO_PROCEDURAL = "episodic_to_procedural"
|
||||||
|
PRUNING = "pruning"
|
||||||
|
|
||||||
|
|
||||||
|
class ConsolidationStatus(str, PyEnum):
|
||||||
|
"""
|
||||||
|
Status of a consolidation job.
|
||||||
|
|
||||||
|
PENDING: Job is queued
|
||||||
|
RUNNING: Job is currently executing
|
||||||
|
COMPLETED: Job finished successfully
|
||||||
|
FAILED: Job failed with errors
|
||||||
|
"""
|
||||||
|
|
||||||
|
PENDING = "pending"
|
||||||
|
RUNNING = "running"
|
||||||
|
COMPLETED = "completed"
|
||||||
|
FAILED = "failed"
|
||||||
139
backend/app/models/memory/episode.py
Normal file
139
backend/app/models/memory/episode.py
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
# app/models/memory/episode.py
|
||||||
|
"""
|
||||||
|
Episode database model.
|
||||||
|
|
||||||
|
Stores experiential memories - records of past task executions
|
||||||
|
with context, actions, outcomes, and lessons learned.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from sqlalchemy import (
|
||||||
|
BigInteger,
|
||||||
|
CheckConstraint,
|
||||||
|
Column,
|
||||||
|
DateTime,
|
||||||
|
Enum,
|
||||||
|
Float,
|
||||||
|
ForeignKey,
|
||||||
|
Index,
|
||||||
|
String,
|
||||||
|
Text,
|
||||||
|
)
|
||||||
|
from sqlalchemy.dialects.postgresql import (
|
||||||
|
JSONB,
|
||||||
|
UUID as PGUUID,
|
||||||
|
)
|
||||||
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
|
from app.models.base import Base, TimestampMixin, UUIDMixin
|
||||||
|
|
||||||
|
from .enums import EpisodeOutcome
|
||||||
|
|
||||||
|
# Import pgvector type - will be available after migration enables extension
|
||||||
|
try:
|
||||||
|
from pgvector.sqlalchemy import Vector # type: ignore[import-not-found]
|
||||||
|
except ImportError:
|
||||||
|
# Fallback for environments without pgvector
|
||||||
|
Vector = None
|
||||||
|
|
||||||
|
|
||||||
|
class Episode(Base, UUIDMixin, TimestampMixin):
|
||||||
|
"""
|
||||||
|
Episodic memory model.
|
||||||
|
|
||||||
|
Records experiential memories from agent task execution:
|
||||||
|
- What task was performed
|
||||||
|
- What actions were taken
|
||||||
|
- What was the outcome
|
||||||
|
- What lessons were learned
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "episodes"
|
||||||
|
|
||||||
|
# Foreign keys
|
||||||
|
project_id = Column(
|
||||||
|
PGUUID(as_uuid=True),
|
||||||
|
ForeignKey("projects.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
index=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
agent_instance_id = Column(
|
||||||
|
PGUUID(as_uuid=True),
|
||||||
|
ForeignKey("agent_instances.id", ondelete="SET NULL"),
|
||||||
|
nullable=True,
|
||||||
|
index=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
agent_type_id = Column(
|
||||||
|
PGUUID(as_uuid=True),
|
||||||
|
ForeignKey("agent_types.id", ondelete="SET NULL"),
|
||||||
|
nullable=True,
|
||||||
|
index=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Session reference
|
||||||
|
session_id = Column(String(255), nullable=False, index=True)
|
||||||
|
|
||||||
|
# Task information
|
||||||
|
task_type = Column(String(100), nullable=False, index=True)
|
||||||
|
task_description = Column(Text, nullable=False)
|
||||||
|
|
||||||
|
# Actions taken (list of action dictionaries)
|
||||||
|
actions = Column(JSONB, default=list, nullable=False)
|
||||||
|
|
||||||
|
# Context summary
|
||||||
|
context_summary = Column(Text, nullable=False)
|
||||||
|
|
||||||
|
# Outcome
|
||||||
|
outcome: Column[EpisodeOutcome] = Column(
|
||||||
|
Enum(EpisodeOutcome),
|
||||||
|
nullable=False,
|
||||||
|
index=True,
|
||||||
|
)
|
||||||
|
outcome_details = Column(Text, nullable=True)
|
||||||
|
|
||||||
|
# Metrics
|
||||||
|
duration_seconds = Column(Float, nullable=False, default=0.0)
|
||||||
|
tokens_used = Column(BigInteger, nullable=False, default=0)
|
||||||
|
|
||||||
|
# Learning
|
||||||
|
lessons_learned = Column(JSONB, default=list, nullable=False)
|
||||||
|
importance_score = Column(Float, nullable=False, default=0.5, index=True)
|
||||||
|
|
||||||
|
# Vector embedding for semantic search
|
||||||
|
# Using 1536 dimensions for OpenAI text-embedding-3-small
|
||||||
|
embedding = Column(Vector(1536) if Vector else Text, nullable=True)
|
||||||
|
|
||||||
|
# When the episode occurred
|
||||||
|
occurred_at = Column(DateTime(timezone=True), nullable=False, index=True)
|
||||||
|
|
||||||
|
# Relationships
|
||||||
|
project = relationship("Project", foreign_keys=[project_id])
|
||||||
|
agent_instance = relationship("AgentInstance", foreign_keys=[agent_instance_id])
|
||||||
|
agent_type = relationship("AgentType", foreign_keys=[agent_type_id])
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
# Primary query patterns
|
||||||
|
Index("ix_episodes_project_task", "project_id", "task_type"),
|
||||||
|
Index("ix_episodes_project_outcome", "project_id", "outcome"),
|
||||||
|
Index("ix_episodes_agent_task", "agent_instance_id", "task_type"),
|
||||||
|
Index("ix_episodes_project_time", "project_id", "occurred_at"),
|
||||||
|
# For importance-based pruning
|
||||||
|
Index("ix_episodes_importance_time", "importance_score", "occurred_at"),
|
||||||
|
# Data integrity constraints
|
||||||
|
CheckConstraint(
|
||||||
|
"importance_score >= 0.0 AND importance_score <= 1.0",
|
||||||
|
name="ck_episodes_importance_range",
|
||||||
|
),
|
||||||
|
CheckConstraint(
|
||||||
|
"duration_seconds >= 0.0",
|
||||||
|
name="ck_episodes_duration_positive",
|
||||||
|
),
|
||||||
|
CheckConstraint(
|
||||||
|
"tokens_used >= 0",
|
||||||
|
name="ck_episodes_tokens_positive",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"<Episode {self.id} task={self.task_type} outcome={self.outcome.value}>"
|
||||||
120
backend/app/models/memory/fact.py
Normal file
120
backend/app/models/memory/fact.py
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
# app/models/memory/fact.py
|
||||||
|
"""
|
||||||
|
Fact database model.
|
||||||
|
|
||||||
|
Stores semantic memories - learned facts in subject-predicate-object
|
||||||
|
triple format with confidence scores and source tracking.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from sqlalchemy import (
|
||||||
|
CheckConstraint,
|
||||||
|
Column,
|
||||||
|
DateTime,
|
||||||
|
Float,
|
||||||
|
ForeignKey,
|
||||||
|
Index,
|
||||||
|
Integer,
|
||||||
|
String,
|
||||||
|
Text,
|
||||||
|
text,
|
||||||
|
)
|
||||||
|
from sqlalchemy.dialects.postgresql import (
|
||||||
|
JSONB,
|
||||||
|
UUID as PGUUID,
|
||||||
|
)
|
||||||
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
|
from app.models.base import Base, TimestampMixin, UUIDMixin
|
||||||
|
|
||||||
|
# Import pgvector type
|
||||||
|
try:
|
||||||
|
from pgvector.sqlalchemy import Vector # type: ignore[import-not-found]
|
||||||
|
except ImportError:
|
||||||
|
Vector = None
|
||||||
|
|
||||||
|
|
||||||
|
class Fact(Base, UUIDMixin, TimestampMixin):
|
||||||
|
"""
|
||||||
|
Semantic memory model.
|
||||||
|
|
||||||
|
Stores learned facts as subject-predicate-object triples:
|
||||||
|
- "FastAPI" - "uses" - "Starlette framework"
|
||||||
|
- "Project Alpha" - "requires" - "OAuth authentication"
|
||||||
|
|
||||||
|
Facts have confidence scores that decay over time and can be
|
||||||
|
reinforced when the same fact is learned again.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "facts"
|
||||||
|
|
||||||
|
# Scoping: project_id is NULL for global facts
|
||||||
|
project_id = Column(
|
||||||
|
PGUUID(as_uuid=True),
|
||||||
|
ForeignKey("projects.id", ondelete="CASCADE"),
|
||||||
|
nullable=True,
|
||||||
|
index=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Triple format
|
||||||
|
subject = Column(String(500), nullable=False, index=True)
|
||||||
|
predicate = Column(String(255), nullable=False, index=True)
|
||||||
|
object = Column(Text, nullable=False)
|
||||||
|
|
||||||
|
# Confidence score (0.0 to 1.0)
|
||||||
|
confidence = Column(Float, nullable=False, default=0.8, index=True)
|
||||||
|
|
||||||
|
# Source tracking: which episodes contributed to this fact (stored as JSONB array of UUID strings)
|
||||||
|
source_episode_ids: Column[list] = Column(JSONB, default=list, nullable=False)
|
||||||
|
|
||||||
|
# Learning history
|
||||||
|
first_learned = Column(DateTime(timezone=True), nullable=False)
|
||||||
|
last_reinforced = Column(DateTime(timezone=True), nullable=False)
|
||||||
|
reinforcement_count = Column(Integer, nullable=False, default=1)
|
||||||
|
|
||||||
|
# Vector embedding for semantic search
|
||||||
|
embedding = Column(Vector(1536) if Vector else Text, nullable=True)
|
||||||
|
|
||||||
|
# Relationships
|
||||||
|
project = relationship("Project", foreign_keys=[project_id])
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
# Unique constraint on triple within project scope
|
||||||
|
Index(
|
||||||
|
"ix_facts_unique_triple",
|
||||||
|
"project_id",
|
||||||
|
"subject",
|
||||||
|
"predicate",
|
||||||
|
"object",
|
||||||
|
unique=True,
|
||||||
|
postgresql_where=text("project_id IS NOT NULL"),
|
||||||
|
),
|
||||||
|
# Unique constraint on triple for global facts (project_id IS NULL)
|
||||||
|
Index(
|
||||||
|
"ix_facts_unique_triple_global",
|
||||||
|
"subject",
|
||||||
|
"predicate",
|
||||||
|
"object",
|
||||||
|
unique=True,
|
||||||
|
postgresql_where=text("project_id IS NULL"),
|
||||||
|
),
|
||||||
|
# Query patterns
|
||||||
|
Index("ix_facts_subject_predicate", "subject", "predicate"),
|
||||||
|
Index("ix_facts_project_subject", "project_id", "subject"),
|
||||||
|
Index("ix_facts_confidence_time", "confidence", "last_reinforced"),
|
||||||
|
# Note: subject already has index=True on Column definition, no need for explicit index
|
||||||
|
# Data integrity constraints
|
||||||
|
CheckConstraint(
|
||||||
|
"confidence >= 0.0 AND confidence <= 1.0",
|
||||||
|
name="ck_facts_confidence_range",
|
||||||
|
),
|
||||||
|
CheckConstraint(
|
||||||
|
"reinforcement_count >= 1",
|
||||||
|
name="ck_facts_reinforcement_positive",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return (
|
||||||
|
f"<Fact {self.id} '{self.subject}' - '{self.predicate}' - "
|
||||||
|
f"'{self.object[:50]}...' conf={self.confidence:.2f}>"
|
||||||
|
)
|
||||||
129
backend/app/models/memory/procedure.py
Normal file
129
backend/app/models/memory/procedure.py
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
# app/models/memory/procedure.py
|
||||||
|
"""
|
||||||
|
Procedure database model.
|
||||||
|
|
||||||
|
Stores procedural memories - learned skills and procedures
|
||||||
|
derived from successful task execution patterns.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from sqlalchemy import (
|
||||||
|
CheckConstraint,
|
||||||
|
Column,
|
||||||
|
DateTime,
|
||||||
|
ForeignKey,
|
||||||
|
Index,
|
||||||
|
Integer,
|
||||||
|
String,
|
||||||
|
Text,
|
||||||
|
)
|
||||||
|
from sqlalchemy.dialects.postgresql import (
|
||||||
|
JSONB,
|
||||||
|
UUID as PGUUID,
|
||||||
|
)
|
||||||
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
|
from app.models.base import Base, TimestampMixin, UUIDMixin
|
||||||
|
|
||||||
|
# Import pgvector type
|
||||||
|
try:
|
||||||
|
from pgvector.sqlalchemy import Vector # type: ignore[import-not-found]
|
||||||
|
except ImportError:
|
||||||
|
Vector = None
|
||||||
|
|
||||||
|
|
||||||
|
class Procedure(Base, UUIDMixin, TimestampMixin):
|
||||||
|
"""
|
||||||
|
Procedural memory model.
|
||||||
|
|
||||||
|
Stores learned procedures (skills) extracted from successful
|
||||||
|
task execution patterns:
|
||||||
|
- Name and trigger pattern for matching
|
||||||
|
- Step-by-step actions
|
||||||
|
- Success/failure tracking
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "procedures"
|
||||||
|
|
||||||
|
# Scoping
|
||||||
|
project_id = Column(
|
||||||
|
PGUUID(as_uuid=True),
|
||||||
|
ForeignKey("projects.id", ondelete="CASCADE"),
|
||||||
|
nullable=True,
|
||||||
|
index=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
agent_type_id = Column(
|
||||||
|
PGUUID(as_uuid=True),
|
||||||
|
ForeignKey("agent_types.id", ondelete="SET NULL"),
|
||||||
|
nullable=True,
|
||||||
|
index=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Procedure identification
|
||||||
|
name = Column(String(255), nullable=False, index=True)
|
||||||
|
trigger_pattern = Column(Text, nullable=False)
|
||||||
|
|
||||||
|
# Steps as JSON array of step objects
|
||||||
|
# Each step: {order, action, parameters, expected_outcome, fallback_action}
|
||||||
|
steps = Column(JSONB, default=list, nullable=False)
|
||||||
|
|
||||||
|
# Success tracking
|
||||||
|
success_count = Column(Integer, nullable=False, default=0)
|
||||||
|
failure_count = Column(Integer, nullable=False, default=0)
|
||||||
|
|
||||||
|
# Usage tracking
|
||||||
|
last_used = Column(DateTime(timezone=True), nullable=True, index=True)
|
||||||
|
|
||||||
|
# Vector embedding for semantic matching
|
||||||
|
embedding = Column(Vector(1536) if Vector else Text, nullable=True)
|
||||||
|
|
||||||
|
# Relationships
|
||||||
|
project = relationship("Project", foreign_keys=[project_id])
|
||||||
|
agent_type = relationship("AgentType", foreign_keys=[agent_type_id])
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
# Unique procedure name within scope
|
||||||
|
Index(
|
||||||
|
"ix_procedures_unique_name",
|
||||||
|
"project_id",
|
||||||
|
"agent_type_id",
|
||||||
|
"name",
|
||||||
|
unique=True,
|
||||||
|
),
|
||||||
|
# Query patterns
|
||||||
|
Index("ix_procedures_project_name", "project_id", "name"),
|
||||||
|
# Note: agent_type_id already has index=True on Column definition
|
||||||
|
# For finding best procedures
|
||||||
|
Index("ix_procedures_success_rate", "success_count", "failure_count"),
|
||||||
|
# Data integrity constraints
|
||||||
|
CheckConstraint(
|
||||||
|
"success_count >= 0",
|
||||||
|
name="ck_procedures_success_positive",
|
||||||
|
),
|
||||||
|
CheckConstraint(
|
||||||
|
"failure_count >= 0",
|
||||||
|
name="ck_procedures_failure_positive",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def success_rate(self) -> float:
|
||||||
|
"""Calculate the success rate of this procedure."""
|
||||||
|
# Snapshot values to avoid race conditions in concurrent access
|
||||||
|
success = self.success_count
|
||||||
|
failure = self.failure_count
|
||||||
|
total = success + failure
|
||||||
|
if total == 0:
|
||||||
|
return 0.0
|
||||||
|
return success / total
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_uses(self) -> int:
|
||||||
|
"""Get total number of times this procedure was used."""
|
||||||
|
# Snapshot values for consistency
|
||||||
|
return self.success_count + self.failure_count
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return (
|
||||||
|
f"<Procedure {self.name} ({self.id}) success_rate={self.success_rate:.2%}>"
|
||||||
|
)
|
||||||
58
backend/app/models/memory/working_memory.py
Normal file
58
backend/app/models/memory/working_memory.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
# app/models/memory/working_memory.py
|
||||||
|
"""
|
||||||
|
Working Memory database model.
|
||||||
|
|
||||||
|
Stores ephemeral key-value data for active sessions with TTL support.
|
||||||
|
Used as database backup when Redis is unavailable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from sqlalchemy import Column, DateTime, Enum, Index, String
|
||||||
|
from sqlalchemy.dialects.postgresql import JSONB
|
||||||
|
|
||||||
|
from app.models.base import Base, TimestampMixin, UUIDMixin
|
||||||
|
|
||||||
|
from .enums import ScopeType
|
||||||
|
|
||||||
|
|
||||||
|
class WorkingMemory(Base, UUIDMixin, TimestampMixin):
|
||||||
|
"""
|
||||||
|
Working memory storage table.
|
||||||
|
|
||||||
|
Provides database-backed working memory as fallback when
|
||||||
|
Redis is unavailable. Supports TTL-based expiration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "working_memory"
|
||||||
|
|
||||||
|
# Scoping
|
||||||
|
scope_type: Column[ScopeType] = Column(
|
||||||
|
Enum(ScopeType),
|
||||||
|
nullable=False,
|
||||||
|
index=True,
|
||||||
|
)
|
||||||
|
scope_id = Column(String(255), nullable=False, index=True)
|
||||||
|
|
||||||
|
# Key-value storage
|
||||||
|
key = Column(String(255), nullable=False)
|
||||||
|
value = Column(JSONB, nullable=False)
|
||||||
|
|
||||||
|
# TTL support
|
||||||
|
expires_at = Column(DateTime(timezone=True), nullable=True, index=True)
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
# Primary lookup: scope + key
|
||||||
|
Index(
|
||||||
|
"ix_working_memory_scope_key",
|
||||||
|
"scope_type",
|
||||||
|
"scope_id",
|
||||||
|
"key",
|
||||||
|
unique=True,
|
||||||
|
),
|
||||||
|
# For cleanup of expired entries
|
||||||
|
Index("ix_working_memory_expires", "expires_at"),
|
||||||
|
# For listing all keys in a scope
|
||||||
|
Index("ix_working_memory_scope_list", "scope_type", "scope_id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"<WorkingMemory {self.scope_type.value}:{self.scope_id}:{self.key}>"
|
||||||
@@ -62,7 +62,11 @@ class AgentInstance(Base, UUIDMixin, TimestampMixin):
|
|||||||
|
|
||||||
# Status tracking
|
# Status tracking
|
||||||
status: Column[AgentStatus] = Column(
|
status: Column[AgentStatus] = Column(
|
||||||
Enum(AgentStatus),
|
Enum(
|
||||||
|
AgentStatus,
|
||||||
|
name="agent_status",
|
||||||
|
values_callable=lambda x: [e.value for e in x],
|
||||||
|
),
|
||||||
default=AgentStatus.IDLE,
|
default=AgentStatus.IDLE,
|
||||||
nullable=False,
|
nullable=False,
|
||||||
index=True,
|
index=True,
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ An AgentType is a template that defines the capabilities, personality,
|
|||||||
and model configuration for agent instances.
|
and model configuration for agent instances.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from sqlalchemy import Boolean, Column, Index, String, Text
|
from sqlalchemy import Boolean, Column, Index, Integer, String, Text
|
||||||
from sqlalchemy.dialects.postgresql import JSONB
|
from sqlalchemy.dialects.postgresql import JSONB
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
@@ -56,6 +56,24 @@ class AgentType(Base, UUIDMixin, TimestampMixin):
|
|||||||
# Whether this agent type is available for new instances
|
# Whether this agent type is available for new instances
|
||||||
is_active = Column(Boolean, default=True, nullable=False, index=True)
|
is_active = Column(Boolean, default=True, nullable=False, index=True)
|
||||||
|
|
||||||
|
# Category for grouping agents (development, design, quality, etc.)
|
||||||
|
category = Column(String(50), nullable=True, index=True)
|
||||||
|
|
||||||
|
# Lucide icon identifier for UI display (e.g., "code", "palette", "shield")
|
||||||
|
icon = Column(String(50), nullable=True, default="bot")
|
||||||
|
|
||||||
|
# Hex color code for visual distinction (e.g., "#3B82F6")
|
||||||
|
color = Column(String(7), nullable=True, default="#3B82F6")
|
||||||
|
|
||||||
|
# Display ordering within category (lower = first)
|
||||||
|
sort_order = Column(Integer, nullable=False, default=0, index=True)
|
||||||
|
|
||||||
|
# List of typical tasks this agent excels at
|
||||||
|
typical_tasks = Column(JSONB, default=list, nullable=False)
|
||||||
|
|
||||||
|
# List of agent slugs that collaborate well with this type
|
||||||
|
collaboration_hints = Column(JSONB, default=list, nullable=False)
|
||||||
|
|
||||||
# Relationships
|
# Relationships
|
||||||
instances = relationship(
|
instances = relationship(
|
||||||
"AgentInstance",
|
"AgentInstance",
|
||||||
@@ -66,6 +84,7 @@ class AgentType(Base, UUIDMixin, TimestampMixin):
|
|||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index("ix_agent_types_slug_active", "slug", "is_active"),
|
Index("ix_agent_types_slug_active", "slug", "is_active"),
|
||||||
Index("ix_agent_types_name_active", "name", "is_active"),
|
Index("ix_agent_types_name_active", "name", "is_active"),
|
||||||
|
Index("ix_agent_types_category_sort", "category", "sort_order"),
|
||||||
)
|
)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
|
|||||||
@@ -167,3 +167,29 @@ class SprintStatus(str, PyEnum):
|
|||||||
IN_REVIEW = "in_review"
|
IN_REVIEW = "in_review"
|
||||||
COMPLETED = "completed"
|
COMPLETED = "completed"
|
||||||
CANCELLED = "cancelled"
|
CANCELLED = "cancelled"
|
||||||
|
|
||||||
|
|
||||||
|
class AgentTypeCategory(str, PyEnum):
|
||||||
|
"""
|
||||||
|
Category classification for agent types.
|
||||||
|
|
||||||
|
Used for grouping and filtering agents in the UI.
|
||||||
|
|
||||||
|
DEVELOPMENT: Product, project, and engineering roles
|
||||||
|
DESIGN: UI/UX and design research roles
|
||||||
|
QUALITY: QA and security engineering
|
||||||
|
OPERATIONS: DevOps and MLOps
|
||||||
|
AI_ML: Machine learning and AI specialists
|
||||||
|
DATA: Data science and engineering
|
||||||
|
LEADERSHIP: Technical leadership roles
|
||||||
|
DOMAIN_EXPERT: Industry and domain specialists
|
||||||
|
"""
|
||||||
|
|
||||||
|
DEVELOPMENT = "development"
|
||||||
|
DESIGN = "design"
|
||||||
|
QUALITY = "quality"
|
||||||
|
OPERATIONS = "operations"
|
||||||
|
AI_ML = "ai_ml"
|
||||||
|
DATA = "data"
|
||||||
|
LEADERSHIP = "leadership"
|
||||||
|
DOMAIN_EXPERT = "domain_expert"
|
||||||
|
|||||||
@@ -59,7 +59,9 @@ class Issue(Base, UUIDMixin, TimestampMixin):
|
|||||||
|
|
||||||
# Issue type (Epic, Story, Task, Bug)
|
# Issue type (Epic, Story, Task, Bug)
|
||||||
type: Column[IssueType] = Column(
|
type: Column[IssueType] = Column(
|
||||||
Enum(IssueType),
|
Enum(
|
||||||
|
IssueType, name="issue_type", values_callable=lambda x: [e.value for e in x]
|
||||||
|
),
|
||||||
default=IssueType.TASK,
|
default=IssueType.TASK,
|
||||||
nullable=False,
|
nullable=False,
|
||||||
index=True,
|
index=True,
|
||||||
@@ -78,14 +80,22 @@ class Issue(Base, UUIDMixin, TimestampMixin):
|
|||||||
|
|
||||||
# Status and priority
|
# Status and priority
|
||||||
status: Column[IssueStatus] = Column(
|
status: Column[IssueStatus] = Column(
|
||||||
Enum(IssueStatus),
|
Enum(
|
||||||
|
IssueStatus,
|
||||||
|
name="issue_status",
|
||||||
|
values_callable=lambda x: [e.value for e in x],
|
||||||
|
),
|
||||||
default=IssueStatus.OPEN,
|
default=IssueStatus.OPEN,
|
||||||
nullable=False,
|
nullable=False,
|
||||||
index=True,
|
index=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
priority: Column[IssuePriority] = Column(
|
priority: Column[IssuePriority] = Column(
|
||||||
Enum(IssuePriority),
|
Enum(
|
||||||
|
IssuePriority,
|
||||||
|
name="issue_priority",
|
||||||
|
values_callable=lambda x: [e.value for e in x],
|
||||||
|
),
|
||||||
default=IssuePriority.MEDIUM,
|
default=IssuePriority.MEDIUM,
|
||||||
nullable=False,
|
nullable=False,
|
||||||
index=True,
|
index=True,
|
||||||
@@ -132,7 +142,11 @@ class Issue(Base, UUIDMixin, TimestampMixin):
|
|||||||
|
|
||||||
# Sync status with external tracker
|
# Sync status with external tracker
|
||||||
sync_status: Column[SyncStatus] = Column(
|
sync_status: Column[SyncStatus] = Column(
|
||||||
Enum(SyncStatus),
|
Enum(
|
||||||
|
SyncStatus,
|
||||||
|
name="sync_status",
|
||||||
|
values_callable=lambda x: [e.value for e in x],
|
||||||
|
),
|
||||||
default=SyncStatus.SYNCED,
|
default=SyncStatus.SYNCED,
|
||||||
nullable=False,
|
nullable=False,
|
||||||
# Note: Index defined in __table_args__ as ix_issues_sync_status
|
# Note: Index defined in __table_args__ as ix_issues_sync_status
|
||||||
@@ -158,7 +172,11 @@ class Issue(Base, UUIDMixin, TimestampMixin):
|
|||||||
Index("ix_issues_project_status", "project_id", "status"),
|
Index("ix_issues_project_status", "project_id", "status"),
|
||||||
Index("ix_issues_project_priority", "project_id", "priority"),
|
Index("ix_issues_project_priority", "project_id", "priority"),
|
||||||
Index("ix_issues_project_sprint", "project_id", "sprint_id"),
|
Index("ix_issues_project_sprint", "project_id", "sprint_id"),
|
||||||
Index("ix_issues_external_tracker_id", "external_tracker_type", "external_issue_id"),
|
Index(
|
||||||
|
"ix_issues_external_tracker_id",
|
||||||
|
"external_tracker_type",
|
||||||
|
"external_issue_id",
|
||||||
|
),
|
||||||
Index("ix_issues_sync_status", "sync_status"),
|
Index("ix_issues_sync_status", "sync_status"),
|
||||||
Index("ix_issues_project_agent", "project_id", "assigned_agent_id"),
|
Index("ix_issues_project_agent", "project_id", "assigned_agent_id"),
|
||||||
Index("ix_issues_project_type", "project_id", "type"),
|
Index("ix_issues_project_type", "project_id", "type"),
|
||||||
|
|||||||
@@ -35,28 +35,44 @@ class Project(Base, UUIDMixin, TimestampMixin):
|
|||||||
description = Column(Text, nullable=True)
|
description = Column(Text, nullable=True)
|
||||||
|
|
||||||
autonomy_level: Column[AutonomyLevel] = Column(
|
autonomy_level: Column[AutonomyLevel] = Column(
|
||||||
Enum(AutonomyLevel),
|
Enum(
|
||||||
|
AutonomyLevel,
|
||||||
|
name="autonomy_level",
|
||||||
|
values_callable=lambda x: [e.value for e in x],
|
||||||
|
),
|
||||||
default=AutonomyLevel.MILESTONE,
|
default=AutonomyLevel.MILESTONE,
|
||||||
nullable=False,
|
nullable=False,
|
||||||
index=True,
|
index=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
status: Column[ProjectStatus] = Column(
|
status: Column[ProjectStatus] = Column(
|
||||||
Enum(ProjectStatus),
|
Enum(
|
||||||
|
ProjectStatus,
|
||||||
|
name="project_status",
|
||||||
|
values_callable=lambda x: [e.value for e in x],
|
||||||
|
),
|
||||||
default=ProjectStatus.ACTIVE,
|
default=ProjectStatus.ACTIVE,
|
||||||
nullable=False,
|
nullable=False,
|
||||||
index=True,
|
index=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
complexity: Column[ProjectComplexity] = Column(
|
complexity: Column[ProjectComplexity] = Column(
|
||||||
Enum(ProjectComplexity),
|
Enum(
|
||||||
|
ProjectComplexity,
|
||||||
|
name="project_complexity",
|
||||||
|
values_callable=lambda x: [e.value for e in x],
|
||||||
|
),
|
||||||
default=ProjectComplexity.MEDIUM,
|
default=ProjectComplexity.MEDIUM,
|
||||||
nullable=False,
|
nullable=False,
|
||||||
index=True,
|
index=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
client_mode: Column[ClientMode] = Column(
|
client_mode: Column[ClientMode] = Column(
|
||||||
Enum(ClientMode),
|
Enum(
|
||||||
|
ClientMode,
|
||||||
|
name="client_mode",
|
||||||
|
values_callable=lambda x: [e.value for e in x],
|
||||||
|
),
|
||||||
default=ClientMode.AUTO,
|
default=ClientMode.AUTO,
|
||||||
nullable=False,
|
nullable=False,
|
||||||
index=True,
|
index=True,
|
||||||
|
|||||||
@@ -5,7 +5,17 @@ Sprint model for Syndarix AI consulting platform.
|
|||||||
A Sprint represents a time-boxed iteration for organizing and delivering work.
|
A Sprint represents a time-boxed iteration for organizing and delivering work.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from sqlalchemy import Column, Date, Enum, ForeignKey, Index, Integer, String, Text, UniqueConstraint
|
from sqlalchemy import (
|
||||||
|
Column,
|
||||||
|
Date,
|
||||||
|
Enum,
|
||||||
|
ForeignKey,
|
||||||
|
Index,
|
||||||
|
Integer,
|
||||||
|
String,
|
||||||
|
Text,
|
||||||
|
UniqueConstraint,
|
||||||
|
)
|
||||||
from sqlalchemy.dialects.postgresql import UUID as PGUUID
|
from sqlalchemy.dialects.postgresql import UUID as PGUUID
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
@@ -47,7 +57,11 @@ class Sprint(Base, UUIDMixin, TimestampMixin):
|
|||||||
|
|
||||||
# Status
|
# Status
|
||||||
status: Column[SprintStatus] = Column(
|
status: Column[SprintStatus] = Column(
|
||||||
Enum(SprintStatus),
|
Enum(
|
||||||
|
SprintStatus,
|
||||||
|
name="sprint_status",
|
||||||
|
values_callable=lambda x: [e.value for e in x],
|
||||||
|
),
|
||||||
default=SprintStatus.PLANNED,
|
default=SprintStatus.PLANNED,
|
||||||
nullable=False,
|
nullable=False,
|
||||||
index=True,
|
index=True,
|
||||||
|
|||||||
@@ -205,9 +205,7 @@ class SprintCompletedPayload(BaseModel):
|
|||||||
sprint_id: UUID = Field(..., description="Sprint ID")
|
sprint_id: UUID = Field(..., description="Sprint ID")
|
||||||
sprint_name: str = Field(..., description="Sprint name")
|
sprint_name: str = Field(..., description="Sprint name")
|
||||||
completed_issues: int = Field(default=0, description="Number of completed issues")
|
completed_issues: int = Field(default=0, description="Number of completed issues")
|
||||||
incomplete_issues: int = Field(
|
incomplete_issues: int = Field(default=0, description="Number of incomplete issues")
|
||||||
default=0, description="Number of incomplete issues"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ApprovalRequestedPayload(BaseModel):
|
class ApprovalRequestedPayload(BaseModel):
|
||||||
|
|||||||
@@ -10,6 +10,8 @@ from uuid import UUID
|
|||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||||
|
|
||||||
|
from app.models.syndarix.enums import AgentTypeCategory
|
||||||
|
|
||||||
|
|
||||||
class AgentTypeBase(BaseModel):
|
class AgentTypeBase(BaseModel):
|
||||||
"""Base agent type schema with common fields."""
|
"""Base agent type schema with common fields."""
|
||||||
@@ -26,6 +28,14 @@ class AgentTypeBase(BaseModel):
|
|||||||
tool_permissions: dict[str, Any] = Field(default_factory=dict)
|
tool_permissions: dict[str, Any] = Field(default_factory=dict)
|
||||||
is_active: bool = True
|
is_active: bool = True
|
||||||
|
|
||||||
|
# Category and display fields
|
||||||
|
category: AgentTypeCategory | None = None
|
||||||
|
icon: str | None = Field(None, max_length=50)
|
||||||
|
color: str | None = Field(None, pattern=r"^#[0-9A-Fa-f]{6}$")
|
||||||
|
sort_order: int = Field(default=0, ge=0, le=1000)
|
||||||
|
typical_tasks: list[str] = Field(default_factory=list)
|
||||||
|
collaboration_hints: list[str] = Field(default_factory=list)
|
||||||
|
|
||||||
@field_validator("slug")
|
@field_validator("slug")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_slug(cls, v: str | None) -> str | None:
|
def validate_slug(cls, v: str | None) -> str | None:
|
||||||
@@ -62,6 +72,18 @@ class AgentTypeBase(BaseModel):
|
|||||||
"""Validate MCP server list."""
|
"""Validate MCP server list."""
|
||||||
return [s.strip() for s in v if s.strip()]
|
return [s.strip() for s in v if s.strip()]
|
||||||
|
|
||||||
|
@field_validator("typical_tasks")
|
||||||
|
@classmethod
|
||||||
|
def validate_typical_tasks(cls, v: list[str]) -> list[str]:
|
||||||
|
"""Validate and normalize typical tasks list."""
|
||||||
|
return [t.strip() for t in v if t.strip()]
|
||||||
|
|
||||||
|
@field_validator("collaboration_hints")
|
||||||
|
@classmethod
|
||||||
|
def validate_collaboration_hints(cls, v: list[str]) -> list[str]:
|
||||||
|
"""Validate and normalize collaboration hints (agent slugs)."""
|
||||||
|
return [h.strip().lower() for h in v if h.strip()]
|
||||||
|
|
||||||
|
|
||||||
class AgentTypeCreate(AgentTypeBase):
|
class AgentTypeCreate(AgentTypeBase):
|
||||||
"""Schema for creating a new agent type."""
|
"""Schema for creating a new agent type."""
|
||||||
@@ -87,6 +109,14 @@ class AgentTypeUpdate(BaseModel):
|
|||||||
tool_permissions: dict[str, Any] | None = None
|
tool_permissions: dict[str, Any] | None = None
|
||||||
is_active: bool | None = None
|
is_active: bool | None = None
|
||||||
|
|
||||||
|
# Category and display fields (all optional for updates)
|
||||||
|
category: AgentTypeCategory | None = None
|
||||||
|
icon: str | None = Field(None, max_length=50)
|
||||||
|
color: str | None = Field(None, pattern=r"^#[0-9A-Fa-f]{6}$")
|
||||||
|
sort_order: int | None = Field(None, ge=0, le=1000)
|
||||||
|
typical_tasks: list[str] | None = None
|
||||||
|
collaboration_hints: list[str] | None = None
|
||||||
|
|
||||||
@field_validator("slug")
|
@field_validator("slug")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_slug(cls, v: str | None) -> str | None:
|
def validate_slug(cls, v: str | None) -> str | None:
|
||||||
@@ -119,6 +149,22 @@ class AgentTypeUpdate(BaseModel):
|
|||||||
return v
|
return v
|
||||||
return [e.strip().lower() for e in v if e.strip()]
|
return [e.strip().lower() for e in v if e.strip()]
|
||||||
|
|
||||||
|
@field_validator("typical_tasks")
|
||||||
|
@classmethod
|
||||||
|
def validate_typical_tasks(cls, v: list[str] | None) -> list[str] | None:
|
||||||
|
"""Validate and normalize typical tasks list."""
|
||||||
|
if v is None:
|
||||||
|
return v
|
||||||
|
return [t.strip() for t in v if t.strip()]
|
||||||
|
|
||||||
|
@field_validator("collaboration_hints")
|
||||||
|
@classmethod
|
||||||
|
def validate_collaboration_hints(cls, v: list[str] | None) -> list[str] | None:
|
||||||
|
"""Validate and normalize collaboration hints (agent slugs)."""
|
||||||
|
if v is None:
|
||||||
|
return v
|
||||||
|
return [h.strip().lower() for h in v if h.strip()]
|
||||||
|
|
||||||
|
|
||||||
class AgentTypeInDB(AgentTypeBase):
|
class AgentTypeInDB(AgentTypeBase):
|
||||||
"""Schema for agent type in database."""
|
"""Schema for agent type in database."""
|
||||||
|
|||||||
@@ -99,9 +99,7 @@ class IssueAssign(BaseModel):
|
|||||||
def validate_assignment(self) -> "IssueAssign":
|
def validate_assignment(self) -> "IssueAssign":
|
||||||
"""Ensure only one type of assignee is set."""
|
"""Ensure only one type of assignee is set."""
|
||||||
if self.assigned_agent_id and self.human_assignee:
|
if self.assigned_agent_id and self.human_assignee:
|
||||||
raise ValueError(
|
raise ValueError("Cannot assign to both an agent and a human. Choose one.")
|
||||||
"Cannot assign to both an agent and a human. Choose one."
|
|
||||||
)
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
182
backend/app/services/context/__init__.py
Normal file
182
backend/app/services/context/__init__.py
Normal file
@@ -0,0 +1,182 @@
|
|||||||
|
"""
|
||||||
|
Context Management Engine
|
||||||
|
|
||||||
|
Sophisticated context assembly and optimization for LLM requests.
|
||||||
|
Provides intelligent context selection, token budget management,
|
||||||
|
and model-specific formatting.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
from app.services.context import (
|
||||||
|
ContextSettings,
|
||||||
|
get_context_settings,
|
||||||
|
SystemContext,
|
||||||
|
KnowledgeContext,
|
||||||
|
ConversationContext,
|
||||||
|
TaskContext,
|
||||||
|
ToolContext,
|
||||||
|
TokenBudget,
|
||||||
|
BudgetAllocator,
|
||||||
|
TokenCalculator,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get settings
|
||||||
|
settings = get_context_settings()
|
||||||
|
|
||||||
|
# Create budget for a model
|
||||||
|
allocator = BudgetAllocator(settings)
|
||||||
|
budget = allocator.create_budget_for_model("claude-3-sonnet")
|
||||||
|
|
||||||
|
# Create context instances
|
||||||
|
system_ctx = SystemContext.create_persona(
|
||||||
|
name="Code Assistant",
|
||||||
|
description="You are a helpful code assistant.",
|
||||||
|
capabilities=["Write code", "Debug issues"],
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Budget Management
|
||||||
|
# Adapters
|
||||||
|
from .adapters import (
|
||||||
|
ClaudeAdapter,
|
||||||
|
DefaultAdapter,
|
||||||
|
ModelAdapter,
|
||||||
|
OpenAIAdapter,
|
||||||
|
get_adapter,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assembly
|
||||||
|
from .assembly import (
|
||||||
|
ContextPipeline,
|
||||||
|
PipelineMetrics,
|
||||||
|
)
|
||||||
|
from .budget import (
|
||||||
|
BudgetAllocator,
|
||||||
|
TokenBudget,
|
||||||
|
TokenCalculator,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cache
|
||||||
|
from .cache import ContextCache
|
||||||
|
|
||||||
|
# Compression
|
||||||
|
from .compression import (
|
||||||
|
ContextCompressor,
|
||||||
|
TruncationResult,
|
||||||
|
TruncationStrategy,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Configuration
|
||||||
|
from .config import (
|
||||||
|
ContextSettings,
|
||||||
|
get_context_settings,
|
||||||
|
get_default_settings,
|
||||||
|
reset_context_settings,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Engine
|
||||||
|
from .engine import ContextEngine, create_context_engine
|
||||||
|
|
||||||
|
# Exceptions
|
||||||
|
from .exceptions import (
|
||||||
|
AssemblyTimeoutError,
|
||||||
|
BudgetExceededError,
|
||||||
|
CacheError,
|
||||||
|
CompressionError,
|
||||||
|
ContextError,
|
||||||
|
ContextNotFoundError,
|
||||||
|
FormattingError,
|
||||||
|
InvalidContextError,
|
||||||
|
ScoringError,
|
||||||
|
TokenCountError,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prioritization
|
||||||
|
from .prioritization import (
|
||||||
|
ContextRanker,
|
||||||
|
RankingResult,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Scoring
|
||||||
|
from .scoring import (
|
||||||
|
BaseScorer,
|
||||||
|
CompositeScorer,
|
||||||
|
PriorityScorer,
|
||||||
|
RecencyScorer,
|
||||||
|
RelevanceScorer,
|
||||||
|
ScoredContext,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Types
|
||||||
|
from .types import (
|
||||||
|
AssembledContext,
|
||||||
|
BaseContext,
|
||||||
|
ContextPriority,
|
||||||
|
ContextType,
|
||||||
|
ConversationContext,
|
||||||
|
KnowledgeContext,
|
||||||
|
MemoryContext,
|
||||||
|
MemorySubtype,
|
||||||
|
MessageRole,
|
||||||
|
SystemContext,
|
||||||
|
TaskComplexity,
|
||||||
|
TaskContext,
|
||||||
|
TaskStatus,
|
||||||
|
ToolContext,
|
||||||
|
ToolResultStatus,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AssembledContext",
|
||||||
|
"AssemblyTimeoutError",
|
||||||
|
"BaseContext",
|
||||||
|
"BaseScorer",
|
||||||
|
"BudgetAllocator",
|
||||||
|
"BudgetExceededError",
|
||||||
|
"CacheError",
|
||||||
|
"ClaudeAdapter",
|
||||||
|
"CompositeScorer",
|
||||||
|
"CompressionError",
|
||||||
|
"ContextCache",
|
||||||
|
"ContextCompressor",
|
||||||
|
"ContextEngine",
|
||||||
|
"ContextError",
|
||||||
|
"ContextNotFoundError",
|
||||||
|
"ContextPipeline",
|
||||||
|
"ContextPriority",
|
||||||
|
"ContextRanker",
|
||||||
|
"ContextSettings",
|
||||||
|
"ContextType",
|
||||||
|
"ConversationContext",
|
||||||
|
"DefaultAdapter",
|
||||||
|
"FormattingError",
|
||||||
|
"InvalidContextError",
|
||||||
|
"KnowledgeContext",
|
||||||
|
"MemoryContext",
|
||||||
|
"MemorySubtype",
|
||||||
|
"MessageRole",
|
||||||
|
"ModelAdapter",
|
||||||
|
"OpenAIAdapter",
|
||||||
|
"PipelineMetrics",
|
||||||
|
"PriorityScorer",
|
||||||
|
"RankingResult",
|
||||||
|
"RecencyScorer",
|
||||||
|
"RelevanceScorer",
|
||||||
|
"ScoredContext",
|
||||||
|
"ScoringError",
|
||||||
|
"SystemContext",
|
||||||
|
"TaskComplexity",
|
||||||
|
"TaskContext",
|
||||||
|
"TaskStatus",
|
||||||
|
"TokenBudget",
|
||||||
|
"TokenCalculator",
|
||||||
|
"TokenCountError",
|
||||||
|
"ToolContext",
|
||||||
|
"ToolResultStatus",
|
||||||
|
"TruncationResult",
|
||||||
|
"TruncationStrategy",
|
||||||
|
"create_context_engine",
|
||||||
|
"get_adapter",
|
||||||
|
"get_context_settings",
|
||||||
|
"get_default_settings",
|
||||||
|
"reset_context_settings",
|
||||||
|
]
|
||||||
35
backend/app/services/context/adapters/__init__.py
Normal file
35
backend/app/services/context/adapters/__init__.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
"""
|
||||||
|
Model Adapters Module.
|
||||||
|
|
||||||
|
Provides model-specific context formatting adapters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .base import DefaultAdapter, ModelAdapter
|
||||||
|
from .claude import ClaudeAdapter
|
||||||
|
from .openai import OpenAIAdapter
|
||||||
|
|
||||||
|
|
||||||
|
def get_adapter(model: str) -> ModelAdapter:
|
||||||
|
"""
|
||||||
|
Get the appropriate adapter for a model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Adapter instance for the model
|
||||||
|
"""
|
||||||
|
if ClaudeAdapter.matches_model(model):
|
||||||
|
return ClaudeAdapter()
|
||||||
|
elif OpenAIAdapter.matches_model(model):
|
||||||
|
return OpenAIAdapter()
|
||||||
|
return DefaultAdapter()
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ClaudeAdapter",
|
||||||
|
"DefaultAdapter",
|
||||||
|
"ModelAdapter",
|
||||||
|
"OpenAIAdapter",
|
||||||
|
"get_adapter",
|
||||||
|
]
|
||||||
178
backend/app/services/context/adapters/base.py
Normal file
178
backend/app/services/context/adapters/base.py
Normal file
@@ -0,0 +1,178 @@
|
|||||||
|
"""
|
||||||
|
Base Model Adapter.
|
||||||
|
|
||||||
|
Abstract base class for model-specific context formatting.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, ClassVar
|
||||||
|
|
||||||
|
from ..types import BaseContext, ContextType
|
||||||
|
|
||||||
|
|
||||||
|
class ModelAdapter(ABC):
|
||||||
|
"""
|
||||||
|
Abstract base adapter for model-specific context formatting.
|
||||||
|
|
||||||
|
Each adapter knows how to format contexts for optimal
|
||||||
|
understanding by a specific LLM family (Claude, OpenAI, etc.).
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Model name patterns this adapter handles
|
||||||
|
MODEL_PATTERNS: ClassVar[list[str]] = []
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def matches_model(cls, model: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if this adapter handles the given model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model name to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if this adapter handles the model
|
||||||
|
"""
|
||||||
|
model_lower = model.lower()
|
||||||
|
return any(pattern in model_lower for pattern in cls.MODEL_PATTERNS)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def format(
|
||||||
|
self,
|
||||||
|
contexts: list[BaseContext],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Format contexts for the target model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
contexts: List of contexts to format
|
||||||
|
**kwargs: Additional formatting options
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted context string
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def format_type(
|
||||||
|
self,
|
||||||
|
contexts: list[BaseContext],
|
||||||
|
context_type: ContextType,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Format contexts of a specific type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
contexts: List of contexts of the same type
|
||||||
|
context_type: The type of contexts
|
||||||
|
**kwargs: Additional formatting options
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted string for this context type
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_type_order(self) -> list[ContextType]:
|
||||||
|
"""
|
||||||
|
Get the preferred order of context types.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of context types in preferred order
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
ContextType.SYSTEM,
|
||||||
|
ContextType.TASK,
|
||||||
|
ContextType.KNOWLEDGE,
|
||||||
|
ContextType.CONVERSATION,
|
||||||
|
ContextType.TOOL,
|
||||||
|
]
|
||||||
|
|
||||||
|
def group_by_type(
|
||||||
|
self, contexts: list[BaseContext]
|
||||||
|
) -> dict[ContextType, list[BaseContext]]:
|
||||||
|
"""
|
||||||
|
Group contexts by their type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
contexts: List of contexts to group
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping context type to list of contexts
|
||||||
|
"""
|
||||||
|
by_type: dict[ContextType, list[BaseContext]] = {}
|
||||||
|
for context in contexts:
|
||||||
|
ct = context.get_type()
|
||||||
|
if ct not in by_type:
|
||||||
|
by_type[ct] = []
|
||||||
|
by_type[ct].append(context)
|
||||||
|
return by_type
|
||||||
|
|
||||||
|
def get_separator(self) -> str:
|
||||||
|
"""
|
||||||
|
Get the separator between context sections.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Separator string
|
||||||
|
"""
|
||||||
|
return "\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultAdapter(ModelAdapter):
|
||||||
|
"""
|
||||||
|
Default adapter for unknown models.
|
||||||
|
|
||||||
|
Uses simple plain-text formatting with minimal structure.
|
||||||
|
"""
|
||||||
|
|
||||||
|
MODEL_PATTERNS: ClassVar[list[str]] = [] # Fallback adapter
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def matches_model(cls, model: str) -> bool:
|
||||||
|
"""Always returns True as fallback."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def format(
|
||||||
|
self,
|
||||||
|
contexts: list[BaseContext],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
"""Format contexts as plain text."""
|
||||||
|
if not contexts:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
by_type = self.group_by_type(contexts)
|
||||||
|
parts: list[str] = []
|
||||||
|
|
||||||
|
for ct in self.get_type_order():
|
||||||
|
if ct in by_type:
|
||||||
|
formatted = self.format_type(by_type[ct], ct, **kwargs)
|
||||||
|
if formatted:
|
||||||
|
parts.append(formatted)
|
||||||
|
|
||||||
|
return self.get_separator().join(parts)
|
||||||
|
|
||||||
|
def format_type(
|
||||||
|
self,
|
||||||
|
contexts: list[BaseContext],
|
||||||
|
context_type: ContextType,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
"""Format contexts of a type as plain text."""
|
||||||
|
if not contexts:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
content = "\n\n".join(c.content for c in contexts)
|
||||||
|
|
||||||
|
if context_type == ContextType.SYSTEM:
|
||||||
|
return content
|
||||||
|
elif context_type == ContextType.TASK:
|
||||||
|
return f"Task:\n{content}"
|
||||||
|
elif context_type == ContextType.KNOWLEDGE:
|
||||||
|
return f"Reference Information:\n{content}"
|
||||||
|
elif context_type == ContextType.CONVERSATION:
|
||||||
|
return f"Previous Conversation:\n{content}"
|
||||||
|
elif context_type == ContextType.TOOL:
|
||||||
|
return f"Tool Results:\n{content}"
|
||||||
|
|
||||||
|
return content
|
||||||
212
backend/app/services/context/adapters/claude.py
Normal file
212
backend/app/services/context/adapters/claude.py
Normal file
@@ -0,0 +1,212 @@
|
|||||||
|
"""
|
||||||
|
Claude Model Adapter.
|
||||||
|
|
||||||
|
Provides Claude-specific context formatting using XML tags
|
||||||
|
which Claude models understand natively.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, ClassVar
|
||||||
|
|
||||||
|
from ..types import BaseContext, ContextType
|
||||||
|
from .base import ModelAdapter
|
||||||
|
|
||||||
|
|
||||||
|
class ClaudeAdapter(ModelAdapter):
|
||||||
|
"""
|
||||||
|
Claude-specific context formatting adapter.
|
||||||
|
|
||||||
|
Claude models have native understanding of XML structure,
|
||||||
|
so we use XML tags for clear delineation of context types.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- XML tags for each context type
|
||||||
|
- Document structure for knowledge contexts
|
||||||
|
- Role-based message formatting for conversations
|
||||||
|
- Tool result wrapping with tool names
|
||||||
|
"""
|
||||||
|
|
||||||
|
MODEL_PATTERNS: ClassVar[list[str]] = ["claude", "anthropic"]
|
||||||
|
|
||||||
|
def format(
|
||||||
|
self,
|
||||||
|
contexts: list[BaseContext],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Format contexts for Claude models.
|
||||||
|
|
||||||
|
Uses XML tags for structured content that Claude
|
||||||
|
understands natively.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
contexts: List of contexts to format
|
||||||
|
**kwargs: Additional formatting options
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
XML-structured context string
|
||||||
|
"""
|
||||||
|
if not contexts:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
by_type = self.group_by_type(contexts)
|
||||||
|
parts: list[str] = []
|
||||||
|
|
||||||
|
for ct in self.get_type_order():
|
||||||
|
if ct in by_type:
|
||||||
|
formatted = self.format_type(by_type[ct], ct, **kwargs)
|
||||||
|
if formatted:
|
||||||
|
parts.append(formatted)
|
||||||
|
|
||||||
|
return self.get_separator().join(parts)
|
||||||
|
|
||||||
|
def format_type(
|
||||||
|
self,
|
||||||
|
contexts: list[BaseContext],
|
||||||
|
context_type: ContextType,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Format contexts of a specific type for Claude.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
contexts: List of contexts of the same type
|
||||||
|
context_type: The type of contexts
|
||||||
|
**kwargs: Additional formatting options
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
XML-formatted string for this context type
|
||||||
|
"""
|
||||||
|
if not contexts:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
if context_type == ContextType.SYSTEM:
|
||||||
|
return self._format_system(contexts)
|
||||||
|
elif context_type == ContextType.TASK:
|
||||||
|
return self._format_task(contexts)
|
||||||
|
elif context_type == ContextType.KNOWLEDGE:
|
||||||
|
return self._format_knowledge(contexts)
|
||||||
|
elif context_type == ContextType.CONVERSATION:
|
||||||
|
return self._format_conversation(contexts)
|
||||||
|
elif context_type == ContextType.TOOL:
|
||||||
|
return self._format_tool(contexts)
|
||||||
|
|
||||||
|
# Fallback for any unhandled context types - still escape content
|
||||||
|
# to prevent XML injection if new types are added without updating adapter
|
||||||
|
return "\n".join(self._escape_xml_content(c.content) for c in contexts)
|
||||||
|
|
||||||
|
def _format_system(self, contexts: list[BaseContext]) -> str:
|
||||||
|
"""Format system contexts."""
|
||||||
|
# System prompts are typically admin-controlled, but escape for safety
|
||||||
|
content = "\n\n".join(self._escape_xml_content(c.content) for c in contexts)
|
||||||
|
return f"<system_instructions>\n{content}\n</system_instructions>"
|
||||||
|
|
||||||
|
def _format_task(self, contexts: list[BaseContext]) -> str:
|
||||||
|
"""Format task contexts."""
|
||||||
|
content = "\n\n".join(self._escape_xml_content(c.content) for c in contexts)
|
||||||
|
return f"<current_task>\n{content}\n</current_task>"
|
||||||
|
|
||||||
|
def _format_knowledge(self, contexts: list[BaseContext]) -> str:
|
||||||
|
"""
|
||||||
|
Format knowledge contexts as structured documents.
|
||||||
|
|
||||||
|
Each knowledge context becomes a document with source attribution.
|
||||||
|
All content is XML-escaped to prevent injection attacks.
|
||||||
|
"""
|
||||||
|
parts = ["<reference_documents>"]
|
||||||
|
|
||||||
|
for ctx in contexts:
|
||||||
|
source = self._escape_xml(ctx.source)
|
||||||
|
# Escape content to prevent XML injection
|
||||||
|
content = self._escape_xml_content(ctx.content)
|
||||||
|
score = ctx.metadata.get("score", ctx.metadata.get("relevance_score", ""))
|
||||||
|
|
||||||
|
if score:
|
||||||
|
# Escape score to prevent XML injection via metadata
|
||||||
|
escaped_score = self._escape_xml(str(score))
|
||||||
|
parts.append(
|
||||||
|
f'<document source="{source}" relevance="{escaped_score}">'
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
parts.append(f'<document source="{source}">')
|
||||||
|
|
||||||
|
parts.append(content)
|
||||||
|
parts.append("</document>")
|
||||||
|
|
||||||
|
parts.append("</reference_documents>")
|
||||||
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
def _format_conversation(self, contexts: list[BaseContext]) -> str:
|
||||||
|
"""
|
||||||
|
Format conversation contexts as message history.
|
||||||
|
|
||||||
|
Uses role-based message tags for clear turn delineation.
|
||||||
|
All content is XML-escaped to prevent prompt injection.
|
||||||
|
"""
|
||||||
|
parts = ["<conversation_history>"]
|
||||||
|
|
||||||
|
for ctx in contexts:
|
||||||
|
role = self._escape_xml(ctx.metadata.get("role", "user"))
|
||||||
|
# Escape content to prevent prompt injection via fake XML tags
|
||||||
|
content = self._escape_xml_content(ctx.content)
|
||||||
|
parts.append(f'<message role="{role}">')
|
||||||
|
parts.append(content)
|
||||||
|
parts.append("</message>")
|
||||||
|
|
||||||
|
parts.append("</conversation_history>")
|
||||||
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
def _format_tool(self, contexts: list[BaseContext]) -> str:
|
||||||
|
"""
|
||||||
|
Format tool contexts as tool results.
|
||||||
|
|
||||||
|
Each tool result is wrapped with the tool name.
|
||||||
|
All content is XML-escaped to prevent injection.
|
||||||
|
"""
|
||||||
|
parts = ["<tool_results>"]
|
||||||
|
|
||||||
|
for ctx in contexts:
|
||||||
|
tool_name = self._escape_xml(ctx.metadata.get("tool_name", "unknown"))
|
||||||
|
status = ctx.metadata.get("status", "")
|
||||||
|
|
||||||
|
if status:
|
||||||
|
parts.append(
|
||||||
|
f'<tool_result name="{tool_name}" status="{self._escape_xml(status)}">'
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
parts.append(f'<tool_result name="{tool_name}">')
|
||||||
|
|
||||||
|
# Escape content to prevent injection
|
||||||
|
parts.append(self._escape_xml_content(ctx.content))
|
||||||
|
parts.append("</tool_result>")
|
||||||
|
|
||||||
|
parts.append("</tool_results>")
|
||||||
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _escape_xml(text: str) -> str:
|
||||||
|
"""Escape XML special characters in attribute values."""
|
||||||
|
return (
|
||||||
|
text.replace("&", "&")
|
||||||
|
.replace("<", "<")
|
||||||
|
.replace(">", ">")
|
||||||
|
.replace('"', """)
|
||||||
|
.replace("'", "'")
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _escape_xml_content(text: str) -> str:
|
||||||
|
"""
|
||||||
|
Escape XML special characters in element content.
|
||||||
|
|
||||||
|
This prevents XML injection attacks where malicious content
|
||||||
|
could break out of XML tags or inject fake tags for prompt injection.
|
||||||
|
|
||||||
|
Only escapes &, <, > since quotes don't need escaping in content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Content text to escape
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
XML-safe content string
|
||||||
|
"""
|
||||||
|
return text.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||||
160
backend/app/services/context/adapters/openai.py
Normal file
160
backend/app/services/context/adapters/openai.py
Normal file
@@ -0,0 +1,160 @@
|
|||||||
|
"""
|
||||||
|
OpenAI Model Adapter.
|
||||||
|
|
||||||
|
Provides OpenAI-specific context formatting using markdown
|
||||||
|
which GPT models understand well.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, ClassVar
|
||||||
|
|
||||||
|
from ..types import BaseContext, ContextType
|
||||||
|
from .base import ModelAdapter
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIAdapter(ModelAdapter):
|
||||||
|
"""
|
||||||
|
OpenAI-specific context formatting adapter.
|
||||||
|
|
||||||
|
GPT models work well with markdown formatting,
|
||||||
|
so we use headers and structured markdown for clarity.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Markdown headers for each context type
|
||||||
|
- Bulleted lists for document sources
|
||||||
|
- Bold role labels for conversations
|
||||||
|
- Code blocks for tool outputs
|
||||||
|
"""
|
||||||
|
|
||||||
|
MODEL_PATTERNS: ClassVar[list[str]] = ["gpt", "openai", "o1", "o3"]
|
||||||
|
|
||||||
|
def format(
|
||||||
|
self,
|
||||||
|
contexts: list[BaseContext],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Format contexts for OpenAI models.
|
||||||
|
|
||||||
|
Uses markdown formatting for structured content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
contexts: List of contexts to format
|
||||||
|
**kwargs: Additional formatting options
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Markdown-structured context string
|
||||||
|
"""
|
||||||
|
if not contexts:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
by_type = self.group_by_type(contexts)
|
||||||
|
parts: list[str] = []
|
||||||
|
|
||||||
|
for ct in self.get_type_order():
|
||||||
|
if ct in by_type:
|
||||||
|
formatted = self.format_type(by_type[ct], ct, **kwargs)
|
||||||
|
if formatted:
|
||||||
|
parts.append(formatted)
|
||||||
|
|
||||||
|
return self.get_separator().join(parts)
|
||||||
|
|
||||||
|
def format_type(
|
||||||
|
self,
|
||||||
|
contexts: list[BaseContext],
|
||||||
|
context_type: ContextType,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Format contexts of a specific type for OpenAI.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
contexts: List of contexts of the same type
|
||||||
|
context_type: The type of contexts
|
||||||
|
**kwargs: Additional formatting options
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Markdown-formatted string for this context type
|
||||||
|
"""
|
||||||
|
if not contexts:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
if context_type == ContextType.SYSTEM:
|
||||||
|
return self._format_system(contexts)
|
||||||
|
elif context_type == ContextType.TASK:
|
||||||
|
return self._format_task(contexts)
|
||||||
|
elif context_type == ContextType.KNOWLEDGE:
|
||||||
|
return self._format_knowledge(contexts)
|
||||||
|
elif context_type == ContextType.CONVERSATION:
|
||||||
|
return self._format_conversation(contexts)
|
||||||
|
elif context_type == ContextType.TOOL:
|
||||||
|
return self._format_tool(contexts)
|
||||||
|
|
||||||
|
return "\n".join(c.content for c in contexts)
|
||||||
|
|
||||||
|
def _format_system(self, contexts: list[BaseContext]) -> str:
|
||||||
|
"""Format system contexts."""
|
||||||
|
content = "\n\n".join(c.content for c in contexts)
|
||||||
|
return content
|
||||||
|
|
||||||
|
def _format_task(self, contexts: list[BaseContext]) -> str:
|
||||||
|
"""Format task contexts."""
|
||||||
|
content = "\n\n".join(c.content for c in contexts)
|
||||||
|
return f"## Current Task\n\n{content}"
|
||||||
|
|
||||||
|
def _format_knowledge(self, contexts: list[BaseContext]) -> str:
|
||||||
|
"""
|
||||||
|
Format knowledge contexts as structured documents.
|
||||||
|
|
||||||
|
Each knowledge context becomes a section with source attribution.
|
||||||
|
"""
|
||||||
|
parts = ["## Reference Documents\n"]
|
||||||
|
|
||||||
|
for ctx in contexts:
|
||||||
|
source = ctx.source
|
||||||
|
score = ctx.metadata.get("score", ctx.metadata.get("relevance_score", ""))
|
||||||
|
|
||||||
|
if score:
|
||||||
|
parts.append(f"### Source: {source} (relevance: {score})\n")
|
||||||
|
else:
|
||||||
|
parts.append(f"### Source: {source}\n")
|
||||||
|
|
||||||
|
parts.append(ctx.content)
|
||||||
|
parts.append("")
|
||||||
|
|
||||||
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
def _format_conversation(self, contexts: list[BaseContext]) -> str:
|
||||||
|
"""
|
||||||
|
Format conversation contexts as message history.
|
||||||
|
|
||||||
|
Uses bold role labels for clear turn delineation.
|
||||||
|
"""
|
||||||
|
parts = []
|
||||||
|
|
||||||
|
for ctx in contexts:
|
||||||
|
role = ctx.metadata.get("role", "user").upper()
|
||||||
|
parts.append(f"**{role}**: {ctx.content}")
|
||||||
|
|
||||||
|
return "\n\n".join(parts)
|
||||||
|
|
||||||
|
def _format_tool(self, contexts: list[BaseContext]) -> str:
|
||||||
|
"""
|
||||||
|
Format tool contexts as tool results.
|
||||||
|
|
||||||
|
Each tool result is in a code block with the tool name.
|
||||||
|
"""
|
||||||
|
parts = ["## Recent Tool Results\n"]
|
||||||
|
|
||||||
|
for ctx in contexts:
|
||||||
|
tool_name = ctx.metadata.get("tool_name", "unknown")
|
||||||
|
status = ctx.metadata.get("status", "")
|
||||||
|
|
||||||
|
if status:
|
||||||
|
parts.append(f"### Tool: {tool_name} ({status})\n")
|
||||||
|
else:
|
||||||
|
parts.append(f"### Tool: {tool_name}\n")
|
||||||
|
|
||||||
|
parts.append(f"```\n{ctx.content}\n```")
|
||||||
|
parts.append("")
|
||||||
|
|
||||||
|
return "\n".join(parts)
|
||||||
12
backend/app/services/context/assembly/__init__.py
Normal file
12
backend/app/services/context/assembly/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
"""
|
||||||
|
Context Assembly Module.
|
||||||
|
|
||||||
|
Provides the assembly pipeline and formatting.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .pipeline import ContextPipeline, PipelineMetrics
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ContextPipeline",
|
||||||
|
"PipelineMetrics",
|
||||||
|
]
|
||||||
362
backend/app/services/context/assembly/pipeline.py
Normal file
362
backend/app/services/context/assembly/pipeline.py
Normal file
@@ -0,0 +1,362 @@
|
|||||||
|
"""
|
||||||
|
Context Assembly Pipeline.
|
||||||
|
|
||||||
|
Orchestrates the full context assembly workflow:
|
||||||
|
Gather → Count → Score → Rank → Compress → Format
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from ..adapters import get_adapter
|
||||||
|
from ..budget import BudgetAllocator, TokenBudget, TokenCalculator
|
||||||
|
from ..compression.truncation import ContextCompressor
|
||||||
|
from ..config import ContextSettings, get_context_settings
|
||||||
|
from ..exceptions import AssemblyTimeoutError
|
||||||
|
from ..prioritization import ContextRanker
|
||||||
|
from ..scoring import CompositeScorer
|
||||||
|
from ..types import AssembledContext, BaseContext, ContextType
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from app.services.mcp.client_manager import MCPClientManager
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PipelineMetrics:
|
||||||
|
"""Metrics from pipeline execution."""
|
||||||
|
|
||||||
|
start_time: datetime = field(default_factory=lambda: datetime.now(UTC))
|
||||||
|
end_time: datetime | None = None
|
||||||
|
total_contexts: int = 0
|
||||||
|
selected_contexts: int = 0
|
||||||
|
excluded_contexts: int = 0
|
||||||
|
compressed_contexts: int = 0
|
||||||
|
total_tokens: int = 0
|
||||||
|
assembly_time_ms: float = 0.0
|
||||||
|
scoring_time_ms: float = 0.0
|
||||||
|
compression_time_ms: float = 0.0
|
||||||
|
formatting_time_ms: float = 0.0
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert to dictionary."""
|
||||||
|
return {
|
||||||
|
"start_time": self.start_time.isoformat(),
|
||||||
|
"end_time": self.end_time.isoformat() if self.end_time else None,
|
||||||
|
"total_contexts": self.total_contexts,
|
||||||
|
"selected_contexts": self.selected_contexts,
|
||||||
|
"excluded_contexts": self.excluded_contexts,
|
||||||
|
"compressed_contexts": self.compressed_contexts,
|
||||||
|
"total_tokens": self.total_tokens,
|
||||||
|
"assembly_time_ms": round(self.assembly_time_ms, 2),
|
||||||
|
"scoring_time_ms": round(self.scoring_time_ms, 2),
|
||||||
|
"compression_time_ms": round(self.compression_time_ms, 2),
|
||||||
|
"formatting_time_ms": round(self.formatting_time_ms, 2),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ContextPipeline:
|
||||||
|
"""
|
||||||
|
Context assembly pipeline.
|
||||||
|
|
||||||
|
Orchestrates the full workflow of context assembly:
|
||||||
|
1. Validate and count tokens for all contexts
|
||||||
|
2. Score contexts based on relevance, recency, and priority
|
||||||
|
3. Rank and select contexts within budget
|
||||||
|
4. Compress if needed to fit remaining budget
|
||||||
|
5. Format for the target model
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
mcp_manager: "MCPClientManager | None" = None,
|
||||||
|
settings: ContextSettings | None = None,
|
||||||
|
calculator: TokenCalculator | None = None,
|
||||||
|
scorer: CompositeScorer | None = None,
|
||||||
|
ranker: ContextRanker | None = None,
|
||||||
|
compressor: ContextCompressor | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize the context pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mcp_manager: MCP client manager for LLM Gateway integration
|
||||||
|
settings: Context settings
|
||||||
|
calculator: Token calculator
|
||||||
|
scorer: Context scorer
|
||||||
|
ranker: Context ranker
|
||||||
|
compressor: Context compressor
|
||||||
|
"""
|
||||||
|
self._settings = settings or get_context_settings()
|
||||||
|
self._mcp = mcp_manager
|
||||||
|
|
||||||
|
# Initialize components
|
||||||
|
self._calculator = calculator or TokenCalculator(mcp_manager=mcp_manager)
|
||||||
|
self._scorer = scorer or CompositeScorer(
|
||||||
|
mcp_manager=mcp_manager, settings=self._settings
|
||||||
|
)
|
||||||
|
self._ranker = ranker or ContextRanker(
|
||||||
|
scorer=self._scorer, calculator=self._calculator
|
||||||
|
)
|
||||||
|
self._compressor = compressor or ContextCompressor(calculator=self._calculator)
|
||||||
|
self._allocator = BudgetAllocator(self._settings)
|
||||||
|
|
||||||
|
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
|
||||||
|
"""Set MCP manager for all components."""
|
||||||
|
self._mcp = mcp_manager
|
||||||
|
self._calculator.set_mcp_manager(mcp_manager)
|
||||||
|
self._scorer.set_mcp_manager(mcp_manager)
|
||||||
|
|
||||||
|
async def assemble(
|
||||||
|
self,
|
||||||
|
contexts: list[BaseContext],
|
||||||
|
query: str,
|
||||||
|
model: str,
|
||||||
|
max_tokens: int | None = None,
|
||||||
|
custom_budget: TokenBudget | None = None,
|
||||||
|
compress: bool = True,
|
||||||
|
format_output: bool = True,
|
||||||
|
timeout_ms: int | None = None,
|
||||||
|
) -> AssembledContext:
|
||||||
|
"""
|
||||||
|
Assemble context for an LLM request.
|
||||||
|
|
||||||
|
This is the main entry point for context assembly.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
contexts: List of contexts to assemble
|
||||||
|
query: Query to optimize for
|
||||||
|
model: Target model name
|
||||||
|
max_tokens: Maximum total tokens (uses model default if None)
|
||||||
|
custom_budget: Optional pre-configured budget
|
||||||
|
compress: Whether to compress oversized contexts
|
||||||
|
format_output: Whether to format the final output
|
||||||
|
timeout_ms: Maximum assembly time in milliseconds
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AssembledContext with optimized content
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssemblyTimeoutError: If assembly exceeds timeout
|
||||||
|
"""
|
||||||
|
timeout = timeout_ms or self._settings.max_assembly_time_ms
|
||||||
|
start = time.perf_counter()
|
||||||
|
metrics = PipelineMetrics(total_contexts=len(contexts))
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create or use budget
|
||||||
|
if custom_budget:
|
||||||
|
budget = custom_budget
|
||||||
|
elif max_tokens:
|
||||||
|
budget = self._allocator.create_budget(max_tokens)
|
||||||
|
else:
|
||||||
|
budget = self._allocator.create_budget_for_model(model)
|
||||||
|
|
||||||
|
# 1. Count tokens for all contexts (with timeout enforcement)
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
self._ensure_token_counts(contexts, model),
|
||||||
|
timeout=self._remaining_timeout(start, timeout),
|
||||||
|
)
|
||||||
|
except TimeoutError:
|
||||||
|
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||||
|
raise AssemblyTimeoutError(
|
||||||
|
message="Context assembly timed out during token counting",
|
||||||
|
elapsed_ms=elapsed_ms,
|
||||||
|
timeout_ms=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check timeout (handles edge case where operation finished just at limit)
|
||||||
|
self._check_timeout(start, timeout, "token counting")
|
||||||
|
|
||||||
|
# 2. Score and rank contexts (with timeout enforcement)
|
||||||
|
scoring_start = time.perf_counter()
|
||||||
|
try:
|
||||||
|
ranking_result = await asyncio.wait_for(
|
||||||
|
self._ranker.rank(
|
||||||
|
contexts=contexts,
|
||||||
|
query=query,
|
||||||
|
budget=budget,
|
||||||
|
model=model,
|
||||||
|
),
|
||||||
|
timeout=self._remaining_timeout(start, timeout),
|
||||||
|
)
|
||||||
|
except TimeoutError:
|
||||||
|
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||||
|
raise AssemblyTimeoutError(
|
||||||
|
message="Context assembly timed out during scoring/ranking",
|
||||||
|
elapsed_ms=elapsed_ms,
|
||||||
|
timeout_ms=timeout,
|
||||||
|
)
|
||||||
|
metrics.scoring_time_ms = (time.perf_counter() - scoring_start) * 1000
|
||||||
|
|
||||||
|
selected_contexts = ranking_result.selected_contexts
|
||||||
|
metrics.selected_contexts = len(selected_contexts)
|
||||||
|
metrics.excluded_contexts = len(ranking_result.excluded)
|
||||||
|
|
||||||
|
# Check timeout
|
||||||
|
self._check_timeout(start, timeout, "scoring")
|
||||||
|
|
||||||
|
# 3. Compress if needed and enabled (with timeout enforcement)
|
||||||
|
if compress and self._needs_compression(selected_contexts, budget):
|
||||||
|
compression_start = time.perf_counter()
|
||||||
|
try:
|
||||||
|
selected_contexts = await asyncio.wait_for(
|
||||||
|
self._compressor.compress_contexts(
|
||||||
|
selected_contexts, budget, model
|
||||||
|
),
|
||||||
|
timeout=self._remaining_timeout(start, timeout),
|
||||||
|
)
|
||||||
|
except TimeoutError:
|
||||||
|
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||||
|
raise AssemblyTimeoutError(
|
||||||
|
message="Context assembly timed out during compression",
|
||||||
|
elapsed_ms=elapsed_ms,
|
||||||
|
timeout_ms=timeout,
|
||||||
|
)
|
||||||
|
metrics.compression_time_ms = (
|
||||||
|
time.perf_counter() - compression_start
|
||||||
|
) * 1000
|
||||||
|
metrics.compressed_contexts = sum(
|
||||||
|
1 for c in selected_contexts if c.metadata.get("truncated", False)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check timeout
|
||||||
|
self._check_timeout(start, timeout, "compression")
|
||||||
|
|
||||||
|
# 4. Format output
|
||||||
|
formatting_start = time.perf_counter()
|
||||||
|
if format_output:
|
||||||
|
formatted_content = self._format_contexts(selected_contexts, model)
|
||||||
|
else:
|
||||||
|
formatted_content = "\n\n".join(c.content for c in selected_contexts)
|
||||||
|
metrics.formatting_time_ms = (time.perf_counter() - formatting_start) * 1000
|
||||||
|
|
||||||
|
# Calculate final metrics
|
||||||
|
total_tokens = sum(c.token_count or 0 for c in selected_contexts)
|
||||||
|
metrics.total_tokens = total_tokens
|
||||||
|
metrics.assembly_time_ms = (time.perf_counter() - start) * 1000
|
||||||
|
metrics.end_time = datetime.now(UTC)
|
||||||
|
|
||||||
|
return AssembledContext(
|
||||||
|
content=formatted_content,
|
||||||
|
total_tokens=total_tokens,
|
||||||
|
context_count=len(selected_contexts),
|
||||||
|
assembly_time_ms=metrics.assembly_time_ms,
|
||||||
|
model=model,
|
||||||
|
contexts=selected_contexts,
|
||||||
|
excluded_count=metrics.excluded_contexts,
|
||||||
|
metadata={
|
||||||
|
"metrics": metrics.to_dict(),
|
||||||
|
"query": query,
|
||||||
|
"budget": budget.to_dict(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
except AssemblyTimeoutError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Context assembly failed: {e}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _ensure_token_counts(
|
||||||
|
self,
|
||||||
|
contexts: list[BaseContext],
|
||||||
|
model: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Ensure all contexts have token counts."""
|
||||||
|
tasks = []
|
||||||
|
for context in contexts:
|
||||||
|
if context.token_count is None:
|
||||||
|
tasks.append(self._count_and_set(context, model))
|
||||||
|
|
||||||
|
if tasks:
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
async def _count_and_set(
|
||||||
|
self,
|
||||||
|
context: BaseContext,
|
||||||
|
model: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Count tokens and set on context."""
|
||||||
|
count = await self._calculator.count_tokens(context.content, model)
|
||||||
|
context.token_count = count
|
||||||
|
|
||||||
|
def _needs_compression(
|
||||||
|
self,
|
||||||
|
contexts: list[BaseContext],
|
||||||
|
budget: TokenBudget,
|
||||||
|
) -> bool:
|
||||||
|
"""Check if any contexts exceed their type budget."""
|
||||||
|
# Group by type and check totals
|
||||||
|
by_type: dict[ContextType, int] = {}
|
||||||
|
for context in contexts:
|
||||||
|
ct = context.get_type()
|
||||||
|
by_type[ct] = by_type.get(ct, 0) + (context.token_count or 0)
|
||||||
|
|
||||||
|
for ct, total in by_type.items():
|
||||||
|
if total > budget.get_allocation(ct):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Also check if utilization exceeds threshold
|
||||||
|
return budget.utilization() > self._settings.compression_threshold
|
||||||
|
|
||||||
|
def _format_contexts(
|
||||||
|
self,
|
||||||
|
contexts: list[BaseContext],
|
||||||
|
model: str,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Format contexts for the target model.
|
||||||
|
|
||||||
|
Uses model-specific adapters (ClaudeAdapter, OpenAIAdapter, etc.)
|
||||||
|
to format contexts optimally for each model family.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
contexts: Contexts to format
|
||||||
|
model: Target model name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted context string
|
||||||
|
"""
|
||||||
|
adapter = get_adapter(model)
|
||||||
|
return adapter.format(contexts)
|
||||||
|
|
||||||
|
def _check_timeout(
|
||||||
|
self,
|
||||||
|
start: float,
|
||||||
|
timeout_ms: int,
|
||||||
|
phase: str,
|
||||||
|
) -> None:
|
||||||
|
"""Check if timeout exceeded and raise if so."""
|
||||||
|
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||||
|
if elapsed_ms >= timeout_ms:
|
||||||
|
raise AssemblyTimeoutError(
|
||||||
|
message=f"Context assembly timed out during {phase}",
|
||||||
|
elapsed_ms=elapsed_ms,
|
||||||
|
timeout_ms=timeout_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _remaining_timeout(self, start: float, timeout_ms: int) -> float:
|
||||||
|
"""
|
||||||
|
Calculate remaining timeout in seconds for asyncio.wait_for.
|
||||||
|
|
||||||
|
Returns at least a small positive value to avoid immediate timeout
|
||||||
|
edge cases with wait_for.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start: Start time from time.perf_counter()
|
||||||
|
timeout_ms: Total timeout in milliseconds
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Remaining timeout in seconds (minimum 0.001)
|
||||||
|
"""
|
||||||
|
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||||
|
remaining_ms = timeout_ms - elapsed_ms
|
||||||
|
# Return at least 1ms to avoid zero/negative timeout edge cases
|
||||||
|
return max(remaining_ms / 1000.0, 0.001)
|
||||||
14
backend/app/services/context/budget/__init__.py
Normal file
14
backend/app/services/context/budget/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
"""
|
||||||
|
Token Budget Management Module.
|
||||||
|
|
||||||
|
Provides token counting and budget allocation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .allocator import BudgetAllocator, TokenBudget
|
||||||
|
from .calculator import TokenCalculator
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BudgetAllocator",
|
||||||
|
"TokenBudget",
|
||||||
|
"TokenCalculator",
|
||||||
|
]
|
||||||
444
backend/app/services/context/budget/allocator.py
Normal file
444
backend/app/services/context/budget/allocator.py
Normal file
@@ -0,0 +1,444 @@
|
|||||||
|
"""
|
||||||
|
Token Budget Allocator for Context Management.
|
||||||
|
|
||||||
|
Manages token budget allocation across context types.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from ..config import ContextSettings, get_context_settings
|
||||||
|
from ..exceptions import BudgetExceededError
|
||||||
|
from ..types import ContextType
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TokenBudget:
|
||||||
|
"""
|
||||||
|
Token budget allocation and tracking.
|
||||||
|
|
||||||
|
Tracks allocated tokens per context type and
|
||||||
|
monitors usage to prevent overflows.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Total budget
|
||||||
|
total: int
|
||||||
|
|
||||||
|
# Allocated per type
|
||||||
|
system: int = 0
|
||||||
|
task: int = 0
|
||||||
|
knowledge: int = 0
|
||||||
|
conversation: int = 0
|
||||||
|
tools: int = 0
|
||||||
|
memory: int = 0 # Agent memory (working, episodic, semantic, procedural)
|
||||||
|
response_reserve: int = 0
|
||||||
|
buffer: int = 0
|
||||||
|
|
||||||
|
# Usage tracking
|
||||||
|
used: dict[str, int] = field(default_factory=dict)
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
"""Initialize usage tracking."""
|
||||||
|
if not self.used:
|
||||||
|
self.used = {ct.value: 0 for ct in ContextType}
|
||||||
|
|
||||||
|
def get_allocation(self, context_type: ContextType | str) -> int:
|
||||||
|
"""
|
||||||
|
Get allocated tokens for a context type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context_type: Context type to get allocation for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Allocated token count
|
||||||
|
"""
|
||||||
|
if isinstance(context_type, ContextType):
|
||||||
|
context_type = context_type.value
|
||||||
|
|
||||||
|
allocation_map = {
|
||||||
|
"system": self.system,
|
||||||
|
"task": self.task,
|
||||||
|
"knowledge": self.knowledge,
|
||||||
|
"conversation": self.conversation,
|
||||||
|
"tool": self.tools,
|
||||||
|
"memory": self.memory,
|
||||||
|
}
|
||||||
|
return allocation_map.get(context_type, 0)
|
||||||
|
|
||||||
|
def get_used(self, context_type: ContextType | str) -> int:
|
||||||
|
"""
|
||||||
|
Get used tokens for a context type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context_type: Context type to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Used token count
|
||||||
|
"""
|
||||||
|
if isinstance(context_type, ContextType):
|
||||||
|
context_type = context_type.value
|
||||||
|
return self.used.get(context_type, 0)
|
||||||
|
|
||||||
|
def remaining(self, context_type: ContextType | str) -> int:
|
||||||
|
"""
|
||||||
|
Get remaining tokens for a context type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context_type: Context type to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Remaining token count
|
||||||
|
"""
|
||||||
|
allocated = self.get_allocation(context_type)
|
||||||
|
used = self.get_used(context_type)
|
||||||
|
return max(0, allocated - used)
|
||||||
|
|
||||||
|
def total_remaining(self) -> int:
|
||||||
|
"""
|
||||||
|
Get total remaining tokens across all types.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Total remaining tokens
|
||||||
|
"""
|
||||||
|
total_used = sum(self.used.values())
|
||||||
|
usable = self.total - self.response_reserve - self.buffer
|
||||||
|
return max(0, usable - total_used)
|
||||||
|
|
||||||
|
def total_used(self) -> int:
|
||||||
|
"""
|
||||||
|
Get total used tokens.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Total used tokens
|
||||||
|
"""
|
||||||
|
return sum(self.used.values())
|
||||||
|
|
||||||
|
def can_fit(self, context_type: ContextType | str, tokens: int) -> bool:
|
||||||
|
"""
|
||||||
|
Check if tokens fit within budget for a type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context_type: Context type to check
|
||||||
|
tokens: Number of tokens to fit
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if tokens fit within remaining budget
|
||||||
|
"""
|
||||||
|
return tokens <= self.remaining(context_type)
|
||||||
|
|
||||||
|
def allocate(
|
||||||
|
self,
|
||||||
|
context_type: ContextType | str,
|
||||||
|
tokens: int,
|
||||||
|
force: bool = False,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Allocate (use) tokens from a context type's budget.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context_type: Context type to allocate from
|
||||||
|
tokens: Number of tokens to allocate
|
||||||
|
force: If True, allow exceeding budget
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if allocation succeeded
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
BudgetExceededError: If tokens exceed budget and force=False
|
||||||
|
"""
|
||||||
|
if isinstance(context_type, ContextType):
|
||||||
|
context_type = context_type.value
|
||||||
|
|
||||||
|
if not force and not self.can_fit(context_type, tokens):
|
||||||
|
raise BudgetExceededError(
|
||||||
|
message=f"Token budget exceeded for {context_type}",
|
||||||
|
allocated=self.get_allocation(context_type),
|
||||||
|
requested=self.get_used(context_type) + tokens,
|
||||||
|
context_type=context_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.used[context_type] = self.used.get(context_type, 0) + tokens
|
||||||
|
return True
|
||||||
|
|
||||||
|
def deallocate(
|
||||||
|
self,
|
||||||
|
context_type: ContextType | str,
|
||||||
|
tokens: int,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Deallocate (return) tokens to a context type's budget.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context_type: Context type to return to
|
||||||
|
tokens: Number of tokens to return
|
||||||
|
"""
|
||||||
|
if isinstance(context_type, ContextType):
|
||||||
|
context_type = context_type.value
|
||||||
|
|
||||||
|
current = self.used.get(context_type, 0)
|
||||||
|
self.used[context_type] = max(0, current - tokens)
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""Reset all usage tracking."""
|
||||||
|
self.used = {ct.value: 0 for ct in ContextType}
|
||||||
|
|
||||||
|
def utilization(self, context_type: ContextType | str | None = None) -> float:
|
||||||
|
"""
|
||||||
|
Get budget utilization percentage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context_type: Specific type or None for total
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Utilization as a fraction (0.0 to 1.0+)
|
||||||
|
"""
|
||||||
|
if context_type is None:
|
||||||
|
usable = self.total - self.response_reserve - self.buffer
|
||||||
|
if usable <= 0:
|
||||||
|
return 0.0
|
||||||
|
return self.total_used() / usable
|
||||||
|
|
||||||
|
allocated = self.get_allocation(context_type)
|
||||||
|
if allocated <= 0:
|
||||||
|
return 0.0
|
||||||
|
return self.get_used(context_type) / allocated
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert budget to dictionary."""
|
||||||
|
return {
|
||||||
|
"total": self.total,
|
||||||
|
"allocations": {
|
||||||
|
"system": self.system,
|
||||||
|
"task": self.task,
|
||||||
|
"knowledge": self.knowledge,
|
||||||
|
"conversation": self.conversation,
|
||||||
|
"tools": self.tools,
|
||||||
|
"memory": self.memory,
|
||||||
|
"response_reserve": self.response_reserve,
|
||||||
|
"buffer": self.buffer,
|
||||||
|
},
|
||||||
|
"used": dict(self.used),
|
||||||
|
"remaining": {ct.value: self.remaining(ct) for ct in ContextType},
|
||||||
|
"total_used": self.total_used(),
|
||||||
|
"total_remaining": self.total_remaining(),
|
||||||
|
"utilization": round(self.utilization(), 3),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class BudgetAllocator:
|
||||||
|
"""
|
||||||
|
Budget allocator for context management.
|
||||||
|
|
||||||
|
Creates token budgets based on configuration and
|
||||||
|
model context window sizes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, settings: ContextSettings | None = None) -> None:
|
||||||
|
"""
|
||||||
|
Initialize budget allocator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
settings: Context settings (uses default if None)
|
||||||
|
"""
|
||||||
|
self._settings = settings or get_context_settings()
|
||||||
|
|
||||||
|
def create_budget(
|
||||||
|
self,
|
||||||
|
total_tokens: int,
|
||||||
|
custom_allocations: dict[str, float] | None = None,
|
||||||
|
) -> TokenBudget:
|
||||||
|
"""
|
||||||
|
Create a token budget with allocations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
total_tokens: Total available tokens
|
||||||
|
custom_allocations: Optional custom allocation percentages
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TokenBudget with allocations set
|
||||||
|
"""
|
||||||
|
# Use custom or default allocations
|
||||||
|
if custom_allocations:
|
||||||
|
alloc = custom_allocations
|
||||||
|
else:
|
||||||
|
alloc = self._settings.get_budget_allocation()
|
||||||
|
|
||||||
|
return TokenBudget(
|
||||||
|
total=total_tokens,
|
||||||
|
system=int(total_tokens * alloc.get("system", 0.05)),
|
||||||
|
task=int(total_tokens * alloc.get("task", 0.10)),
|
||||||
|
knowledge=int(total_tokens * alloc.get("knowledge", 0.30)),
|
||||||
|
conversation=int(total_tokens * alloc.get("conversation", 0.15)),
|
||||||
|
tools=int(total_tokens * alloc.get("tools", 0.05)),
|
||||||
|
memory=int(total_tokens * alloc.get("memory", 0.15)),
|
||||||
|
response_reserve=int(total_tokens * alloc.get("response", 0.15)),
|
||||||
|
buffer=int(total_tokens * alloc.get("buffer", 0.05)),
|
||||||
|
)
|
||||||
|
|
||||||
|
def adjust_budget(
|
||||||
|
self,
|
||||||
|
budget: TokenBudget,
|
||||||
|
context_type: ContextType | str,
|
||||||
|
adjustment: int,
|
||||||
|
) -> TokenBudget:
|
||||||
|
"""
|
||||||
|
Adjust a specific allocation in a budget.
|
||||||
|
|
||||||
|
Takes tokens from buffer and adds to specified type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
budget: Budget to adjust
|
||||||
|
context_type: Type to adjust
|
||||||
|
adjustment: Positive to increase, negative to decrease
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Adjusted budget
|
||||||
|
"""
|
||||||
|
if isinstance(context_type, ContextType):
|
||||||
|
context_type = context_type.value
|
||||||
|
|
||||||
|
# Calculate adjustment (limited by buffer for increases, by current allocation for decreases)
|
||||||
|
if adjustment > 0:
|
||||||
|
# Taking from buffer - limited by available buffer
|
||||||
|
actual_adjustment = min(adjustment, budget.buffer)
|
||||||
|
budget.buffer -= actual_adjustment
|
||||||
|
else:
|
||||||
|
# Returning to buffer - limited by current allocation of target type
|
||||||
|
current_allocation = budget.get_allocation(context_type)
|
||||||
|
# Can't return more than current allocation
|
||||||
|
actual_adjustment = max(adjustment, -current_allocation)
|
||||||
|
# Add returned tokens back to buffer (adjustment is negative, so subtract)
|
||||||
|
budget.buffer -= actual_adjustment
|
||||||
|
|
||||||
|
# Apply to target type
|
||||||
|
if context_type == "system":
|
||||||
|
budget.system = max(0, budget.system + actual_adjustment)
|
||||||
|
elif context_type == "task":
|
||||||
|
budget.task = max(0, budget.task + actual_adjustment)
|
||||||
|
elif context_type == "knowledge":
|
||||||
|
budget.knowledge = max(0, budget.knowledge + actual_adjustment)
|
||||||
|
elif context_type == "conversation":
|
||||||
|
budget.conversation = max(0, budget.conversation + actual_adjustment)
|
||||||
|
elif context_type == "tool":
|
||||||
|
budget.tools = max(0, budget.tools + actual_adjustment)
|
||||||
|
elif context_type == "memory":
|
||||||
|
budget.memory = max(0, budget.memory + actual_adjustment)
|
||||||
|
|
||||||
|
return budget
|
||||||
|
|
||||||
|
def rebalance_budget(
|
||||||
|
self,
|
||||||
|
budget: TokenBudget,
|
||||||
|
prioritize: list[ContextType] | None = None,
|
||||||
|
) -> TokenBudget:
|
||||||
|
"""
|
||||||
|
Rebalance budget based on actual usage.
|
||||||
|
|
||||||
|
Moves unused allocations to prioritized types.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
budget: Budget to rebalance
|
||||||
|
prioritize: Types to prioritize (in order)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Rebalanced budget
|
||||||
|
"""
|
||||||
|
if prioritize is None:
|
||||||
|
prioritize = [
|
||||||
|
ContextType.KNOWLEDGE,
|
||||||
|
ContextType.MEMORY,
|
||||||
|
ContextType.TASK,
|
||||||
|
ContextType.SYSTEM,
|
||||||
|
]
|
||||||
|
|
||||||
|
# Calculate unused tokens per type
|
||||||
|
unused: dict[str, int] = {}
|
||||||
|
for ct in ContextType:
|
||||||
|
remaining = budget.remaining(ct)
|
||||||
|
if remaining > 0:
|
||||||
|
unused[ct.value] = remaining
|
||||||
|
|
||||||
|
# Calculate total reclaimable (excluding prioritized types)
|
||||||
|
prioritize_values = {ct.value for ct in prioritize}
|
||||||
|
reclaimable = sum(
|
||||||
|
tokens for ct, tokens in unused.items() if ct not in prioritize_values
|
||||||
|
)
|
||||||
|
|
||||||
|
# Redistribute to prioritized types that are near capacity
|
||||||
|
for ct in prioritize:
|
||||||
|
utilization = budget.utilization(ct)
|
||||||
|
|
||||||
|
if utilization > 0.8: # Near capacity
|
||||||
|
# Give more tokens from reclaimable pool
|
||||||
|
bonus = min(reclaimable, budget.get_allocation(ct) // 2)
|
||||||
|
self.adjust_budget(budget, ct, bonus)
|
||||||
|
reclaimable -= bonus
|
||||||
|
|
||||||
|
if reclaimable <= 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
return budget
|
||||||
|
|
||||||
|
def get_model_context_size(self, model: str) -> int:
|
||||||
|
"""
|
||||||
|
Get context window size for a model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Context window size in tokens
|
||||||
|
"""
|
||||||
|
# Common model context sizes
|
||||||
|
context_sizes = {
|
||||||
|
"claude-3-opus": 200000,
|
||||||
|
"claude-3-sonnet": 200000,
|
||||||
|
"claude-3-haiku": 200000,
|
||||||
|
"claude-3-5-sonnet": 200000,
|
||||||
|
"claude-3-5-haiku": 200000,
|
||||||
|
"claude-opus-4": 200000,
|
||||||
|
"gpt-4-turbo": 128000,
|
||||||
|
"gpt-4": 8192,
|
||||||
|
"gpt-4-32k": 32768,
|
||||||
|
"gpt-4o": 128000,
|
||||||
|
"gpt-4o-mini": 128000,
|
||||||
|
"gpt-3.5-turbo": 16385,
|
||||||
|
"gemini-1.5-pro": 2000000,
|
||||||
|
"gemini-1.5-flash": 1000000,
|
||||||
|
"gemini-2.0-flash": 1000000,
|
||||||
|
"qwen-plus": 32000,
|
||||||
|
"qwen-turbo": 8000,
|
||||||
|
"deepseek-chat": 64000,
|
||||||
|
"deepseek-reasoner": 64000,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Check exact match first
|
||||||
|
model_lower = model.lower()
|
||||||
|
if model_lower in context_sizes:
|
||||||
|
return context_sizes[model_lower]
|
||||||
|
|
||||||
|
# Check prefix match
|
||||||
|
for model_name, size in context_sizes.items():
|
||||||
|
if model_lower.startswith(model_name):
|
||||||
|
return size
|
||||||
|
|
||||||
|
# Default fallback
|
||||||
|
return 8192
|
||||||
|
|
||||||
|
def create_budget_for_model(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
custom_allocations: dict[str, float] | None = None,
|
||||||
|
) -> TokenBudget:
|
||||||
|
"""
|
||||||
|
Create a budget based on model's context window.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model name
|
||||||
|
custom_allocations: Optional custom allocation percentages
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TokenBudget sized for the model
|
||||||
|
"""
|
||||||
|
context_size = self.get_model_context_size(model)
|
||||||
|
return self.create_budget(context_size, custom_allocations)
|
||||||
285
backend/app/services/context/budget/calculator.py
Normal file
285
backend/app/services/context/budget/calculator.py
Normal file
@@ -0,0 +1,285 @@
|
|||||||
|
"""
|
||||||
|
Token Calculator for Context Management.
|
||||||
|
|
||||||
|
Provides token counting with caching and fallback estimation.
|
||||||
|
Integrates with LLM Gateway for accurate counts.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING, Any, ClassVar, Protocol
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from app.services.mcp.client_manager import MCPClientManager
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TokenCounterProtocol(Protocol):
|
||||||
|
"""Protocol for token counting implementations."""
|
||||||
|
|
||||||
|
async def count_tokens(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
model: str | None = None,
|
||||||
|
) -> int:
|
||||||
|
"""Count tokens in text."""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class TokenCalculator:
|
||||||
|
"""
|
||||||
|
Token calculator with LLM Gateway integration.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- In-memory caching for repeated text
|
||||||
|
- Fallback to character-based estimation
|
||||||
|
- Model-specific counting when possible
|
||||||
|
|
||||||
|
The calculator uses the LLM Gateway's count_tokens tool
|
||||||
|
for accurate counting, with a local cache to avoid
|
||||||
|
repeated calls for the same content.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Default characters per token ratio for estimation
|
||||||
|
DEFAULT_CHARS_PER_TOKEN: ClassVar[float] = 4.0
|
||||||
|
|
||||||
|
# Model-specific ratios (more accurate estimation)
|
||||||
|
MODEL_CHAR_RATIOS: ClassVar[dict[str, float]] = {
|
||||||
|
"claude": 3.5,
|
||||||
|
"gpt-4": 4.0,
|
||||||
|
"gpt-3.5": 4.0,
|
||||||
|
"gemini": 4.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
mcp_manager: "MCPClientManager | None" = None,
|
||||||
|
project_id: str = "system",
|
||||||
|
agent_id: str = "context-engine",
|
||||||
|
cache_enabled: bool = True,
|
||||||
|
cache_max_size: int = 10000,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize token calculator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mcp_manager: MCP client manager for LLM Gateway calls
|
||||||
|
project_id: Project ID for LLM Gateway calls
|
||||||
|
agent_id: Agent ID for LLM Gateway calls
|
||||||
|
cache_enabled: Whether to enable in-memory caching
|
||||||
|
cache_max_size: Maximum cache entries
|
||||||
|
"""
|
||||||
|
self._mcp = mcp_manager
|
||||||
|
self._project_id = project_id
|
||||||
|
self._agent_id = agent_id
|
||||||
|
self._cache_enabled = cache_enabled
|
||||||
|
self._cache_max_size = cache_max_size
|
||||||
|
|
||||||
|
# In-memory cache: hash(model:text) -> token_count
|
||||||
|
self._cache: dict[str, int] = {}
|
||||||
|
self._cache_hits = 0
|
||||||
|
self._cache_misses = 0
|
||||||
|
|
||||||
|
def _get_cache_key(self, text: str, model: str | None) -> str:
|
||||||
|
"""Generate cache key from text and model."""
|
||||||
|
# Use hash for efficient storage
|
||||||
|
content = f"{model or 'default'}:{text}"
|
||||||
|
return hashlib.sha256(content.encode()).hexdigest()[:32]
|
||||||
|
|
||||||
|
def _check_cache(self, cache_key: str) -> int | None:
|
||||||
|
"""Check cache for existing count."""
|
||||||
|
if not self._cache_enabled:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if cache_key in self._cache:
|
||||||
|
self._cache_hits += 1
|
||||||
|
return self._cache[cache_key]
|
||||||
|
|
||||||
|
self._cache_misses += 1
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _store_cache(self, cache_key: str, count: int) -> None:
|
||||||
|
"""Store count in cache."""
|
||||||
|
if not self._cache_enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Simple LRU-like eviction: remove oldest entries when full
|
||||||
|
if len(self._cache) >= self._cache_max_size:
|
||||||
|
# Remove first 10% of entries
|
||||||
|
entries_to_remove = self._cache_max_size // 10
|
||||||
|
keys_to_remove = list(self._cache.keys())[:entries_to_remove]
|
||||||
|
for key in keys_to_remove:
|
||||||
|
del self._cache[key]
|
||||||
|
|
||||||
|
self._cache[cache_key] = count
|
||||||
|
|
||||||
|
def estimate_tokens(self, text: str, model: str | None = None) -> int:
|
||||||
|
"""
|
||||||
|
Estimate token count based on character count.
|
||||||
|
|
||||||
|
This is a fast fallback when LLM Gateway is unavailable.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text to count
|
||||||
|
model: Optional model for more accurate ratio
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Estimated token count
|
||||||
|
"""
|
||||||
|
if not text:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# Get model-specific ratio
|
||||||
|
ratio = self.DEFAULT_CHARS_PER_TOKEN
|
||||||
|
if model:
|
||||||
|
model_lower = model.lower()
|
||||||
|
for model_prefix, model_ratio in self.MODEL_CHAR_RATIOS.items():
|
||||||
|
if model_prefix in model_lower:
|
||||||
|
ratio = model_ratio
|
||||||
|
break
|
||||||
|
|
||||||
|
return max(1, int(len(text) / ratio))
|
||||||
|
|
||||||
|
async def count_tokens(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
model: str | None = None,
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Count tokens in text.
|
||||||
|
|
||||||
|
Uses LLM Gateway for accurate counts with fallback to estimation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text to count
|
||||||
|
model: Optional model for accurate counting
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Token count
|
||||||
|
"""
|
||||||
|
if not text:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# Check cache first
|
||||||
|
cache_key = self._get_cache_key(text, model)
|
||||||
|
cached = self._check_cache(cache_key)
|
||||||
|
if cached is not None:
|
||||||
|
return cached
|
||||||
|
|
||||||
|
# Try LLM Gateway
|
||||||
|
if self._mcp is not None:
|
||||||
|
try:
|
||||||
|
result = await self._mcp.call_tool(
|
||||||
|
server="llm-gateway",
|
||||||
|
tool="count_tokens",
|
||||||
|
args={
|
||||||
|
"project_id": self._project_id,
|
||||||
|
"agent_id": self._agent_id,
|
||||||
|
"text": text,
|
||||||
|
"model": model,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse result
|
||||||
|
if result.success and result.data:
|
||||||
|
count = self._parse_token_count(result.data)
|
||||||
|
if count is not None:
|
||||||
|
self._store_cache(cache_key, count)
|
||||||
|
return count
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"LLM Gateway token count failed, using estimation: {e}")
|
||||||
|
|
||||||
|
# Fallback to estimation
|
||||||
|
count = self.estimate_tokens(text, model)
|
||||||
|
self._store_cache(cache_key, count)
|
||||||
|
return count
|
||||||
|
|
||||||
|
def _parse_token_count(self, data: Any) -> int | None:
|
||||||
|
"""Parse token count from LLM Gateway response."""
|
||||||
|
if isinstance(data, dict):
|
||||||
|
if "token_count" in data:
|
||||||
|
return int(data["token_count"])
|
||||||
|
if "tokens" in data:
|
||||||
|
return int(data["tokens"])
|
||||||
|
if "count" in data:
|
||||||
|
return int(data["count"])
|
||||||
|
|
||||||
|
if isinstance(data, int):
|
||||||
|
return data
|
||||||
|
|
||||||
|
if isinstance(data, str):
|
||||||
|
# Try to parse from text content
|
||||||
|
try:
|
||||||
|
# Handle {"token_count": 123} or just "123"
|
||||||
|
import json
|
||||||
|
|
||||||
|
parsed = json.loads(data)
|
||||||
|
if isinstance(parsed, dict) and "token_count" in parsed:
|
||||||
|
return int(parsed["token_count"])
|
||||||
|
if isinstance(parsed, int):
|
||||||
|
return parsed
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
# Try direct int conversion
|
||||||
|
try:
|
||||||
|
return int(data)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def count_tokens_batch(
|
||||||
|
self,
|
||||||
|
texts: list[str],
|
||||||
|
model: str | None = None,
|
||||||
|
) -> list[int]:
|
||||||
|
"""
|
||||||
|
Count tokens for multiple texts.
|
||||||
|
|
||||||
|
Efficient batch counting with caching and parallel execution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: List of texts to count
|
||||||
|
model: Optional model for accurate counting
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of token counts (same order as input)
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
if not texts:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Execute all token counts in parallel for better performance
|
||||||
|
tasks = [self.count_tokens(text, model) for text in texts]
|
||||||
|
return await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
def clear_cache(self) -> None:
|
||||||
|
"""Clear the token count cache."""
|
||||||
|
self._cache.clear()
|
||||||
|
self._cache_hits = 0
|
||||||
|
self._cache_misses = 0
|
||||||
|
|
||||||
|
def get_cache_stats(self) -> dict[str, Any]:
|
||||||
|
"""Get cache statistics."""
|
||||||
|
total = self._cache_hits + self._cache_misses
|
||||||
|
hit_rate = self._cache_hits / total if total > 0 else 0.0
|
||||||
|
|
||||||
|
return {
|
||||||
|
"enabled": self._cache_enabled,
|
||||||
|
"size": len(self._cache),
|
||||||
|
"max_size": self._cache_max_size,
|
||||||
|
"hits": self._cache_hits,
|
||||||
|
"misses": self._cache_misses,
|
||||||
|
"hit_rate": round(hit_rate, 3),
|
||||||
|
}
|
||||||
|
|
||||||
|
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
|
||||||
|
"""
|
||||||
|
Set the MCP manager (for lazy initialization).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mcp_manager: MCP client manager instance
|
||||||
|
"""
|
||||||
|
self._mcp = mcp_manager
|
||||||
11
backend/app/services/context/cache/__init__.py
vendored
Normal file
11
backend/app/services/context/cache/__init__.py
vendored
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
"""
|
||||||
|
Context Cache Module.
|
||||||
|
|
||||||
|
Provides Redis-based caching for assembled contexts.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .context_cache import ContextCache
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ContextCache",
|
||||||
|
]
|
||||||
434
backend/app/services/context/cache/context_cache.py
vendored
Normal file
434
backend/app/services/context/cache/context_cache.py
vendored
Normal file
@@ -0,0 +1,434 @@
|
|||||||
|
"""
|
||||||
|
Context Cache Implementation.
|
||||||
|
|
||||||
|
Provides Redis-based caching for context operations including
|
||||||
|
assembled contexts, token counts, and scoring results.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from ..config import ContextSettings, get_context_settings
|
||||||
|
from ..exceptions import CacheError
|
||||||
|
from ..types import AssembledContext, BaseContext
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from redis.asyncio import Redis
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ContextCache:
|
||||||
|
"""
|
||||||
|
Redis-based caching for context operations.
|
||||||
|
|
||||||
|
Provides caching for:
|
||||||
|
- Assembled contexts (fingerprint-based)
|
||||||
|
- Token counts (content hash-based)
|
||||||
|
- Scoring results (context + query hash-based)
|
||||||
|
|
||||||
|
Cache keys use a hierarchical structure:
|
||||||
|
- ctx:assembled:{fingerprint}
|
||||||
|
- ctx:tokens:{model}:{content_hash}
|
||||||
|
- ctx:score:{scorer}:{context_hash}:{query_hash}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
redis: "Redis | None" = None,
|
||||||
|
settings: ContextSettings | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize the context cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
redis: Redis connection (optional for testing)
|
||||||
|
settings: Cache settings
|
||||||
|
"""
|
||||||
|
self._redis = redis
|
||||||
|
self._settings = settings or get_context_settings()
|
||||||
|
self._prefix = self._settings.cache_prefix
|
||||||
|
self._ttl = self._settings.cache_ttl_seconds
|
||||||
|
|
||||||
|
# In-memory fallback cache when Redis unavailable
|
||||||
|
self._memory_cache: dict[str, tuple[str, float]] = {}
|
||||||
|
self._max_memory_items = self._settings.cache_memory_max_items
|
||||||
|
|
||||||
|
def set_redis(self, redis: "Redis") -> None:
|
||||||
|
"""Set Redis connection."""
|
||||||
|
self._redis = redis
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_enabled(self) -> bool:
|
||||||
|
"""Check if caching is enabled and available."""
|
||||||
|
return self._settings.cache_enabled and self._redis is not None
|
||||||
|
|
||||||
|
def _cache_key(self, *parts: str) -> str:
|
||||||
|
"""
|
||||||
|
Build a cache key from parts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*parts: Key components
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Colon-separated cache key
|
||||||
|
"""
|
||||||
|
return f"{self._prefix}:{':'.join(parts)}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _hash_content(content: str) -> str:
|
||||||
|
"""
|
||||||
|
Compute hash of content for cache key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Content to hash
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
32-character hex hash
|
||||||
|
"""
|
||||||
|
return hashlib.sha256(content.encode()).hexdigest()[:32]
|
||||||
|
|
||||||
|
def compute_fingerprint(
|
||||||
|
self,
|
||||||
|
contexts: list[BaseContext],
|
||||||
|
query: str,
|
||||||
|
model: str,
|
||||||
|
project_id: str | None = None,
|
||||||
|
agent_id: str | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Compute a fingerprint for a context assembly request.
|
||||||
|
|
||||||
|
The fingerprint is based on:
|
||||||
|
- Project and agent IDs (for tenant isolation)
|
||||||
|
- Context content hash and metadata (not full content for performance)
|
||||||
|
- Query string
|
||||||
|
- Target model
|
||||||
|
|
||||||
|
SECURITY: project_id and agent_id MUST be included to prevent
|
||||||
|
cross-tenant cache pollution. Without these, one tenant could
|
||||||
|
receive cached contexts from another tenant with the same query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
contexts: List of contexts
|
||||||
|
query: Query string
|
||||||
|
model: Model name
|
||||||
|
project_id: Project ID for tenant isolation
|
||||||
|
agent_id: Agent ID for tenant isolation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
32-character hex fingerprint
|
||||||
|
"""
|
||||||
|
# Build a deterministic representation using content hashes for performance
|
||||||
|
# This avoids JSON serializing potentially large content strings
|
||||||
|
context_data = []
|
||||||
|
for ctx in contexts:
|
||||||
|
context_data.append(
|
||||||
|
{
|
||||||
|
"type": ctx.get_type().value,
|
||||||
|
"content_hash": self._hash_content(
|
||||||
|
ctx.content
|
||||||
|
), # Hash instead of full content
|
||||||
|
"source": ctx.source,
|
||||||
|
"priority": ctx.priority, # Already an int
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
data = {
|
||||||
|
# CRITICAL: Include tenant identifiers for cache isolation
|
||||||
|
"project_id": project_id or "",
|
||||||
|
"agent_id": agent_id or "",
|
||||||
|
"contexts": context_data,
|
||||||
|
"query": query,
|
||||||
|
"model": model,
|
||||||
|
}
|
||||||
|
|
||||||
|
content = json.dumps(data, sort_keys=True)
|
||||||
|
return self._hash_content(content)
|
||||||
|
|
||||||
|
async def get_assembled(
|
||||||
|
self,
|
||||||
|
fingerprint: str,
|
||||||
|
) -> AssembledContext | None:
|
||||||
|
"""
|
||||||
|
Get cached assembled context by fingerprint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fingerprint: Assembly fingerprint
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cached AssembledContext or None if not found
|
||||||
|
"""
|
||||||
|
if not self.is_enabled:
|
||||||
|
return None
|
||||||
|
|
||||||
|
key = self._cache_key("assembled", fingerprint)
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = await self._redis.get(key) # type: ignore
|
||||||
|
if data:
|
||||||
|
logger.debug(f"Cache hit for assembled context: {fingerprint}")
|
||||||
|
result = AssembledContext.from_json(data)
|
||||||
|
result.cache_hit = True
|
||||||
|
result.cache_key = fingerprint
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Cache get error: {e}")
|
||||||
|
raise CacheError(f"Failed to get assembled context: {e}") from e
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def set_assembled(
|
||||||
|
self,
|
||||||
|
fingerprint: str,
|
||||||
|
context: AssembledContext,
|
||||||
|
ttl: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Cache an assembled context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fingerprint: Assembly fingerprint
|
||||||
|
context: Assembled context to cache
|
||||||
|
ttl: Optional TTL override in seconds
|
||||||
|
"""
|
||||||
|
if not self.is_enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
key = self._cache_key("assembled", fingerprint)
|
||||||
|
expire = ttl or self._ttl
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self._redis.setex(key, expire, context.to_json()) # type: ignore
|
||||||
|
logger.debug(f"Cached assembled context: {fingerprint}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Cache set error: {e}")
|
||||||
|
raise CacheError(f"Failed to cache assembled context: {e}") from e
|
||||||
|
|
||||||
|
async def get_token_count(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
model: str | None = None,
|
||||||
|
) -> int | None:
|
||||||
|
"""
|
||||||
|
Get cached token count.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Content to look up
|
||||||
|
model: Model name for model-specific tokenization
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cached token count or None if not found
|
||||||
|
"""
|
||||||
|
model_key = model or "default"
|
||||||
|
content_hash = self._hash_content(content)
|
||||||
|
key = self._cache_key("tokens", model_key, content_hash)
|
||||||
|
|
||||||
|
# Try in-memory first
|
||||||
|
if key in self._memory_cache:
|
||||||
|
return int(self._memory_cache[key][0])
|
||||||
|
|
||||||
|
if not self.is_enabled:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = await self._redis.get(key) # type: ignore
|
||||||
|
if data:
|
||||||
|
count = int(data)
|
||||||
|
# Store in memory for faster subsequent access
|
||||||
|
self._set_memory(key, str(count))
|
||||||
|
return count
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Cache get error for tokens: {e}")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def set_token_count(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
count: int,
|
||||||
|
model: str | None = None,
|
||||||
|
ttl: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Cache a token count.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Content that was counted
|
||||||
|
count: Token count
|
||||||
|
model: Model name
|
||||||
|
ttl: Optional TTL override in seconds
|
||||||
|
"""
|
||||||
|
model_key = model or "default"
|
||||||
|
content_hash = self._hash_content(content)
|
||||||
|
key = self._cache_key("tokens", model_key, content_hash)
|
||||||
|
expire = ttl or self._ttl
|
||||||
|
|
||||||
|
# Always store in memory
|
||||||
|
self._set_memory(key, str(count))
|
||||||
|
|
||||||
|
if not self.is_enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self._redis.setex(key, expire, str(count)) # type: ignore
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Cache set error for tokens: {e}")
|
||||||
|
|
||||||
|
async def get_score(
|
||||||
|
self,
|
||||||
|
scorer_name: str,
|
||||||
|
context_id: str,
|
||||||
|
query: str,
|
||||||
|
) -> float | None:
|
||||||
|
"""
|
||||||
|
Get cached score.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scorer_name: Name of the scorer
|
||||||
|
context_id: Context identifier
|
||||||
|
query: Query string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cached score or None if not found
|
||||||
|
"""
|
||||||
|
query_hash = self._hash_content(query)[:16]
|
||||||
|
key = self._cache_key("score", scorer_name, context_id, query_hash)
|
||||||
|
|
||||||
|
# Try in-memory first
|
||||||
|
if key in self._memory_cache:
|
||||||
|
return float(self._memory_cache[key][0])
|
||||||
|
|
||||||
|
if not self.is_enabled:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = await self._redis.get(key) # type: ignore
|
||||||
|
if data:
|
||||||
|
score = float(data)
|
||||||
|
self._set_memory(key, str(score))
|
||||||
|
return score
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Cache get error for score: {e}")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def set_score(
|
||||||
|
self,
|
||||||
|
scorer_name: str,
|
||||||
|
context_id: str,
|
||||||
|
query: str,
|
||||||
|
score: float,
|
||||||
|
ttl: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Cache a score.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scorer_name: Name of the scorer
|
||||||
|
context_id: Context identifier
|
||||||
|
query: Query string
|
||||||
|
score: Score value
|
||||||
|
ttl: Optional TTL override in seconds
|
||||||
|
"""
|
||||||
|
query_hash = self._hash_content(query)[:16]
|
||||||
|
key = self._cache_key("score", scorer_name, context_id, query_hash)
|
||||||
|
expire = ttl or self._ttl
|
||||||
|
|
||||||
|
# Always store in memory
|
||||||
|
self._set_memory(key, str(score))
|
||||||
|
|
||||||
|
if not self.is_enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self._redis.setex(key, expire, str(score)) # type: ignore
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Cache set error for score: {e}")
|
||||||
|
|
||||||
|
async def invalidate(self, pattern: str) -> int:
|
||||||
|
"""
|
||||||
|
Invalidate cache entries matching a pattern.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pattern: Key pattern (supports * wildcard)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of keys deleted
|
||||||
|
"""
|
||||||
|
if not self.is_enabled:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
full_pattern = self._cache_key(pattern)
|
||||||
|
deleted = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for key in self._redis.scan_iter(match=full_pattern): # type: ignore
|
||||||
|
await self._redis.delete(key) # type: ignore
|
||||||
|
deleted += 1
|
||||||
|
|
||||||
|
logger.info(f"Invalidated {deleted} cache entries matching {pattern}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Cache invalidation error: {e}")
|
||||||
|
raise CacheError(f"Failed to invalidate cache: {e}") from e
|
||||||
|
|
||||||
|
return deleted
|
||||||
|
|
||||||
|
async def clear_all(self) -> int:
|
||||||
|
"""
|
||||||
|
Clear all context cache entries.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of keys deleted
|
||||||
|
"""
|
||||||
|
self._memory_cache.clear()
|
||||||
|
return await self.invalidate("*")
|
||||||
|
|
||||||
|
def _set_memory(self, key: str, value: str) -> None:
|
||||||
|
"""
|
||||||
|
Set a value in the memory cache.
|
||||||
|
|
||||||
|
Uses LRU-style eviction when max items reached.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Cache key
|
||||||
|
value: Value to store
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
if len(self._memory_cache) >= self._max_memory_items:
|
||||||
|
# Evict oldest entries
|
||||||
|
sorted_keys = sorted(
|
||||||
|
self._memory_cache.keys(),
|
||||||
|
key=lambda k: self._memory_cache[k][1],
|
||||||
|
)
|
||||||
|
for k in sorted_keys[: len(sorted_keys) // 2]:
|
||||||
|
del self._memory_cache[k]
|
||||||
|
|
||||||
|
self._memory_cache[key] = (value, time.time())
|
||||||
|
|
||||||
|
async def get_stats(self) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get cache statistics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with cache stats
|
||||||
|
"""
|
||||||
|
stats = {
|
||||||
|
"enabled": self._settings.cache_enabled,
|
||||||
|
"redis_available": self._redis is not None,
|
||||||
|
"memory_items": len(self._memory_cache),
|
||||||
|
"ttl_seconds": self._ttl,
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.is_enabled:
|
||||||
|
try:
|
||||||
|
# Get Redis info
|
||||||
|
info = await self._redis.info("memory") # type: ignore
|
||||||
|
stats["redis_memory_used"] = info.get("used_memory_human", "unknown")
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Failed to get Redis stats: {e}")
|
||||||
|
|
||||||
|
return stats
|
||||||
13
backend/app/services/context/compression/__init__.py
Normal file
13
backend/app/services/context/compression/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
"""
|
||||||
|
Context Compression Module.
|
||||||
|
|
||||||
|
Provides truncation and compression strategies.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .truncation import ContextCompressor, TruncationResult, TruncationStrategy
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ContextCompressor",
|
||||||
|
"TruncationResult",
|
||||||
|
"TruncationStrategy",
|
||||||
|
]
|
||||||
453
backend/app/services/context/compression/truncation.py
Normal file
453
backend/app/services/context/compression/truncation.py
Normal file
@@ -0,0 +1,453 @@
|
|||||||
|
"""
|
||||||
|
Smart Truncation for Context Compression.
|
||||||
|
|
||||||
|
Provides intelligent truncation strategies to reduce context size
|
||||||
|
while preserving the most important information.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from ..config import ContextSettings, get_context_settings
|
||||||
|
from ..types import BaseContext, ContextType
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..budget import TokenBudget, TokenCalculator
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _estimate_tokens(text: str, model: str | None = None) -> int:
|
||||||
|
"""
|
||||||
|
Estimate token count using model-specific character ratios.
|
||||||
|
|
||||||
|
Module-level function for reuse across classes. Uses the same ratios
|
||||||
|
as TokenCalculator for consistency.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text to estimate tokens for
|
||||||
|
model: Optional model name for model-specific ratios
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Estimated token count (minimum 1)
|
||||||
|
"""
|
||||||
|
# Model-specific character ratios (chars per token)
|
||||||
|
model_ratios = {
|
||||||
|
"claude": 3.5,
|
||||||
|
"gpt-4": 4.0,
|
||||||
|
"gpt-3.5": 4.0,
|
||||||
|
"gemini": 4.0,
|
||||||
|
}
|
||||||
|
default_ratio = 4.0
|
||||||
|
|
||||||
|
ratio = default_ratio
|
||||||
|
if model:
|
||||||
|
model_lower = model.lower()
|
||||||
|
for model_prefix, model_ratio in model_ratios.items():
|
||||||
|
if model_prefix in model_lower:
|
||||||
|
ratio = model_ratio
|
||||||
|
break
|
||||||
|
|
||||||
|
return max(1, int(len(text) / ratio))
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TruncationResult:
|
||||||
|
"""Result of truncation operation."""
|
||||||
|
|
||||||
|
original_tokens: int
|
||||||
|
truncated_tokens: int
|
||||||
|
content: str
|
||||||
|
truncated: bool
|
||||||
|
truncation_ratio: float # 0.0 = no truncation, 1.0 = completely removed
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tokens_saved(self) -> int:
|
||||||
|
"""Calculate tokens saved by truncation."""
|
||||||
|
return self.original_tokens - self.truncated_tokens
|
||||||
|
|
||||||
|
|
||||||
|
class TruncationStrategy:
|
||||||
|
"""
|
||||||
|
Smart truncation strategies for context compression.
|
||||||
|
|
||||||
|
Strategies:
|
||||||
|
1. End truncation: Cut from end (for knowledge/docs)
|
||||||
|
2. Middle truncation: Keep start and end (for code)
|
||||||
|
3. Sentence-aware: Truncate at sentence boundaries
|
||||||
|
4. Semantic chunking: Keep most relevant chunks
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
calculator: "TokenCalculator | None" = None,
|
||||||
|
preserve_ratio_start: float | None = None,
|
||||||
|
min_content_length: int | None = None,
|
||||||
|
settings: ContextSettings | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize truncation strategy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
calculator: Token calculator for accurate counting
|
||||||
|
preserve_ratio_start: Ratio of content to keep from start (overrides settings)
|
||||||
|
min_content_length: Minimum characters to preserve (overrides settings)
|
||||||
|
settings: Context settings (uses global if None)
|
||||||
|
"""
|
||||||
|
self._settings = settings or get_context_settings()
|
||||||
|
self._calculator = calculator
|
||||||
|
|
||||||
|
# Use provided values or fall back to settings
|
||||||
|
self._preserve_ratio_start = (
|
||||||
|
preserve_ratio_start
|
||||||
|
if preserve_ratio_start is not None
|
||||||
|
else self._settings.truncation_preserve_ratio
|
||||||
|
)
|
||||||
|
self._min_content_length = (
|
||||||
|
min_content_length
|
||||||
|
if min_content_length is not None
|
||||||
|
else self._settings.truncation_min_content_length
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def truncation_marker(self) -> str:
|
||||||
|
"""Get truncation marker from settings."""
|
||||||
|
return self._settings.truncation_marker
|
||||||
|
|
||||||
|
def set_calculator(self, calculator: "TokenCalculator") -> None:
|
||||||
|
"""Set token calculator."""
|
||||||
|
self._calculator = calculator
|
||||||
|
|
||||||
|
async def truncate_to_tokens(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
max_tokens: int,
|
||||||
|
strategy: str = "end",
|
||||||
|
model: str | None = None,
|
||||||
|
) -> TruncationResult:
|
||||||
|
"""
|
||||||
|
Truncate content to fit within token limit.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Content to truncate
|
||||||
|
max_tokens: Maximum tokens allowed
|
||||||
|
strategy: Truncation strategy ('end', 'middle', 'sentence')
|
||||||
|
model: Model for token counting
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TruncationResult with truncated content
|
||||||
|
"""
|
||||||
|
if not content:
|
||||||
|
return TruncationResult(
|
||||||
|
original_tokens=0,
|
||||||
|
truncated_tokens=0,
|
||||||
|
content="",
|
||||||
|
truncated=False,
|
||||||
|
truncation_ratio=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get original token count
|
||||||
|
original_tokens = await self._count_tokens(content, model)
|
||||||
|
|
||||||
|
if original_tokens <= max_tokens:
|
||||||
|
return TruncationResult(
|
||||||
|
original_tokens=original_tokens,
|
||||||
|
truncated_tokens=original_tokens,
|
||||||
|
content=content,
|
||||||
|
truncated=False,
|
||||||
|
truncation_ratio=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply truncation strategy
|
||||||
|
if strategy == "middle":
|
||||||
|
truncated = await self._truncate_middle(content, max_tokens, model)
|
||||||
|
elif strategy == "sentence":
|
||||||
|
truncated = await self._truncate_sentence(content, max_tokens, model)
|
||||||
|
else: # "end"
|
||||||
|
truncated = await self._truncate_end(content, max_tokens, model)
|
||||||
|
|
||||||
|
truncated_tokens = await self._count_tokens(truncated, model)
|
||||||
|
|
||||||
|
return TruncationResult(
|
||||||
|
original_tokens=original_tokens,
|
||||||
|
truncated_tokens=truncated_tokens,
|
||||||
|
content=truncated,
|
||||||
|
truncated=True,
|
||||||
|
truncation_ratio=0.0
|
||||||
|
if original_tokens == 0
|
||||||
|
else 1 - (truncated_tokens / original_tokens),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _truncate_end(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
max_tokens: int,
|
||||||
|
model: str | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Truncate from end of content.
|
||||||
|
|
||||||
|
Simple but effective for most content types.
|
||||||
|
"""
|
||||||
|
# Binary search for optimal truncation point
|
||||||
|
marker_tokens = await self._count_tokens(self.truncation_marker, model)
|
||||||
|
available_tokens = max(0, max_tokens - marker_tokens)
|
||||||
|
|
||||||
|
# Edge case: if no tokens available for content, return just the marker
|
||||||
|
if available_tokens <= 0:
|
||||||
|
return self.truncation_marker
|
||||||
|
|
||||||
|
# Estimate characters per token (guard against division by zero)
|
||||||
|
content_tokens = await self._count_tokens(content, model)
|
||||||
|
if content_tokens == 0:
|
||||||
|
return content + self.truncation_marker
|
||||||
|
chars_per_token = len(content) / content_tokens
|
||||||
|
|
||||||
|
# Start with estimated position
|
||||||
|
estimated_chars = int(available_tokens * chars_per_token)
|
||||||
|
truncated = content[:estimated_chars]
|
||||||
|
|
||||||
|
# Refine with binary search
|
||||||
|
low, high = len(truncated) // 2, len(truncated)
|
||||||
|
best = truncated
|
||||||
|
|
||||||
|
for _ in range(5): # Max 5 iterations
|
||||||
|
mid = (low + high) // 2
|
||||||
|
candidate = content[:mid]
|
||||||
|
tokens = await self._count_tokens(candidate, model)
|
||||||
|
|
||||||
|
if tokens <= available_tokens:
|
||||||
|
best = candidate
|
||||||
|
low = mid + 1
|
||||||
|
else:
|
||||||
|
high = mid - 1
|
||||||
|
|
||||||
|
return best + self.truncation_marker
|
||||||
|
|
||||||
|
async def _truncate_middle(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
max_tokens: int,
|
||||||
|
model: str | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Truncate from middle, keeping start and end.
|
||||||
|
|
||||||
|
Good for code or content where context at boundaries matters.
|
||||||
|
"""
|
||||||
|
marker_tokens = await self._count_tokens(self.truncation_marker, model)
|
||||||
|
available_tokens = max_tokens - marker_tokens
|
||||||
|
|
||||||
|
# Split between start and end
|
||||||
|
start_tokens = int(available_tokens * self._preserve_ratio_start)
|
||||||
|
end_tokens = available_tokens - start_tokens
|
||||||
|
|
||||||
|
# Get start portion
|
||||||
|
start_content = await self._get_content_for_tokens(
|
||||||
|
content, start_tokens, from_start=True, model=model
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get end portion
|
||||||
|
end_content = await self._get_content_for_tokens(
|
||||||
|
content, end_tokens, from_start=False, model=model
|
||||||
|
)
|
||||||
|
|
||||||
|
return start_content + self.truncation_marker + end_content
|
||||||
|
|
||||||
|
async def _truncate_sentence(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
max_tokens: int,
|
||||||
|
model: str | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Truncate at sentence boundaries.
|
||||||
|
|
||||||
|
Produces cleaner output by not cutting mid-sentence.
|
||||||
|
"""
|
||||||
|
# Split into sentences
|
||||||
|
sentences = re.split(r"(?<=[.!?])\s+", content)
|
||||||
|
|
||||||
|
result: list[str] = []
|
||||||
|
total_tokens = 0
|
||||||
|
marker_tokens = await self._count_tokens(self.truncation_marker, model)
|
||||||
|
available = max_tokens - marker_tokens
|
||||||
|
|
||||||
|
for sentence in sentences:
|
||||||
|
sentence_tokens = await self._count_tokens(sentence, model)
|
||||||
|
if total_tokens + sentence_tokens <= available:
|
||||||
|
result.append(sentence)
|
||||||
|
total_tokens += sentence_tokens
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
if len(result) < len(sentences):
|
||||||
|
return " ".join(result) + self.truncation_marker
|
||||||
|
return " ".join(result)
|
||||||
|
|
||||||
|
async def _get_content_for_tokens(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
target_tokens: int,
|
||||||
|
from_start: bool = True,
|
||||||
|
model: str | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Get portion of content fitting within token limit."""
|
||||||
|
if target_tokens <= 0:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
current_tokens = await self._count_tokens(content, model)
|
||||||
|
if current_tokens <= target_tokens:
|
||||||
|
return content
|
||||||
|
|
||||||
|
# Estimate characters (guard against division by zero)
|
||||||
|
if current_tokens == 0:
|
||||||
|
return content
|
||||||
|
chars_per_token = len(content) / current_tokens
|
||||||
|
estimated_chars = int(target_tokens * chars_per_token)
|
||||||
|
|
||||||
|
if from_start:
|
||||||
|
return content[:estimated_chars]
|
||||||
|
else:
|
||||||
|
return content[-estimated_chars:]
|
||||||
|
|
||||||
|
async def _count_tokens(self, text: str, model: str | None = None) -> int:
|
||||||
|
"""Count tokens using calculator or estimation."""
|
||||||
|
if self._calculator is not None:
|
||||||
|
return await self._calculator.count_tokens(text, model)
|
||||||
|
|
||||||
|
# Fallback estimation with model-specific ratios
|
||||||
|
return _estimate_tokens(text, model)
|
||||||
|
|
||||||
|
|
||||||
|
class ContextCompressor:
|
||||||
|
"""
|
||||||
|
Compresses contexts to fit within budget constraints.
|
||||||
|
|
||||||
|
Uses truncation strategies to reduce context size while
|
||||||
|
preserving the most important information.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
truncation: TruncationStrategy | None = None,
|
||||||
|
calculator: "TokenCalculator | None" = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize context compressor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
truncation: Truncation strategy to use
|
||||||
|
calculator: Token calculator for counting
|
||||||
|
"""
|
||||||
|
self._truncation = truncation or TruncationStrategy(calculator)
|
||||||
|
self._calculator = calculator
|
||||||
|
|
||||||
|
if calculator:
|
||||||
|
self._truncation.set_calculator(calculator)
|
||||||
|
|
||||||
|
def set_calculator(self, calculator: "TokenCalculator") -> None:
|
||||||
|
"""Set token calculator."""
|
||||||
|
self._calculator = calculator
|
||||||
|
self._truncation.set_calculator(calculator)
|
||||||
|
|
||||||
|
async def compress_context(
|
||||||
|
self,
|
||||||
|
context: BaseContext,
|
||||||
|
max_tokens: int,
|
||||||
|
model: str | None = None,
|
||||||
|
) -> BaseContext:
|
||||||
|
"""
|
||||||
|
Compress a single context to fit token limit.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: Context to compress
|
||||||
|
max_tokens: Maximum tokens allowed
|
||||||
|
model: Model for token counting
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Compressed context (may be same object if no compression needed)
|
||||||
|
"""
|
||||||
|
current_tokens = context.token_count or await self._count_tokens(
|
||||||
|
context.content, model
|
||||||
|
)
|
||||||
|
|
||||||
|
if current_tokens <= max_tokens:
|
||||||
|
return context
|
||||||
|
|
||||||
|
# Choose strategy based on context type
|
||||||
|
strategy = self._get_strategy_for_type(context.get_type())
|
||||||
|
|
||||||
|
result = await self._truncation.truncate_to_tokens(
|
||||||
|
content=context.content,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
strategy=strategy,
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update context with truncated content
|
||||||
|
context.content = result.content
|
||||||
|
context.token_count = result.truncated_tokens
|
||||||
|
context.metadata["truncated"] = True
|
||||||
|
context.metadata["original_tokens"] = result.original_tokens
|
||||||
|
|
||||||
|
return context
|
||||||
|
|
||||||
|
async def compress_contexts(
|
||||||
|
self,
|
||||||
|
contexts: list[BaseContext],
|
||||||
|
budget: "TokenBudget",
|
||||||
|
model: str | None = None,
|
||||||
|
) -> list[BaseContext]:
|
||||||
|
"""
|
||||||
|
Compress multiple contexts to fit within budget.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
contexts: Contexts to potentially compress
|
||||||
|
budget: Token budget constraints
|
||||||
|
model: Model for token counting
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of contexts (compressed as needed)
|
||||||
|
"""
|
||||||
|
result: list[BaseContext] = []
|
||||||
|
|
||||||
|
for context in contexts:
|
||||||
|
context_type = context.get_type()
|
||||||
|
remaining = budget.remaining(context_type)
|
||||||
|
current_tokens = context.token_count or await self._count_tokens(
|
||||||
|
context.content, model
|
||||||
|
)
|
||||||
|
|
||||||
|
if current_tokens > remaining:
|
||||||
|
# Need to compress
|
||||||
|
compressed = await self.compress_context(context, remaining, model)
|
||||||
|
result.append(compressed)
|
||||||
|
logger.debug(
|
||||||
|
f"Compressed {context_type.value} context from "
|
||||||
|
f"{current_tokens} to {compressed.token_count} tokens"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result.append(context)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _get_strategy_for_type(self, context_type: ContextType) -> str:
|
||||||
|
"""Get optimal truncation strategy for context type."""
|
||||||
|
strategies = {
|
||||||
|
ContextType.SYSTEM: "end", # Keep instructions at start
|
||||||
|
ContextType.TASK: "end", # Keep task description start
|
||||||
|
ContextType.KNOWLEDGE: "sentence", # Clean sentence boundaries
|
||||||
|
ContextType.CONVERSATION: "end", # Keep recent conversation
|
||||||
|
ContextType.TOOL: "middle", # Keep command and result summary
|
||||||
|
}
|
||||||
|
return strategies.get(context_type, "end")
|
||||||
|
|
||||||
|
async def _count_tokens(self, text: str, model: str | None = None) -> int:
|
||||||
|
"""Count tokens using calculator or estimation."""
|
||||||
|
if self._calculator is not None:
|
||||||
|
return await self._calculator.count_tokens(text, model)
|
||||||
|
# Use model-specific estimation for consistency
|
||||||
|
return _estimate_tokens(text, model)
|
||||||
380
backend/app/services/context/config.py
Normal file
380
backend/app/services/context/config.py
Normal file
@@ -0,0 +1,380 @@
|
|||||||
|
"""
|
||||||
|
Context Management Engine Configuration.
|
||||||
|
|
||||||
|
Provides Pydantic settings for context assembly,
|
||||||
|
token budget allocation, and caching.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import threading
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import Field, field_validator, model_validator
|
||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
|
|
||||||
|
class ContextSettings(BaseSettings):
|
||||||
|
"""
|
||||||
|
Configuration for the Context Management Engine.
|
||||||
|
|
||||||
|
All settings can be overridden via environment variables
|
||||||
|
with the CTX_ prefix.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Budget allocation percentages (must sum to 1.0)
|
||||||
|
budget_system: float = Field(
|
||||||
|
default=0.05,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
description="Percentage of budget for system prompts (5%)",
|
||||||
|
)
|
||||||
|
budget_task: float = Field(
|
||||||
|
default=0.10,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
description="Percentage of budget for task context (10%)",
|
||||||
|
)
|
||||||
|
budget_knowledge: float = Field(
|
||||||
|
default=0.40,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
description="Percentage of budget for RAG/knowledge (40%)",
|
||||||
|
)
|
||||||
|
budget_conversation: float = Field(
|
||||||
|
default=0.20,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
description="Percentage of budget for conversation history (20%)",
|
||||||
|
)
|
||||||
|
budget_tools: float = Field(
|
||||||
|
default=0.05,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
description="Percentage of budget for tool descriptions (5%)",
|
||||||
|
)
|
||||||
|
budget_response: float = Field(
|
||||||
|
default=0.15,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
description="Percentage reserved for response (15%)",
|
||||||
|
)
|
||||||
|
budget_buffer: float = Field(
|
||||||
|
default=0.05,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
description="Percentage buffer for safety margin (5%)",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Scoring weights
|
||||||
|
scoring_relevance_weight: float = Field(
|
||||||
|
default=0.5,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
description="Weight for relevance scoring",
|
||||||
|
)
|
||||||
|
scoring_recency_weight: float = Field(
|
||||||
|
default=0.3,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
description="Weight for recency scoring",
|
||||||
|
)
|
||||||
|
scoring_priority_weight: float = Field(
|
||||||
|
default=0.2,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
description="Weight for priority scoring",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Recency decay settings
|
||||||
|
recency_decay_hours: float = Field(
|
||||||
|
default=24.0,
|
||||||
|
gt=0.0,
|
||||||
|
description="Hours until recency score decays to 50%",
|
||||||
|
)
|
||||||
|
recency_max_age_hours: float = Field(
|
||||||
|
default=168.0,
|
||||||
|
gt=0.0,
|
||||||
|
description="Hours until context is considered stale (7 days)",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compression settings
|
||||||
|
compression_threshold: float = Field(
|
||||||
|
default=0.8,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
description="Compress when budget usage exceeds this percentage",
|
||||||
|
)
|
||||||
|
truncation_marker: str = Field(
|
||||||
|
default="\n\n[...content truncated...]\n\n",
|
||||||
|
description="Marker text to insert where content was truncated",
|
||||||
|
)
|
||||||
|
truncation_preserve_ratio: float = Field(
|
||||||
|
default=0.7,
|
||||||
|
ge=0.1,
|
||||||
|
le=0.9,
|
||||||
|
description="Ratio of content to preserve from start in middle truncation (0.7 = 70% start, 30% end)",
|
||||||
|
)
|
||||||
|
truncation_min_content_length: int = Field(
|
||||||
|
default=100,
|
||||||
|
ge=10,
|
||||||
|
le=1000,
|
||||||
|
description="Minimum content length in characters before truncation applies",
|
||||||
|
)
|
||||||
|
summary_model_group: str = Field(
|
||||||
|
default="fast",
|
||||||
|
description="Model group to use for summarization",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Caching settings
|
||||||
|
cache_enabled: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Enable Redis caching for assembled contexts",
|
||||||
|
)
|
||||||
|
cache_ttl_seconds: int = Field(
|
||||||
|
default=3600,
|
||||||
|
ge=60,
|
||||||
|
le=86400,
|
||||||
|
description="Cache TTL in seconds (1 hour default, max 24 hours)",
|
||||||
|
)
|
||||||
|
cache_prefix: str = Field(
|
||||||
|
default="ctx",
|
||||||
|
description="Redis key prefix for context cache",
|
||||||
|
)
|
||||||
|
cache_memory_max_items: int = Field(
|
||||||
|
default=1000,
|
||||||
|
ge=100,
|
||||||
|
le=100000,
|
||||||
|
description="Maximum items in memory fallback cache when Redis unavailable",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Performance settings
|
||||||
|
max_assembly_time_ms: int = Field(
|
||||||
|
default=2000,
|
||||||
|
ge=10,
|
||||||
|
le=30000,
|
||||||
|
description="Maximum time for context assembly in milliseconds. "
|
||||||
|
"Should be high enough to accommodate MCP calls for knowledge retrieval.",
|
||||||
|
)
|
||||||
|
parallel_scoring: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Score contexts in parallel for better performance",
|
||||||
|
)
|
||||||
|
max_parallel_scores: int = Field(
|
||||||
|
default=10,
|
||||||
|
ge=1,
|
||||||
|
le=50,
|
||||||
|
description="Maximum number of contexts to score in parallel",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Knowledge retrieval settings
|
||||||
|
knowledge_search_type: str = Field(
|
||||||
|
default="hybrid",
|
||||||
|
description="Default search type for knowledge retrieval",
|
||||||
|
)
|
||||||
|
knowledge_max_results: int = Field(
|
||||||
|
default=10,
|
||||||
|
ge=1,
|
||||||
|
le=50,
|
||||||
|
description="Maximum knowledge chunks to retrieve",
|
||||||
|
)
|
||||||
|
knowledge_min_score: float = Field(
|
||||||
|
default=0.5,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
description="Minimum relevance score for knowledge",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Relevance scoring settings
|
||||||
|
relevance_keyword_fallback_weight: float = Field(
|
||||||
|
default=0.5,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
description="Maximum score for keyword-based fallback scoring (when semantic unavailable)",
|
||||||
|
)
|
||||||
|
relevance_semantic_max_chars: int = Field(
|
||||||
|
default=2000,
|
||||||
|
ge=100,
|
||||||
|
le=10000,
|
||||||
|
description="Maximum content length in chars for semantic similarity computation",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Diversity/ranking settings
|
||||||
|
diversity_max_per_source: int = Field(
|
||||||
|
default=3,
|
||||||
|
ge=1,
|
||||||
|
le=20,
|
||||||
|
description="Maximum contexts from the same source in diversity reranking",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Conversation history settings
|
||||||
|
conversation_max_turns: int = Field(
|
||||||
|
default=20,
|
||||||
|
ge=1,
|
||||||
|
le=100,
|
||||||
|
description="Maximum conversation turns to include",
|
||||||
|
)
|
||||||
|
conversation_recent_priority: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Prioritize recent conversation turns",
|
||||||
|
)
|
||||||
|
|
||||||
|
@field_validator("knowledge_search_type")
|
||||||
|
@classmethod
|
||||||
|
def validate_search_type(cls, v: str) -> str:
|
||||||
|
"""Validate search type is valid."""
|
||||||
|
valid_types = {"semantic", "keyword", "hybrid"}
|
||||||
|
if v not in valid_types:
|
||||||
|
raise ValueError(f"search_type must be one of: {valid_types}")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_budget_allocation(self) -> "ContextSettings":
|
||||||
|
"""Validate that budget percentages sum to 1.0."""
|
||||||
|
total = (
|
||||||
|
self.budget_system
|
||||||
|
+ self.budget_task
|
||||||
|
+ self.budget_knowledge
|
||||||
|
+ self.budget_conversation
|
||||||
|
+ self.budget_tools
|
||||||
|
+ self.budget_response
|
||||||
|
+ self.budget_buffer
|
||||||
|
)
|
||||||
|
# Allow small floating point error
|
||||||
|
if abs(total - 1.0) > 0.001:
|
||||||
|
raise ValueError(
|
||||||
|
f"Budget percentages must sum to 1.0, got {total:.3f}. "
|
||||||
|
f"Current allocation: system={self.budget_system}, task={self.budget_task}, "
|
||||||
|
f"knowledge={self.budget_knowledge}, conversation={self.budget_conversation}, "
|
||||||
|
f"tools={self.budget_tools}, response={self.budget_response}, buffer={self.budget_buffer}"
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_scoring_weights(self) -> "ContextSettings":
|
||||||
|
"""Validate that scoring weights sum to 1.0."""
|
||||||
|
total = (
|
||||||
|
self.scoring_relevance_weight
|
||||||
|
+ self.scoring_recency_weight
|
||||||
|
+ self.scoring_priority_weight
|
||||||
|
)
|
||||||
|
# Allow small floating point error
|
||||||
|
if abs(total - 1.0) > 0.001:
|
||||||
|
raise ValueError(
|
||||||
|
f"Scoring weights must sum to 1.0, got {total:.3f}. "
|
||||||
|
f"Current weights: relevance={self.scoring_relevance_weight}, "
|
||||||
|
f"recency={self.scoring_recency_weight}, priority={self.scoring_priority_weight}"
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def get_budget_allocation(self) -> dict[str, float]:
|
||||||
|
"""Get budget allocation as a dictionary."""
|
||||||
|
return {
|
||||||
|
"system": self.budget_system,
|
||||||
|
"task": self.budget_task,
|
||||||
|
"knowledge": self.budget_knowledge,
|
||||||
|
"conversation": self.budget_conversation,
|
||||||
|
"tools": self.budget_tools,
|
||||||
|
"response": self.budget_response,
|
||||||
|
"buffer": self.budget_buffer,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_scoring_weights(self) -> dict[str, float]:
|
||||||
|
"""Get scoring weights as a dictionary."""
|
||||||
|
return {
|
||||||
|
"relevance": self.scoring_relevance_weight,
|
||||||
|
"recency": self.scoring_recency_weight,
|
||||||
|
"priority": self.scoring_priority_weight,
|
||||||
|
}
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert settings to dictionary for logging/debugging."""
|
||||||
|
return {
|
||||||
|
"budget": self.get_budget_allocation(),
|
||||||
|
"scoring": self.get_scoring_weights(),
|
||||||
|
"compression": {
|
||||||
|
"threshold": self.compression_threshold,
|
||||||
|
"summary_model_group": self.summary_model_group,
|
||||||
|
"truncation_marker": self.truncation_marker,
|
||||||
|
"truncation_preserve_ratio": self.truncation_preserve_ratio,
|
||||||
|
"truncation_min_content_length": self.truncation_min_content_length,
|
||||||
|
},
|
||||||
|
"cache": {
|
||||||
|
"enabled": self.cache_enabled,
|
||||||
|
"ttl_seconds": self.cache_ttl_seconds,
|
||||||
|
"prefix": self.cache_prefix,
|
||||||
|
"memory_max_items": self.cache_memory_max_items,
|
||||||
|
},
|
||||||
|
"performance": {
|
||||||
|
"max_assembly_time_ms": self.max_assembly_time_ms,
|
||||||
|
"parallel_scoring": self.parallel_scoring,
|
||||||
|
"max_parallel_scores": self.max_parallel_scores,
|
||||||
|
},
|
||||||
|
"knowledge": {
|
||||||
|
"search_type": self.knowledge_search_type,
|
||||||
|
"max_results": self.knowledge_max_results,
|
||||||
|
"min_score": self.knowledge_min_score,
|
||||||
|
},
|
||||||
|
"relevance": {
|
||||||
|
"keyword_fallback_weight": self.relevance_keyword_fallback_weight,
|
||||||
|
"semantic_max_chars": self.relevance_semantic_max_chars,
|
||||||
|
},
|
||||||
|
"diversity": {
|
||||||
|
"max_per_source": self.diversity_max_per_source,
|
||||||
|
},
|
||||||
|
"conversation": {
|
||||||
|
"max_turns": self.conversation_max_turns,
|
||||||
|
"recent_priority": self.conversation_recent_priority,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
model_config = {
|
||||||
|
"env_prefix": "CTX_",
|
||||||
|
"env_file": "../.env",
|
||||||
|
"env_file_encoding": "utf-8",
|
||||||
|
"case_sensitive": False,
|
||||||
|
"extra": "ignore",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Thread-safe singleton pattern
|
||||||
|
_settings: ContextSettings | None = None
|
||||||
|
_settings_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
def get_context_settings() -> ContextSettings:
|
||||||
|
"""
|
||||||
|
Get the global ContextSettings instance.
|
||||||
|
|
||||||
|
Thread-safe with double-checked locking pattern.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ContextSettings instance
|
||||||
|
"""
|
||||||
|
global _settings
|
||||||
|
if _settings is None:
|
||||||
|
with _settings_lock:
|
||||||
|
if _settings is None:
|
||||||
|
_settings = ContextSettings()
|
||||||
|
return _settings
|
||||||
|
|
||||||
|
|
||||||
|
def reset_context_settings() -> None:
|
||||||
|
"""
|
||||||
|
Reset the global settings instance.
|
||||||
|
|
||||||
|
Primarily used for testing.
|
||||||
|
"""
|
||||||
|
global _settings
|
||||||
|
with _settings_lock:
|
||||||
|
_settings = None
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def get_default_settings() -> ContextSettings:
|
||||||
|
"""
|
||||||
|
Get default settings (cached).
|
||||||
|
|
||||||
|
Use this for read-only access to defaults.
|
||||||
|
For mutable access, use get_context_settings().
|
||||||
|
"""
|
||||||
|
return ContextSettings()
|
||||||
582
backend/app/services/context/engine.py
Normal file
582
backend/app/services/context/engine.py
Normal file
@@ -0,0 +1,582 @@
|
|||||||
|
"""
|
||||||
|
Context Management Engine.
|
||||||
|
|
||||||
|
Main orchestration layer for context assembly and optimization.
|
||||||
|
Provides a high-level API for assembling optimized context for LLM requests.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from .assembly import ContextPipeline
|
||||||
|
from .budget import BudgetAllocator, TokenBudget, TokenCalculator
|
||||||
|
from .cache import ContextCache
|
||||||
|
from .compression import ContextCompressor
|
||||||
|
from .config import ContextSettings, get_context_settings
|
||||||
|
from .prioritization import ContextRanker
|
||||||
|
from .scoring import CompositeScorer
|
||||||
|
from .types import (
|
||||||
|
AssembledContext,
|
||||||
|
BaseContext,
|
||||||
|
ConversationContext,
|
||||||
|
KnowledgeContext,
|
||||||
|
MemoryContext,
|
||||||
|
MessageRole,
|
||||||
|
SystemContext,
|
||||||
|
TaskContext,
|
||||||
|
ToolContext,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from redis.asyncio import Redis
|
||||||
|
|
||||||
|
from app.services.mcp.client_manager import MCPClientManager
|
||||||
|
from app.services.memory.integration import MemoryContextSource
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ContextEngine:
|
||||||
|
"""
|
||||||
|
Main context management engine.
|
||||||
|
|
||||||
|
Provides high-level API for context assembly and optimization.
|
||||||
|
Integrates all components: scoring, ranking, compression, formatting, and caching.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
engine = ContextEngine(mcp_manager=mcp, redis=redis)
|
||||||
|
|
||||||
|
# Assemble context for an LLM request
|
||||||
|
result = await engine.assemble_context(
|
||||||
|
project_id="proj-123",
|
||||||
|
agent_id="agent-456",
|
||||||
|
query="implement user authentication",
|
||||||
|
model="claude-3-sonnet",
|
||||||
|
system_prompt="You are an expert developer.",
|
||||||
|
knowledge_query="authentication best practices",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use the assembled context
|
||||||
|
print(result.content)
|
||||||
|
print(f"Tokens: {result.total_tokens}")
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
mcp_manager: "MCPClientManager | None" = None,
|
||||||
|
redis: "Redis | None" = None,
|
||||||
|
settings: ContextSettings | None = None,
|
||||||
|
memory_source: "MemoryContextSource | None" = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize the context engine.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mcp_manager: MCP client manager for LLM Gateway/Knowledge Base
|
||||||
|
redis: Redis connection for caching
|
||||||
|
settings: Context settings
|
||||||
|
memory_source: Optional memory context source for agent memory
|
||||||
|
"""
|
||||||
|
self._mcp = mcp_manager
|
||||||
|
self._settings = settings or get_context_settings()
|
||||||
|
self._memory_source = memory_source
|
||||||
|
|
||||||
|
# Initialize components
|
||||||
|
self._calculator = TokenCalculator(mcp_manager=mcp_manager)
|
||||||
|
self._scorer = CompositeScorer(mcp_manager=mcp_manager, settings=self._settings)
|
||||||
|
self._ranker = ContextRanker(scorer=self._scorer, calculator=self._calculator)
|
||||||
|
self._compressor = ContextCompressor(calculator=self._calculator)
|
||||||
|
self._allocator = BudgetAllocator(self._settings)
|
||||||
|
self._cache = ContextCache(redis=redis, settings=self._settings)
|
||||||
|
|
||||||
|
# Pipeline for assembly
|
||||||
|
self._pipeline = ContextPipeline(
|
||||||
|
mcp_manager=mcp_manager,
|
||||||
|
settings=self._settings,
|
||||||
|
calculator=self._calculator,
|
||||||
|
scorer=self._scorer,
|
||||||
|
ranker=self._ranker,
|
||||||
|
compressor=self._compressor,
|
||||||
|
)
|
||||||
|
|
||||||
|
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
|
||||||
|
"""
|
||||||
|
Set MCP manager for all components.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mcp_manager: MCP client manager
|
||||||
|
"""
|
||||||
|
self._mcp = mcp_manager
|
||||||
|
self._calculator.set_mcp_manager(mcp_manager)
|
||||||
|
self._scorer.set_mcp_manager(mcp_manager)
|
||||||
|
self._pipeline.set_mcp_manager(mcp_manager)
|
||||||
|
|
||||||
|
def set_redis(self, redis: "Redis") -> None:
|
||||||
|
"""
|
||||||
|
Set Redis connection for caching.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
redis: Redis connection
|
||||||
|
"""
|
||||||
|
self._cache.set_redis(redis)
|
||||||
|
|
||||||
|
def set_memory_source(self, memory_source: "MemoryContextSource") -> None:
|
||||||
|
"""
|
||||||
|
Set memory context source for agent memory integration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
memory_source: Memory context source
|
||||||
|
"""
|
||||||
|
self._memory_source = memory_source
|
||||||
|
|
||||||
|
async def assemble_context(
|
||||||
|
self,
|
||||||
|
project_id: str,
|
||||||
|
agent_id: str,
|
||||||
|
query: str,
|
||||||
|
model: str,
|
||||||
|
max_tokens: int | None = None,
|
||||||
|
system_prompt: str | None = None,
|
||||||
|
task_description: str | None = None,
|
||||||
|
knowledge_query: str | None = None,
|
||||||
|
knowledge_limit: int = 10,
|
||||||
|
memory_query: str | None = None,
|
||||||
|
memory_limit: int = 20,
|
||||||
|
session_id: str | None = None,
|
||||||
|
agent_type_id: str | None = None,
|
||||||
|
conversation_history: list[dict[str, str]] | None = None,
|
||||||
|
tool_results: list[dict[str, Any]] | None = None,
|
||||||
|
custom_contexts: list[BaseContext] | None = None,
|
||||||
|
custom_budget: TokenBudget | None = None,
|
||||||
|
compress: bool = True,
|
||||||
|
format_output: bool = True,
|
||||||
|
use_cache: bool = True,
|
||||||
|
) -> AssembledContext:
|
||||||
|
"""
|
||||||
|
Assemble optimized context for an LLM request.
|
||||||
|
|
||||||
|
This is the main entry point for context management.
|
||||||
|
It gathers context from various sources, scores and ranks them,
|
||||||
|
compresses if needed, and formats for the target model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: Project identifier
|
||||||
|
agent_id: Agent identifier
|
||||||
|
query: User's query or current request
|
||||||
|
model: Target model name
|
||||||
|
max_tokens: Maximum context tokens (uses model default if None)
|
||||||
|
system_prompt: System prompt/instructions
|
||||||
|
task_description: Current task description
|
||||||
|
knowledge_query: Query for knowledge base search
|
||||||
|
knowledge_limit: Max number of knowledge results
|
||||||
|
memory_query: Query for agent memory search
|
||||||
|
memory_limit: Max number of memory results
|
||||||
|
session_id: Session ID for working memory access
|
||||||
|
agent_type_id: Agent type ID for procedural memory
|
||||||
|
conversation_history: List of {"role": str, "content": str}
|
||||||
|
tool_results: List of tool results to include
|
||||||
|
custom_contexts: Additional custom contexts
|
||||||
|
custom_budget: Custom token budget
|
||||||
|
compress: Whether to apply compression
|
||||||
|
format_output: Whether to format for the model
|
||||||
|
use_cache: Whether to use caching
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AssembledContext with optimized content
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssemblyTimeoutError: If assembly exceeds timeout
|
||||||
|
BudgetExceededError: If context exceeds budget
|
||||||
|
"""
|
||||||
|
# Gather all contexts
|
||||||
|
contexts: list[BaseContext] = []
|
||||||
|
|
||||||
|
# 1. System context
|
||||||
|
if system_prompt:
|
||||||
|
contexts.append(
|
||||||
|
SystemContext(
|
||||||
|
content=system_prompt,
|
||||||
|
source="system_prompt",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Task context
|
||||||
|
if task_description:
|
||||||
|
contexts.append(
|
||||||
|
TaskContext(
|
||||||
|
content=task_description,
|
||||||
|
source=f"task:{project_id}:{agent_id}",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Knowledge context from Knowledge Base
|
||||||
|
if knowledge_query and self._mcp:
|
||||||
|
knowledge_contexts = await self._fetch_knowledge(
|
||||||
|
project_id=project_id,
|
||||||
|
agent_id=agent_id,
|
||||||
|
query=knowledge_query,
|
||||||
|
limit=knowledge_limit,
|
||||||
|
)
|
||||||
|
contexts.extend(knowledge_contexts)
|
||||||
|
|
||||||
|
# 4. Memory context from Agent Memory System
|
||||||
|
if memory_query and self._memory_source:
|
||||||
|
memory_contexts = await self._fetch_memory(
|
||||||
|
project_id=project_id,
|
||||||
|
agent_id=agent_id,
|
||||||
|
query=memory_query,
|
||||||
|
limit=memory_limit,
|
||||||
|
session_id=session_id,
|
||||||
|
agent_type_id=agent_type_id,
|
||||||
|
)
|
||||||
|
contexts.extend(memory_contexts)
|
||||||
|
|
||||||
|
# 5. Conversation history
|
||||||
|
if conversation_history:
|
||||||
|
contexts.extend(self._convert_conversation(conversation_history))
|
||||||
|
|
||||||
|
# 6. Tool results
|
||||||
|
if tool_results:
|
||||||
|
contexts.extend(self._convert_tool_results(tool_results))
|
||||||
|
|
||||||
|
# 7. Custom contexts
|
||||||
|
if custom_contexts:
|
||||||
|
contexts.extend(custom_contexts)
|
||||||
|
|
||||||
|
# Check cache if enabled
|
||||||
|
fingerprint: str | None = None
|
||||||
|
if use_cache and self._cache.is_enabled:
|
||||||
|
# Include project_id and agent_id for tenant isolation
|
||||||
|
fingerprint = self._cache.compute_fingerprint(
|
||||||
|
contexts, query, model, project_id=project_id, agent_id=agent_id
|
||||||
|
)
|
||||||
|
cached = await self._cache.get_assembled(fingerprint)
|
||||||
|
if cached:
|
||||||
|
logger.debug(f"Cache hit for context assembly: {fingerprint}")
|
||||||
|
return cached
|
||||||
|
|
||||||
|
# Run assembly pipeline
|
||||||
|
result = await self._pipeline.assemble(
|
||||||
|
contexts=contexts,
|
||||||
|
query=query,
|
||||||
|
model=model,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
custom_budget=custom_budget,
|
||||||
|
compress=compress,
|
||||||
|
format_output=format_output,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cache result if enabled (reuse fingerprint computed above)
|
||||||
|
if use_cache and self._cache.is_enabled and fingerprint is not None:
|
||||||
|
await self._cache.set_assembled(fingerprint, result)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def _fetch_knowledge(
|
||||||
|
self,
|
||||||
|
project_id: str,
|
||||||
|
agent_id: str,
|
||||||
|
query: str,
|
||||||
|
limit: int = 10,
|
||||||
|
) -> list[KnowledgeContext]:
|
||||||
|
"""
|
||||||
|
Fetch relevant knowledge from Knowledge Base via MCP.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: Project identifier
|
||||||
|
agent_id: Agent identifier
|
||||||
|
query: Search query
|
||||||
|
limit: Maximum results
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of KnowledgeContext instances
|
||||||
|
"""
|
||||||
|
if not self._mcp:
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await self._mcp.call_tool(
|
||||||
|
"knowledge-base",
|
||||||
|
"search_knowledge",
|
||||||
|
{
|
||||||
|
"project_id": project_id,
|
||||||
|
"agent_id": agent_id,
|
||||||
|
"query": query,
|
||||||
|
"search_type": "hybrid",
|
||||||
|
"limit": limit,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check both ToolResult.success AND response success
|
||||||
|
if not result.success:
|
||||||
|
logger.warning(f"Knowledge search failed: {result.error}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
if not isinstance(result.data, dict) or not result.data.get(
|
||||||
|
"success", True
|
||||||
|
):
|
||||||
|
logger.warning("Knowledge search returned unsuccessful response")
|
||||||
|
return []
|
||||||
|
|
||||||
|
contexts = []
|
||||||
|
results = result.data.get("results", [])
|
||||||
|
for chunk in results:
|
||||||
|
contexts.append(
|
||||||
|
KnowledgeContext(
|
||||||
|
content=chunk.get("content", ""),
|
||||||
|
source=chunk.get("source_path", "unknown"),
|
||||||
|
relevance_score=chunk.get("score", 0.0),
|
||||||
|
metadata={
|
||||||
|
"chunk_id": chunk.get(
|
||||||
|
"id"
|
||||||
|
), # Server returns 'id' not 'chunk_id'
|
||||||
|
"document_id": chunk.get("document_id"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(f"Fetched {len(contexts)} knowledge chunks for query: {query}")
|
||||||
|
return contexts
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to fetch knowledge: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def _fetch_memory(
|
||||||
|
self,
|
||||||
|
project_id: str,
|
||||||
|
agent_id: str,
|
||||||
|
query: str,
|
||||||
|
limit: int = 20,
|
||||||
|
session_id: str | None = None,
|
||||||
|
agent_type_id: str | None = None,
|
||||||
|
) -> list[MemoryContext]:
|
||||||
|
"""
|
||||||
|
Fetch relevant memories from Agent Memory System.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: Project identifier
|
||||||
|
agent_id: Agent identifier
|
||||||
|
query: Search query
|
||||||
|
limit: Maximum results
|
||||||
|
session_id: Session ID for working memory
|
||||||
|
agent_type_id: Agent type ID for procedural memory
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of MemoryContext instances
|
||||||
|
"""
|
||||||
|
if not self._memory_source:
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Import here to avoid circular imports
|
||||||
|
|
||||||
|
# Configure fetch limits
|
||||||
|
from app.services.memory.integration.context_source import MemoryFetchConfig
|
||||||
|
|
||||||
|
config = MemoryFetchConfig(
|
||||||
|
working_limit=min(limit // 4, 5),
|
||||||
|
episodic_limit=min(limit // 2, 10),
|
||||||
|
semantic_limit=min(limit // 2, 10),
|
||||||
|
procedural_limit=min(limit // 4, 5),
|
||||||
|
include_working=session_id is not None,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await self._memory_source.fetch_context(
|
||||||
|
query=query,
|
||||||
|
project_id=UUID(project_id),
|
||||||
|
agent_instance_id=UUID(agent_id) if agent_id else None,
|
||||||
|
agent_type_id=UUID(agent_type_id) if agent_type_id else None,
|
||||||
|
session_id=session_id,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Fetched {len(result.contexts)} memory contexts for query: {query}, "
|
||||||
|
f"by_type: {result.by_type}"
|
||||||
|
)
|
||||||
|
return result.contexts[:limit]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to fetch memory: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _convert_conversation(
|
||||||
|
self,
|
||||||
|
history: list[dict[str, str]],
|
||||||
|
) -> list[ConversationContext]:
|
||||||
|
"""
|
||||||
|
Convert conversation history to ConversationContext instances.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
history: List of {"role": str, "content": str}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of ConversationContext instances
|
||||||
|
"""
|
||||||
|
contexts = []
|
||||||
|
for i, turn in enumerate(history):
|
||||||
|
role_str = turn.get("role", "user").lower()
|
||||||
|
role = (
|
||||||
|
MessageRole.ASSISTANT if role_str == "assistant" else MessageRole.USER
|
||||||
|
)
|
||||||
|
|
||||||
|
contexts.append(
|
||||||
|
ConversationContext(
|
||||||
|
content=turn.get("content", ""),
|
||||||
|
source=f"conversation:{i}",
|
||||||
|
role=role,
|
||||||
|
metadata={"role": role_str, "turn": i},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return contexts
|
||||||
|
|
||||||
|
def _convert_tool_results(
|
||||||
|
self,
|
||||||
|
results: list[dict[str, Any]],
|
||||||
|
) -> list[ToolContext]:
|
||||||
|
"""
|
||||||
|
Convert tool results to ToolContext instances.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results: List of tool result dictionaries
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of ToolContext instances
|
||||||
|
"""
|
||||||
|
contexts = []
|
||||||
|
for result in results:
|
||||||
|
tool_name = result.get("tool_name", "unknown")
|
||||||
|
content = result.get("content", result.get("result", ""))
|
||||||
|
|
||||||
|
# Handle dict content
|
||||||
|
if isinstance(content, dict):
|
||||||
|
import json
|
||||||
|
|
||||||
|
content = json.dumps(content, indent=2)
|
||||||
|
|
||||||
|
contexts.append(
|
||||||
|
ToolContext(
|
||||||
|
content=str(content),
|
||||||
|
source=f"tool:{tool_name}",
|
||||||
|
metadata={
|
||||||
|
"tool_name": tool_name,
|
||||||
|
"status": result.get("status", "success"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return contexts
|
||||||
|
|
||||||
|
async def get_budget_for_model(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
max_tokens: int | None = None,
|
||||||
|
) -> TokenBudget:
|
||||||
|
"""
|
||||||
|
Get the token budget for a specific model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model name
|
||||||
|
max_tokens: Optional max tokens override
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TokenBudget instance
|
||||||
|
"""
|
||||||
|
if max_tokens:
|
||||||
|
return self._allocator.create_budget(max_tokens)
|
||||||
|
return self._allocator.create_budget_for_model(model)
|
||||||
|
|
||||||
|
async def count_tokens(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
model: str | None = None,
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Count tokens in content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Content to count
|
||||||
|
model: Model for model-specific tokenization
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Token count
|
||||||
|
"""
|
||||||
|
# Check cache first
|
||||||
|
cached = await self._cache.get_token_count(content, model)
|
||||||
|
if cached is not None:
|
||||||
|
return cached
|
||||||
|
|
||||||
|
count = await self._calculator.count_tokens(content, model)
|
||||||
|
|
||||||
|
# Cache the result
|
||||||
|
await self._cache.set_token_count(content, count, model)
|
||||||
|
|
||||||
|
return count
|
||||||
|
|
||||||
|
async def invalidate_cache(
|
||||||
|
self,
|
||||||
|
project_id: str | None = None,
|
||||||
|
pattern: str | None = None,
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Invalidate cache entries.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: Invalidate all cache for a project
|
||||||
|
pattern: Custom pattern to match
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of entries invalidated
|
||||||
|
"""
|
||||||
|
if pattern:
|
||||||
|
return await self._cache.invalidate(pattern)
|
||||||
|
elif project_id:
|
||||||
|
return await self._cache.invalidate(f"*{project_id}*")
|
||||||
|
else:
|
||||||
|
return await self._cache.clear_all()
|
||||||
|
|
||||||
|
async def get_stats(self) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get engine statistics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with engine stats
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"cache": await self._cache.get_stats(),
|
||||||
|
"settings": {
|
||||||
|
"compression_threshold": self._settings.compression_threshold,
|
||||||
|
"max_assembly_time_ms": self._settings.max_assembly_time_ms,
|
||||||
|
"cache_enabled": self._settings.cache_enabled,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Convenience factory function
|
||||||
|
def create_context_engine(
|
||||||
|
mcp_manager: "MCPClientManager | None" = None,
|
||||||
|
redis: "Redis | None" = None,
|
||||||
|
settings: ContextSettings | None = None,
|
||||||
|
memory_source: "MemoryContextSource | None" = None,
|
||||||
|
) -> ContextEngine:
|
||||||
|
"""
|
||||||
|
Create a context engine instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mcp_manager: MCP client manager
|
||||||
|
redis: Redis connection
|
||||||
|
settings: Context settings
|
||||||
|
memory_source: Optional memory context source
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured ContextEngine instance
|
||||||
|
"""
|
||||||
|
return ContextEngine(
|
||||||
|
mcp_manager=mcp_manager,
|
||||||
|
redis=redis,
|
||||||
|
settings=settings,
|
||||||
|
memory_source=memory_source,
|
||||||
|
)
|
||||||
354
backend/app/services/context/exceptions.py
Normal file
354
backend/app/services/context/exceptions.py
Normal file
@@ -0,0 +1,354 @@
|
|||||||
|
"""
|
||||||
|
Context Management Engine Exceptions.
|
||||||
|
|
||||||
|
Provides a hierarchy of exceptions for context assembly,
|
||||||
|
token budget management, and related operations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
class ContextError(Exception):
|
||||||
|
"""
|
||||||
|
Base exception for all context management errors.
|
||||||
|
|
||||||
|
All context-related exceptions should inherit from this class
|
||||||
|
to allow for catch-all handling when needed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, message: str, details: dict[str, Any] | None = None) -> None:
|
||||||
|
"""
|
||||||
|
Initialize context error.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Human-readable error message
|
||||||
|
details: Optional dict with additional error context
|
||||||
|
"""
|
||||||
|
self.message = message
|
||||||
|
self.details = details or {}
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert exception to dictionary for logging/serialization."""
|
||||||
|
return {
|
||||||
|
"error_type": self.__class__.__name__,
|
||||||
|
"message": self.message,
|
||||||
|
"details": self.details,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class BudgetExceededError(ContextError):
|
||||||
|
"""
|
||||||
|
Raised when token budget is exceeded.
|
||||||
|
|
||||||
|
This occurs when the assembled context would exceed the
|
||||||
|
allocated token budget for a specific context type or total.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str = "Token budget exceeded",
|
||||||
|
allocated: int = 0,
|
||||||
|
requested: int = 0,
|
||||||
|
context_type: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize budget exceeded error.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Error message
|
||||||
|
allocated: Tokens allocated for this context type
|
||||||
|
requested: Tokens requested
|
||||||
|
context_type: Type of context that exceeded budget
|
||||||
|
"""
|
||||||
|
details: dict[str, Any] = {
|
||||||
|
"allocated": allocated,
|
||||||
|
"requested": requested,
|
||||||
|
"overage": requested - allocated,
|
||||||
|
}
|
||||||
|
if context_type:
|
||||||
|
details["context_type"] = context_type
|
||||||
|
|
||||||
|
super().__init__(message, details)
|
||||||
|
self.allocated = allocated
|
||||||
|
self.requested = requested
|
||||||
|
self.context_type = context_type
|
||||||
|
|
||||||
|
|
||||||
|
class TokenCountError(ContextError):
|
||||||
|
"""
|
||||||
|
Raised when token counting fails.
|
||||||
|
|
||||||
|
This typically occurs when the LLM Gateway token counting
|
||||||
|
service is unavailable or returns an error.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str = "Failed to count tokens",
|
||||||
|
model: str | None = None,
|
||||||
|
text_length: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize token count error.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Error message
|
||||||
|
model: Model for which counting was attempted
|
||||||
|
text_length: Length of text that failed to count
|
||||||
|
"""
|
||||||
|
details: dict[str, Any] = {}
|
||||||
|
if model:
|
||||||
|
details["model"] = model
|
||||||
|
if text_length is not None:
|
||||||
|
details["text_length"] = text_length
|
||||||
|
|
||||||
|
super().__init__(message, details)
|
||||||
|
self.model = model
|
||||||
|
self.text_length = text_length
|
||||||
|
|
||||||
|
|
||||||
|
class CompressionError(ContextError):
|
||||||
|
"""
|
||||||
|
Raised when context compression fails.
|
||||||
|
|
||||||
|
This can occur when summarization or truncation cannot
|
||||||
|
reduce content to fit within the budget.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str = "Failed to compress context",
|
||||||
|
original_tokens: int | None = None,
|
||||||
|
target_tokens: int | None = None,
|
||||||
|
achieved_tokens: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize compression error.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Error message
|
||||||
|
original_tokens: Tokens before compression
|
||||||
|
target_tokens: Target token count
|
||||||
|
achieved_tokens: Tokens achieved after compression attempt
|
||||||
|
"""
|
||||||
|
details: dict[str, Any] = {}
|
||||||
|
if original_tokens is not None:
|
||||||
|
details["original_tokens"] = original_tokens
|
||||||
|
if target_tokens is not None:
|
||||||
|
details["target_tokens"] = target_tokens
|
||||||
|
if achieved_tokens is not None:
|
||||||
|
details["achieved_tokens"] = achieved_tokens
|
||||||
|
|
||||||
|
super().__init__(message, details)
|
||||||
|
self.original_tokens = original_tokens
|
||||||
|
self.target_tokens = target_tokens
|
||||||
|
self.achieved_tokens = achieved_tokens
|
||||||
|
|
||||||
|
|
||||||
|
class AssemblyTimeoutError(ContextError):
|
||||||
|
"""
|
||||||
|
Raised when context assembly exceeds time limit.
|
||||||
|
|
||||||
|
Context assembly must complete within a configurable
|
||||||
|
time limit to maintain responsiveness.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str = "Context assembly timed out",
|
||||||
|
timeout_ms: int = 0,
|
||||||
|
elapsed_ms: float = 0.0,
|
||||||
|
stage: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize assembly timeout error.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Error message
|
||||||
|
timeout_ms: Configured timeout in milliseconds
|
||||||
|
elapsed_ms: Actual elapsed time in milliseconds
|
||||||
|
stage: Pipeline stage where timeout occurred
|
||||||
|
"""
|
||||||
|
details: dict[str, Any] = {
|
||||||
|
"timeout_ms": timeout_ms,
|
||||||
|
"elapsed_ms": round(elapsed_ms, 2),
|
||||||
|
}
|
||||||
|
if stage:
|
||||||
|
details["stage"] = stage
|
||||||
|
|
||||||
|
super().__init__(message, details)
|
||||||
|
self.timeout_ms = timeout_ms
|
||||||
|
self.elapsed_ms = elapsed_ms
|
||||||
|
self.stage = stage
|
||||||
|
|
||||||
|
|
||||||
|
class ScoringError(ContextError):
|
||||||
|
"""
|
||||||
|
Raised when context scoring fails.
|
||||||
|
|
||||||
|
This occurs when relevance, recency, or priority scoring
|
||||||
|
encounters an error.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str = "Failed to score context",
|
||||||
|
scorer_type: str | None = None,
|
||||||
|
context_id: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize scoring error.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Error message
|
||||||
|
scorer_type: Type of scorer that failed
|
||||||
|
context_id: ID of context being scored
|
||||||
|
"""
|
||||||
|
details: dict[str, Any] = {}
|
||||||
|
if scorer_type:
|
||||||
|
details["scorer_type"] = scorer_type
|
||||||
|
if context_id:
|
||||||
|
details["context_id"] = context_id
|
||||||
|
|
||||||
|
super().__init__(message, details)
|
||||||
|
self.scorer_type = scorer_type
|
||||||
|
self.context_id = context_id
|
||||||
|
|
||||||
|
|
||||||
|
class FormattingError(ContextError):
|
||||||
|
"""
|
||||||
|
Raised when context formatting fails.
|
||||||
|
|
||||||
|
This occurs when converting assembled context to
|
||||||
|
model-specific format fails.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str = "Failed to format context",
|
||||||
|
model: str | None = None,
|
||||||
|
adapter: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize formatting error.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Error message
|
||||||
|
model: Target model
|
||||||
|
adapter: Adapter that failed
|
||||||
|
"""
|
||||||
|
details: dict[str, Any] = {}
|
||||||
|
if model:
|
||||||
|
details["model"] = model
|
||||||
|
if adapter:
|
||||||
|
details["adapter"] = adapter
|
||||||
|
|
||||||
|
super().__init__(message, details)
|
||||||
|
self.model = model
|
||||||
|
self.adapter = adapter
|
||||||
|
|
||||||
|
|
||||||
|
class CacheError(ContextError):
|
||||||
|
"""
|
||||||
|
Raised when cache operations fail.
|
||||||
|
|
||||||
|
This is typically non-fatal and should be handled
|
||||||
|
gracefully by falling back to recomputation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str = "Cache operation failed",
|
||||||
|
operation: str | None = None,
|
||||||
|
cache_key: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize cache error.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Error message
|
||||||
|
operation: Cache operation that failed (get, set, delete)
|
||||||
|
cache_key: Key involved in the failed operation
|
||||||
|
"""
|
||||||
|
details: dict[str, Any] = {}
|
||||||
|
if operation:
|
||||||
|
details["operation"] = operation
|
||||||
|
if cache_key:
|
||||||
|
details["cache_key"] = cache_key
|
||||||
|
|
||||||
|
super().__init__(message, details)
|
||||||
|
self.operation = operation
|
||||||
|
self.cache_key = cache_key
|
||||||
|
|
||||||
|
|
||||||
|
class ContextNotFoundError(ContextError):
|
||||||
|
"""
|
||||||
|
Raised when expected context is not found.
|
||||||
|
|
||||||
|
This occurs when required context sources return
|
||||||
|
no results or are unavailable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str = "Required context not found",
|
||||||
|
source: str | None = None,
|
||||||
|
query: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize context not found error.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Error message
|
||||||
|
source: Source that returned no results
|
||||||
|
query: Query used to search
|
||||||
|
"""
|
||||||
|
details: dict[str, Any] = {}
|
||||||
|
if source:
|
||||||
|
details["source"] = source
|
||||||
|
if query:
|
||||||
|
details["query"] = query
|
||||||
|
|
||||||
|
super().__init__(message, details)
|
||||||
|
self.source = source
|
||||||
|
self.query = query
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidContextError(ContextError):
|
||||||
|
"""
|
||||||
|
Raised when context data is invalid.
|
||||||
|
|
||||||
|
This occurs when context content or metadata
|
||||||
|
fails validation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str = "Invalid context data",
|
||||||
|
field: str | None = None,
|
||||||
|
value: Any | None = None,
|
||||||
|
reason: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize invalid context error.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Error message
|
||||||
|
field: Field that is invalid
|
||||||
|
value: Invalid value (may be redacted for security)
|
||||||
|
reason: Reason for invalidity
|
||||||
|
"""
|
||||||
|
details: dict[str, Any] = {}
|
||||||
|
if field:
|
||||||
|
details["field"] = field
|
||||||
|
if value is not None:
|
||||||
|
# Avoid logging potentially sensitive values
|
||||||
|
details["value_type"] = type(value).__name__
|
||||||
|
if reason:
|
||||||
|
details["reason"] = reason
|
||||||
|
|
||||||
|
super().__init__(message, details)
|
||||||
|
self.field = field
|
||||||
|
self.value = value
|
||||||
|
self.reason = reason
|
||||||
12
backend/app/services/context/prioritization/__init__.py
Normal file
12
backend/app/services/context/prioritization/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
"""
|
||||||
|
Context Prioritization Module.
|
||||||
|
|
||||||
|
Provides context ranking and selection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .ranker import ContextRanker, RankingResult
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ContextRanker",
|
||||||
|
"RankingResult",
|
||||||
|
]
|
||||||
374
backend/app/services/context/prioritization/ranker.py
Normal file
374
backend/app/services/context/prioritization/ranker.py
Normal file
@@ -0,0 +1,374 @@
|
|||||||
|
"""
|
||||||
|
Context Ranker for Context Management.
|
||||||
|
|
||||||
|
Ranks and selects contexts based on scores and budget constraints.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from ..budget import TokenBudget, TokenCalculator
|
||||||
|
from ..config import ContextSettings, get_context_settings
|
||||||
|
from ..exceptions import BudgetExceededError
|
||||||
|
from ..scoring.composite import CompositeScorer, ScoredContext
|
||||||
|
from ..types import BaseContext, ContextPriority
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
pass
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RankingResult:
|
||||||
|
"""Result of context ranking and selection."""
|
||||||
|
|
||||||
|
selected: list[ScoredContext]
|
||||||
|
excluded: list[ScoredContext]
|
||||||
|
total_tokens: int
|
||||||
|
selection_stats: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def selected_contexts(self) -> list[BaseContext]:
|
||||||
|
"""Get just the context objects (not scored wrappers)."""
|
||||||
|
return [s.context for s in self.selected]
|
||||||
|
|
||||||
|
|
||||||
|
class ContextRanker:
|
||||||
|
"""
|
||||||
|
Ranks and selects contexts within budget constraints.
|
||||||
|
|
||||||
|
Uses greedy selection to maximize total score
|
||||||
|
while respecting token budgets per context type.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
scorer: CompositeScorer | None = None,
|
||||||
|
calculator: TokenCalculator | None = None,
|
||||||
|
settings: ContextSettings | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize context ranker.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scorer: Composite scorer for scoring contexts
|
||||||
|
calculator: Token calculator for counting tokens
|
||||||
|
settings: Context settings (uses global if None)
|
||||||
|
"""
|
||||||
|
self._settings = settings or get_context_settings()
|
||||||
|
self._scorer = scorer or CompositeScorer()
|
||||||
|
self._calculator = calculator or TokenCalculator()
|
||||||
|
|
||||||
|
def set_scorer(self, scorer: CompositeScorer) -> None:
|
||||||
|
"""Set the scorer."""
|
||||||
|
self._scorer = scorer
|
||||||
|
|
||||||
|
def set_calculator(self, calculator: TokenCalculator) -> None:
|
||||||
|
"""Set the token calculator."""
|
||||||
|
self._calculator = calculator
|
||||||
|
|
||||||
|
async def rank(
|
||||||
|
self,
|
||||||
|
contexts: list[BaseContext],
|
||||||
|
query: str,
|
||||||
|
budget: TokenBudget,
|
||||||
|
model: str | None = None,
|
||||||
|
ensure_required: bool = True,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> RankingResult:
|
||||||
|
"""
|
||||||
|
Rank and select contexts within budget.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
contexts: Contexts to rank
|
||||||
|
query: Query to rank against
|
||||||
|
budget: Token budget constraints
|
||||||
|
model: Model for token counting
|
||||||
|
ensure_required: If True, always include CRITICAL priority contexts
|
||||||
|
**kwargs: Additional scoring parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RankingResult with selected and excluded contexts
|
||||||
|
"""
|
||||||
|
if not contexts:
|
||||||
|
return RankingResult(
|
||||||
|
selected=[],
|
||||||
|
excluded=[],
|
||||||
|
total_tokens=0,
|
||||||
|
selection_stats={"total_contexts": 0},
|
||||||
|
)
|
||||||
|
|
||||||
|
# 1. Ensure all contexts have token counts
|
||||||
|
await self._ensure_token_counts(contexts, model)
|
||||||
|
|
||||||
|
# 2. Score all contexts
|
||||||
|
scored_contexts = await self._scorer.score_batch(contexts, query, **kwargs)
|
||||||
|
|
||||||
|
# 3. Separate required (CRITICAL priority) from optional
|
||||||
|
required: list[ScoredContext] = []
|
||||||
|
optional: list[ScoredContext] = []
|
||||||
|
|
||||||
|
if ensure_required:
|
||||||
|
for sc in scored_contexts:
|
||||||
|
# CRITICAL priority (150) contexts are always included
|
||||||
|
if sc.context.priority >= ContextPriority.CRITICAL.value:
|
||||||
|
required.append(sc)
|
||||||
|
else:
|
||||||
|
optional.append(sc)
|
||||||
|
else:
|
||||||
|
optional = list(scored_contexts)
|
||||||
|
|
||||||
|
# 4. Sort optional by score (highest first)
|
||||||
|
optional.sort(reverse=True)
|
||||||
|
|
||||||
|
# 5. Greedy selection
|
||||||
|
selected: list[ScoredContext] = []
|
||||||
|
excluded: list[ScoredContext] = []
|
||||||
|
total_tokens = 0
|
||||||
|
|
||||||
|
# Calculate the usable budget (total minus reserved portions)
|
||||||
|
usable_budget = budget.total - budget.response_reserve - budget.buffer
|
||||||
|
|
||||||
|
# Guard against invalid budget configuration
|
||||||
|
if usable_budget <= 0:
|
||||||
|
raise BudgetExceededError(
|
||||||
|
message=(
|
||||||
|
f"Invalid budget configuration: no usable tokens available. "
|
||||||
|
f"total={budget.total}, response_reserve={budget.response_reserve}, "
|
||||||
|
f"buffer={budget.buffer}"
|
||||||
|
),
|
||||||
|
allocated=budget.total,
|
||||||
|
requested=0,
|
||||||
|
context_type="CONFIGURATION_ERROR",
|
||||||
|
)
|
||||||
|
|
||||||
|
# First, try to fit required contexts
|
||||||
|
for sc in required:
|
||||||
|
token_count = self._get_valid_token_count(sc.context)
|
||||||
|
context_type = sc.context.get_type()
|
||||||
|
|
||||||
|
if budget.can_fit(context_type, token_count):
|
||||||
|
budget.allocate(context_type, token_count)
|
||||||
|
selected.append(sc)
|
||||||
|
total_tokens += token_count
|
||||||
|
else:
|
||||||
|
# Force-fit CRITICAL contexts if needed, but check total budget first
|
||||||
|
if total_tokens + token_count > usable_budget:
|
||||||
|
# Even CRITICAL contexts cannot exceed total model context window
|
||||||
|
raise BudgetExceededError(
|
||||||
|
message=(
|
||||||
|
f"CRITICAL contexts exceed total budget. "
|
||||||
|
f"Context '{sc.context.source}' ({token_count} tokens) "
|
||||||
|
f"would exceed usable budget of {usable_budget} tokens."
|
||||||
|
),
|
||||||
|
allocated=usable_budget,
|
||||||
|
requested=total_tokens + token_count,
|
||||||
|
context_type="CRITICAL_OVERFLOW",
|
||||||
|
)
|
||||||
|
|
||||||
|
budget.allocate(context_type, token_count, force=True)
|
||||||
|
selected.append(sc)
|
||||||
|
total_tokens += token_count
|
||||||
|
logger.warning(
|
||||||
|
f"Force-fitted CRITICAL context: {sc.context.source} "
|
||||||
|
f"({token_count} tokens)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Then, greedily add optional contexts
|
||||||
|
for sc in optional:
|
||||||
|
token_count = self._get_valid_token_count(sc.context)
|
||||||
|
context_type = sc.context.get_type()
|
||||||
|
|
||||||
|
if budget.can_fit(context_type, token_count):
|
||||||
|
budget.allocate(context_type, token_count)
|
||||||
|
selected.append(sc)
|
||||||
|
total_tokens += token_count
|
||||||
|
else:
|
||||||
|
excluded.append(sc)
|
||||||
|
|
||||||
|
# Build stats
|
||||||
|
stats = {
|
||||||
|
"total_contexts": len(contexts),
|
||||||
|
"required_count": len(required),
|
||||||
|
"selected_count": len(selected),
|
||||||
|
"excluded_count": len(excluded),
|
||||||
|
"total_tokens": total_tokens,
|
||||||
|
"by_type": self._count_by_type(selected),
|
||||||
|
}
|
||||||
|
|
||||||
|
return RankingResult(
|
||||||
|
selected=selected,
|
||||||
|
excluded=excluded,
|
||||||
|
total_tokens=total_tokens,
|
||||||
|
selection_stats=stats,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def rank_simple(
|
||||||
|
self,
|
||||||
|
contexts: list[BaseContext],
|
||||||
|
query: str,
|
||||||
|
max_tokens: int,
|
||||||
|
model: str | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> list[BaseContext]:
|
||||||
|
"""
|
||||||
|
Simple ranking without budget per type.
|
||||||
|
|
||||||
|
Selects top contexts by score until max tokens reached.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
contexts: Contexts to rank
|
||||||
|
query: Query to rank against
|
||||||
|
max_tokens: Maximum total tokens
|
||||||
|
model: Model for token counting
|
||||||
|
**kwargs: Additional scoring parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Selected contexts (in score order)
|
||||||
|
"""
|
||||||
|
if not contexts:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Ensure token counts
|
||||||
|
await self._ensure_token_counts(contexts, model)
|
||||||
|
|
||||||
|
# Score all contexts
|
||||||
|
scored_contexts = await self._scorer.score_batch(contexts, query, **kwargs)
|
||||||
|
|
||||||
|
# Sort by score (highest first)
|
||||||
|
scored_contexts.sort(reverse=True)
|
||||||
|
|
||||||
|
# Greedy selection
|
||||||
|
selected: list[BaseContext] = []
|
||||||
|
total_tokens = 0
|
||||||
|
|
||||||
|
for sc in scored_contexts:
|
||||||
|
token_count = self._get_valid_token_count(sc.context)
|
||||||
|
if total_tokens + token_count <= max_tokens:
|
||||||
|
selected.append(sc.context)
|
||||||
|
total_tokens += token_count
|
||||||
|
|
||||||
|
return selected
|
||||||
|
|
||||||
|
def _get_valid_token_count(self, context: BaseContext) -> int:
|
||||||
|
"""
|
||||||
|
Get validated token count from a context.
|
||||||
|
|
||||||
|
Ensures token_count is set (not None) and non-negative to prevent
|
||||||
|
budget bypass attacks where:
|
||||||
|
- None would be treated as 0 (allowing huge contexts to slip through)
|
||||||
|
- Negative values would corrupt budget tracking
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: Context to get token count from
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Valid non-negative token count
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If token_count is None or negative
|
||||||
|
"""
|
||||||
|
if context.token_count is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Context '{context.source}' has no token count. "
|
||||||
|
"Ensure _ensure_token_counts() is called before ranking."
|
||||||
|
)
|
||||||
|
if context.token_count < 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Context '{context.source}' has invalid negative token count: "
|
||||||
|
f"{context.token_count}"
|
||||||
|
)
|
||||||
|
return context.token_count
|
||||||
|
|
||||||
|
async def _ensure_token_counts(
|
||||||
|
self,
|
||||||
|
contexts: list[BaseContext],
|
||||||
|
model: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Ensure all contexts have token counts.
|
||||||
|
|
||||||
|
Counts tokens in parallel for contexts that don't have counts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
contexts: Contexts to check
|
||||||
|
model: Model for token counting
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
# Find contexts needing counts
|
||||||
|
contexts_needing_counts = [ctx for ctx in contexts if ctx.token_count is None]
|
||||||
|
|
||||||
|
if not contexts_needing_counts:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Count all in parallel
|
||||||
|
tasks = [
|
||||||
|
self._calculator.count_tokens(ctx.content, model)
|
||||||
|
for ctx in contexts_needing_counts
|
||||||
|
]
|
||||||
|
counts = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
# Assign counts back
|
||||||
|
for ctx, count in zip(contexts_needing_counts, counts, strict=True):
|
||||||
|
ctx.token_count = count
|
||||||
|
|
||||||
|
def _count_by_type(
|
||||||
|
self, scored_contexts: list[ScoredContext]
|
||||||
|
) -> dict[str, dict[str, int]]:
|
||||||
|
"""Count selected contexts by type."""
|
||||||
|
by_type: dict[str, dict[str, int]] = {}
|
||||||
|
|
||||||
|
for sc in scored_contexts:
|
||||||
|
type_name = sc.context.get_type().value
|
||||||
|
if type_name not in by_type:
|
||||||
|
by_type[type_name] = {"count": 0, "tokens": 0}
|
||||||
|
by_type[type_name]["count"] += 1
|
||||||
|
# Use validated token count (already validated during ranking)
|
||||||
|
by_type[type_name]["tokens"] += sc.context.token_count or 0
|
||||||
|
|
||||||
|
return by_type
|
||||||
|
|
||||||
|
async def rerank_for_diversity(
|
||||||
|
self,
|
||||||
|
scored_contexts: list[ScoredContext],
|
||||||
|
max_per_source: int | None = None,
|
||||||
|
) -> list[ScoredContext]:
|
||||||
|
"""
|
||||||
|
Rerank to ensure source diversity.
|
||||||
|
|
||||||
|
Prevents too many items from the same source.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scored_contexts: Already scored contexts
|
||||||
|
max_per_source: Maximum items per source (uses settings if None)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Reranked contexts
|
||||||
|
"""
|
||||||
|
# Use provided value or fall back to settings
|
||||||
|
effective_max = (
|
||||||
|
max_per_source
|
||||||
|
if max_per_source is not None
|
||||||
|
else self._settings.diversity_max_per_source
|
||||||
|
)
|
||||||
|
|
||||||
|
source_counts: dict[str, int] = {}
|
||||||
|
result: list[ScoredContext] = []
|
||||||
|
deferred: list[ScoredContext] = []
|
||||||
|
|
||||||
|
for sc in scored_contexts:
|
||||||
|
source = sc.context.source
|
||||||
|
current_count = source_counts.get(source, 0)
|
||||||
|
|
||||||
|
if current_count < effective_max:
|
||||||
|
result.append(sc)
|
||||||
|
source_counts[source] = current_count + 1
|
||||||
|
else:
|
||||||
|
deferred.append(sc)
|
||||||
|
|
||||||
|
# Add deferred items at the end
|
||||||
|
result.extend(deferred)
|
||||||
|
return result
|
||||||
21
backend/app/services/context/scoring/__init__.py
Normal file
21
backend/app/services/context/scoring/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
"""
|
||||||
|
Context Scoring Module.
|
||||||
|
|
||||||
|
Provides scoring strategies for context prioritization.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .base import BaseScorer, ScorerProtocol
|
||||||
|
from .composite import CompositeScorer, ScoredContext
|
||||||
|
from .priority import PriorityScorer
|
||||||
|
from .recency import RecencyScorer
|
||||||
|
from .relevance import RelevanceScorer
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BaseScorer",
|
||||||
|
"CompositeScorer",
|
||||||
|
"PriorityScorer",
|
||||||
|
"RecencyScorer",
|
||||||
|
"RelevanceScorer",
|
||||||
|
"ScoredContext",
|
||||||
|
"ScorerProtocol",
|
||||||
|
]
|
||||||
99
backend/app/services/context/scoring/base.py
Normal file
99
backend/app/services/context/scoring/base.py
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
"""
|
||||||
|
Base Scorer Protocol and Types.
|
||||||
|
|
||||||
|
Defines the interface for context scoring implementations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
|
||||||
|
|
||||||
|
from ..types import BaseContext
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class ScorerProtocol(Protocol):
|
||||||
|
"""Protocol for context scorers."""
|
||||||
|
|
||||||
|
async def score(
|
||||||
|
self,
|
||||||
|
context: BaseContext,
|
||||||
|
query: str,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
Score a context item.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: Context to score
|
||||||
|
query: Query to score against
|
||||||
|
**kwargs: Additional scoring parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Score between 0.0 and 1.0
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class BaseScorer(ABC):
|
||||||
|
"""
|
||||||
|
Abstract base class for context scorers.
|
||||||
|
|
||||||
|
Provides common functionality and interface for
|
||||||
|
different scoring strategies.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, weight: float = 1.0) -> None:
|
||||||
|
"""
|
||||||
|
Initialize scorer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weight: Weight for this scorer in composite scoring
|
||||||
|
"""
|
||||||
|
self._weight = weight
|
||||||
|
|
||||||
|
@property
|
||||||
|
def weight(self) -> float:
|
||||||
|
"""Get scorer weight."""
|
||||||
|
return self._weight
|
||||||
|
|
||||||
|
@weight.setter
|
||||||
|
def weight(self, value: float) -> None:
|
||||||
|
"""Set scorer weight."""
|
||||||
|
if not 0.0 <= value <= 1.0:
|
||||||
|
raise ValueError("Weight must be between 0.0 and 1.0")
|
||||||
|
self._weight = value
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def score(
|
||||||
|
self,
|
||||||
|
context: BaseContext,
|
||||||
|
query: str,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
Score a context item.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: Context to score
|
||||||
|
query: Query to score against
|
||||||
|
**kwargs: Additional scoring parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Score between 0.0 and 1.0
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def normalize_score(self, score: float) -> float:
|
||||||
|
"""
|
||||||
|
Normalize score to [0.0, 1.0] range.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
score: Raw score
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Normalized score
|
||||||
|
"""
|
||||||
|
return max(0.0, min(1.0, score))
|
||||||
368
backend/app/services/context/scoring/composite.py
Normal file
368
backend/app/services/context/scoring/composite.py
Normal file
@@ -0,0 +1,368 @@
|
|||||||
|
"""
|
||||||
|
Composite Scorer for Context Management.
|
||||||
|
|
||||||
|
Combines multiple scoring strategies with configurable weights.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from ..config import ContextSettings, get_context_settings
|
||||||
|
from ..types import BaseContext
|
||||||
|
from .priority import PriorityScorer
|
||||||
|
from .recency import RecencyScorer
|
||||||
|
from .relevance import RelevanceScorer
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from app.services.mcp.client_manager import MCPClientManager
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ScoredContext:
|
||||||
|
"""Context with computed scores."""
|
||||||
|
|
||||||
|
context: BaseContext
|
||||||
|
composite_score: float
|
||||||
|
relevance_score: float = 0.0
|
||||||
|
recency_score: float = 0.0
|
||||||
|
priority_score: float = 0.0
|
||||||
|
|
||||||
|
def __lt__(self, other: "ScoredContext") -> bool:
|
||||||
|
"""Enable sorting by composite score."""
|
||||||
|
return self.composite_score < other.composite_score
|
||||||
|
|
||||||
|
def __gt__(self, other: "ScoredContext") -> bool:
|
||||||
|
"""Enable sorting by composite score."""
|
||||||
|
return self.composite_score > other.composite_score
|
||||||
|
|
||||||
|
|
||||||
|
class CompositeScorer:
|
||||||
|
"""
|
||||||
|
Combines multiple scoring strategies.
|
||||||
|
|
||||||
|
Weights:
|
||||||
|
- relevance: How well content matches the query
|
||||||
|
- recency: How recent the content is
|
||||||
|
- priority: Explicit priority assignments
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
mcp_manager: "MCPClientManager | None" = None,
|
||||||
|
settings: ContextSettings | None = None,
|
||||||
|
relevance_weight: float | None = None,
|
||||||
|
recency_weight: float | None = None,
|
||||||
|
priority_weight: float | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize composite scorer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mcp_manager: MCP manager for semantic scoring
|
||||||
|
settings: Context settings (uses default if None)
|
||||||
|
relevance_weight: Override relevance weight
|
||||||
|
recency_weight: Override recency weight
|
||||||
|
priority_weight: Override priority weight
|
||||||
|
"""
|
||||||
|
self._settings = settings or get_context_settings()
|
||||||
|
weights = self._settings.get_scoring_weights()
|
||||||
|
|
||||||
|
self._relevance_weight = (
|
||||||
|
relevance_weight if relevance_weight is not None else weights["relevance"]
|
||||||
|
)
|
||||||
|
self._recency_weight = (
|
||||||
|
recency_weight if recency_weight is not None else weights["recency"]
|
||||||
|
)
|
||||||
|
self._priority_weight = (
|
||||||
|
priority_weight if priority_weight is not None else weights["priority"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize scorers
|
||||||
|
self._relevance_scorer = RelevanceScorer(
|
||||||
|
mcp_manager=mcp_manager,
|
||||||
|
weight=self._relevance_weight,
|
||||||
|
)
|
||||||
|
self._recency_scorer = RecencyScorer(weight=self._recency_weight)
|
||||||
|
self._priority_scorer = PriorityScorer(weight=self._priority_weight)
|
||||||
|
|
||||||
|
# Per-context locks to prevent race conditions during parallel scoring
|
||||||
|
# Uses dict with (lock, last_used_time) tuples for cleanup
|
||||||
|
self._context_locks: dict[str, tuple[asyncio.Lock, float]] = {}
|
||||||
|
self._locks_lock = asyncio.Lock() # Lock to protect _context_locks access
|
||||||
|
self._max_locks = 1000 # Maximum locks to keep (prevent memory growth)
|
||||||
|
self._lock_ttl = 60.0 # Seconds before a lock can be cleaned up
|
||||||
|
|
||||||
|
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
|
||||||
|
"""Set MCP manager for semantic scoring."""
|
||||||
|
self._relevance_scorer.set_mcp_manager(mcp_manager)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def weights(self) -> dict[str, float]:
|
||||||
|
"""Get current scoring weights."""
|
||||||
|
return {
|
||||||
|
"relevance": self._relevance_weight,
|
||||||
|
"recency": self._recency_weight,
|
||||||
|
"priority": self._priority_weight,
|
||||||
|
}
|
||||||
|
|
||||||
|
def update_weights(
|
||||||
|
self,
|
||||||
|
relevance: float | None = None,
|
||||||
|
recency: float | None = None,
|
||||||
|
priority: float | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Update scoring weights.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
relevance: New relevance weight
|
||||||
|
recency: New recency weight
|
||||||
|
priority: New priority weight
|
||||||
|
"""
|
||||||
|
if relevance is not None:
|
||||||
|
self._relevance_weight = max(0.0, min(1.0, relevance))
|
||||||
|
self._relevance_scorer.weight = self._relevance_weight
|
||||||
|
|
||||||
|
if recency is not None:
|
||||||
|
self._recency_weight = max(0.0, min(1.0, recency))
|
||||||
|
self._recency_scorer.weight = self._recency_weight
|
||||||
|
|
||||||
|
if priority is not None:
|
||||||
|
self._priority_weight = max(0.0, min(1.0, priority))
|
||||||
|
self._priority_scorer.weight = self._priority_weight
|
||||||
|
|
||||||
|
async def _get_context_lock(self, context_id: str) -> asyncio.Lock:
|
||||||
|
"""
|
||||||
|
Get or create a lock for a specific context.
|
||||||
|
|
||||||
|
Thread-safe access to per-context locks prevents race conditions
|
||||||
|
when the same context is scored concurrently. Includes automatic
|
||||||
|
cleanup of old locks to prevent memory growth.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context_id: The context ID to get a lock for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
asyncio.Lock for the context
|
||||||
|
"""
|
||||||
|
now = time.time()
|
||||||
|
|
||||||
|
# Fast path: check if lock exists without acquiring main lock
|
||||||
|
# NOTE: We only READ here - no writes to avoid race conditions
|
||||||
|
# with cleanup. The timestamp will be updated in the slow path
|
||||||
|
# if the lock is still valid.
|
||||||
|
lock_entry = self._context_locks.get(context_id)
|
||||||
|
if lock_entry is not None:
|
||||||
|
lock, _ = lock_entry
|
||||||
|
# Return the lock but defer timestamp update to avoid race
|
||||||
|
# The lock is still valid; timestamp update is best-effort
|
||||||
|
return lock
|
||||||
|
|
||||||
|
# Slow path: create lock or update timestamp while holding main lock
|
||||||
|
async with self._locks_lock:
|
||||||
|
# Double-check after acquiring lock - entry may have been
|
||||||
|
# created by another coroutine or deleted by cleanup
|
||||||
|
lock_entry = self._context_locks.get(context_id)
|
||||||
|
if lock_entry is not None:
|
||||||
|
lock, _ = lock_entry
|
||||||
|
# Safe to update timestamp here since we hold the lock
|
||||||
|
self._context_locks[context_id] = (lock, now)
|
||||||
|
return lock
|
||||||
|
|
||||||
|
# Cleanup old locks if we have too many
|
||||||
|
if len(self._context_locks) >= self._max_locks:
|
||||||
|
self._cleanup_old_locks(now)
|
||||||
|
|
||||||
|
# Create new lock
|
||||||
|
new_lock = asyncio.Lock()
|
||||||
|
self._context_locks[context_id] = (new_lock, now)
|
||||||
|
return new_lock
|
||||||
|
|
||||||
|
def _cleanup_old_locks(self, now: float) -> None:
|
||||||
|
"""
|
||||||
|
Remove old locks that haven't been used recently.
|
||||||
|
|
||||||
|
Called while holding _locks_lock. Removes locks older than _lock_ttl,
|
||||||
|
but only if they're not currently held.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
now: Current timestamp for age calculation
|
||||||
|
"""
|
||||||
|
cutoff = now - self._lock_ttl
|
||||||
|
to_remove = []
|
||||||
|
|
||||||
|
for context_id, (lock, last_used) in self._context_locks.items():
|
||||||
|
# Only remove if old AND not currently held
|
||||||
|
if last_used < cutoff and not lock.locked():
|
||||||
|
to_remove.append(context_id)
|
||||||
|
|
||||||
|
# Remove oldest 50% if still over limit after TTL filtering
|
||||||
|
if len(self._context_locks) - len(to_remove) >= self._max_locks:
|
||||||
|
# Sort by last used time and mark oldest for removal
|
||||||
|
sorted_entries = sorted(
|
||||||
|
self._context_locks.items(),
|
||||||
|
key=lambda x: x[1][1], # Sort by last_used time
|
||||||
|
)
|
||||||
|
# Remove oldest 50% that aren't locked
|
||||||
|
target_remove = len(self._context_locks) // 2
|
||||||
|
for context_id, (lock, _) in sorted_entries:
|
||||||
|
if len(to_remove) >= target_remove:
|
||||||
|
break
|
||||||
|
if context_id not in to_remove and not lock.locked():
|
||||||
|
to_remove.append(context_id)
|
||||||
|
|
||||||
|
for context_id in to_remove:
|
||||||
|
del self._context_locks[context_id]
|
||||||
|
|
||||||
|
if to_remove:
|
||||||
|
logger.debug(f"Cleaned up {len(to_remove)} context locks")
|
||||||
|
|
||||||
|
async def score(
|
||||||
|
self,
|
||||||
|
context: BaseContext,
|
||||||
|
query: str,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
Compute composite score for a context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: Context to score
|
||||||
|
query: Query to score against
|
||||||
|
**kwargs: Additional scoring parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Composite score between 0.0 and 1.0
|
||||||
|
"""
|
||||||
|
scored = await self.score_with_details(context, query, **kwargs)
|
||||||
|
return scored.composite_score
|
||||||
|
|
||||||
|
async def score_with_details(
|
||||||
|
self,
|
||||||
|
context: BaseContext,
|
||||||
|
query: str,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ScoredContext:
|
||||||
|
"""
|
||||||
|
Compute composite score with individual scores.
|
||||||
|
|
||||||
|
Uses per-context locking to prevent race conditions when the same
|
||||||
|
context is scored concurrently in parallel scoring operations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: Context to score
|
||||||
|
query: Query to score against
|
||||||
|
**kwargs: Additional scoring parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ScoredContext with all scores
|
||||||
|
"""
|
||||||
|
# Get lock for this specific context to prevent race conditions
|
||||||
|
# within concurrent scoring operations for the same query
|
||||||
|
context_lock = await self._get_context_lock(context.id)
|
||||||
|
|
||||||
|
async with context_lock:
|
||||||
|
# Compute individual scores in parallel
|
||||||
|
# Note: We do NOT cache scores on the context because scores are
|
||||||
|
# query-dependent. Caching without considering the query would
|
||||||
|
# return incorrect scores for different queries.
|
||||||
|
relevance_task = self._relevance_scorer.score(context, query, **kwargs)
|
||||||
|
recency_task = self._recency_scorer.score(context, query, **kwargs)
|
||||||
|
priority_task = self._priority_scorer.score(context, query, **kwargs)
|
||||||
|
|
||||||
|
relevance_score, recency_score, priority_score = await asyncio.gather(
|
||||||
|
relevance_task, recency_task, priority_task
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute weighted composite
|
||||||
|
total_weight = (
|
||||||
|
self._relevance_weight + self._recency_weight + self._priority_weight
|
||||||
|
)
|
||||||
|
|
||||||
|
if total_weight > 0:
|
||||||
|
composite = (
|
||||||
|
relevance_score * self._relevance_weight
|
||||||
|
+ recency_score * self._recency_weight
|
||||||
|
+ priority_score * self._priority_weight
|
||||||
|
) / total_weight
|
||||||
|
else:
|
||||||
|
composite = 0.0
|
||||||
|
|
||||||
|
return ScoredContext(
|
||||||
|
context=context,
|
||||||
|
composite_score=composite,
|
||||||
|
relevance_score=relevance_score,
|
||||||
|
recency_score=recency_score,
|
||||||
|
priority_score=priority_score,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def score_batch(
|
||||||
|
self,
|
||||||
|
contexts: list[BaseContext],
|
||||||
|
query: str,
|
||||||
|
parallel: bool = True,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> list[ScoredContext]:
|
||||||
|
"""
|
||||||
|
Score multiple contexts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
contexts: Contexts to score
|
||||||
|
query: Query to score against
|
||||||
|
parallel: Whether to score in parallel
|
||||||
|
**kwargs: Additional scoring parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of ScoredContext (same order as input)
|
||||||
|
"""
|
||||||
|
if parallel:
|
||||||
|
tasks = [self.score_with_details(ctx, query, **kwargs) for ctx in contexts]
|
||||||
|
return await asyncio.gather(*tasks)
|
||||||
|
else:
|
||||||
|
results = []
|
||||||
|
for ctx in contexts:
|
||||||
|
scored = await self.score_with_details(ctx, query, **kwargs)
|
||||||
|
results.append(scored)
|
||||||
|
return results
|
||||||
|
|
||||||
|
async def rank(
|
||||||
|
self,
|
||||||
|
contexts: list[BaseContext],
|
||||||
|
query: str,
|
||||||
|
limit: int | None = None,
|
||||||
|
min_score: float = 0.0,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> list[ScoredContext]:
|
||||||
|
"""
|
||||||
|
Score and rank contexts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
contexts: Contexts to rank
|
||||||
|
query: Query to rank against
|
||||||
|
limit: Maximum number of results
|
||||||
|
min_score: Minimum score threshold
|
||||||
|
**kwargs: Additional scoring parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sorted list of ScoredContext (highest first)
|
||||||
|
"""
|
||||||
|
# Score all contexts
|
||||||
|
scored = await self.score_batch(contexts, query, **kwargs)
|
||||||
|
|
||||||
|
# Filter by minimum score
|
||||||
|
if min_score > 0:
|
||||||
|
scored = [s for s in scored if s.composite_score >= min_score]
|
||||||
|
|
||||||
|
# Sort by score (highest first)
|
||||||
|
scored.sort(reverse=True)
|
||||||
|
|
||||||
|
# Apply limit
|
||||||
|
if limit is not None:
|
||||||
|
scored = scored[:limit]
|
||||||
|
|
||||||
|
return scored
|
||||||
135
backend/app/services/context/scoring/priority.py
Normal file
135
backend/app/services/context/scoring/priority.py
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
"""
|
||||||
|
Priority Scorer for Context Management.
|
||||||
|
|
||||||
|
Scores context based on assigned priority levels.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, ClassVar
|
||||||
|
|
||||||
|
from ..types import BaseContext, ContextType
|
||||||
|
from .base import BaseScorer
|
||||||
|
|
||||||
|
|
||||||
|
class PriorityScorer(BaseScorer):
|
||||||
|
"""
|
||||||
|
Scores context based on priority levels.
|
||||||
|
|
||||||
|
Converts priority enum values to normalized scores.
|
||||||
|
Also applies type-based priority bonuses.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Default priority bonuses by context type
|
||||||
|
DEFAULT_TYPE_BONUSES: ClassVar[dict[ContextType, float]] = {
|
||||||
|
ContextType.SYSTEM: 0.2, # System prompts get a boost
|
||||||
|
ContextType.TASK: 0.15, # Current task is important
|
||||||
|
ContextType.TOOL: 0.1, # Recent tool results matter
|
||||||
|
ContextType.KNOWLEDGE: 0.0, # Knowledge scored by relevance
|
||||||
|
ContextType.CONVERSATION: 0.0, # Conversation scored by recency
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
weight: float = 1.0,
|
||||||
|
type_bonuses: dict[ContextType, float] | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize priority scorer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weight: Scorer weight for composite scoring
|
||||||
|
type_bonuses: Optional context-type priority bonuses
|
||||||
|
"""
|
||||||
|
super().__init__(weight)
|
||||||
|
self._type_bonuses = type_bonuses or self.DEFAULT_TYPE_BONUSES.copy()
|
||||||
|
|
||||||
|
async def score(
|
||||||
|
self,
|
||||||
|
context: BaseContext,
|
||||||
|
query: str,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
Score context based on priority.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: Context to score
|
||||||
|
query: Query (not used for priority, kept for interface)
|
||||||
|
**kwargs: Additional parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Priority score between 0.0 and 1.0
|
||||||
|
"""
|
||||||
|
# Get base priority score
|
||||||
|
priority_value = context.priority
|
||||||
|
base_score = self._priority_to_score(priority_value)
|
||||||
|
|
||||||
|
# Apply type bonus
|
||||||
|
context_type = context.get_type()
|
||||||
|
bonus = self._type_bonuses.get(context_type, 0.0)
|
||||||
|
|
||||||
|
return self.normalize_score(base_score + bonus)
|
||||||
|
|
||||||
|
def _priority_to_score(self, priority: int) -> float:
|
||||||
|
"""
|
||||||
|
Convert priority value to normalized score.
|
||||||
|
|
||||||
|
Priority values (from ContextPriority):
|
||||||
|
- CRITICAL (100) -> 1.0
|
||||||
|
- HIGH (80) -> 0.8
|
||||||
|
- NORMAL (50) -> 0.5
|
||||||
|
- LOW (20) -> 0.2
|
||||||
|
- MINIMAL (0) -> 0.0
|
||||||
|
|
||||||
|
Args:
|
||||||
|
priority: Priority value (0-100)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Normalized score (0.0-1.0)
|
||||||
|
"""
|
||||||
|
# Clamp to valid range
|
||||||
|
clamped = max(0, min(100, priority))
|
||||||
|
return clamped / 100.0
|
||||||
|
|
||||||
|
def get_type_bonus(self, context_type: ContextType) -> float:
|
||||||
|
"""
|
||||||
|
Get priority bonus for a context type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context_type: Context type
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Bonus value
|
||||||
|
"""
|
||||||
|
return self._type_bonuses.get(context_type, 0.0)
|
||||||
|
|
||||||
|
def set_type_bonus(self, context_type: ContextType, bonus: float) -> None:
|
||||||
|
"""
|
||||||
|
Set priority bonus for a context type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context_type: Context type
|
||||||
|
bonus: Bonus value (0.0-1.0)
|
||||||
|
"""
|
||||||
|
if not 0.0 <= bonus <= 1.0:
|
||||||
|
raise ValueError("Bonus must be between 0.0 and 1.0")
|
||||||
|
self._type_bonuses[context_type] = bonus
|
||||||
|
|
||||||
|
async def score_batch(
|
||||||
|
self,
|
||||||
|
contexts: list[BaseContext],
|
||||||
|
query: str,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> list[float]:
|
||||||
|
"""
|
||||||
|
Score multiple contexts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
contexts: Contexts to score
|
||||||
|
query: Query (not used)
|
||||||
|
**kwargs: Additional parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of scores (same order as input)
|
||||||
|
"""
|
||||||
|
# Priority scoring is fast, no async needed
|
||||||
|
return [await self.score(ctx, query, **kwargs) for ctx in contexts]
|
||||||
141
backend/app/services/context/scoring/recency.py
Normal file
141
backend/app/services/context/scoring/recency.py
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
"""
|
||||||
|
Recency Scorer for Context Management.
|
||||||
|
|
||||||
|
Scores context based on how recent it is.
|
||||||
|
More recent content gets higher scores.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from ..types import BaseContext, ContextType
|
||||||
|
from .base import BaseScorer
|
||||||
|
|
||||||
|
|
||||||
|
class RecencyScorer(BaseScorer):
|
||||||
|
"""
|
||||||
|
Scores context based on recency.
|
||||||
|
|
||||||
|
Uses exponential decay to score content based on age.
|
||||||
|
More recent content scores higher.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
weight: float = 1.0,
|
||||||
|
half_life_hours: float = 24.0,
|
||||||
|
type_half_lives: dict[ContextType, float] | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize recency scorer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weight: Scorer weight for composite scoring
|
||||||
|
half_life_hours: Default hours until score decays to 0.5
|
||||||
|
type_half_lives: Optional context-type-specific half lives
|
||||||
|
"""
|
||||||
|
super().__init__(weight)
|
||||||
|
self._half_life_hours = half_life_hours
|
||||||
|
self._type_half_lives = type_half_lives or {}
|
||||||
|
|
||||||
|
# Set sensible defaults for context types
|
||||||
|
if ContextType.CONVERSATION not in self._type_half_lives:
|
||||||
|
self._type_half_lives[ContextType.CONVERSATION] = 1.0 # 1 hour
|
||||||
|
if ContextType.TOOL not in self._type_half_lives:
|
||||||
|
self._type_half_lives[ContextType.TOOL] = 0.5 # 30 minutes
|
||||||
|
if ContextType.KNOWLEDGE not in self._type_half_lives:
|
||||||
|
self._type_half_lives[ContextType.KNOWLEDGE] = 168.0 # 1 week
|
||||||
|
if ContextType.SYSTEM not in self._type_half_lives:
|
||||||
|
self._type_half_lives[ContextType.SYSTEM] = 720.0 # 30 days
|
||||||
|
if ContextType.TASK not in self._type_half_lives:
|
||||||
|
self._type_half_lives[ContextType.TASK] = 24.0 # 1 day
|
||||||
|
|
||||||
|
async def score(
|
||||||
|
self,
|
||||||
|
context: BaseContext,
|
||||||
|
query: str,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
Score context based on recency.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: Context to score
|
||||||
|
query: Query (not used for recency, kept for interface)
|
||||||
|
**kwargs: Additional parameters
|
||||||
|
- reference_time: Time to measure recency from (default: now)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Recency score between 0.0 and 1.0
|
||||||
|
"""
|
||||||
|
reference_time = kwargs.get("reference_time")
|
||||||
|
if reference_time is None:
|
||||||
|
reference_time = datetime.now(UTC)
|
||||||
|
elif reference_time.tzinfo is None:
|
||||||
|
reference_time = reference_time.replace(tzinfo=UTC)
|
||||||
|
|
||||||
|
# Ensure context timestamp is timezone-aware
|
||||||
|
context_time = context.timestamp
|
||||||
|
if context_time.tzinfo is None:
|
||||||
|
context_time = context_time.replace(tzinfo=UTC)
|
||||||
|
|
||||||
|
# Calculate age in hours
|
||||||
|
age = reference_time - context_time
|
||||||
|
age_hours = max(0, age.total_seconds() / 3600)
|
||||||
|
|
||||||
|
# Get half-life for this context type
|
||||||
|
context_type = context.get_type()
|
||||||
|
half_life = self._type_half_lives.get(context_type, self._half_life_hours)
|
||||||
|
|
||||||
|
# Exponential decay
|
||||||
|
decay_factor = math.exp(-math.log(2) * age_hours / half_life)
|
||||||
|
|
||||||
|
return self.normalize_score(decay_factor)
|
||||||
|
|
||||||
|
def get_half_life(self, context_type: ContextType) -> float:
|
||||||
|
"""
|
||||||
|
Get half-life for a context type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context_type: Context type to get half-life for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Half-life in hours
|
||||||
|
"""
|
||||||
|
return self._type_half_lives.get(context_type, self._half_life_hours)
|
||||||
|
|
||||||
|
def set_half_life(self, context_type: ContextType, hours: float) -> None:
|
||||||
|
"""
|
||||||
|
Set half-life for a context type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context_type: Context type to set half-life for
|
||||||
|
hours: Half-life in hours
|
||||||
|
"""
|
||||||
|
if hours <= 0:
|
||||||
|
raise ValueError("Half-life must be positive")
|
||||||
|
self._type_half_lives[context_type] = hours
|
||||||
|
|
||||||
|
async def score_batch(
|
||||||
|
self,
|
||||||
|
contexts: list[BaseContext],
|
||||||
|
query: str,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> list[float]:
|
||||||
|
"""
|
||||||
|
Score multiple contexts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
contexts: Contexts to score
|
||||||
|
query: Query (not used)
|
||||||
|
**kwargs: Additional parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of scores (same order as input)
|
||||||
|
"""
|
||||||
|
scores = []
|
||||||
|
for context in contexts:
|
||||||
|
score = await self.score(context, query, **kwargs)
|
||||||
|
scores.append(score)
|
||||||
|
return scores
|
||||||
220
backend/app/services/context/scoring/relevance.py
Normal file
220
backend/app/services/context/scoring/relevance.py
Normal file
@@ -0,0 +1,220 @@
|
|||||||
|
"""
|
||||||
|
Relevance Scorer for Context Management.
|
||||||
|
|
||||||
|
Scores context based on semantic similarity to the query.
|
||||||
|
Uses Knowledge Base embeddings when available.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from ..config import ContextSettings, get_context_settings
|
||||||
|
from ..types import BaseContext, KnowledgeContext
|
||||||
|
from .base import BaseScorer
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from app.services.mcp.client_manager import MCPClientManager
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class RelevanceScorer(BaseScorer):
|
||||||
|
"""
|
||||||
|
Scores context based on relevance to query.
|
||||||
|
|
||||||
|
Uses multiple strategies:
|
||||||
|
1. Pre-computed scores (from RAG results)
|
||||||
|
2. MCP-based semantic similarity (via Knowledge Base)
|
||||||
|
3. Keyword matching fallback
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
mcp_manager: "MCPClientManager | None" = None,
|
||||||
|
weight: float = 1.0,
|
||||||
|
keyword_fallback_weight: float | None = None,
|
||||||
|
semantic_max_chars: int | None = None,
|
||||||
|
settings: ContextSettings | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize relevance scorer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mcp_manager: MCP manager for Knowledge Base calls
|
||||||
|
weight: Scorer weight for composite scoring
|
||||||
|
keyword_fallback_weight: Max score for keyword-based fallback (overrides settings)
|
||||||
|
semantic_max_chars: Max content length for semantic similarity (overrides settings)
|
||||||
|
settings: Context settings (uses global if None)
|
||||||
|
"""
|
||||||
|
super().__init__(weight)
|
||||||
|
self._settings = settings or get_context_settings()
|
||||||
|
self._mcp = mcp_manager
|
||||||
|
|
||||||
|
# Use provided values or fall back to settings
|
||||||
|
self._keyword_fallback_weight = (
|
||||||
|
keyword_fallback_weight
|
||||||
|
if keyword_fallback_weight is not None
|
||||||
|
else self._settings.relevance_keyword_fallback_weight
|
||||||
|
)
|
||||||
|
self._semantic_max_chars = (
|
||||||
|
semantic_max_chars
|
||||||
|
if semantic_max_chars is not None
|
||||||
|
else self._settings.relevance_semantic_max_chars
|
||||||
|
)
|
||||||
|
|
||||||
|
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
|
||||||
|
"""Set MCP manager for semantic scoring."""
|
||||||
|
self._mcp = mcp_manager
|
||||||
|
|
||||||
|
async def score(
|
||||||
|
self,
|
||||||
|
context: BaseContext,
|
||||||
|
query: str,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
Score context relevance to query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: Context to score
|
||||||
|
query: Query to score against
|
||||||
|
**kwargs: Additional parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Relevance score between 0.0 and 1.0
|
||||||
|
"""
|
||||||
|
# 1. Check for pre-computed relevance score
|
||||||
|
if (
|
||||||
|
isinstance(context, KnowledgeContext)
|
||||||
|
and context.relevance_score is not None
|
||||||
|
):
|
||||||
|
return self.normalize_score(context.relevance_score)
|
||||||
|
|
||||||
|
# 2. Check metadata for score
|
||||||
|
if "relevance_score" in context.metadata:
|
||||||
|
return self.normalize_score(context.metadata["relevance_score"])
|
||||||
|
|
||||||
|
if "score" in context.metadata:
|
||||||
|
return self.normalize_score(context.metadata["score"])
|
||||||
|
|
||||||
|
# 3. Try MCP-based semantic similarity (if compute_similarity tool is available)
|
||||||
|
# Note: This requires the knowledge-base MCP server to implement compute_similarity
|
||||||
|
if self._mcp is not None:
|
||||||
|
try:
|
||||||
|
score = await self._compute_semantic_similarity(context, query)
|
||||||
|
if score is not None:
|
||||||
|
return score
|
||||||
|
except Exception as e:
|
||||||
|
# Log at debug level since this is expected if compute_similarity
|
||||||
|
# tool is not implemented in the Knowledge Base server
|
||||||
|
logger.debug(
|
||||||
|
f"Semantic scoring unavailable, using keyword fallback: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. Fall back to keyword matching
|
||||||
|
return self._compute_keyword_score(context, query)
|
||||||
|
|
||||||
|
async def _compute_semantic_similarity(
|
||||||
|
self,
|
||||||
|
context: BaseContext,
|
||||||
|
query: str,
|
||||||
|
) -> float | None:
|
||||||
|
"""
|
||||||
|
Compute semantic similarity using Knowledge Base embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: Context to score
|
||||||
|
query: Query to compare
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Similarity score or None if unavailable
|
||||||
|
"""
|
||||||
|
if self._mcp is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Use Knowledge Base's search capability to compute similarity
|
||||||
|
result = await self._mcp.call_tool(
|
||||||
|
server="knowledge-base",
|
||||||
|
tool="compute_similarity",
|
||||||
|
args={
|
||||||
|
"text1": query,
|
||||||
|
"text2": context.content[
|
||||||
|
: self._semantic_max_chars
|
||||||
|
], # Limit content length
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.success and isinstance(result.data, dict):
|
||||||
|
similarity = result.data.get("similarity")
|
||||||
|
if similarity is not None:
|
||||||
|
return self.normalize_score(float(similarity))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Semantic similarity computation failed: {e}")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _compute_keyword_score(
|
||||||
|
self,
|
||||||
|
context: BaseContext,
|
||||||
|
query: str,
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
Compute relevance score based on keyword matching.
|
||||||
|
|
||||||
|
Simple but fast fallback when semantic search is unavailable.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: Context to score
|
||||||
|
query: Query to match
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Keyword-based relevance score
|
||||||
|
"""
|
||||||
|
if not query or not context.content:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
# Extract keywords from query
|
||||||
|
query_lower = query.lower()
|
||||||
|
content_lower = context.content.lower()
|
||||||
|
|
||||||
|
# Simple word tokenization
|
||||||
|
query_words = set(re.findall(r"\b\w{3,}\b", query_lower))
|
||||||
|
content_words = set(re.findall(r"\b\w{3,}\b", content_lower))
|
||||||
|
|
||||||
|
if not query_words:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
# Calculate overlap
|
||||||
|
common_words = query_words & content_words
|
||||||
|
overlap_ratio = len(common_words) / len(query_words)
|
||||||
|
|
||||||
|
# Apply fallback weight ceiling
|
||||||
|
return self.normalize_score(overlap_ratio * self._keyword_fallback_weight)
|
||||||
|
|
||||||
|
async def score_batch(
|
||||||
|
self,
|
||||||
|
contexts: list[BaseContext],
|
||||||
|
query: str,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> list[float]:
|
||||||
|
"""
|
||||||
|
Score multiple contexts in parallel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
contexts: Contexts to score
|
||||||
|
query: Query to score against
|
||||||
|
**kwargs: Additional parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of scores (same order as input)
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
if not contexts:
|
||||||
|
return []
|
||||||
|
|
||||||
|
tasks = [self.score(context, query, **kwargs) for context in contexts]
|
||||||
|
return await asyncio.gather(*tasks)
|
||||||
49
backend/app/services/context/types/__init__.py
Normal file
49
backend/app/services/context/types/__init__.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
"""
|
||||||
|
Context Types Module.
|
||||||
|
|
||||||
|
Provides all context types used in the Context Management Engine.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .base import (
|
||||||
|
AssembledContext,
|
||||||
|
BaseContext,
|
||||||
|
ContextPriority,
|
||||||
|
ContextType,
|
||||||
|
)
|
||||||
|
from .conversation import (
|
||||||
|
ConversationContext,
|
||||||
|
MessageRole,
|
||||||
|
)
|
||||||
|
from .knowledge import KnowledgeContext
|
||||||
|
from .memory import (
|
||||||
|
MemoryContext,
|
||||||
|
MemorySubtype,
|
||||||
|
)
|
||||||
|
from .system import SystemContext
|
||||||
|
from .task import (
|
||||||
|
TaskComplexity,
|
||||||
|
TaskContext,
|
||||||
|
TaskStatus,
|
||||||
|
)
|
||||||
|
from .tool import (
|
||||||
|
ToolContext,
|
||||||
|
ToolResultStatus,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AssembledContext",
|
||||||
|
"BaseContext",
|
||||||
|
"ContextPriority",
|
||||||
|
"ContextType",
|
||||||
|
"ConversationContext",
|
||||||
|
"KnowledgeContext",
|
||||||
|
"MemoryContext",
|
||||||
|
"MemorySubtype",
|
||||||
|
"MessageRole",
|
||||||
|
"SystemContext",
|
||||||
|
"TaskComplexity",
|
||||||
|
"TaskContext",
|
||||||
|
"TaskStatus",
|
||||||
|
"ToolContext",
|
||||||
|
"ToolResultStatus",
|
||||||
|
]
|
||||||
348
backend/app/services/context/types/base.py
Normal file
348
backend/app/services/context/types/base.py
Normal file
@@ -0,0 +1,348 @@
|
|||||||
|
"""
|
||||||
|
Base Context Types and Enums.
|
||||||
|
|
||||||
|
Provides the foundation for all context types used in
|
||||||
|
the Context Management Engine.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
|
||||||
|
class ContextType(str, Enum):
|
||||||
|
"""
|
||||||
|
Types of context that can be assembled.
|
||||||
|
|
||||||
|
Each type has specific handling, formatting, and
|
||||||
|
budget allocation rules.
|
||||||
|
"""
|
||||||
|
|
||||||
|
SYSTEM = "system"
|
||||||
|
TASK = "task"
|
||||||
|
KNOWLEDGE = "knowledge"
|
||||||
|
CONVERSATION = "conversation"
|
||||||
|
TOOL = "tool"
|
||||||
|
MEMORY = "memory" # Agent memory (working, episodic, semantic, procedural)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_string(cls, value: str) -> "ContextType":
|
||||||
|
"""
|
||||||
|
Convert string to ContextType.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: String value
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ContextType enum value
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If value is not a valid context type
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return cls(value.lower())
|
||||||
|
except ValueError:
|
||||||
|
valid = ", ".join(t.value for t in cls)
|
||||||
|
raise ValueError(f"Invalid context type '{value}'. Valid types: {valid}")
|
||||||
|
|
||||||
|
|
||||||
|
class ContextPriority(int, Enum):
|
||||||
|
"""
|
||||||
|
Priority levels for context ordering.
|
||||||
|
|
||||||
|
Higher values indicate higher priority.
|
||||||
|
"""
|
||||||
|
|
||||||
|
LOWEST = 0
|
||||||
|
LOW = 25
|
||||||
|
NORMAL = 50
|
||||||
|
HIGH = 75
|
||||||
|
HIGHEST = 100
|
||||||
|
CRITICAL = 150 # Never omit
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_int(cls, value: int) -> "ContextPriority":
|
||||||
|
"""
|
||||||
|
Get closest priority level for an integer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: Integer priority value
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Closest ContextPriority enum value
|
||||||
|
"""
|
||||||
|
priorities = sorted(cls, key=lambda p: p.value)
|
||||||
|
for priority in reversed(priorities):
|
||||||
|
if value >= priority.value:
|
||||||
|
return priority
|
||||||
|
return cls.LOWEST
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(eq=False)
|
||||||
|
class BaseContext(ABC):
|
||||||
|
"""
|
||||||
|
Abstract base class for all context types.
|
||||||
|
|
||||||
|
Provides common fields and methods for context handling,
|
||||||
|
scoring, and serialization.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Required fields
|
||||||
|
content: str
|
||||||
|
source: str
|
||||||
|
|
||||||
|
# Optional fields with defaults
|
||||||
|
id: str = field(default_factory=lambda: str(uuid4()))
|
||||||
|
timestamp: datetime = field(default_factory=lambda: datetime.now(UTC))
|
||||||
|
priority: int = field(default=ContextPriority.NORMAL.value)
|
||||||
|
metadata: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
# Computed/cached fields
|
||||||
|
_token_count: int | None = field(default=None, repr=False)
|
||||||
|
_score: float | None = field(default=None, repr=False)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def token_count(self) -> int | None:
|
||||||
|
"""Get cached token count (None if not counted yet)."""
|
||||||
|
return self._token_count
|
||||||
|
|
||||||
|
@token_count.setter
|
||||||
|
def token_count(self, value: int) -> None:
|
||||||
|
"""Set token count."""
|
||||||
|
self._token_count = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def score(self) -> float | None:
|
||||||
|
"""Get cached score (None if not scored yet)."""
|
||||||
|
return self._score
|
||||||
|
|
||||||
|
@score.setter
|
||||||
|
def score(self, value: float) -> None:
|
||||||
|
"""Set score (clamped to 0.0-1.0)."""
|
||||||
|
self._score = max(0.0, min(1.0, value))
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_type(self) -> ContextType:
|
||||||
|
"""
|
||||||
|
Get the type of this context.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ContextType enum value
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_age_seconds(self) -> float:
|
||||||
|
"""
|
||||||
|
Get age of context in seconds.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Age in seconds since creation
|
||||||
|
"""
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
delta = now - self.timestamp
|
||||||
|
return delta.total_seconds()
|
||||||
|
|
||||||
|
def get_age_hours(self) -> float:
|
||||||
|
"""
|
||||||
|
Get age of context in hours.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Age in hours since creation
|
||||||
|
"""
|
||||||
|
return self.get_age_seconds() / 3600
|
||||||
|
|
||||||
|
def is_stale(self, max_age_hours: float = 168.0) -> bool:
|
||||||
|
"""
|
||||||
|
Check if context is stale.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_age_hours: Maximum age before considered stale (default 7 days)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if context is older than max_age_hours
|
||||||
|
"""
|
||||||
|
return self.get_age_hours() > max_age_hours
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Convert context to dictionary for serialization.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary representation
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"id": self.id,
|
||||||
|
"type": self.get_type().value,
|
||||||
|
"content": self.content,
|
||||||
|
"source": self.source,
|
||||||
|
"timestamp": self.timestamp.isoformat(),
|
||||||
|
"priority": self.priority,
|
||||||
|
"metadata": self.metadata,
|
||||||
|
"token_count": self._token_count,
|
||||||
|
"score": self._score,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: dict[str, Any]) -> "BaseContext":
|
||||||
|
"""
|
||||||
|
Create context from dictionary.
|
||||||
|
|
||||||
|
Note: Subclasses should override this to return correct type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Dictionary with context data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Context instance
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Subclasses must implement from_dict")
|
||||||
|
|
||||||
|
def truncate(self, max_tokens: int, suffix: str = "... [truncated]") -> str:
|
||||||
|
"""
|
||||||
|
Truncate content to fit within token limit.
|
||||||
|
|
||||||
|
This is a rough estimation based on characters.
|
||||||
|
For accurate truncation, use the TokenCalculator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_tokens: Maximum tokens allowed
|
||||||
|
suffix: Suffix to append when truncated
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Truncated content
|
||||||
|
"""
|
||||||
|
if self._token_count is None or self._token_count <= max_tokens:
|
||||||
|
return self.content
|
||||||
|
|
||||||
|
# Rough estimation: 4 chars per token on average
|
||||||
|
estimated_chars = max_tokens * 4
|
||||||
|
suffix_chars = len(suffix)
|
||||||
|
|
||||||
|
if len(self.content) <= estimated_chars:
|
||||||
|
return self.content
|
||||||
|
|
||||||
|
truncated = self.content[: estimated_chars - suffix_chars]
|
||||||
|
# Try to break at word boundary
|
||||||
|
last_space = truncated.rfind(" ")
|
||||||
|
if last_space > estimated_chars * 0.8:
|
||||||
|
truncated = truncated[:last_space]
|
||||||
|
|
||||||
|
return truncated + suffix
|
||||||
|
|
||||||
|
def __hash__(self) -> int:
|
||||||
|
"""Hash based on ID for set/dict usage."""
|
||||||
|
return hash(self.id)
|
||||||
|
|
||||||
|
def __eq__(self, other: object) -> bool:
|
||||||
|
"""Equality based on ID."""
|
||||||
|
if not isinstance(other, BaseContext):
|
||||||
|
return False
|
||||||
|
return self.id == other.id
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AssembledContext:
|
||||||
|
"""
|
||||||
|
Result of context assembly.
|
||||||
|
|
||||||
|
Contains the final formatted context ready for LLM consumption,
|
||||||
|
along with metadata about the assembly process.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Main content
|
||||||
|
content: str
|
||||||
|
total_tokens: int
|
||||||
|
|
||||||
|
# Assembly metadata
|
||||||
|
context_count: int
|
||||||
|
excluded_count: int = 0
|
||||||
|
assembly_time_ms: float = 0.0
|
||||||
|
model: str = ""
|
||||||
|
|
||||||
|
# Included contexts (optional - for inspection)
|
||||||
|
contexts: list["BaseContext"] = field(default_factory=list)
|
||||||
|
|
||||||
|
# Additional metadata from assembly
|
||||||
|
metadata: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
# Budget tracking
|
||||||
|
budget_total: int = 0
|
||||||
|
budget_used: int = 0
|
||||||
|
|
||||||
|
# Context breakdown
|
||||||
|
by_type: dict[str, int] = field(default_factory=dict)
|
||||||
|
|
||||||
|
# Cache info
|
||||||
|
cache_hit: bool = False
|
||||||
|
cache_key: str | None = None
|
||||||
|
|
||||||
|
# Aliases for backward compatibility
|
||||||
|
@property
|
||||||
|
def token_count(self) -> int:
|
||||||
|
"""Alias for total_tokens."""
|
||||||
|
return self.total_tokens
|
||||||
|
|
||||||
|
@property
|
||||||
|
def contexts_included(self) -> int:
|
||||||
|
"""Alias for context_count."""
|
||||||
|
return self.context_count
|
||||||
|
|
||||||
|
@property
|
||||||
|
def contexts_excluded(self) -> int:
|
||||||
|
"""Alias for excluded_count."""
|
||||||
|
return self.excluded_count
|
||||||
|
|
||||||
|
@property
|
||||||
|
def budget_utilization(self) -> float:
|
||||||
|
"""Get budget utilization percentage."""
|
||||||
|
if self.budget_total == 0:
|
||||||
|
return 0.0
|
||||||
|
return self.budget_used / self.budget_total
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert to dictionary."""
|
||||||
|
return {
|
||||||
|
"content": self.content,
|
||||||
|
"total_tokens": self.total_tokens,
|
||||||
|
"context_count": self.context_count,
|
||||||
|
"excluded_count": self.excluded_count,
|
||||||
|
"assembly_time_ms": round(self.assembly_time_ms, 2),
|
||||||
|
"model": self.model,
|
||||||
|
"metadata": self.metadata,
|
||||||
|
"budget_total": self.budget_total,
|
||||||
|
"budget_used": self.budget_used,
|
||||||
|
"budget_utilization": round(self.budget_utilization, 3),
|
||||||
|
"by_type": self.by_type,
|
||||||
|
"cache_hit": self.cache_hit,
|
||||||
|
"cache_key": self.cache_key,
|
||||||
|
}
|
||||||
|
|
||||||
|
def to_json(self) -> str:
|
||||||
|
"""Convert to JSON string."""
|
||||||
|
import json
|
||||||
|
|
||||||
|
return json.dumps(self.to_dict())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_json(cls, json_str: str) -> "AssembledContext":
|
||||||
|
"""Create from JSON string."""
|
||||||
|
import json
|
||||||
|
|
||||||
|
data = json.loads(json_str)
|
||||||
|
return cls(
|
||||||
|
content=data["content"],
|
||||||
|
total_tokens=data["total_tokens"],
|
||||||
|
context_count=data["context_count"],
|
||||||
|
excluded_count=data.get("excluded_count", 0),
|
||||||
|
assembly_time_ms=data.get("assembly_time_ms", 0.0),
|
||||||
|
model=data.get("model", ""),
|
||||||
|
metadata=data.get("metadata", {}),
|
||||||
|
budget_total=data.get("budget_total", 0),
|
||||||
|
budget_used=data.get("budget_used", 0),
|
||||||
|
by_type=data.get("by_type", {}),
|
||||||
|
cache_hit=data.get("cache_hit", False),
|
||||||
|
cache_key=data.get("cache_key"),
|
||||||
|
)
|
||||||
182
backend/app/services/context/types/conversation.py
Normal file
182
backend/app/services/context/types/conversation.py
Normal file
@@ -0,0 +1,182 @@
|
|||||||
|
"""
|
||||||
|
Conversation Context Type.
|
||||||
|
|
||||||
|
Represents conversation history for context continuity.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .base import BaseContext, ContextPriority, ContextType
|
||||||
|
|
||||||
|
|
||||||
|
class MessageRole(str, Enum):
|
||||||
|
"""Roles for conversation messages."""
|
||||||
|
|
||||||
|
USER = "user"
|
||||||
|
ASSISTANT = "assistant"
|
||||||
|
SYSTEM = "system"
|
||||||
|
TOOL = "tool"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_string(cls, value: str) -> "MessageRole":
|
||||||
|
"""Convert string to MessageRole."""
|
||||||
|
try:
|
||||||
|
return cls(value.lower())
|
||||||
|
except ValueError:
|
||||||
|
# Default to user for unknown roles
|
||||||
|
return cls.USER
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(eq=False)
|
||||||
|
class ConversationContext(BaseContext):
|
||||||
|
"""
|
||||||
|
Context from conversation history.
|
||||||
|
|
||||||
|
Represents a single turn in the conversation,
|
||||||
|
including user messages, assistant responses,
|
||||||
|
and tool results.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Conversation-specific fields
|
||||||
|
role: MessageRole = field(default=MessageRole.USER)
|
||||||
|
turn_index: int = field(default=0)
|
||||||
|
session_id: str | None = field(default=None)
|
||||||
|
parent_message_id: str | None = field(default=None)
|
||||||
|
|
||||||
|
def get_type(self) -> ContextType:
|
||||||
|
"""Return CONVERSATION context type."""
|
||||||
|
return ContextType.CONVERSATION
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert to dictionary with conversation-specific fields."""
|
||||||
|
base = super().to_dict()
|
||||||
|
base.update(
|
||||||
|
{
|
||||||
|
"role": self.role.value,
|
||||||
|
"turn_index": self.turn_index,
|
||||||
|
"session_id": self.session_id,
|
||||||
|
"parent_message_id": self.parent_message_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return base
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: dict[str, Any]) -> "ConversationContext":
|
||||||
|
"""Create ConversationContext from dictionary."""
|
||||||
|
role = data.get("role", "user")
|
||||||
|
if isinstance(role, str):
|
||||||
|
role = MessageRole.from_string(role)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
id=data.get("id", ""),
|
||||||
|
content=data["content"],
|
||||||
|
source=data.get("source", "conversation"),
|
||||||
|
timestamp=datetime.fromisoformat(data["timestamp"])
|
||||||
|
if isinstance(data.get("timestamp"), str)
|
||||||
|
else data.get("timestamp", datetime.now(UTC)),
|
||||||
|
priority=data.get("priority", ContextPriority.NORMAL.value),
|
||||||
|
metadata=data.get("metadata", {}),
|
||||||
|
role=role,
|
||||||
|
turn_index=data.get("turn_index", 0),
|
||||||
|
session_id=data.get("session_id"),
|
||||||
|
parent_message_id=data.get("parent_message_id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_message(
|
||||||
|
cls,
|
||||||
|
content: str,
|
||||||
|
role: str | MessageRole,
|
||||||
|
turn_index: int = 0,
|
||||||
|
session_id: str | None = None,
|
||||||
|
timestamp: datetime | None = None,
|
||||||
|
) -> "ConversationContext":
|
||||||
|
"""
|
||||||
|
Create ConversationContext from a message.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Message content
|
||||||
|
role: Message role (user, assistant, system, tool)
|
||||||
|
turn_index: Position in conversation
|
||||||
|
session_id: Session identifier
|
||||||
|
timestamp: Message timestamp
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ConversationContext instance
|
||||||
|
"""
|
||||||
|
if isinstance(role, str):
|
||||||
|
role = MessageRole.from_string(role)
|
||||||
|
|
||||||
|
# Recent messages have higher priority
|
||||||
|
priority = ContextPriority.NORMAL.value
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
content=content,
|
||||||
|
source="conversation",
|
||||||
|
role=role,
|
||||||
|
turn_index=turn_index,
|
||||||
|
session_id=session_id,
|
||||||
|
timestamp=timestamp or datetime.now(UTC),
|
||||||
|
priority=priority,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_history(
|
||||||
|
cls,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
session_id: str | None = None,
|
||||||
|
) -> list["ConversationContext"]:
|
||||||
|
"""
|
||||||
|
Create multiple ConversationContexts from message history.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of message dicts with 'role' and 'content'
|
||||||
|
session_id: Session identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of ConversationContext instances
|
||||||
|
"""
|
||||||
|
contexts = []
|
||||||
|
for i, msg in enumerate(messages):
|
||||||
|
ctx = cls.from_message(
|
||||||
|
content=msg.get("content", ""),
|
||||||
|
role=msg.get("role", "user"),
|
||||||
|
turn_index=i,
|
||||||
|
session_id=session_id,
|
||||||
|
timestamp=datetime.fromisoformat(msg["timestamp"])
|
||||||
|
if "timestamp" in msg
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
contexts.append(ctx)
|
||||||
|
return contexts
|
||||||
|
|
||||||
|
def is_user_message(self) -> bool:
|
||||||
|
"""Check if this is a user message."""
|
||||||
|
return self.role == MessageRole.USER
|
||||||
|
|
||||||
|
def is_assistant_message(self) -> bool:
|
||||||
|
"""Check if this is an assistant message."""
|
||||||
|
return self.role == MessageRole.ASSISTANT
|
||||||
|
|
||||||
|
def is_tool_result(self) -> bool:
|
||||||
|
"""Check if this is a tool result."""
|
||||||
|
return self.role == MessageRole.TOOL
|
||||||
|
|
||||||
|
def format_for_prompt(self) -> str:
|
||||||
|
"""
|
||||||
|
Format message for inclusion in prompt.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted message string
|
||||||
|
"""
|
||||||
|
role_labels = {
|
||||||
|
MessageRole.USER: "User",
|
||||||
|
MessageRole.ASSISTANT: "Assistant",
|
||||||
|
MessageRole.SYSTEM: "System",
|
||||||
|
MessageRole.TOOL: "Tool Result",
|
||||||
|
}
|
||||||
|
label = role_labels.get(self.role, "Unknown")
|
||||||
|
return f"{label}: {self.content}"
|
||||||
152
backend/app/services/context/types/knowledge.py
Normal file
152
backend/app/services/context/types/knowledge.py
Normal file
@@ -0,0 +1,152 @@
|
|||||||
|
"""
|
||||||
|
Knowledge Context Type.
|
||||||
|
|
||||||
|
Represents RAG results from the Knowledge Base MCP server.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .base import BaseContext, ContextPriority, ContextType
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(eq=False)
|
||||||
|
class KnowledgeContext(BaseContext):
|
||||||
|
"""
|
||||||
|
Context from knowledge base / RAG retrieval.
|
||||||
|
|
||||||
|
Knowledge context represents chunks retrieved from the
|
||||||
|
Knowledge Base MCP server, including:
|
||||||
|
- Code snippets
|
||||||
|
- Documentation
|
||||||
|
- Previous conversations
|
||||||
|
- External knowledge
|
||||||
|
|
||||||
|
Each chunk includes relevance scoring from the search.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Knowledge-specific fields
|
||||||
|
collection: str = field(default="default")
|
||||||
|
file_type: str | None = field(default=None)
|
||||||
|
chunk_index: int = field(default=0)
|
||||||
|
relevance_score: float = field(default=0.0)
|
||||||
|
search_query: str = field(default="")
|
||||||
|
|
||||||
|
def get_type(self) -> ContextType:
|
||||||
|
"""Return KNOWLEDGE context type."""
|
||||||
|
return ContextType.KNOWLEDGE
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert to dictionary with knowledge-specific fields."""
|
||||||
|
base = super().to_dict()
|
||||||
|
base.update(
|
||||||
|
{
|
||||||
|
"collection": self.collection,
|
||||||
|
"file_type": self.file_type,
|
||||||
|
"chunk_index": self.chunk_index,
|
||||||
|
"relevance_score": self.relevance_score,
|
||||||
|
"search_query": self.search_query,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return base
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: dict[str, Any]) -> "KnowledgeContext":
|
||||||
|
"""Create KnowledgeContext from dictionary."""
|
||||||
|
return cls(
|
||||||
|
id=data.get("id", ""),
|
||||||
|
content=data["content"],
|
||||||
|
source=data["source"],
|
||||||
|
timestamp=datetime.fromisoformat(data["timestamp"])
|
||||||
|
if isinstance(data.get("timestamp"), str)
|
||||||
|
else data.get("timestamp", datetime.now(UTC)),
|
||||||
|
priority=data.get("priority", ContextPriority.NORMAL.value),
|
||||||
|
metadata=data.get("metadata", {}),
|
||||||
|
collection=data.get("collection", "default"),
|
||||||
|
file_type=data.get("file_type"),
|
||||||
|
chunk_index=data.get("chunk_index", 0),
|
||||||
|
relevance_score=data.get("relevance_score", 0.0),
|
||||||
|
search_query=data.get("search_query", ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_search_result(
|
||||||
|
cls,
|
||||||
|
result: dict[str, Any],
|
||||||
|
query: str,
|
||||||
|
) -> "KnowledgeContext":
|
||||||
|
"""
|
||||||
|
Create KnowledgeContext from a Knowledge Base search result.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
result: Search result from Knowledge Base MCP
|
||||||
|
query: Search query used
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
KnowledgeContext instance
|
||||||
|
"""
|
||||||
|
return cls(
|
||||||
|
content=result.get("content", ""),
|
||||||
|
source=result.get("source_path", "unknown"),
|
||||||
|
collection=result.get("collection", "default"),
|
||||||
|
file_type=result.get("file_type"),
|
||||||
|
chunk_index=result.get("chunk_index", 0),
|
||||||
|
relevance_score=result.get("score", 0.0),
|
||||||
|
search_query=query,
|
||||||
|
metadata={
|
||||||
|
"chunk_id": result.get("id"),
|
||||||
|
"content_hash": result.get("content_hash"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_search_results(
|
||||||
|
cls,
|
||||||
|
results: list[dict[str, Any]],
|
||||||
|
query: str,
|
||||||
|
) -> list["KnowledgeContext"]:
|
||||||
|
"""
|
||||||
|
Create multiple KnowledgeContexts from search results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results: List of search results
|
||||||
|
query: Search query used
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of KnowledgeContext instances
|
||||||
|
"""
|
||||||
|
return [cls.from_search_result(r, query) for r in results]
|
||||||
|
|
||||||
|
def is_code(self) -> bool:
|
||||||
|
"""Check if this is code content."""
|
||||||
|
code_types = {
|
||||||
|
"python",
|
||||||
|
"javascript",
|
||||||
|
"typescript",
|
||||||
|
"go",
|
||||||
|
"rust",
|
||||||
|
"java",
|
||||||
|
"c",
|
||||||
|
"cpp",
|
||||||
|
}
|
||||||
|
return self.file_type is not None and self.file_type.lower() in code_types
|
||||||
|
|
||||||
|
def is_documentation(self) -> bool:
|
||||||
|
"""Check if this is documentation content."""
|
||||||
|
doc_types = {"markdown", "rst", "txt", "md"}
|
||||||
|
return self.file_type is not None and self.file_type.lower() in doc_types
|
||||||
|
|
||||||
|
def get_formatted_source(self) -> str:
|
||||||
|
"""
|
||||||
|
Get a formatted source string for display.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted source string
|
||||||
|
"""
|
||||||
|
parts = [self.source]
|
||||||
|
if self.file_type:
|
||||||
|
parts.append(f"({self.file_type})")
|
||||||
|
if self.collection != "default":
|
||||||
|
parts.insert(0, f"[{self.collection}]")
|
||||||
|
return " ".join(parts)
|
||||||
282
backend/app/services/context/types/memory.py
Normal file
282
backend/app/services/context/types/memory.py
Normal file
@@ -0,0 +1,282 @@
|
|||||||
|
"""
|
||||||
|
Memory Context Type.
|
||||||
|
|
||||||
|
Represents agent memory as context for LLM requests.
|
||||||
|
Includes working, episodic, semantic, and procedural memories.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .base import BaseContext, ContextPriority, ContextType
|
||||||
|
|
||||||
|
|
||||||
|
class MemorySubtype(str, Enum):
|
||||||
|
"""Types of agent memory."""
|
||||||
|
|
||||||
|
WORKING = "working" # Session-scoped temporary data
|
||||||
|
EPISODIC = "episodic" # Task history and outcomes
|
||||||
|
SEMANTIC = "semantic" # Facts and knowledge
|
||||||
|
PROCEDURAL = "procedural" # Learned procedures
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(eq=False)
|
||||||
|
class MemoryContext(BaseContext):
|
||||||
|
"""
|
||||||
|
Context from agent memory system.
|
||||||
|
|
||||||
|
Memory context represents data retrieved from the agent
|
||||||
|
memory system, including:
|
||||||
|
- Working memory: Current session state
|
||||||
|
- Episodic memory: Past task experiences
|
||||||
|
- Semantic memory: Learned facts and knowledge
|
||||||
|
- Procedural memory: Known procedures and workflows
|
||||||
|
|
||||||
|
Each memory item includes relevance scoring from search.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Memory-specific fields
|
||||||
|
memory_subtype: MemorySubtype = field(default=MemorySubtype.EPISODIC)
|
||||||
|
memory_id: str | None = field(default=None)
|
||||||
|
relevance_score: float = field(default=0.0)
|
||||||
|
importance: float = field(default=0.5)
|
||||||
|
search_query: str = field(default="")
|
||||||
|
|
||||||
|
# Type-specific fields (populated based on memory_subtype)
|
||||||
|
key: str | None = field(default=None) # For working memory
|
||||||
|
task_type: str | None = field(default=None) # For episodic
|
||||||
|
outcome: str | None = field(default=None) # For episodic
|
||||||
|
subject: str | None = field(default=None) # For semantic
|
||||||
|
predicate: str | None = field(default=None) # For semantic
|
||||||
|
object_value: str | None = field(default=None) # For semantic
|
||||||
|
trigger: str | None = field(default=None) # For procedural
|
||||||
|
success_rate: float | None = field(default=None) # For procedural
|
||||||
|
|
||||||
|
def get_type(self) -> ContextType:
|
||||||
|
"""Return MEMORY context type."""
|
||||||
|
return ContextType.MEMORY
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert to dictionary with memory-specific fields."""
|
||||||
|
base = super().to_dict()
|
||||||
|
base.update(
|
||||||
|
{
|
||||||
|
"memory_subtype": self.memory_subtype.value,
|
||||||
|
"memory_id": self.memory_id,
|
||||||
|
"relevance_score": self.relevance_score,
|
||||||
|
"importance": self.importance,
|
||||||
|
"search_query": self.search_query,
|
||||||
|
"key": self.key,
|
||||||
|
"task_type": self.task_type,
|
||||||
|
"outcome": self.outcome,
|
||||||
|
"subject": self.subject,
|
||||||
|
"predicate": self.predicate,
|
||||||
|
"object_value": self.object_value,
|
||||||
|
"trigger": self.trigger,
|
||||||
|
"success_rate": self.success_rate,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return base
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: dict[str, Any]) -> "MemoryContext":
|
||||||
|
"""Create MemoryContext from dictionary."""
|
||||||
|
return cls(
|
||||||
|
id=data.get("id", ""),
|
||||||
|
content=data["content"],
|
||||||
|
source=data["source"],
|
||||||
|
timestamp=datetime.fromisoformat(data["timestamp"])
|
||||||
|
if isinstance(data.get("timestamp"), str)
|
||||||
|
else data.get("timestamp", datetime.now(UTC)),
|
||||||
|
priority=data.get("priority", ContextPriority.NORMAL.value),
|
||||||
|
metadata=data.get("metadata", {}),
|
||||||
|
memory_subtype=MemorySubtype(data.get("memory_subtype", "episodic")),
|
||||||
|
memory_id=data.get("memory_id"),
|
||||||
|
relevance_score=data.get("relevance_score", 0.0),
|
||||||
|
importance=data.get("importance", 0.5),
|
||||||
|
search_query=data.get("search_query", ""),
|
||||||
|
key=data.get("key"),
|
||||||
|
task_type=data.get("task_type"),
|
||||||
|
outcome=data.get("outcome"),
|
||||||
|
subject=data.get("subject"),
|
||||||
|
predicate=data.get("predicate"),
|
||||||
|
object_value=data.get("object_value"),
|
||||||
|
trigger=data.get("trigger"),
|
||||||
|
success_rate=data.get("success_rate"),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_working_memory(
|
||||||
|
cls,
|
||||||
|
key: str,
|
||||||
|
value: Any,
|
||||||
|
source: str = "working_memory",
|
||||||
|
query: str = "",
|
||||||
|
) -> "MemoryContext":
|
||||||
|
"""
|
||||||
|
Create MemoryContext from working memory entry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Working memory key
|
||||||
|
value: Value stored at key
|
||||||
|
source: Source identifier
|
||||||
|
query: Search query used
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MemoryContext instance
|
||||||
|
"""
|
||||||
|
return cls(
|
||||||
|
content=str(value),
|
||||||
|
source=source,
|
||||||
|
memory_subtype=MemorySubtype.WORKING,
|
||||||
|
key=key,
|
||||||
|
relevance_score=1.0, # Working memory is always relevant
|
||||||
|
importance=0.8, # Higher importance for current session state
|
||||||
|
search_query=query,
|
||||||
|
priority=ContextPriority.HIGH.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_episodic_memory(
|
||||||
|
cls,
|
||||||
|
episode: Any,
|
||||||
|
query: str = "",
|
||||||
|
) -> "MemoryContext":
|
||||||
|
"""
|
||||||
|
Create MemoryContext from episodic memory episode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
episode: Episode object from episodic memory
|
||||||
|
query: Search query used
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MemoryContext instance
|
||||||
|
"""
|
||||||
|
outcome_val = None
|
||||||
|
if hasattr(episode, "outcome") and episode.outcome:
|
||||||
|
outcome_val = (
|
||||||
|
episode.outcome.value
|
||||||
|
if hasattr(episode.outcome, "value")
|
||||||
|
else str(episode.outcome)
|
||||||
|
)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
content=episode.task_description,
|
||||||
|
source=f"episodic:{episode.id}",
|
||||||
|
memory_subtype=MemorySubtype.EPISODIC,
|
||||||
|
memory_id=str(episode.id),
|
||||||
|
relevance_score=getattr(episode, "importance_score", 0.5),
|
||||||
|
importance=getattr(episode, "importance_score", 0.5),
|
||||||
|
search_query=query,
|
||||||
|
task_type=getattr(episode, "task_type", None),
|
||||||
|
outcome=outcome_val,
|
||||||
|
metadata={
|
||||||
|
"session_id": getattr(episode, "session_id", None),
|
||||||
|
"occurred_at": episode.occurred_at.isoformat()
|
||||||
|
if hasattr(episode, "occurred_at") and episode.occurred_at
|
||||||
|
else None,
|
||||||
|
"lessons_learned": getattr(episode, "lessons_learned", []),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_semantic_memory(
|
||||||
|
cls,
|
||||||
|
fact: Any,
|
||||||
|
query: str = "",
|
||||||
|
) -> "MemoryContext":
|
||||||
|
"""
|
||||||
|
Create MemoryContext from semantic memory fact.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fact: Fact object from semantic memory
|
||||||
|
query: Search query used
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MemoryContext instance
|
||||||
|
"""
|
||||||
|
triple = f"{fact.subject} {fact.predicate} {fact.object}"
|
||||||
|
return cls(
|
||||||
|
content=triple,
|
||||||
|
source=f"semantic:{fact.id}",
|
||||||
|
memory_subtype=MemorySubtype.SEMANTIC,
|
||||||
|
memory_id=str(fact.id),
|
||||||
|
relevance_score=getattr(fact, "confidence", 0.5),
|
||||||
|
importance=getattr(fact, "confidence", 0.5),
|
||||||
|
search_query=query,
|
||||||
|
subject=fact.subject,
|
||||||
|
predicate=fact.predicate,
|
||||||
|
object_value=fact.object,
|
||||||
|
priority=ContextPriority.NORMAL.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_procedural_memory(
|
||||||
|
cls,
|
||||||
|
procedure: Any,
|
||||||
|
query: str = "",
|
||||||
|
) -> "MemoryContext":
|
||||||
|
"""
|
||||||
|
Create MemoryContext from procedural memory procedure.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
procedure: Procedure object from procedural memory
|
||||||
|
query: Search query used
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MemoryContext instance
|
||||||
|
"""
|
||||||
|
# Format steps as content
|
||||||
|
steps = getattr(procedure, "steps", [])
|
||||||
|
steps_content = "\n".join(
|
||||||
|
f" {i + 1}. {step.get('action', step) if isinstance(step, dict) else step}"
|
||||||
|
for i, step in enumerate(steps)
|
||||||
|
)
|
||||||
|
content = f"Procedure: {procedure.name}\nTrigger: {procedure.trigger_pattern}\nSteps:\n{steps_content}"
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
content=content,
|
||||||
|
source=f"procedural:{procedure.id}",
|
||||||
|
memory_subtype=MemorySubtype.PROCEDURAL,
|
||||||
|
memory_id=str(procedure.id),
|
||||||
|
relevance_score=getattr(procedure, "success_rate", 0.5),
|
||||||
|
importance=0.7, # Procedures are moderately important
|
||||||
|
search_query=query,
|
||||||
|
trigger=procedure.trigger_pattern,
|
||||||
|
success_rate=getattr(procedure, "success_rate", None),
|
||||||
|
metadata={
|
||||||
|
"steps_count": len(steps),
|
||||||
|
"execution_count": getattr(procedure, "success_count", 0)
|
||||||
|
+ getattr(procedure, "failure_count", 0),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def is_working_memory(self) -> bool:
|
||||||
|
"""Check if this is working memory."""
|
||||||
|
return self.memory_subtype == MemorySubtype.WORKING
|
||||||
|
|
||||||
|
def is_episodic_memory(self) -> bool:
|
||||||
|
"""Check if this is episodic memory."""
|
||||||
|
return self.memory_subtype == MemorySubtype.EPISODIC
|
||||||
|
|
||||||
|
def is_semantic_memory(self) -> bool:
|
||||||
|
"""Check if this is semantic memory."""
|
||||||
|
return self.memory_subtype == MemorySubtype.SEMANTIC
|
||||||
|
|
||||||
|
def is_procedural_memory(self) -> bool:
|
||||||
|
"""Check if this is procedural memory."""
|
||||||
|
return self.memory_subtype == MemorySubtype.PROCEDURAL
|
||||||
|
|
||||||
|
def get_formatted_source(self) -> str:
|
||||||
|
"""
|
||||||
|
Get a formatted source string for display.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted source string
|
||||||
|
"""
|
||||||
|
parts = [f"[{self.memory_subtype.value}]", self.source]
|
||||||
|
if self.memory_id:
|
||||||
|
parts.append(f"({self.memory_id[:8]}...)")
|
||||||
|
return " ".join(parts)
|
||||||
138
backend/app/services/context/types/system.py
Normal file
138
backend/app/services/context/types/system.py
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
"""
|
||||||
|
System Context Type.
|
||||||
|
|
||||||
|
Represents system prompts, instructions, and agent personas.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .base import BaseContext, ContextPriority, ContextType
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(eq=False)
|
||||||
|
class SystemContext(BaseContext):
|
||||||
|
"""
|
||||||
|
Context for system prompts and instructions.
|
||||||
|
|
||||||
|
System context typically includes:
|
||||||
|
- Agent persona and role definitions
|
||||||
|
- Behavioral instructions
|
||||||
|
- Safety guidelines
|
||||||
|
- Output format requirements
|
||||||
|
|
||||||
|
System context is usually high priority and should
|
||||||
|
rarely be truncated or omitted.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# System context specific fields
|
||||||
|
role: str = field(default="assistant")
|
||||||
|
instructions_type: str = field(default="general")
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
"""Set high priority for system context."""
|
||||||
|
# System context defaults to high priority
|
||||||
|
if self.priority == ContextPriority.NORMAL.value:
|
||||||
|
self.priority = ContextPriority.HIGH.value
|
||||||
|
|
||||||
|
def get_type(self) -> ContextType:
|
||||||
|
"""Return SYSTEM context type."""
|
||||||
|
return ContextType.SYSTEM
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert to dictionary with system-specific fields."""
|
||||||
|
base = super().to_dict()
|
||||||
|
base.update(
|
||||||
|
{
|
||||||
|
"role": self.role,
|
||||||
|
"instructions_type": self.instructions_type,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return base
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: dict[str, Any]) -> "SystemContext":
|
||||||
|
"""Create SystemContext from dictionary."""
|
||||||
|
return cls(
|
||||||
|
id=data.get("id", ""),
|
||||||
|
content=data["content"],
|
||||||
|
source=data["source"],
|
||||||
|
timestamp=datetime.fromisoformat(data["timestamp"])
|
||||||
|
if isinstance(data.get("timestamp"), str)
|
||||||
|
else data.get("timestamp", datetime.now(UTC)),
|
||||||
|
priority=data.get("priority", ContextPriority.HIGH.value),
|
||||||
|
metadata=data.get("metadata", {}),
|
||||||
|
role=data.get("role", "assistant"),
|
||||||
|
instructions_type=data.get("instructions_type", "general"),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_persona(
|
||||||
|
cls,
|
||||||
|
name: str,
|
||||||
|
description: str,
|
||||||
|
capabilities: list[str] | None = None,
|
||||||
|
constraints: list[str] | None = None,
|
||||||
|
) -> "SystemContext":
|
||||||
|
"""
|
||||||
|
Create a persona system context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Agent name/role
|
||||||
|
description: Role description
|
||||||
|
capabilities: List of things the agent can do
|
||||||
|
constraints: List of limitations
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SystemContext with formatted persona
|
||||||
|
"""
|
||||||
|
parts = [f"You are {name}.", "", description]
|
||||||
|
|
||||||
|
if capabilities:
|
||||||
|
parts.append("")
|
||||||
|
parts.append("You can:")
|
||||||
|
for cap in capabilities:
|
||||||
|
parts.append(f"- {cap}")
|
||||||
|
|
||||||
|
if constraints:
|
||||||
|
parts.append("")
|
||||||
|
parts.append("You must not:")
|
||||||
|
for constraint in constraints:
|
||||||
|
parts.append(f"- {constraint}")
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
content="\n".join(parts),
|
||||||
|
source="persona_builder",
|
||||||
|
role=name.lower().replace(" ", "_"),
|
||||||
|
instructions_type="persona",
|
||||||
|
priority=ContextPriority.HIGHEST.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_instructions(
|
||||||
|
cls,
|
||||||
|
instructions: str | list[str],
|
||||||
|
source: str = "instructions",
|
||||||
|
) -> "SystemContext":
|
||||||
|
"""
|
||||||
|
Create an instructions system context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
instructions: Instructions string or list of instruction strings
|
||||||
|
source: Source identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SystemContext with instructions
|
||||||
|
"""
|
||||||
|
if isinstance(instructions, list):
|
||||||
|
content = "\n".join(f"- {inst}" for inst in instructions)
|
||||||
|
else:
|
||||||
|
content = instructions
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
content=content,
|
||||||
|
source=source,
|
||||||
|
instructions_type="instructions",
|
||||||
|
priority=ContextPriority.HIGH.value,
|
||||||
|
)
|
||||||
193
backend/app/services/context/types/task.py
Normal file
193
backend/app/services/context/types/task.py
Normal file
@@ -0,0 +1,193 @@
|
|||||||
|
"""
|
||||||
|
Task Context Type.
|
||||||
|
|
||||||
|
Represents the current task or objective for the agent.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .base import BaseContext, ContextPriority, ContextType
|
||||||
|
|
||||||
|
|
||||||
|
class TaskStatus(str, Enum):
|
||||||
|
"""Status of a task."""
|
||||||
|
|
||||||
|
PENDING = "pending"
|
||||||
|
IN_PROGRESS = "in_progress"
|
||||||
|
BLOCKED = "blocked"
|
||||||
|
COMPLETED = "completed"
|
||||||
|
FAILED = "failed"
|
||||||
|
|
||||||
|
|
||||||
|
class TaskComplexity(str, Enum):
|
||||||
|
"""Complexity level of a task."""
|
||||||
|
|
||||||
|
TRIVIAL = "trivial"
|
||||||
|
SIMPLE = "simple"
|
||||||
|
MODERATE = "moderate"
|
||||||
|
COMPLEX = "complex"
|
||||||
|
VERY_COMPLEX = "very_complex"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(eq=False)
|
||||||
|
class TaskContext(BaseContext):
|
||||||
|
"""
|
||||||
|
Context for the current task or objective.
|
||||||
|
|
||||||
|
Task context provides information about what the agent
|
||||||
|
should accomplish, including:
|
||||||
|
- Task description and goals
|
||||||
|
- Acceptance criteria
|
||||||
|
- Constraints and requirements
|
||||||
|
- Related issue/ticket information
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Task-specific fields
|
||||||
|
title: str = field(default="")
|
||||||
|
status: TaskStatus = field(default=TaskStatus.PENDING)
|
||||||
|
complexity: TaskComplexity = field(default=TaskComplexity.MODERATE)
|
||||||
|
issue_id: str | None = field(default=None)
|
||||||
|
project_id: str | None = field(default=None)
|
||||||
|
acceptance_criteria: list[str] = field(default_factory=list)
|
||||||
|
constraints: list[str] = field(default_factory=list)
|
||||||
|
parent_task_id: str | None = field(default=None)
|
||||||
|
|
||||||
|
# Note: TaskContext should typically have HIGH priority,
|
||||||
|
# but we don't auto-promote to allow explicit priority setting.
|
||||||
|
# Use TaskContext.create() for default HIGH priority behavior.
|
||||||
|
|
||||||
|
def get_type(self) -> ContextType:
|
||||||
|
"""Return TASK context type."""
|
||||||
|
return ContextType.TASK
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert to dictionary with task-specific fields."""
|
||||||
|
base = super().to_dict()
|
||||||
|
base.update(
|
||||||
|
{
|
||||||
|
"title": self.title,
|
||||||
|
"status": self.status.value,
|
||||||
|
"complexity": self.complexity.value,
|
||||||
|
"issue_id": self.issue_id,
|
||||||
|
"project_id": self.project_id,
|
||||||
|
"acceptance_criteria": self.acceptance_criteria,
|
||||||
|
"constraints": self.constraints,
|
||||||
|
"parent_task_id": self.parent_task_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return base
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: dict[str, Any]) -> "TaskContext":
|
||||||
|
"""Create TaskContext from dictionary."""
|
||||||
|
status = data.get("status", "pending")
|
||||||
|
if isinstance(status, str):
|
||||||
|
status = TaskStatus(status)
|
||||||
|
|
||||||
|
complexity = data.get("complexity", "moderate")
|
||||||
|
if isinstance(complexity, str):
|
||||||
|
complexity = TaskComplexity(complexity)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
id=data.get("id", ""),
|
||||||
|
content=data["content"],
|
||||||
|
source=data.get("source", "task"),
|
||||||
|
timestamp=datetime.fromisoformat(data["timestamp"])
|
||||||
|
if isinstance(data.get("timestamp"), str)
|
||||||
|
else data.get("timestamp", datetime.now(UTC)),
|
||||||
|
priority=data.get("priority", ContextPriority.HIGH.value),
|
||||||
|
metadata=data.get("metadata", {}),
|
||||||
|
title=data.get("title", ""),
|
||||||
|
status=status,
|
||||||
|
complexity=complexity,
|
||||||
|
issue_id=data.get("issue_id"),
|
||||||
|
project_id=data.get("project_id"),
|
||||||
|
acceptance_criteria=data.get("acceptance_criteria", []),
|
||||||
|
constraints=data.get("constraints", []),
|
||||||
|
parent_task_id=data.get("parent_task_id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(
|
||||||
|
cls,
|
||||||
|
title: str,
|
||||||
|
description: str,
|
||||||
|
acceptance_criteria: list[str] | None = None,
|
||||||
|
constraints: list[str] | None = None,
|
||||||
|
issue_id: str | None = None,
|
||||||
|
project_id: str | None = None,
|
||||||
|
complexity: TaskComplexity | str = TaskComplexity.MODERATE,
|
||||||
|
) -> "TaskContext":
|
||||||
|
"""
|
||||||
|
Create a task context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
title: Task title
|
||||||
|
description: Task description
|
||||||
|
acceptance_criteria: List of acceptance criteria
|
||||||
|
constraints: List of constraints
|
||||||
|
issue_id: Related issue ID
|
||||||
|
project_id: Project ID
|
||||||
|
complexity: Task complexity
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TaskContext instance
|
||||||
|
"""
|
||||||
|
if isinstance(complexity, str):
|
||||||
|
complexity = TaskComplexity(complexity)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
content=description,
|
||||||
|
source=f"task:{issue_id}" if issue_id else "task",
|
||||||
|
title=title,
|
||||||
|
status=TaskStatus.IN_PROGRESS,
|
||||||
|
complexity=complexity,
|
||||||
|
issue_id=issue_id,
|
||||||
|
project_id=project_id,
|
||||||
|
acceptance_criteria=acceptance_criteria or [],
|
||||||
|
constraints=constraints or [],
|
||||||
|
)
|
||||||
|
|
||||||
|
def format_for_prompt(self) -> str:
|
||||||
|
"""
|
||||||
|
Format task for inclusion in prompt.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted task string
|
||||||
|
"""
|
||||||
|
parts = []
|
||||||
|
|
||||||
|
if self.title:
|
||||||
|
parts.append(f"Task: {self.title}")
|
||||||
|
parts.append("")
|
||||||
|
|
||||||
|
parts.append(self.content)
|
||||||
|
|
||||||
|
if self.acceptance_criteria:
|
||||||
|
parts.append("")
|
||||||
|
parts.append("Acceptance Criteria:")
|
||||||
|
for criterion in self.acceptance_criteria:
|
||||||
|
parts.append(f"- {criterion}")
|
||||||
|
|
||||||
|
if self.constraints:
|
||||||
|
parts.append("")
|
||||||
|
parts.append("Constraints:")
|
||||||
|
for constraint in self.constraints:
|
||||||
|
parts.append(f"- {constraint}")
|
||||||
|
|
||||||
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
def is_active(self) -> bool:
|
||||||
|
"""Check if task is currently active."""
|
||||||
|
return self.status in (TaskStatus.PENDING, TaskStatus.IN_PROGRESS)
|
||||||
|
|
||||||
|
def is_complete(self) -> bool:
|
||||||
|
"""Check if task is complete."""
|
||||||
|
return self.status == TaskStatus.COMPLETED
|
||||||
|
|
||||||
|
def is_blocked(self) -> bool:
|
||||||
|
"""Check if task is blocked."""
|
||||||
|
return self.status == TaskStatus.BLOCKED
|
||||||
211
backend/app/services/context/types/tool.py
Normal file
211
backend/app/services/context/types/tool.py
Normal file
@@ -0,0 +1,211 @@
|
|||||||
|
"""
|
||||||
|
Tool Context Type.
|
||||||
|
|
||||||
|
Represents available tools and recent tool execution results.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .base import BaseContext, ContextPriority, ContextType
|
||||||
|
|
||||||
|
|
||||||
|
class ToolResultStatus(str, Enum):
|
||||||
|
"""Status of a tool execution result."""
|
||||||
|
|
||||||
|
SUCCESS = "success"
|
||||||
|
ERROR = "error"
|
||||||
|
TIMEOUT = "timeout"
|
||||||
|
CANCELLED = "cancelled"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(eq=False)
|
||||||
|
class ToolContext(BaseContext):
|
||||||
|
"""
|
||||||
|
Context for tools and tool execution results.
|
||||||
|
|
||||||
|
Tool context includes:
|
||||||
|
- Tool descriptions and parameters
|
||||||
|
- Recent tool execution results
|
||||||
|
- Tool availability information
|
||||||
|
|
||||||
|
This helps the LLM understand what tools are available
|
||||||
|
and what results previous tool calls produced.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Tool-specific fields
|
||||||
|
tool_name: str = field(default="")
|
||||||
|
tool_description: str = field(default="")
|
||||||
|
is_result: bool = field(default=False)
|
||||||
|
result_status: ToolResultStatus | None = field(default=None)
|
||||||
|
execution_time_ms: float | None = field(default=None)
|
||||||
|
parameters: dict[str, Any] = field(default_factory=dict)
|
||||||
|
server_name: str | None = field(default=None)
|
||||||
|
|
||||||
|
def get_type(self) -> ContextType:
|
||||||
|
"""Return TOOL context type."""
|
||||||
|
return ContextType.TOOL
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert to dictionary with tool-specific fields."""
|
||||||
|
base = super().to_dict()
|
||||||
|
base.update(
|
||||||
|
{
|
||||||
|
"tool_name": self.tool_name,
|
||||||
|
"tool_description": self.tool_description,
|
||||||
|
"is_result": self.is_result,
|
||||||
|
"result_status": self.result_status.value
|
||||||
|
if self.result_status
|
||||||
|
else None,
|
||||||
|
"execution_time_ms": self.execution_time_ms,
|
||||||
|
"parameters": self.parameters,
|
||||||
|
"server_name": self.server_name,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return base
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: dict[str, Any]) -> "ToolContext":
|
||||||
|
"""Create ToolContext from dictionary."""
|
||||||
|
result_status = data.get("result_status")
|
||||||
|
if isinstance(result_status, str):
|
||||||
|
result_status = ToolResultStatus(result_status)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
id=data.get("id", ""),
|
||||||
|
content=data["content"],
|
||||||
|
source=data.get("source", "tool"),
|
||||||
|
timestamp=datetime.fromisoformat(data["timestamp"])
|
||||||
|
if isinstance(data.get("timestamp"), str)
|
||||||
|
else data.get("timestamp", datetime.now(UTC)),
|
||||||
|
priority=data.get("priority", ContextPriority.NORMAL.value),
|
||||||
|
metadata=data.get("metadata", {}),
|
||||||
|
tool_name=data.get("tool_name", ""),
|
||||||
|
tool_description=data.get("tool_description", ""),
|
||||||
|
is_result=data.get("is_result", False),
|
||||||
|
result_status=result_status,
|
||||||
|
execution_time_ms=data.get("execution_time_ms"),
|
||||||
|
parameters=data.get("parameters", {}),
|
||||||
|
server_name=data.get("server_name"),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_tool_definition(
|
||||||
|
cls,
|
||||||
|
name: str,
|
||||||
|
description: str,
|
||||||
|
parameters: dict[str, Any] | None = None,
|
||||||
|
server_name: str | None = None,
|
||||||
|
) -> "ToolContext":
|
||||||
|
"""
|
||||||
|
Create a ToolContext from a tool definition.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Tool name
|
||||||
|
description: Tool description
|
||||||
|
parameters: Tool parameter schema
|
||||||
|
server_name: MCP server name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ToolContext instance
|
||||||
|
"""
|
||||||
|
# Format content as tool documentation
|
||||||
|
content_parts = [f"Tool: {name}", "", description]
|
||||||
|
|
||||||
|
if parameters:
|
||||||
|
content_parts.append("")
|
||||||
|
content_parts.append("Parameters:")
|
||||||
|
for param_name, param_info in parameters.items():
|
||||||
|
param_type = param_info.get("type", "any")
|
||||||
|
param_desc = param_info.get("description", "")
|
||||||
|
required = param_info.get("required", False)
|
||||||
|
req_marker = " (required)" if required else ""
|
||||||
|
content_parts.append(f" - {param_name}: {param_type}{req_marker}")
|
||||||
|
if param_desc:
|
||||||
|
content_parts.append(f" {param_desc}")
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
content="\n".join(content_parts),
|
||||||
|
source=f"tool:{server_name}:{name}" if server_name else f"tool:{name}",
|
||||||
|
tool_name=name,
|
||||||
|
tool_description=description,
|
||||||
|
is_result=False,
|
||||||
|
parameters=parameters or {},
|
||||||
|
server_name=server_name,
|
||||||
|
priority=ContextPriority.LOW.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_tool_result(
|
||||||
|
cls,
|
||||||
|
tool_name: str,
|
||||||
|
result: Any,
|
||||||
|
status: ToolResultStatus = ToolResultStatus.SUCCESS,
|
||||||
|
execution_time_ms: float | None = None,
|
||||||
|
parameters: dict[str, Any] | None = None,
|
||||||
|
server_name: str | None = None,
|
||||||
|
) -> "ToolContext":
|
||||||
|
"""
|
||||||
|
Create a ToolContext from a tool execution result.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: Name of the tool that was executed
|
||||||
|
result: Result content (will be converted to string)
|
||||||
|
status: Execution status
|
||||||
|
execution_time_ms: Execution time in milliseconds
|
||||||
|
parameters: Parameters that were passed to the tool
|
||||||
|
server_name: MCP server name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ToolContext instance
|
||||||
|
"""
|
||||||
|
# Convert result to string content
|
||||||
|
if isinstance(result, str):
|
||||||
|
content = result
|
||||||
|
elif isinstance(result, dict):
|
||||||
|
import json
|
||||||
|
|
||||||
|
try:
|
||||||
|
content = json.dumps(result, indent=2)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
content = str(result)
|
||||||
|
else:
|
||||||
|
content = str(result)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
content=content,
|
||||||
|
source=f"tool_result:{server_name}:{tool_name}"
|
||||||
|
if server_name
|
||||||
|
else f"tool_result:{tool_name}",
|
||||||
|
tool_name=tool_name,
|
||||||
|
is_result=True,
|
||||||
|
result_status=status,
|
||||||
|
execution_time_ms=execution_time_ms,
|
||||||
|
parameters=parameters or {},
|
||||||
|
server_name=server_name,
|
||||||
|
priority=ContextPriority.HIGH.value, # Recent results are high priority
|
||||||
|
)
|
||||||
|
|
||||||
|
def is_successful(self) -> bool:
|
||||||
|
"""Check if this is a successful tool result."""
|
||||||
|
return self.is_result and self.result_status == ToolResultStatus.SUCCESS
|
||||||
|
|
||||||
|
def is_error(self) -> bool:
|
||||||
|
"""Check if this is an error result."""
|
||||||
|
return self.is_result and self.result_status == ToolResultStatus.ERROR
|
||||||
|
|
||||||
|
def format_for_prompt(self) -> str:
|
||||||
|
"""
|
||||||
|
Format tool context for inclusion in prompt.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted tool string
|
||||||
|
"""
|
||||||
|
if self.is_result:
|
||||||
|
status_str = self.result_status.value if self.result_status else "unknown"
|
||||||
|
header = f"Tool Result ({self.tool_name}, {status_str}):"
|
||||||
|
return f"{header}\n{self.content}"
|
||||||
|
else:
|
||||||
|
return self.content
|
||||||
@@ -54,22 +54,18 @@ class EventBusError(Exception):
|
|||||||
"""Base exception for EventBus errors."""
|
"""Base exception for EventBus errors."""
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class EventBusConnectionError(EventBusError):
|
class EventBusConnectionError(EventBusError):
|
||||||
"""Raised when connection to Redis fails."""
|
"""Raised when connection to Redis fails."""
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class EventBusPublishError(EventBusError):
|
class EventBusPublishError(EventBusError):
|
||||||
"""Raised when publishing an event fails."""
|
"""Raised when publishing an event fails."""
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class EventBusSubscriptionError(EventBusError):
|
class EventBusSubscriptionError(EventBusError):
|
||||||
"""Raised when subscribing to channels fails."""
|
"""Raised when subscribing to channels fails."""
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class EventBus:
|
class EventBus:
|
||||||
"""
|
"""
|
||||||
EventBus for Redis Pub/Sub communication.
|
EventBus for Redis Pub/Sub communication.
|
||||||
|
|||||||
85
backend/app/services/mcp/__init__.py
Normal file
85
backend/app/services/mcp/__init__.py
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
"""
|
||||||
|
MCP Client Service Package
|
||||||
|
|
||||||
|
Provides infrastructure for communicating with MCP (Model Context Protocol)
|
||||||
|
servers. This is the foundation for AI agent tool integration.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
from app.services.mcp import get_mcp_client, MCPClientManager
|
||||||
|
|
||||||
|
# In FastAPI route
|
||||||
|
async def my_route(mcp: MCPClientManager = Depends(get_mcp_client)):
|
||||||
|
result = await mcp.call_tool("llm-gateway", "chat", {"prompt": "Hello"})
|
||||||
|
|
||||||
|
# Direct usage
|
||||||
|
manager = MCPClientManager()
|
||||||
|
await manager.initialize()
|
||||||
|
result = await manager.call_tool("issues", "create_issue", {...})
|
||||||
|
await manager.shutdown()
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .client_manager import (
|
||||||
|
MCPClientManager,
|
||||||
|
ServerHealth,
|
||||||
|
get_mcp_client,
|
||||||
|
reset_mcp_client,
|
||||||
|
shutdown_mcp_client,
|
||||||
|
)
|
||||||
|
from .config import (
|
||||||
|
MCPConfig,
|
||||||
|
MCPServerConfig,
|
||||||
|
TransportType,
|
||||||
|
create_default_config,
|
||||||
|
load_mcp_config,
|
||||||
|
)
|
||||||
|
from .connection import ConnectionPool, ConnectionState, MCPConnection
|
||||||
|
from .exceptions import (
|
||||||
|
MCPCircuitOpenError,
|
||||||
|
MCPConnectionError,
|
||||||
|
MCPError,
|
||||||
|
MCPServerNotFoundError,
|
||||||
|
MCPTimeoutError,
|
||||||
|
MCPToolError,
|
||||||
|
MCPToolNotFoundError,
|
||||||
|
MCPValidationError,
|
||||||
|
)
|
||||||
|
from .registry import MCPServerRegistry, ServerCapabilities, get_registry
|
||||||
|
from .routing import AsyncCircuitBreaker, CircuitState, ToolInfo, ToolResult, ToolRouter
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# Main facade
|
||||||
|
"MCPClientManager",
|
||||||
|
"get_mcp_client",
|
||||||
|
"shutdown_mcp_client",
|
||||||
|
"reset_mcp_client",
|
||||||
|
"ServerHealth",
|
||||||
|
# Configuration
|
||||||
|
"MCPConfig",
|
||||||
|
"MCPServerConfig",
|
||||||
|
"TransportType",
|
||||||
|
"load_mcp_config",
|
||||||
|
"create_default_config",
|
||||||
|
# Registry
|
||||||
|
"MCPServerRegistry",
|
||||||
|
"ServerCapabilities",
|
||||||
|
"get_registry",
|
||||||
|
# Connection
|
||||||
|
"ConnectionPool",
|
||||||
|
"ConnectionState",
|
||||||
|
"MCPConnection",
|
||||||
|
# Routing
|
||||||
|
"ToolRouter",
|
||||||
|
"ToolInfo",
|
||||||
|
"ToolResult",
|
||||||
|
"AsyncCircuitBreaker",
|
||||||
|
"CircuitState",
|
||||||
|
# Exceptions
|
||||||
|
"MCPError",
|
||||||
|
"MCPConnectionError",
|
||||||
|
"MCPTimeoutError",
|
||||||
|
"MCPToolError",
|
||||||
|
"MCPServerNotFoundError",
|
||||||
|
"MCPToolNotFoundError",
|
||||||
|
"MCPCircuitOpenError",
|
||||||
|
"MCPValidationError",
|
||||||
|
]
|
||||||
438
backend/app/services/mcp/client_manager.py
Normal file
438
backend/app/services/mcp/client_manager.py
Normal file
@@ -0,0 +1,438 @@
|
|||||||
|
"""
|
||||||
|
MCP Client Manager
|
||||||
|
|
||||||
|
Main facade for all MCP operations. Manages server connections,
|
||||||
|
tool discovery, and provides a unified interface for tool calls.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .config import MCPConfig, MCPServerConfig, load_mcp_config
|
||||||
|
from .connection import ConnectionPool, ConnectionState
|
||||||
|
from .exceptions import MCPServerNotFoundError
|
||||||
|
from .registry import MCPServerRegistry, get_registry
|
||||||
|
from .routing import ToolInfo, ToolResult, ToolRouter
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ServerHealth:
|
||||||
|
"""Health status for an MCP server."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
healthy: bool
|
||||||
|
state: str
|
||||||
|
url: str
|
||||||
|
error: str | None = None
|
||||||
|
tools_count: int = 0
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert to dictionary."""
|
||||||
|
return {
|
||||||
|
"name": self.name,
|
||||||
|
"healthy": self.healthy,
|
||||||
|
"state": self.state,
|
||||||
|
"url": self.url,
|
||||||
|
"error": self.error,
|
||||||
|
"tools_count": self.tools_count,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class MCPClientManager:
|
||||||
|
"""
|
||||||
|
Central manager for all MCP client operations.
|
||||||
|
|
||||||
|
Provides a unified interface for:
|
||||||
|
- Connecting to MCP servers
|
||||||
|
- Discovering and calling tools
|
||||||
|
- Health monitoring
|
||||||
|
- Connection lifecycle management
|
||||||
|
|
||||||
|
This is the main entry point for MCP operations in the application.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: MCPConfig | None = None,
|
||||||
|
registry: MCPServerRegistry | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize the MCP client manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Optional MCP configuration. If None, loads from default.
|
||||||
|
registry: Optional registry instance. If None, uses singleton.
|
||||||
|
"""
|
||||||
|
self._registry = registry or get_registry()
|
||||||
|
self._pool = ConnectionPool()
|
||||||
|
self._router: ToolRouter | None = None
|
||||||
|
self._initialized = False
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
# Load configuration if provided
|
||||||
|
if config is not None:
|
||||||
|
self._registry.load_config(config)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_initialized(self) -> bool:
|
||||||
|
"""Check if the manager is initialized."""
|
||||||
|
return self._initialized
|
||||||
|
|
||||||
|
async def initialize(self, config: MCPConfig | None = None) -> None:
|
||||||
|
"""
|
||||||
|
Initialize the MCP client manager.
|
||||||
|
|
||||||
|
Loads configuration, creates connections, and discovers tools.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Optional configuration to load
|
||||||
|
"""
|
||||||
|
async with self._lock:
|
||||||
|
if self._initialized:
|
||||||
|
logger.warning("MCPClientManager already initialized")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("Initializing MCP Client Manager")
|
||||||
|
|
||||||
|
# Load configuration
|
||||||
|
if config is not None:
|
||||||
|
self._registry.load_config(config)
|
||||||
|
elif len(self._registry.list_servers()) == 0:
|
||||||
|
# Try to load from default location
|
||||||
|
self._registry.load_config(load_mcp_config())
|
||||||
|
|
||||||
|
# Create router
|
||||||
|
self._router = ToolRouter(self._registry, self._pool)
|
||||||
|
|
||||||
|
# Connect to all enabled servers
|
||||||
|
await self._connect_all_servers()
|
||||||
|
|
||||||
|
# Discover tools from all servers
|
||||||
|
if self._router:
|
||||||
|
await self._router.discover_tools()
|
||||||
|
|
||||||
|
self._initialized = True
|
||||||
|
logger.info(
|
||||||
|
"MCP Client Manager initialized with %d servers",
|
||||||
|
len(self._registry.list_enabled_servers()),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _connect_all_servers(self) -> None:
|
||||||
|
"""Connect to all enabled MCP servers concurrently."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
enabled_servers = self._registry.get_enabled_configs()
|
||||||
|
|
||||||
|
async def connect_server(name: str, config: "MCPServerConfig") -> None:
|
||||||
|
try:
|
||||||
|
await self._pool.get_connection(name, config)
|
||||||
|
logger.info("Connected to MCP server: %s", name)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to connect to MCP server %s: %s", name, e)
|
||||||
|
|
||||||
|
# Connect to all servers concurrently for faster startup
|
||||||
|
await asyncio.gather(
|
||||||
|
*(connect_server(name, config) for name, config in enabled_servers.items()),
|
||||||
|
return_exceptions=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
"""
|
||||||
|
Shutdown the MCP client manager.
|
||||||
|
|
||||||
|
Closes all connections and cleans up resources.
|
||||||
|
"""
|
||||||
|
async with self._lock:
|
||||||
|
if not self._initialized:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("Shutting down MCP Client Manager")
|
||||||
|
|
||||||
|
await self._pool.close_all()
|
||||||
|
self._initialized = False
|
||||||
|
|
||||||
|
logger.info("MCP Client Manager shutdown complete")
|
||||||
|
|
||||||
|
async def connect(self, server_name: str) -> None:
|
||||||
|
"""
|
||||||
|
Connect to a specific MCP server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server_name: Name of the server to connect to
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
MCPServerNotFoundError: If server is not registered
|
||||||
|
"""
|
||||||
|
config = self._registry.get(server_name)
|
||||||
|
await self._pool.get_connection(server_name, config)
|
||||||
|
logger.info("Connected to MCP server: %s", server_name)
|
||||||
|
|
||||||
|
async def disconnect(self, server_name: str) -> None:
|
||||||
|
"""
|
||||||
|
Disconnect from a specific MCP server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server_name: Name of the server to disconnect from
|
||||||
|
"""
|
||||||
|
await self._pool.close_connection(server_name)
|
||||||
|
logger.info("Disconnected from MCP server: %s", server_name)
|
||||||
|
|
||||||
|
async def disconnect_all(self) -> None:
|
||||||
|
"""Disconnect from all MCP servers."""
|
||||||
|
await self._pool.close_all()
|
||||||
|
|
||||||
|
async def call_tool(
|
||||||
|
self,
|
||||||
|
server: str,
|
||||||
|
tool: str,
|
||||||
|
args: dict[str, Any] | None = None,
|
||||||
|
timeout: float | None = None,
|
||||||
|
) -> ToolResult:
|
||||||
|
"""
|
||||||
|
Call a tool on a specific MCP server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server: Name of the MCP server
|
||||||
|
tool: Name of the tool to call
|
||||||
|
args: Tool arguments
|
||||||
|
timeout: Optional timeout override
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tool execution result
|
||||||
|
"""
|
||||||
|
if not self._initialized or self._router is None:
|
||||||
|
await self.initialize()
|
||||||
|
|
||||||
|
assert self._router is not None # Guaranteed after initialize()
|
||||||
|
return await self._router.call_tool(
|
||||||
|
server_name=server,
|
||||||
|
tool_name=tool,
|
||||||
|
arguments=args,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def route_tool(
|
||||||
|
self,
|
||||||
|
tool: str,
|
||||||
|
args: dict[str, Any] | None = None,
|
||||||
|
timeout: float | None = None,
|
||||||
|
) -> ToolResult:
|
||||||
|
"""
|
||||||
|
Route a tool call to the appropriate server automatically.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool: Name of the tool to call
|
||||||
|
args: Tool arguments
|
||||||
|
timeout: Optional timeout override
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tool execution result
|
||||||
|
"""
|
||||||
|
if not self._initialized or self._router is None:
|
||||||
|
await self.initialize()
|
||||||
|
|
||||||
|
assert self._router is not None # Guaranteed after initialize()
|
||||||
|
return await self._router.route_tool(
|
||||||
|
tool_name=tool,
|
||||||
|
arguments=args,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def list_tools(self, server: str) -> list[ToolInfo]:
|
||||||
|
"""
|
||||||
|
List all tools available on a specific server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server: Name of the MCP server
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of tool information
|
||||||
|
"""
|
||||||
|
capabilities = await self._registry.get_capabilities(server)
|
||||||
|
return [
|
||||||
|
ToolInfo(
|
||||||
|
name=t.get("name", ""),
|
||||||
|
description=t.get("description"),
|
||||||
|
server_name=server,
|
||||||
|
input_schema=t.get("input_schema"),
|
||||||
|
)
|
||||||
|
for t in capabilities.tools
|
||||||
|
]
|
||||||
|
|
||||||
|
async def list_all_tools(self) -> list[ToolInfo]:
|
||||||
|
"""
|
||||||
|
List all tools from all servers.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of tool information
|
||||||
|
"""
|
||||||
|
if not self._initialized or self._router is None:
|
||||||
|
await self.initialize()
|
||||||
|
|
||||||
|
assert self._router is not None # Guaranteed after initialize()
|
||||||
|
return await self._router.list_all_tools()
|
||||||
|
|
||||||
|
async def health_check(self) -> dict[str, ServerHealth]:
|
||||||
|
"""
|
||||||
|
Perform health check on all MCP servers.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict mapping server names to health status
|
||||||
|
"""
|
||||||
|
results: dict[str, ServerHealth] = {}
|
||||||
|
pool_status = self._pool.get_status()
|
||||||
|
pool_health = await self._pool.health_check_all()
|
||||||
|
|
||||||
|
for server_name in self._registry.list_servers():
|
||||||
|
try:
|
||||||
|
config = self._registry.get(server_name)
|
||||||
|
status = pool_status.get(server_name, {})
|
||||||
|
healthy = pool_health.get(server_name, False)
|
||||||
|
|
||||||
|
capabilities = self._registry.get_cached_capabilities(server_name)
|
||||||
|
|
||||||
|
results[server_name] = ServerHealth(
|
||||||
|
name=server_name,
|
||||||
|
healthy=healthy,
|
||||||
|
state=status.get("state", ConnectionState.DISCONNECTED.value),
|
||||||
|
url=config.url,
|
||||||
|
tools_count=len(capabilities.tools),
|
||||||
|
)
|
||||||
|
except MCPServerNotFoundError:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
results[server_name] = ServerHealth(
|
||||||
|
name=server_name,
|
||||||
|
healthy=False,
|
||||||
|
state=ConnectionState.ERROR.value,
|
||||||
|
url="unknown",
|
||||||
|
error=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def list_servers(self) -> list[str]:
|
||||||
|
"""Get list of all registered server names."""
|
||||||
|
return self._registry.list_servers()
|
||||||
|
|
||||||
|
def list_enabled_servers(self) -> list[str]:
|
||||||
|
"""Get list of enabled server names."""
|
||||||
|
return self._registry.list_enabled_servers()
|
||||||
|
|
||||||
|
def get_server_config(self, server_name: str) -> MCPServerConfig:
|
||||||
|
"""
|
||||||
|
Get configuration for a specific server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server_name: Name of the server
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Server configuration
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
MCPServerNotFoundError: If server is not registered
|
||||||
|
"""
|
||||||
|
return self._registry.get(server_name)
|
||||||
|
|
||||||
|
def register_server(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
config: MCPServerConfig,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Register a new MCP server at runtime.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Unique server name
|
||||||
|
config: Server configuration
|
||||||
|
"""
|
||||||
|
self._registry.register(name, config)
|
||||||
|
|
||||||
|
def unregister_server(self, name: str) -> bool:
|
||||||
|
"""
|
||||||
|
Unregister an MCP server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Server name to unregister
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if server was found and removed
|
||||||
|
"""
|
||||||
|
return self._registry.unregister(name)
|
||||||
|
|
||||||
|
def get_circuit_breaker_status(self) -> dict[str, dict[str, Any]]:
|
||||||
|
"""Get status of all circuit breakers."""
|
||||||
|
if self._router is None:
|
||||||
|
return {}
|
||||||
|
return self._router.get_circuit_breaker_status()
|
||||||
|
|
||||||
|
async def reset_circuit_breaker(self, server_name: str) -> bool:
|
||||||
|
"""
|
||||||
|
Reset a circuit breaker for a server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server_name: Name of the server
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if circuit breaker was reset
|
||||||
|
"""
|
||||||
|
if self._router is None:
|
||||||
|
return False
|
||||||
|
return await self._router.reset_circuit_breaker(server_name)
|
||||||
|
|
||||||
|
|
||||||
|
# Singleton instance
|
||||||
|
_manager_instance: MCPClientManager | None = None
|
||||||
|
_manager_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
async def get_mcp_client() -> MCPClientManager:
|
||||||
|
"""
|
||||||
|
Get the global MCP client manager instance.
|
||||||
|
|
||||||
|
This is the main dependency injection point for FastAPI.
|
||||||
|
Uses proper locking to avoid race conditions in async contexts.
|
||||||
|
"""
|
||||||
|
global _manager_instance
|
||||||
|
|
||||||
|
# Use lock for the entire check-and-create operation to avoid race conditions
|
||||||
|
async with _manager_lock:
|
||||||
|
if _manager_instance is None:
|
||||||
|
_manager_instance = MCPClientManager()
|
||||||
|
await _manager_instance.initialize()
|
||||||
|
|
||||||
|
return _manager_instance
|
||||||
|
|
||||||
|
|
||||||
|
async def shutdown_mcp_client() -> None:
|
||||||
|
"""Shutdown the global MCP client manager."""
|
||||||
|
global _manager_instance
|
||||||
|
|
||||||
|
# Use lock to prevent race with get_mcp_client()
|
||||||
|
async with _manager_lock:
|
||||||
|
if _manager_instance is not None:
|
||||||
|
await _manager_instance.shutdown()
|
||||||
|
_manager_instance = None
|
||||||
|
|
||||||
|
|
||||||
|
async def reset_mcp_client() -> None:
|
||||||
|
"""
|
||||||
|
Reset the global MCP client manager (for testing).
|
||||||
|
|
||||||
|
This is an async function to properly acquire the manager lock
|
||||||
|
and avoid race conditions with get_mcp_client().
|
||||||
|
"""
|
||||||
|
global _manager_instance
|
||||||
|
|
||||||
|
async with _manager_lock:
|
||||||
|
if _manager_instance is not None:
|
||||||
|
# Shutdown gracefully before resetting
|
||||||
|
try:
|
||||||
|
await _manager_instance.shutdown()
|
||||||
|
except Exception: # noqa: S110
|
||||||
|
pass # Ignore errors during test cleanup
|
||||||
|
_manager_instance = None
|
||||||
245
backend/app/services/mcp/config.py
Normal file
245
backend/app/services/mcp/config.py
Normal file
@@ -0,0 +1,245 @@
|
|||||||
|
"""
|
||||||
|
MCP Configuration System
|
||||||
|
|
||||||
|
Pydantic models for MCP server configuration with YAML file loading
|
||||||
|
and environment variable overrides.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from enum import Enum
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
|
|
||||||
|
class TransportType(str, Enum):
|
||||||
|
"""Supported MCP transport types."""
|
||||||
|
|
||||||
|
HTTP = "http"
|
||||||
|
STDIO = "stdio"
|
||||||
|
SSE = "sse"
|
||||||
|
|
||||||
|
|
||||||
|
class MCPServerConfig(BaseModel):
|
||||||
|
"""Configuration for a single MCP server."""
|
||||||
|
|
||||||
|
url: str = Field(..., description="Server URL (supports ${ENV_VAR} syntax)")
|
||||||
|
transport: TransportType = Field(
|
||||||
|
default=TransportType.HTTP,
|
||||||
|
description="Transport protocol to use",
|
||||||
|
)
|
||||||
|
timeout: int = Field(
|
||||||
|
default=30,
|
||||||
|
ge=1,
|
||||||
|
le=600,
|
||||||
|
description="Request timeout in seconds",
|
||||||
|
)
|
||||||
|
retry_attempts: int = Field(
|
||||||
|
default=3,
|
||||||
|
ge=0,
|
||||||
|
le=10,
|
||||||
|
description="Number of retry attempts on failure",
|
||||||
|
)
|
||||||
|
retry_delay: float = Field(
|
||||||
|
default=1.0,
|
||||||
|
ge=0.1,
|
||||||
|
le=60.0,
|
||||||
|
description="Initial delay between retries in seconds",
|
||||||
|
)
|
||||||
|
retry_max_delay: float = Field(
|
||||||
|
default=30.0,
|
||||||
|
ge=1.0,
|
||||||
|
le=300.0,
|
||||||
|
description="Maximum delay between retries in seconds",
|
||||||
|
)
|
||||||
|
circuit_breaker_threshold: int = Field(
|
||||||
|
default=5,
|
||||||
|
ge=1,
|
||||||
|
le=50,
|
||||||
|
description="Number of failures before opening circuit",
|
||||||
|
)
|
||||||
|
circuit_breaker_timeout: float = Field(
|
||||||
|
default=30.0,
|
||||||
|
ge=5.0,
|
||||||
|
le=300.0,
|
||||||
|
description="Seconds to wait before attempting to close circuit",
|
||||||
|
)
|
||||||
|
enabled: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Whether this server is enabled",
|
||||||
|
)
|
||||||
|
description: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Human-readable description of the server",
|
||||||
|
)
|
||||||
|
|
||||||
|
@field_validator("url", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def expand_env_vars(cls, v: str) -> str:
|
||||||
|
"""Expand environment variables in URL using ${VAR:-default} syntax."""
|
||||||
|
if not isinstance(v, str):
|
||||||
|
return v
|
||||||
|
|
||||||
|
result = v
|
||||||
|
# Find all ${VAR} or ${VAR:-default} patterns
|
||||||
|
import re
|
||||||
|
|
||||||
|
pattern = r"\$\{([^}]+)\}"
|
||||||
|
matches = re.findall(pattern, v)
|
||||||
|
|
||||||
|
for match in matches:
|
||||||
|
if ":-" in match:
|
||||||
|
var_name, default = match.split(":-", 1)
|
||||||
|
else:
|
||||||
|
var_name, default = match, ""
|
||||||
|
|
||||||
|
env_value = os.environ.get(var_name.strip(), default)
|
||||||
|
result = result.replace(f"${{{match}}}", env_value)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class MCPConfig(BaseModel):
|
||||||
|
"""Root configuration for all MCP servers."""
|
||||||
|
|
||||||
|
mcp_servers: dict[str, MCPServerConfig] = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
description="Map of server names to their configurations",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Global defaults
|
||||||
|
default_timeout: int = Field(
|
||||||
|
default=30,
|
||||||
|
description="Default timeout for all servers",
|
||||||
|
)
|
||||||
|
default_retry_attempts: int = Field(
|
||||||
|
default=3,
|
||||||
|
description="Default retry attempts for all servers",
|
||||||
|
)
|
||||||
|
connection_pool_size: int = Field(
|
||||||
|
default=10,
|
||||||
|
ge=1,
|
||||||
|
le=100,
|
||||||
|
description="Maximum connections per server",
|
||||||
|
)
|
||||||
|
health_check_interval: int = Field(
|
||||||
|
default=30,
|
||||||
|
ge=5,
|
||||||
|
le=300,
|
||||||
|
description="Seconds between health checks",
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_yaml(cls, path: str | Path) -> "MCPConfig":
|
||||||
|
"""Load configuration from a YAML file."""
|
||||||
|
path = Path(path)
|
||||||
|
if not path.exists():
|
||||||
|
raise FileNotFoundError(f"MCP config file not found: {path}")
|
||||||
|
|
||||||
|
with path.open("r") as f:
|
||||||
|
data = yaml.safe_load(f)
|
||||||
|
|
||||||
|
if data is None:
|
||||||
|
data = {}
|
||||||
|
|
||||||
|
return cls.model_validate(data)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: dict[str, Any]) -> "MCPConfig":
|
||||||
|
"""Load configuration from a dictionary."""
|
||||||
|
return cls.model_validate(data)
|
||||||
|
|
||||||
|
def get_server(self, name: str) -> MCPServerConfig | None:
|
||||||
|
"""Get a server configuration by name."""
|
||||||
|
return self.mcp_servers.get(name)
|
||||||
|
|
||||||
|
def get_enabled_servers(self) -> dict[str, MCPServerConfig]:
|
||||||
|
"""Get all enabled server configurations."""
|
||||||
|
return {
|
||||||
|
name: config for name, config in self.mcp_servers.items() if config.enabled
|
||||||
|
}
|
||||||
|
|
||||||
|
def list_server_names(self) -> list[str]:
|
||||||
|
"""Get list of all configured server names."""
|
||||||
|
return list(self.mcp_servers.keys())
|
||||||
|
|
||||||
|
|
||||||
|
# Default configuration path
|
||||||
|
DEFAULT_CONFIG_PATH = Path(__file__).parent.parent.parent.parent / "mcp_servers.yaml"
|
||||||
|
|
||||||
|
|
||||||
|
def load_mcp_config(path: str | Path | None = None) -> MCPConfig:
|
||||||
|
"""
|
||||||
|
Load MCP configuration from file or environment.
|
||||||
|
|
||||||
|
Priority:
|
||||||
|
1. Explicit path parameter
|
||||||
|
2. MCP_CONFIG_PATH environment variable
|
||||||
|
3. Default path (backend/mcp_servers.yaml)
|
||||||
|
4. Empty config if no file exists
|
||||||
|
|
||||||
|
In test mode (IS_TEST=True), retry settings are reduced for faster tests.
|
||||||
|
"""
|
||||||
|
if path is None:
|
||||||
|
path = os.environ.get("MCP_CONFIG_PATH", str(DEFAULT_CONFIG_PATH))
|
||||||
|
|
||||||
|
path = Path(path)
|
||||||
|
|
||||||
|
if not path.exists():
|
||||||
|
# Return empty config if no file exists (allows runtime registration)
|
||||||
|
return MCPConfig()
|
||||||
|
|
||||||
|
config = MCPConfig.from_yaml(path)
|
||||||
|
|
||||||
|
# In test mode, reduce retry settings to speed up tests
|
||||||
|
is_test = os.environ.get("IS_TEST", "").lower() in ("true", "1", "yes")
|
||||||
|
if is_test:
|
||||||
|
for server_config in config.mcp_servers.values():
|
||||||
|
server_config.retry_attempts = 1 # Single attempt
|
||||||
|
server_config.retry_delay = 0.1 # 100ms instead of 1s
|
||||||
|
server_config.retry_max_delay = 0.5 # 500ms max
|
||||||
|
server_config.timeout = 2 # 2s timeout instead of 30-120s
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def create_default_config() -> MCPConfig:
|
||||||
|
"""
|
||||||
|
Create a default MCP configuration with standard servers.
|
||||||
|
|
||||||
|
This is useful for development and as a template.
|
||||||
|
"""
|
||||||
|
return MCPConfig(
|
||||||
|
mcp_servers={
|
||||||
|
"llm-gateway": MCPServerConfig(
|
||||||
|
url="${LLM_GATEWAY_URL:-http://localhost:8001}",
|
||||||
|
transport=TransportType.HTTP,
|
||||||
|
timeout=60,
|
||||||
|
description="LLM Gateway for multi-provider AI interactions",
|
||||||
|
),
|
||||||
|
"knowledge-base": MCPServerConfig(
|
||||||
|
url="${KNOWLEDGE_BASE_URL:-http://localhost:8002}",
|
||||||
|
transport=TransportType.HTTP,
|
||||||
|
timeout=30,
|
||||||
|
description="Knowledge Base for RAG and document retrieval",
|
||||||
|
),
|
||||||
|
"git-ops": MCPServerConfig(
|
||||||
|
url="${GIT_OPS_URL:-http://localhost:8003}",
|
||||||
|
transport=TransportType.HTTP,
|
||||||
|
timeout=120,
|
||||||
|
description="Git Operations for repository management",
|
||||||
|
),
|
||||||
|
"issues": MCPServerConfig(
|
||||||
|
url="${ISSUES_URL:-http://localhost:8004}",
|
||||||
|
transport=TransportType.HTTP,
|
||||||
|
timeout=30,
|
||||||
|
description="Issue Tracker for Gitea/GitHub/GitLab",
|
||||||
|
),
|
||||||
|
},
|
||||||
|
default_timeout=30,
|
||||||
|
default_retry_attempts=3,
|
||||||
|
connection_pool_size=10,
|
||||||
|
health_check_interval=30,
|
||||||
|
)
|
||||||
473
backend/app/services/mcp/connection.py
Normal file
473
backend/app/services/mcp/connection.py
Normal file
@@ -0,0 +1,473 @@
|
|||||||
|
"""
|
||||||
|
MCP Connection Management
|
||||||
|
|
||||||
|
Handles connection lifecycle, pooling, and automatic reconnection
|
||||||
|
for MCP servers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from .config import MCPServerConfig, TransportType
|
||||||
|
from .exceptions import MCPConnectionError, MCPTimeoutError
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectionState(str, Enum):
|
||||||
|
"""Connection state enumeration."""
|
||||||
|
|
||||||
|
DISCONNECTED = "disconnected"
|
||||||
|
CONNECTING = "connecting"
|
||||||
|
CONNECTED = "connected"
|
||||||
|
RECONNECTING = "reconnecting"
|
||||||
|
ERROR = "error"
|
||||||
|
|
||||||
|
|
||||||
|
class MCPConnection:
|
||||||
|
"""
|
||||||
|
Manages a single connection to an MCP server.
|
||||||
|
|
||||||
|
Handles connection lifecycle, health checking, and automatic reconnection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
server_name: str,
|
||||||
|
config: MCPServerConfig,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize connection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server_name: Name of the MCP server
|
||||||
|
config: Server configuration
|
||||||
|
"""
|
||||||
|
self.server_name = server_name
|
||||||
|
self.config = config
|
||||||
|
self._state = ConnectionState.DISCONNECTED
|
||||||
|
self._client: httpx.AsyncClient | None = None
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
self._last_activity: float | None = None
|
||||||
|
self._connection_attempts = 0
|
||||||
|
self._last_error: Exception | None = None
|
||||||
|
|
||||||
|
# Reconnection settings
|
||||||
|
self._base_delay = config.retry_delay
|
||||||
|
self._max_delay = config.retry_max_delay
|
||||||
|
self._max_attempts = config.retry_attempts
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state(self) -> ConnectionState:
|
||||||
|
"""Get current connection state."""
|
||||||
|
return self._state
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_connected(self) -> bool:
|
||||||
|
"""Check if connection is established."""
|
||||||
|
return self._state == ConnectionState.CONNECTED
|
||||||
|
|
||||||
|
@property
|
||||||
|
def last_error(self) -> Exception | None:
|
||||||
|
"""Get the last error that occurred."""
|
||||||
|
return self._last_error
|
||||||
|
|
||||||
|
async def connect(self) -> None:
|
||||||
|
"""
|
||||||
|
Establish connection to the MCP server.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
MCPConnectionError: If connection fails after all retries
|
||||||
|
"""
|
||||||
|
async with self._lock:
|
||||||
|
if self._state == ConnectionState.CONNECTED:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._state = ConnectionState.CONNECTING
|
||||||
|
self._connection_attempts = 0
|
||||||
|
self._last_error = None
|
||||||
|
|
||||||
|
while self._connection_attempts < self._max_attempts:
|
||||||
|
try:
|
||||||
|
await self._do_connect()
|
||||||
|
self._state = ConnectionState.CONNECTED
|
||||||
|
self._last_activity = time.time()
|
||||||
|
logger.info(
|
||||||
|
"Connected to MCP server: %s at %s",
|
||||||
|
self.server_name,
|
||||||
|
self.config.url,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
self._connection_attempts += 1
|
||||||
|
self._last_error = e
|
||||||
|
logger.warning(
|
||||||
|
"Connection attempt %d/%d failed for %s: %s",
|
||||||
|
self._connection_attempts,
|
||||||
|
self._max_attempts,
|
||||||
|
self.server_name,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._connection_attempts < self._max_attempts:
|
||||||
|
delay = self._calculate_backoff_delay()
|
||||||
|
logger.debug(
|
||||||
|
"Retrying connection to %s in %.1fs",
|
||||||
|
self.server_name,
|
||||||
|
delay,
|
||||||
|
)
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
|
||||||
|
# All attempts failed
|
||||||
|
self._state = ConnectionState.ERROR
|
||||||
|
raise MCPConnectionError(
|
||||||
|
f"Failed to connect after {self._max_attempts} attempts",
|
||||||
|
server_name=self.server_name,
|
||||||
|
url=self.config.url,
|
||||||
|
cause=self._last_error,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _do_connect(self) -> None:
|
||||||
|
"""Perform the actual connection (transport-specific)."""
|
||||||
|
if self.config.transport == TransportType.HTTP:
|
||||||
|
self._client = httpx.AsyncClient(
|
||||||
|
base_url=self.config.url,
|
||||||
|
timeout=httpx.Timeout(self.config.timeout),
|
||||||
|
headers={
|
||||||
|
"User-Agent": "Syndarix-MCP-Client/1.0",
|
||||||
|
"Accept": "application/json",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# Verify connectivity with a simple request
|
||||||
|
try:
|
||||||
|
# Try to hit the MCP capabilities endpoint
|
||||||
|
response = await self._client.get("/mcp/capabilities")
|
||||||
|
if response.status_code not in (200, 404):
|
||||||
|
# 404 is acceptable - server might not have capabilities endpoint
|
||||||
|
response.raise_for_status()
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
if e.response.status_code != 404:
|
||||||
|
raise
|
||||||
|
except httpx.ConnectError as e:
|
||||||
|
raise MCPConnectionError(
|
||||||
|
"Failed to connect to server",
|
||||||
|
server_name=self.server_name,
|
||||||
|
url=self.config.url,
|
||||||
|
cause=e,
|
||||||
|
) from e
|
||||||
|
else:
|
||||||
|
# For STDIO and SSE transports, we'll implement later
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Transport {self.config.transport} not yet implemented"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _calculate_backoff_delay(self) -> float:
|
||||||
|
"""Calculate exponential backoff delay with jitter."""
|
||||||
|
import random
|
||||||
|
|
||||||
|
delay = self._base_delay * (2 ** (self._connection_attempts - 1))
|
||||||
|
delay = min(delay, self._max_delay)
|
||||||
|
# Add jitter (±25%)
|
||||||
|
jitter = delay * 0.25 * (random.random() * 2 - 1)
|
||||||
|
return delay + jitter
|
||||||
|
|
||||||
|
async def disconnect(self) -> None:
|
||||||
|
"""Disconnect from the MCP server."""
|
||||||
|
async with self._lock:
|
||||||
|
if self._client is not None:
|
||||||
|
try:
|
||||||
|
await self._client.aclose()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
"Error closing connection to %s: %s",
|
||||||
|
self.server_name,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
self._client = None
|
||||||
|
|
||||||
|
self._state = ConnectionState.DISCONNECTED
|
||||||
|
logger.info("Disconnected from MCP server: %s", self.server_name)
|
||||||
|
|
||||||
|
async def reconnect(self) -> None:
|
||||||
|
"""Reconnect to the MCP server."""
|
||||||
|
async with self._lock:
|
||||||
|
self._state = ConnectionState.RECONNECTING
|
||||||
|
await self.disconnect()
|
||||||
|
await self.connect()
|
||||||
|
|
||||||
|
async def health_check(self) -> bool:
|
||||||
|
"""
|
||||||
|
Perform a health check on the connection.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if connection is healthy
|
||||||
|
"""
|
||||||
|
if not self.is_connected or self._client is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
if self.config.transport == TransportType.HTTP:
|
||||||
|
response = await self._client.get(
|
||||||
|
"/health",
|
||||||
|
timeout=5.0,
|
||||||
|
)
|
||||||
|
return response.status_code == 200
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
"Health check failed for %s: %s",
|
||||||
|
self.server_name,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def execute_request(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
path: str,
|
||||||
|
data: dict[str, Any] | None = None,
|
||||||
|
timeout: float | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Execute an HTTP request to the MCP server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
method: HTTP method (GET, POST, etc.)
|
||||||
|
path: Request path
|
||||||
|
data: Optional request body
|
||||||
|
timeout: Optional timeout override
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response data
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
MCPConnectionError: If not connected
|
||||||
|
MCPTimeoutError: If request times out
|
||||||
|
"""
|
||||||
|
if not self.is_connected or self._client is None:
|
||||||
|
raise MCPConnectionError(
|
||||||
|
"Not connected to server",
|
||||||
|
server_name=self.server_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
effective_timeout = timeout or self.config.timeout
|
||||||
|
|
||||||
|
try:
|
||||||
|
if method.upper() == "GET":
|
||||||
|
response = await self._client.get(
|
||||||
|
path,
|
||||||
|
timeout=effective_timeout,
|
||||||
|
)
|
||||||
|
elif method.upper() == "POST":
|
||||||
|
response = await self._client.post(
|
||||||
|
path,
|
||||||
|
json=data,
|
||||||
|
timeout=effective_timeout,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = await self._client.request(
|
||||||
|
method.upper(),
|
||||||
|
path,
|
||||||
|
json=data,
|
||||||
|
timeout=effective_timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._last_activity = time.time()
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
except httpx.TimeoutException as e:
|
||||||
|
raise MCPTimeoutError(
|
||||||
|
"Request timed out",
|
||||||
|
server_name=self.server_name,
|
||||||
|
timeout_seconds=effective_timeout,
|
||||||
|
operation=f"{method} {path}",
|
||||||
|
) from e
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
raise MCPConnectionError(
|
||||||
|
f"HTTP error: {e.response.status_code}",
|
||||||
|
server_name=self.server_name,
|
||||||
|
url=f"{self.config.url}{path}",
|
||||||
|
cause=e,
|
||||||
|
) from e
|
||||||
|
except Exception as e:
|
||||||
|
raise MCPConnectionError(
|
||||||
|
f"Request failed: {e}",
|
||||||
|
server_name=self.server_name,
|
||||||
|
cause=e,
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectionPool:
|
||||||
|
"""
|
||||||
|
Pool of connections to MCP servers.
|
||||||
|
|
||||||
|
Manages connection lifecycle and provides connection reuse.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, max_connections_per_server: int = 10) -> None:
|
||||||
|
"""
|
||||||
|
Initialize connection pool.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_connections_per_server: Maximum connections per server
|
||||||
|
"""
|
||||||
|
self._connections: dict[str, MCPConnection] = {}
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
self._per_server_locks: dict[str, asyncio.Lock] = {}
|
||||||
|
self._max_per_server = max_connections_per_server
|
||||||
|
|
||||||
|
def _get_server_lock(self, server_name: str) -> asyncio.Lock:
|
||||||
|
"""Get or create a lock for a specific server.
|
||||||
|
|
||||||
|
Uses setdefault for atomic dict access to prevent race conditions
|
||||||
|
where two coroutines could create different locks for the same server.
|
||||||
|
"""
|
||||||
|
# setdefault is atomic - if key exists, returns existing value
|
||||||
|
# if key doesn't exist, inserts new value and returns it
|
||||||
|
return self._per_server_locks.setdefault(server_name, asyncio.Lock())
|
||||||
|
|
||||||
|
async def get_connection(
|
||||||
|
self,
|
||||||
|
server_name: str,
|
||||||
|
config: MCPServerConfig,
|
||||||
|
) -> MCPConnection:
|
||||||
|
"""
|
||||||
|
Get or create a connection to a server.
|
||||||
|
|
||||||
|
Uses per-server locking to avoid blocking all connections
|
||||||
|
when establishing a new connection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server_name: Name of the server
|
||||||
|
config: Server configuration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Active connection
|
||||||
|
"""
|
||||||
|
# Quick check without lock - if connection exists and is connected, return it
|
||||||
|
if server_name in self._connections:
|
||||||
|
connection = self._connections[server_name]
|
||||||
|
if connection.is_connected:
|
||||||
|
return connection
|
||||||
|
|
||||||
|
# Need to create or reconnect - use per-server lock to avoid blocking others
|
||||||
|
async with self._lock:
|
||||||
|
server_lock = self._get_server_lock(server_name)
|
||||||
|
|
||||||
|
async with server_lock:
|
||||||
|
# Double-check after acquiring per-server lock
|
||||||
|
if server_name in self._connections:
|
||||||
|
connection = self._connections[server_name]
|
||||||
|
if connection.is_connected:
|
||||||
|
return connection
|
||||||
|
# Connection exists but not connected - reconnect
|
||||||
|
await connection.connect()
|
||||||
|
return connection
|
||||||
|
|
||||||
|
# Create new connection (outside global lock, under per-server lock)
|
||||||
|
connection = MCPConnection(server_name, config)
|
||||||
|
await connection.connect()
|
||||||
|
|
||||||
|
# Store connection under global lock
|
||||||
|
async with self._lock:
|
||||||
|
self._connections[server_name] = connection
|
||||||
|
|
||||||
|
return connection
|
||||||
|
|
||||||
|
async def release_connection(self, server_name: str) -> None:
|
||||||
|
"""
|
||||||
|
Release a connection (currently just tracks usage).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server_name: Name of the server
|
||||||
|
"""
|
||||||
|
# For now, we keep connections alive
|
||||||
|
# Future: implement connection reaping for idle connections
|
||||||
|
|
||||||
|
async def close_connection(self, server_name: str) -> None:
|
||||||
|
"""
|
||||||
|
Close and remove a connection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server_name: Name of the server
|
||||||
|
"""
|
||||||
|
async with self._lock:
|
||||||
|
if server_name in self._connections:
|
||||||
|
await self._connections[server_name].disconnect()
|
||||||
|
del self._connections[server_name]
|
||||||
|
# Clean up per-server lock
|
||||||
|
if server_name in self._per_server_locks:
|
||||||
|
del self._per_server_locks[server_name]
|
||||||
|
|
||||||
|
async def close_all(self) -> None:
|
||||||
|
"""Close all connections in the pool."""
|
||||||
|
async with self._lock:
|
||||||
|
for connection in self._connections.values():
|
||||||
|
try:
|
||||||
|
await connection.disconnect()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Error closing connection: %s", e)
|
||||||
|
|
||||||
|
self._connections.clear()
|
||||||
|
self._per_server_locks.clear()
|
||||||
|
logger.info("Closed all MCP connections")
|
||||||
|
|
||||||
|
async def health_check_all(self) -> dict[str, bool]:
|
||||||
|
"""
|
||||||
|
Perform health check on all connections.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict mapping server names to health status
|
||||||
|
"""
|
||||||
|
# Copy connections under lock to prevent modification during iteration
|
||||||
|
async with self._lock:
|
||||||
|
connections_snapshot = dict(self._connections)
|
||||||
|
|
||||||
|
results = {}
|
||||||
|
for name, connection in connections_snapshot.items():
|
||||||
|
results[name] = await connection.health_check()
|
||||||
|
return results
|
||||||
|
|
||||||
|
def get_status(self) -> dict[str, dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get status of all connections.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict mapping server names to status info
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
name: {
|
||||||
|
"state": conn.state.value,
|
||||||
|
"is_connected": conn.is_connected,
|
||||||
|
"url": conn.config.url,
|
||||||
|
}
|
||||||
|
for name, conn in self._connections.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def connection(
|
||||||
|
self,
|
||||||
|
server_name: str,
|
||||||
|
config: MCPServerConfig,
|
||||||
|
) -> AsyncGenerator[MCPConnection, None]:
|
||||||
|
"""
|
||||||
|
Context manager for getting a connection.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
async with pool.connection("server", config) as conn:
|
||||||
|
result = await conn.execute_request("POST", "/tool", data)
|
||||||
|
"""
|
||||||
|
conn = await self.get_connection(server_name, config)
|
||||||
|
try:
|
||||||
|
yield conn
|
||||||
|
finally:
|
||||||
|
await self.release_connection(server_name)
|
||||||
201
backend/app/services/mcp/exceptions.py
Normal file
201
backend/app/services/mcp/exceptions.py
Normal file
@@ -0,0 +1,201 @@
|
|||||||
|
"""
|
||||||
|
MCP Exception Classes
|
||||||
|
|
||||||
|
Custom exceptions for MCP client operations with detailed error context.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
class MCPError(Exception):
|
||||||
|
"""Base exception for all MCP-related errors."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
*,
|
||||||
|
server_name: str | None = None,
|
||||||
|
details: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(message)
|
||||||
|
self.message = message
|
||||||
|
self.server_name = server_name
|
||||||
|
self.details = details or {}
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
parts = [self.message]
|
||||||
|
if self.server_name:
|
||||||
|
parts.append(f"server={self.server_name}")
|
||||||
|
if self.details:
|
||||||
|
parts.append(f"details={self.details}")
|
||||||
|
return " | ".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
class MCPConnectionError(MCPError):
|
||||||
|
"""Raised when connection to an MCP server fails."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
*,
|
||||||
|
server_name: str | None = None,
|
||||||
|
url: str | None = None,
|
||||||
|
cause: Exception | None = None,
|
||||||
|
details: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(message, server_name=server_name, details=details)
|
||||||
|
self.url = url
|
||||||
|
self.cause = cause
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
base = super().__str__()
|
||||||
|
if self.url:
|
||||||
|
base = f"{base} | url={self.url}"
|
||||||
|
if self.cause:
|
||||||
|
base = f"{base} | cause={type(self.cause).__name__}: {self.cause}"
|
||||||
|
return base
|
||||||
|
|
||||||
|
|
||||||
|
class MCPTimeoutError(MCPError):
|
||||||
|
"""Raised when an MCP operation times out."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
*,
|
||||||
|
server_name: str | None = None,
|
||||||
|
timeout_seconds: float | None = None,
|
||||||
|
operation: str | None = None,
|
||||||
|
details: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(message, server_name=server_name, details=details)
|
||||||
|
self.timeout_seconds = timeout_seconds
|
||||||
|
self.operation = operation
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
base = super().__str__()
|
||||||
|
if self.timeout_seconds is not None:
|
||||||
|
base = f"{base} | timeout={self.timeout_seconds}s"
|
||||||
|
if self.operation:
|
||||||
|
base = f"{base} | operation={self.operation}"
|
||||||
|
return base
|
||||||
|
|
||||||
|
|
||||||
|
class MCPToolError(MCPError):
|
||||||
|
"""Raised when a tool execution fails."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
*,
|
||||||
|
server_name: str | None = None,
|
||||||
|
tool_name: str | None = None,
|
||||||
|
tool_args: dict[str, Any] | None = None,
|
||||||
|
error_code: str | None = None,
|
||||||
|
details: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(message, server_name=server_name, details=details)
|
||||||
|
self.tool_name = tool_name
|
||||||
|
self.tool_args = tool_args
|
||||||
|
self.error_code = error_code
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
base = super().__str__()
|
||||||
|
if self.tool_name:
|
||||||
|
base = f"{base} | tool={self.tool_name}"
|
||||||
|
if self.error_code:
|
||||||
|
base = f"{base} | error_code={self.error_code}"
|
||||||
|
return base
|
||||||
|
|
||||||
|
|
||||||
|
class MCPServerNotFoundError(MCPError):
|
||||||
|
"""Raised when a requested MCP server is not registered."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
server_name: str,
|
||||||
|
*,
|
||||||
|
available_servers: list[str] | None = None,
|
||||||
|
details: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
message = f"MCP server not found: {server_name}"
|
||||||
|
super().__init__(message, server_name=server_name, details=details)
|
||||||
|
self.available_servers = available_servers or []
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
base = super().__str__()
|
||||||
|
if self.available_servers:
|
||||||
|
base = f"{base} | available={self.available_servers}"
|
||||||
|
return base
|
||||||
|
|
||||||
|
|
||||||
|
class MCPToolNotFoundError(MCPError):
|
||||||
|
"""Raised when a requested tool is not found on any server."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tool_name: str,
|
||||||
|
*,
|
||||||
|
server_name: str | None = None,
|
||||||
|
available_tools: list[str] | None = None,
|
||||||
|
details: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
message = f"Tool not found: {tool_name}"
|
||||||
|
super().__init__(message, server_name=server_name, details=details)
|
||||||
|
self.tool_name = tool_name
|
||||||
|
self.available_tools = available_tools or []
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
base = super().__str__()
|
||||||
|
if self.available_tools:
|
||||||
|
base = f"{base} | available_tools={self.available_tools[:5]}..."
|
||||||
|
return base
|
||||||
|
|
||||||
|
|
||||||
|
class MCPCircuitOpenError(MCPError):
|
||||||
|
"""Raised when a circuit breaker is open (server temporarily unavailable)."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
server_name: str,
|
||||||
|
*,
|
||||||
|
failure_count: int | None = None,
|
||||||
|
reset_timeout: float | None = None,
|
||||||
|
details: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
message = f"Circuit breaker open for server: {server_name}"
|
||||||
|
super().__init__(message, server_name=server_name, details=details)
|
||||||
|
self.failure_count = failure_count
|
||||||
|
self.reset_timeout = reset_timeout
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
base = super().__str__()
|
||||||
|
if self.failure_count is not None:
|
||||||
|
base = f"{base} | failures={self.failure_count}"
|
||||||
|
if self.reset_timeout is not None:
|
||||||
|
base = f"{base} | reset_in={self.reset_timeout}s"
|
||||||
|
return base
|
||||||
|
|
||||||
|
|
||||||
|
class MCPValidationError(MCPError):
|
||||||
|
"""Raised when tool arguments fail validation."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
*,
|
||||||
|
tool_name: str | None = None,
|
||||||
|
field_errors: dict[str, str] | None = None,
|
||||||
|
details: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(message, details=details)
|
||||||
|
self.tool_name = tool_name
|
||||||
|
self.field_errors = field_errors or {}
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
base = super().__str__()
|
||||||
|
if self.tool_name:
|
||||||
|
base = f"{base} | tool={self.tool_name}"
|
||||||
|
if self.field_errors:
|
||||||
|
base = f"{base} | fields={list(self.field_errors.keys())}"
|
||||||
|
return base
|
||||||
305
backend/app/services/mcp/registry.py
Normal file
305
backend/app/services/mcp/registry.py
Normal file
@@ -0,0 +1,305 @@
|
|||||||
|
"""
|
||||||
|
MCP Server Registry
|
||||||
|
|
||||||
|
Thread-safe singleton registry for managing MCP server configurations
|
||||||
|
and their capabilities.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from threading import Lock
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .config import MCPConfig, MCPServerConfig, load_mcp_config
|
||||||
|
from .exceptions import MCPServerNotFoundError
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ServerCapabilities:
|
||||||
|
"""Cached capabilities for an MCP server."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tools: list[dict[str, Any]] | None = None,
|
||||||
|
resources: list[dict[str, Any]] | None = None,
|
||||||
|
prompts: list[dict[str, Any]] | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.tools = tools or []
|
||||||
|
self.resources = resources or []
|
||||||
|
self.prompts = prompts or []
|
||||||
|
self._loaded = False
|
||||||
|
self._load_time: float | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_loaded(self) -> bool:
|
||||||
|
"""Check if capabilities have been loaded."""
|
||||||
|
return self._loaded
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tool_names(self) -> list[str]:
|
||||||
|
"""Get list of tool names."""
|
||||||
|
return [t.get("name", "") for t in self.tools if t.get("name")]
|
||||||
|
|
||||||
|
def mark_loaded(self) -> None:
|
||||||
|
"""Mark capabilities as loaded."""
|
||||||
|
import time
|
||||||
|
|
||||||
|
self._loaded = True
|
||||||
|
self._load_time = time.time()
|
||||||
|
|
||||||
|
|
||||||
|
class MCPServerRegistry:
|
||||||
|
"""
|
||||||
|
Thread-safe singleton registry for MCP servers.
|
||||||
|
|
||||||
|
Manages server configurations and caches their capabilities.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_instance: "MCPServerRegistry | None" = None
|
||||||
|
_lock = Lock()
|
||||||
|
|
||||||
|
def __new__(cls) -> "MCPServerRegistry":
|
||||||
|
"""Ensure singleton pattern."""
|
||||||
|
with cls._lock:
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super().__new__(cls)
|
||||||
|
cls._instance._initialized = False
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""Initialize registry (only runs once due to singleton)."""
|
||||||
|
if getattr(self, "_initialized", False):
|
||||||
|
return
|
||||||
|
|
||||||
|
self._config: MCPConfig = MCPConfig()
|
||||||
|
self._capabilities: dict[str, ServerCapabilities] = {}
|
||||||
|
self._capabilities_lock = asyncio.Lock()
|
||||||
|
self._initialized = True
|
||||||
|
|
||||||
|
logger.info("MCP Server Registry initialized")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_instance(cls) -> "MCPServerRegistry":
|
||||||
|
"""Get the singleton registry instance."""
|
||||||
|
return cls()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def reset_instance(cls) -> None:
|
||||||
|
"""Reset the singleton (for testing)."""
|
||||||
|
with cls._lock:
|
||||||
|
cls._instance = None
|
||||||
|
|
||||||
|
def load_config(self, config: MCPConfig | None = None) -> None:
|
||||||
|
"""
|
||||||
|
Load configuration into the registry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Optional config to load. If None, loads from default path.
|
||||||
|
"""
|
||||||
|
if config is None:
|
||||||
|
config = load_mcp_config()
|
||||||
|
|
||||||
|
self._config = config
|
||||||
|
self._capabilities.clear()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Loaded MCP configuration with %d servers",
|
||||||
|
len(config.mcp_servers),
|
||||||
|
)
|
||||||
|
for name in config.list_server_names():
|
||||||
|
logger.debug("Registered MCP server: %s", name)
|
||||||
|
|
||||||
|
def register(self, name: str, config: MCPServerConfig) -> None:
|
||||||
|
"""
|
||||||
|
Register a new MCP server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Unique server name
|
||||||
|
config: Server configuration
|
||||||
|
"""
|
||||||
|
self._config.mcp_servers[name] = config
|
||||||
|
self._capabilities.pop(name, None) # Clear any cached capabilities
|
||||||
|
|
||||||
|
logger.info("Registered MCP server: %s at %s", name, config.url)
|
||||||
|
|
||||||
|
def unregister(self, name: str) -> bool:
|
||||||
|
"""
|
||||||
|
Unregister an MCP server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Server name to unregister
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if server was found and removed
|
||||||
|
"""
|
||||||
|
if name in self._config.mcp_servers:
|
||||||
|
del self._config.mcp_servers[name]
|
||||||
|
self._capabilities.pop(name, None)
|
||||||
|
logger.info("Unregistered MCP server: %s", name)
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get(self, name: str) -> MCPServerConfig:
|
||||||
|
"""
|
||||||
|
Get a server configuration by name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Server name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Server configuration
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
MCPServerNotFoundError: If server is not registered
|
||||||
|
"""
|
||||||
|
config = self._config.get_server(name)
|
||||||
|
if config is None:
|
||||||
|
raise MCPServerNotFoundError(
|
||||||
|
server_name=name,
|
||||||
|
available_servers=self.list_servers(),
|
||||||
|
)
|
||||||
|
return config
|
||||||
|
|
||||||
|
def get_or_none(self, name: str) -> MCPServerConfig | None:
|
||||||
|
"""
|
||||||
|
Get a server configuration by name, or None if not found.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Server name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Server configuration or None
|
||||||
|
"""
|
||||||
|
return self._config.get_server(name)
|
||||||
|
|
||||||
|
def list_servers(self) -> list[str]:
|
||||||
|
"""Get list of all registered server names."""
|
||||||
|
return self._config.list_server_names()
|
||||||
|
|
||||||
|
def list_enabled_servers(self) -> list[str]:
|
||||||
|
"""Get list of enabled server names."""
|
||||||
|
return list(self._config.get_enabled_servers().keys())
|
||||||
|
|
||||||
|
def get_all_configs(self) -> dict[str, MCPServerConfig]:
|
||||||
|
"""Get all server configurations."""
|
||||||
|
return dict(self._config.mcp_servers)
|
||||||
|
|
||||||
|
def get_enabled_configs(self) -> dict[str, MCPServerConfig]:
|
||||||
|
"""Get all enabled server configurations."""
|
||||||
|
return self._config.get_enabled_servers()
|
||||||
|
|
||||||
|
async def get_capabilities(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
force_refresh: bool = False,
|
||||||
|
) -> ServerCapabilities:
|
||||||
|
"""
|
||||||
|
Get capabilities for a server (lazy-loaded and cached).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Server name
|
||||||
|
force_refresh: If True, refresh cached capabilities
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Server capabilities
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
MCPServerNotFoundError: If server is not registered
|
||||||
|
"""
|
||||||
|
# Verify server exists
|
||||||
|
self.get(name)
|
||||||
|
|
||||||
|
async with self._capabilities_lock:
|
||||||
|
if name not in self._capabilities or force_refresh:
|
||||||
|
# Will be populated by connection manager when connecting
|
||||||
|
self._capabilities[name] = ServerCapabilities()
|
||||||
|
|
||||||
|
return self._capabilities[name]
|
||||||
|
|
||||||
|
def set_capabilities(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
tools: list[dict[str, Any]] | None = None,
|
||||||
|
resources: list[dict[str, Any]] | None = None,
|
||||||
|
prompts: list[dict[str, Any]] | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Set capabilities for a server (called by connection manager).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Server name
|
||||||
|
tools: List of tool definitions
|
||||||
|
resources: List of resource definitions
|
||||||
|
prompts: List of prompt definitions
|
||||||
|
"""
|
||||||
|
capabilities = ServerCapabilities(
|
||||||
|
tools=tools,
|
||||||
|
resources=resources,
|
||||||
|
prompts=prompts,
|
||||||
|
)
|
||||||
|
capabilities.mark_loaded()
|
||||||
|
self._capabilities[name] = capabilities
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"Updated capabilities for %s: %d tools, %d resources, %d prompts",
|
||||||
|
name,
|
||||||
|
len(capabilities.tools),
|
||||||
|
len(capabilities.resources),
|
||||||
|
len(capabilities.prompts),
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_cached_capabilities(self, name: str) -> ServerCapabilities:
|
||||||
|
"""
|
||||||
|
Get cached capabilities without async loading.
|
||||||
|
|
||||||
|
Use this for synchronous access when you only need
|
||||||
|
cached values (e.g., for health check responses).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Server name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cached capabilities or empty ServerCapabilities
|
||||||
|
"""
|
||||||
|
return self._capabilities.get(name, ServerCapabilities())
|
||||||
|
|
||||||
|
def find_server_for_tool(self, tool_name: str) -> str | None:
|
||||||
|
"""
|
||||||
|
Find which server provides a specific tool.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: Name of the tool to find
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Server name or None if not found
|
||||||
|
"""
|
||||||
|
for name, caps in self._capabilities.items():
|
||||||
|
if tool_name in caps.tool_names:
|
||||||
|
return name
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_all_tools(self) -> dict[str, list[dict[str, Any]]]:
|
||||||
|
"""
|
||||||
|
Get all tools from all servers.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict mapping server name to list of tool definitions
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
name: caps.tools
|
||||||
|
for name, caps in self._capabilities.items()
|
||||||
|
if caps.is_loaded
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def global_config(self) -> MCPConfig:
|
||||||
|
"""Get the global MCP configuration."""
|
||||||
|
return self._config
|
||||||
|
|
||||||
|
|
||||||
|
# Module-level convenience function
|
||||||
|
def get_registry() -> MCPServerRegistry:
|
||||||
|
"""Get the global MCP server registry instance."""
|
||||||
|
return MCPServerRegistry.get_instance()
|
||||||
619
backend/app/services/mcp/routing.py
Normal file
619
backend/app/services/mcp/routing.py
Normal file
@@ -0,0 +1,619 @@
|
|||||||
|
"""
|
||||||
|
MCP Tool Call Routing
|
||||||
|
|
||||||
|
Routes tool calls to appropriate servers with retry logic,
|
||||||
|
circuit breakers, and request/response serialization.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .config import MCPServerConfig
|
||||||
|
from .connection import ConnectionPool, MCPConnection
|
||||||
|
from .exceptions import (
|
||||||
|
MCPCircuitOpenError,
|
||||||
|
MCPError,
|
||||||
|
MCPTimeoutError,
|
||||||
|
MCPToolError,
|
||||||
|
MCPToolNotFoundError,
|
||||||
|
)
|
||||||
|
from .registry import MCPServerRegistry
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CircuitState(Enum):
|
||||||
|
"""Circuit breaker states."""
|
||||||
|
|
||||||
|
CLOSED = "closed"
|
||||||
|
OPEN = "open"
|
||||||
|
HALF_OPEN = "half-open"
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncCircuitBreaker:
|
||||||
|
"""
|
||||||
|
Async-compatible circuit breaker implementation.
|
||||||
|
|
||||||
|
Unlike pybreaker which wraps sync functions, this implementation
|
||||||
|
provides explicit success/failure tracking for async code.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
fail_max: int = 5,
|
||||||
|
reset_timeout: float = 30.0,
|
||||||
|
name: str = "",
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize circuit breaker.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fail_max: Maximum failures before opening circuit
|
||||||
|
reset_timeout: Seconds to wait before trying again
|
||||||
|
name: Name for logging
|
||||||
|
"""
|
||||||
|
self.fail_max = fail_max
|
||||||
|
self.reset_timeout = reset_timeout
|
||||||
|
self.name = name
|
||||||
|
self._state = CircuitState.CLOSED
|
||||||
|
self._fail_counter = 0
|
||||||
|
self._last_failure_time: float | None = None
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_state(self) -> str:
|
||||||
|
"""Get current state as string."""
|
||||||
|
# Check if we should transition from OPEN to HALF_OPEN
|
||||||
|
if self._state == CircuitState.OPEN:
|
||||||
|
if self._should_try_reset():
|
||||||
|
return CircuitState.HALF_OPEN.value
|
||||||
|
return self._state.value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def fail_counter(self) -> int:
|
||||||
|
"""Get current failure count."""
|
||||||
|
return self._fail_counter
|
||||||
|
|
||||||
|
def _should_try_reset(self) -> bool:
|
||||||
|
"""Check if enough time has passed to try resetting."""
|
||||||
|
if self._last_failure_time is None:
|
||||||
|
return True
|
||||||
|
return (time.time() - self._last_failure_time) >= self.reset_timeout
|
||||||
|
|
||||||
|
async def success(self) -> None:
|
||||||
|
"""Record a successful call."""
|
||||||
|
async with self._lock:
|
||||||
|
self._fail_counter = 0
|
||||||
|
self._state = CircuitState.CLOSED
|
||||||
|
self._last_failure_time = None
|
||||||
|
|
||||||
|
async def failure(self) -> None:
|
||||||
|
"""Record a failed call."""
|
||||||
|
async with self._lock:
|
||||||
|
self._fail_counter += 1
|
||||||
|
self._last_failure_time = time.time()
|
||||||
|
|
||||||
|
if self._fail_counter >= self.fail_max:
|
||||||
|
self._state = CircuitState.OPEN
|
||||||
|
logger.warning(
|
||||||
|
"Circuit breaker %s opened after %d failures",
|
||||||
|
self.name,
|
||||||
|
self._fail_counter,
|
||||||
|
)
|
||||||
|
|
||||||
|
def is_open(self) -> bool:
|
||||||
|
"""Check if circuit is open (not allowing calls)."""
|
||||||
|
if self._state == CircuitState.OPEN:
|
||||||
|
return not self._should_try_reset()
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def reset(self) -> None:
|
||||||
|
"""Manually reset the circuit breaker."""
|
||||||
|
async with self._lock:
|
||||||
|
self._state = CircuitState.CLOSED
|
||||||
|
self._fail_counter = 0
|
||||||
|
self._last_failure_time = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolInfo:
|
||||||
|
"""Information about an available tool."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
description: str | None = None
|
||||||
|
server_name: str | None = None
|
||||||
|
input_schema: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert to dictionary."""
|
||||||
|
return {
|
||||||
|
"name": self.name,
|
||||||
|
"description": self.description,
|
||||||
|
"server_name": self.server_name,
|
||||||
|
"input_schema": self.input_schema,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolResult:
|
||||||
|
"""Result of a tool execution."""
|
||||||
|
|
||||||
|
success: bool
|
||||||
|
data: Any = None
|
||||||
|
error: str | None = None
|
||||||
|
error_code: str | None = None
|
||||||
|
tool_name: str | None = None
|
||||||
|
server_name: str | None = None
|
||||||
|
execution_time_ms: float = 0.0
|
||||||
|
request_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert to dictionary."""
|
||||||
|
return {
|
||||||
|
"success": self.success,
|
||||||
|
"data": self.data,
|
||||||
|
"error": self.error,
|
||||||
|
"error_code": self.error_code,
|
||||||
|
"tool_name": self.tool_name,
|
||||||
|
"server_name": self.server_name,
|
||||||
|
"execution_time_ms": self.execution_time_ms,
|
||||||
|
"request_id": self.request_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ToolRouter:
|
||||||
|
"""
|
||||||
|
Routes tool calls to the appropriate MCP server.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Tool name to server mapping
|
||||||
|
- Retry logic with exponential backoff
|
||||||
|
- Circuit breaker pattern for fault tolerance
|
||||||
|
- Request/response serialization
|
||||||
|
- Execution timing and metrics
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
registry: MCPServerRegistry,
|
||||||
|
connection_pool: ConnectionPool,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize the tool router.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
registry: MCP server registry
|
||||||
|
connection_pool: Connection pool for servers
|
||||||
|
"""
|
||||||
|
self._registry = registry
|
||||||
|
self._pool = connection_pool
|
||||||
|
self._circuit_breakers: dict[str, AsyncCircuitBreaker] = {}
|
||||||
|
self._tool_to_server: dict[str, str] = {}
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
def _get_circuit_breaker(
|
||||||
|
self,
|
||||||
|
server_name: str,
|
||||||
|
config: MCPServerConfig,
|
||||||
|
) -> AsyncCircuitBreaker:
|
||||||
|
"""Get or create a circuit breaker for a server."""
|
||||||
|
if server_name not in self._circuit_breakers:
|
||||||
|
self._circuit_breakers[server_name] = AsyncCircuitBreaker(
|
||||||
|
fail_max=config.circuit_breaker_threshold,
|
||||||
|
reset_timeout=config.circuit_breaker_timeout,
|
||||||
|
name=f"mcp-{server_name}",
|
||||||
|
)
|
||||||
|
return self._circuit_breakers[server_name]
|
||||||
|
|
||||||
|
async def register_tool_mapping(
|
||||||
|
self,
|
||||||
|
tool_name: str,
|
||||||
|
server_name: str,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Register a mapping from tool name to server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: Name of the tool
|
||||||
|
server_name: Name of the server providing the tool
|
||||||
|
"""
|
||||||
|
async with self._lock:
|
||||||
|
self._tool_to_server[tool_name] = server_name
|
||||||
|
logger.debug("Registered tool %s -> server %s", tool_name, server_name)
|
||||||
|
|
||||||
|
async def discover_tools(self) -> None:
|
||||||
|
"""
|
||||||
|
Discover all tools from registered servers and build mappings.
|
||||||
|
"""
|
||||||
|
for server_name in self._registry.list_enabled_servers():
|
||||||
|
try:
|
||||||
|
config = self._registry.get(server_name)
|
||||||
|
connection = await self._pool.get_connection(server_name, config)
|
||||||
|
|
||||||
|
# Fetch tools from server
|
||||||
|
tools = await self._fetch_tools_from_server(connection)
|
||||||
|
|
||||||
|
# Update registry with capabilities
|
||||||
|
self._registry.set_capabilities(
|
||||||
|
server_name,
|
||||||
|
tools=[t.to_dict() for t in tools],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update tool mappings
|
||||||
|
for tool in tools:
|
||||||
|
await self.register_tool_mapping(tool.name, server_name)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Discovered %d tools from server %s",
|
||||||
|
len(tools),
|
||||||
|
server_name,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to discover tools from %s: %s",
|
||||||
|
server_name,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _fetch_tools_from_server(
|
||||||
|
self,
|
||||||
|
connection: MCPConnection,
|
||||||
|
) -> list[ToolInfo]:
|
||||||
|
"""Fetch available tools from an MCP server."""
|
||||||
|
try:
|
||||||
|
response = await connection.execute_request(
|
||||||
|
"GET",
|
||||||
|
"/mcp/tools",
|
||||||
|
)
|
||||||
|
|
||||||
|
tools = []
|
||||||
|
for tool_data in response.get("tools", []):
|
||||||
|
tools.append(
|
||||||
|
ToolInfo(
|
||||||
|
name=tool_data.get("name", ""),
|
||||||
|
description=tool_data.get("description"),
|
||||||
|
server_name=connection.server_name,
|
||||||
|
input_schema=tool_data.get("inputSchema"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return tools
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
"Error fetching tools from %s: %s",
|
||||||
|
connection.server_name,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
|
||||||
|
def find_server_for_tool(self, tool_name: str) -> str | None:
|
||||||
|
"""
|
||||||
|
Find which server provides a specific tool.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: Name of the tool
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Server name or None if not found
|
||||||
|
"""
|
||||||
|
return self._tool_to_server.get(tool_name)
|
||||||
|
|
||||||
|
async def call_tool(
|
||||||
|
self,
|
||||||
|
server_name: str,
|
||||||
|
tool_name: str,
|
||||||
|
arguments: dict[str, Any] | None = None,
|
||||||
|
timeout: float | None = None,
|
||||||
|
) -> ToolResult:
|
||||||
|
"""
|
||||||
|
Call a tool on a specific server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server_name: Name of the MCP server
|
||||||
|
tool_name: Name of the tool to call
|
||||||
|
arguments: Tool arguments
|
||||||
|
timeout: Optional timeout override
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tool execution result
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
request_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"Tool call [%s]: %s.%s with args %s",
|
||||||
|
request_id,
|
||||||
|
server_name,
|
||||||
|
tool_name,
|
||||||
|
arguments,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
config = self._registry.get(server_name)
|
||||||
|
circuit_breaker = self._get_circuit_breaker(server_name, config)
|
||||||
|
|
||||||
|
# Check circuit breaker state
|
||||||
|
if circuit_breaker.is_open():
|
||||||
|
raise MCPCircuitOpenError(
|
||||||
|
server_name=server_name,
|
||||||
|
failure_count=circuit_breaker.fail_counter,
|
||||||
|
reset_timeout=config.circuit_breaker_timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute with retry logic
|
||||||
|
result = await self._execute_with_retry(
|
||||||
|
server_name=server_name,
|
||||||
|
config=config,
|
||||||
|
tool_name=tool_name,
|
||||||
|
arguments=arguments or {},
|
||||||
|
timeout=timeout,
|
||||||
|
circuit_breaker=circuit_breaker,
|
||||||
|
)
|
||||||
|
|
||||||
|
execution_time = (time.time() - start_time) * 1000
|
||||||
|
|
||||||
|
return ToolResult(
|
||||||
|
success=True,
|
||||||
|
data=result,
|
||||||
|
tool_name=tool_name,
|
||||||
|
server_name=server_name,
|
||||||
|
execution_time_ms=execution_time,
|
||||||
|
request_id=request_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
except MCPCircuitOpenError:
|
||||||
|
raise
|
||||||
|
except MCPError as e:
|
||||||
|
execution_time = (time.time() - start_time) * 1000
|
||||||
|
logger.error(
|
||||||
|
"Tool call failed [%s]: %s.%s - %s",
|
||||||
|
request_id,
|
||||||
|
server_name,
|
||||||
|
tool_name,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
return ToolResult(
|
||||||
|
success=False,
|
||||||
|
error=str(e),
|
||||||
|
error_code=type(e).__name__,
|
||||||
|
tool_name=tool_name,
|
||||||
|
server_name=server_name,
|
||||||
|
execution_time_ms=execution_time,
|
||||||
|
request_id=request_id,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
execution_time = (time.time() - start_time) * 1000
|
||||||
|
logger.exception(
|
||||||
|
"Unexpected error in tool call [%s]: %s.%s",
|
||||||
|
request_id,
|
||||||
|
server_name,
|
||||||
|
tool_name,
|
||||||
|
)
|
||||||
|
return ToolResult(
|
||||||
|
success=False,
|
||||||
|
error=str(e),
|
||||||
|
error_code="UnexpectedError",
|
||||||
|
tool_name=tool_name,
|
||||||
|
server_name=server_name,
|
||||||
|
execution_time_ms=execution_time,
|
||||||
|
request_id=request_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _execute_with_retry(
|
||||||
|
self,
|
||||||
|
server_name: str,
|
||||||
|
config: MCPServerConfig,
|
||||||
|
tool_name: str,
|
||||||
|
arguments: dict[str, Any],
|
||||||
|
timeout: float | None,
|
||||||
|
circuit_breaker: AsyncCircuitBreaker,
|
||||||
|
) -> Any:
|
||||||
|
"""Execute tool call with retry logic."""
|
||||||
|
last_error: Exception | None = None
|
||||||
|
attempts = 0
|
||||||
|
max_attempts = config.retry_attempts + 1 # +1 for initial attempt
|
||||||
|
|
||||||
|
while attempts < max_attempts:
|
||||||
|
attempts += 1
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Use circuit breaker to track failures
|
||||||
|
result = await self._execute_tool_call(
|
||||||
|
server_name=server_name,
|
||||||
|
config=config,
|
||||||
|
tool_name=tool_name,
|
||||||
|
arguments=arguments,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Success - record it
|
||||||
|
await circuit_breaker.success()
|
||||||
|
return result
|
||||||
|
|
||||||
|
except MCPCircuitOpenError:
|
||||||
|
raise
|
||||||
|
except MCPTimeoutError:
|
||||||
|
# Timeout - don't retry
|
||||||
|
await circuit_breaker.failure()
|
||||||
|
raise
|
||||||
|
except MCPToolError:
|
||||||
|
# Tool-level error - don't retry (user error)
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
last_error = e
|
||||||
|
await circuit_breaker.failure()
|
||||||
|
|
||||||
|
if attempts < max_attempts:
|
||||||
|
delay = self._calculate_retry_delay(attempts, config)
|
||||||
|
logger.warning(
|
||||||
|
"Tool call attempt %d/%d failed for %s.%s: %s. "
|
||||||
|
"Retrying in %.1fs",
|
||||||
|
attempts,
|
||||||
|
max_attempts,
|
||||||
|
server_name,
|
||||||
|
tool_name,
|
||||||
|
e,
|
||||||
|
delay,
|
||||||
|
)
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
|
||||||
|
# All attempts failed
|
||||||
|
raise MCPToolError(
|
||||||
|
f"Tool call failed after {max_attempts} attempts",
|
||||||
|
server_name=server_name,
|
||||||
|
tool_name=tool_name,
|
||||||
|
tool_args=arguments,
|
||||||
|
details={"last_error": str(last_error)},
|
||||||
|
)
|
||||||
|
|
||||||
|
def _calculate_retry_delay(
|
||||||
|
self,
|
||||||
|
attempt: int,
|
||||||
|
config: MCPServerConfig,
|
||||||
|
) -> float:
|
||||||
|
"""Calculate exponential backoff delay with jitter."""
|
||||||
|
import random
|
||||||
|
|
||||||
|
delay = config.retry_delay * (2 ** (attempt - 1))
|
||||||
|
delay = min(delay, config.retry_max_delay)
|
||||||
|
# Add jitter (±25%)
|
||||||
|
jitter = delay * 0.25 * (random.random() * 2 - 1)
|
||||||
|
return max(0.1, delay + jitter)
|
||||||
|
|
||||||
|
async def _execute_tool_call(
|
||||||
|
self,
|
||||||
|
server_name: str,
|
||||||
|
config: MCPServerConfig,
|
||||||
|
tool_name: str,
|
||||||
|
arguments: dict[str, Any],
|
||||||
|
timeout: float | None,
|
||||||
|
) -> Any:
|
||||||
|
"""Execute a single tool call."""
|
||||||
|
connection = await self._pool.get_connection(server_name, config)
|
||||||
|
|
||||||
|
# Build MCP tool call request
|
||||||
|
request_body = {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"method": "tools/call",
|
||||||
|
"params": {
|
||||||
|
"name": tool_name,
|
||||||
|
"arguments": arguments,
|
||||||
|
},
|
||||||
|
"id": str(uuid.uuid4()),
|
||||||
|
}
|
||||||
|
|
||||||
|
response = await connection.execute_request(
|
||||||
|
method="POST",
|
||||||
|
path="/mcp",
|
||||||
|
data=request_body,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle JSON-RPC response
|
||||||
|
if "error" in response:
|
||||||
|
error = response["error"]
|
||||||
|
raise MCPToolError(
|
||||||
|
error.get("message", "Tool execution failed"),
|
||||||
|
server_name=server_name,
|
||||||
|
tool_name=tool_name,
|
||||||
|
tool_args=arguments,
|
||||||
|
error_code=str(error.get("code", "UNKNOWN")),
|
||||||
|
)
|
||||||
|
|
||||||
|
return response.get("result")
|
||||||
|
|
||||||
|
async def route_tool(
|
||||||
|
self,
|
||||||
|
tool_name: str,
|
||||||
|
arguments: dict[str, Any] | None = None,
|
||||||
|
timeout: float | None = None,
|
||||||
|
) -> ToolResult:
|
||||||
|
"""
|
||||||
|
Route a tool call to the appropriate server.
|
||||||
|
|
||||||
|
Automatically discovers which server provides the tool.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: Name of the tool to call
|
||||||
|
arguments: Tool arguments
|
||||||
|
timeout: Optional timeout override
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tool execution result
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
MCPToolNotFoundError: If no server provides the tool
|
||||||
|
"""
|
||||||
|
server_name = self.find_server_for_tool(tool_name)
|
||||||
|
|
||||||
|
if server_name is None:
|
||||||
|
# Try to find from registry
|
||||||
|
server_name = self._registry.find_server_for_tool(tool_name)
|
||||||
|
|
||||||
|
if server_name is None:
|
||||||
|
raise MCPToolNotFoundError(
|
||||||
|
tool_name=tool_name,
|
||||||
|
available_tools=list(self._tool_to_server.keys()),
|
||||||
|
)
|
||||||
|
|
||||||
|
return await self.call_tool(
|
||||||
|
server_name=server_name,
|
||||||
|
tool_name=tool_name,
|
||||||
|
arguments=arguments,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def list_all_tools(self) -> list[ToolInfo]:
|
||||||
|
"""
|
||||||
|
Get all available tools from all servers.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of tool information
|
||||||
|
"""
|
||||||
|
tools = []
|
||||||
|
all_server_tools = self._registry.get_all_tools()
|
||||||
|
|
||||||
|
for server_name, server_tools in all_server_tools.items():
|
||||||
|
for tool_data in server_tools:
|
||||||
|
tools.append(
|
||||||
|
ToolInfo(
|
||||||
|
name=tool_data.get("name", ""),
|
||||||
|
description=tool_data.get("description"),
|
||||||
|
server_name=server_name,
|
||||||
|
input_schema=tool_data.get("input_schema"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return tools
|
||||||
|
|
||||||
|
def get_circuit_breaker_status(self) -> dict[str, dict[str, Any]]:
|
||||||
|
"""Get status of all circuit breakers."""
|
||||||
|
return {
|
||||||
|
name: {
|
||||||
|
"state": cb.current_state,
|
||||||
|
"failure_count": cb.fail_counter,
|
||||||
|
}
|
||||||
|
for name, cb in self._circuit_breakers.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
async def reset_circuit_breaker(self, server_name: str) -> bool:
|
||||||
|
"""
|
||||||
|
Manually reset a circuit breaker.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server_name: Name of the server
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if circuit breaker was reset
|
||||||
|
"""
|
||||||
|
async with self._lock:
|
||||||
|
if server_name in self._circuit_breakers:
|
||||||
|
# Reset by removing (will be recreated on next call)
|
||||||
|
del self._circuit_breakers[server_name]
|
||||||
|
logger.info("Reset circuit breaker for %s", server_name)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
141
backend/app/services/memory/__init__.py
Normal file
141
backend/app/services/memory/__init__.py
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
"""
|
||||||
|
Agent Memory System
|
||||||
|
|
||||||
|
Multi-tier cognitive memory for AI agents, providing:
|
||||||
|
- Working Memory: Session-scoped ephemeral state (Redis/In-memory)
|
||||||
|
- Episodic Memory: Experiential records of past tasks (PostgreSQL)
|
||||||
|
- Semantic Memory: Learned facts and knowledge (PostgreSQL + pgvector)
|
||||||
|
- Procedural Memory: Learned skills and procedures (PostgreSQL)
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
from app.services.memory import (
|
||||||
|
MemoryManager,
|
||||||
|
MemorySettings,
|
||||||
|
get_memory_settings,
|
||||||
|
MemoryType,
|
||||||
|
ScopeLevel,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a manager for a session
|
||||||
|
manager = MemoryManager.for_session(
|
||||||
|
session_id="sess-123",
|
||||||
|
project_id=uuid,
|
||||||
|
)
|
||||||
|
|
||||||
|
async with manager:
|
||||||
|
# Working memory
|
||||||
|
await manager.set_working("key", {"data": "value"})
|
||||||
|
value = await manager.get_working("key")
|
||||||
|
|
||||||
|
# Episodic memory
|
||||||
|
episode = await manager.record_episode(episode_data)
|
||||||
|
similar = await manager.search_episodes("query")
|
||||||
|
|
||||||
|
# Semantic memory
|
||||||
|
fact = await manager.store_fact(fact_data)
|
||||||
|
facts = await manager.search_facts("query")
|
||||||
|
|
||||||
|
# Procedural memory
|
||||||
|
procedure = await manager.record_procedure(procedure_data)
|
||||||
|
procedures = await manager.find_procedures("context")
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Configuration
|
||||||
|
from .config import (
|
||||||
|
MemorySettings,
|
||||||
|
get_default_settings,
|
||||||
|
get_memory_settings,
|
||||||
|
reset_memory_settings,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Exceptions
|
||||||
|
from .exceptions import (
|
||||||
|
CheckpointError,
|
||||||
|
EmbeddingError,
|
||||||
|
MemoryCapacityError,
|
||||||
|
MemoryConflictError,
|
||||||
|
MemoryConsolidationError,
|
||||||
|
MemoryError,
|
||||||
|
MemoryExpiredError,
|
||||||
|
MemoryNotFoundError,
|
||||||
|
MemoryRetrievalError,
|
||||||
|
MemoryScopeError,
|
||||||
|
MemorySerializationError,
|
||||||
|
MemoryStorageError,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Manager
|
||||||
|
from .manager import MemoryManager
|
||||||
|
|
||||||
|
# Types
|
||||||
|
from .types import (
|
||||||
|
ConsolidationStatus,
|
||||||
|
ConsolidationType,
|
||||||
|
Episode,
|
||||||
|
EpisodeCreate,
|
||||||
|
Fact,
|
||||||
|
FactCreate,
|
||||||
|
MemoryItem,
|
||||||
|
MemoryStats,
|
||||||
|
MemoryStore,
|
||||||
|
MemoryType,
|
||||||
|
Outcome,
|
||||||
|
Procedure,
|
||||||
|
ProcedureCreate,
|
||||||
|
RetrievalResult,
|
||||||
|
ScopeContext,
|
||||||
|
ScopeLevel,
|
||||||
|
Step,
|
||||||
|
TaskState,
|
||||||
|
WorkingMemoryItem,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reflection (lazy import available)
|
||||||
|
# Import directly: from app.services.memory.reflection import MemoryReflection
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"CheckpointError",
|
||||||
|
"ConsolidationStatus",
|
||||||
|
"ConsolidationType",
|
||||||
|
"EmbeddingError",
|
||||||
|
"Episode",
|
||||||
|
"EpisodeCreate",
|
||||||
|
"Fact",
|
||||||
|
"FactCreate",
|
||||||
|
"MemoryCapacityError",
|
||||||
|
"MemoryConflictError",
|
||||||
|
"MemoryConsolidationError",
|
||||||
|
# Exceptions
|
||||||
|
"MemoryError",
|
||||||
|
"MemoryExpiredError",
|
||||||
|
"MemoryItem",
|
||||||
|
# Manager
|
||||||
|
"MemoryManager",
|
||||||
|
"MemoryNotFoundError",
|
||||||
|
"MemoryRetrievalError",
|
||||||
|
"MemoryScopeError",
|
||||||
|
"MemorySerializationError",
|
||||||
|
# Configuration
|
||||||
|
"MemorySettings",
|
||||||
|
"MemoryStats",
|
||||||
|
"MemoryStorageError",
|
||||||
|
# Types - Abstract
|
||||||
|
"MemoryStore",
|
||||||
|
# Types - Enums
|
||||||
|
"MemoryType",
|
||||||
|
"Outcome",
|
||||||
|
"Procedure",
|
||||||
|
"ProcedureCreate",
|
||||||
|
"RetrievalResult",
|
||||||
|
# Types - Data Classes
|
||||||
|
"ScopeContext",
|
||||||
|
"ScopeLevel",
|
||||||
|
"Step",
|
||||||
|
"TaskState",
|
||||||
|
"WorkingMemoryItem",
|
||||||
|
"get_default_settings",
|
||||||
|
"get_memory_settings",
|
||||||
|
"reset_memory_settings",
|
||||||
|
# MCP Tools - lazy import to avoid circular dependencies
|
||||||
|
# Import directly: from app.services.memory.mcp import MemoryToolService
|
||||||
|
]
|
||||||
21
backend/app/services/memory/cache/__init__.py
vendored
Normal file
21
backend/app/services/memory/cache/__init__.py
vendored
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
# app/services/memory/cache/__init__.py
|
||||||
|
"""
|
||||||
|
Memory Caching Layer.
|
||||||
|
|
||||||
|
Provides caching for memory operations:
|
||||||
|
- Hot Memory Cache: LRU cache for frequently accessed memories
|
||||||
|
- Embedding Cache: Cache embeddings by content hash
|
||||||
|
- Cache Manager: Unified cache management with invalidation
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .cache_manager import CacheManager, CacheStats, get_cache_manager
|
||||||
|
from .embedding_cache import EmbeddingCache
|
||||||
|
from .hot_cache import HotMemoryCache
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"CacheManager",
|
||||||
|
"CacheStats",
|
||||||
|
"EmbeddingCache",
|
||||||
|
"HotMemoryCache",
|
||||||
|
"get_cache_manager",
|
||||||
|
]
|
||||||
505
backend/app/services/memory/cache/cache_manager.py
vendored
Normal file
505
backend/app/services/memory/cache/cache_manager.py
vendored
Normal file
@@ -0,0 +1,505 @@
|
|||||||
|
# app/services/memory/cache/cache_manager.py
|
||||||
|
"""
|
||||||
|
Cache Manager.
|
||||||
|
|
||||||
|
Unified cache management for memory operations.
|
||||||
|
Coordinates hot cache, embedding cache, and retrieval cache.
|
||||||
|
Provides centralized invalidation and statistics.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from app.services.memory.config import get_memory_settings
|
||||||
|
|
||||||
|
from .embedding_cache import EmbeddingCache, create_embedding_cache
|
||||||
|
from .hot_cache import CacheKey, HotMemoryCache, create_hot_cache
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from redis.asyncio import Redis
|
||||||
|
|
||||||
|
from app.services.memory.indexing.retrieval import RetrievalCache
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _utcnow() -> datetime:
|
||||||
|
"""Get current UTC time as timezone-aware datetime."""
|
||||||
|
return datetime.now(UTC)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CacheStats:
|
||||||
|
"""Aggregated cache statistics."""
|
||||||
|
|
||||||
|
hot_cache: dict[str, Any] = field(default_factory=dict)
|
||||||
|
embedding_cache: dict[str, Any] = field(default_factory=dict)
|
||||||
|
retrieval_cache: dict[str, Any] = field(default_factory=dict)
|
||||||
|
overall_hit_rate: float = 0.0
|
||||||
|
last_cleanup: datetime | None = None
|
||||||
|
cleanup_count: int = 0
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert to dictionary."""
|
||||||
|
return {
|
||||||
|
"hot_cache": self.hot_cache,
|
||||||
|
"embedding_cache": self.embedding_cache,
|
||||||
|
"retrieval_cache": self.retrieval_cache,
|
||||||
|
"overall_hit_rate": self.overall_hit_rate,
|
||||||
|
"last_cleanup": self.last_cleanup.isoformat()
|
||||||
|
if self.last_cleanup
|
||||||
|
else None,
|
||||||
|
"cleanup_count": self.cleanup_count,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class CacheManager:
|
||||||
|
"""
|
||||||
|
Unified cache manager for memory operations.
|
||||||
|
|
||||||
|
Provides:
|
||||||
|
- Centralized cache configuration
|
||||||
|
- Coordinated invalidation across caches
|
||||||
|
- Aggregated statistics
|
||||||
|
- Automatic cleanup scheduling
|
||||||
|
|
||||||
|
Performance targets:
|
||||||
|
- Overall cache hit rate > 80%
|
||||||
|
- Cache operations < 1ms (memory), < 5ms (Redis)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hot_cache: HotMemoryCache[Any] | None = None,
|
||||||
|
embedding_cache: EmbeddingCache | None = None,
|
||||||
|
retrieval_cache: "RetrievalCache | None" = None,
|
||||||
|
redis: "Redis | None" = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize the cache manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hot_cache: Optional pre-configured hot cache
|
||||||
|
embedding_cache: Optional pre-configured embedding cache
|
||||||
|
retrieval_cache: Optional pre-configured retrieval cache
|
||||||
|
redis: Optional Redis connection for persistence
|
||||||
|
"""
|
||||||
|
self._settings = get_memory_settings()
|
||||||
|
self._redis = redis
|
||||||
|
self._enabled = self._settings.cache_enabled
|
||||||
|
|
||||||
|
# Initialize caches
|
||||||
|
if hot_cache:
|
||||||
|
self._hot_cache = hot_cache
|
||||||
|
else:
|
||||||
|
self._hot_cache = create_hot_cache(
|
||||||
|
max_size=self._settings.cache_max_items,
|
||||||
|
default_ttl_seconds=self._settings.cache_ttl_seconds,
|
||||||
|
)
|
||||||
|
|
||||||
|
if embedding_cache:
|
||||||
|
self._embedding_cache = embedding_cache
|
||||||
|
else:
|
||||||
|
self._embedding_cache = create_embedding_cache(
|
||||||
|
max_size=self._settings.cache_max_items,
|
||||||
|
default_ttl_seconds=self._settings.cache_ttl_seconds
|
||||||
|
* 12, # 1hr for embeddings
|
||||||
|
redis=redis,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._retrieval_cache = retrieval_cache
|
||||||
|
|
||||||
|
# Stats tracking
|
||||||
|
self._last_cleanup: datetime | None = None
|
||||||
|
self._cleanup_count = 0
|
||||||
|
self._lock = threading.RLock()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Initialized CacheManager: enabled={self._enabled}, "
|
||||||
|
f"redis={'connected' if redis else 'disabled'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def set_redis(self, redis: "Redis") -> None:
|
||||||
|
"""Set Redis connection for all caches."""
|
||||||
|
self._redis = redis
|
||||||
|
self._embedding_cache.set_redis(redis)
|
||||||
|
|
||||||
|
def set_retrieval_cache(self, cache: "RetrievalCache") -> None:
|
||||||
|
"""Set retrieval cache instance."""
|
||||||
|
self._retrieval_cache = cache
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_enabled(self) -> bool:
|
||||||
|
"""Check if caching is enabled."""
|
||||||
|
return self._enabled
|
||||||
|
|
||||||
|
@property
|
||||||
|
def hot_cache(self) -> HotMemoryCache[Any]:
|
||||||
|
"""Get the hot memory cache."""
|
||||||
|
return self._hot_cache
|
||||||
|
|
||||||
|
@property
|
||||||
|
def embedding_cache(self) -> EmbeddingCache:
|
||||||
|
"""Get the embedding cache."""
|
||||||
|
return self._embedding_cache
|
||||||
|
|
||||||
|
@property
|
||||||
|
def retrieval_cache(self) -> "RetrievalCache | None":
|
||||||
|
"""Get the retrieval cache."""
|
||||||
|
return self._retrieval_cache
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Hot Memory Cache Operations
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def get_memory(
|
||||||
|
self,
|
||||||
|
memory_type: str,
|
||||||
|
memory_id: UUID | str,
|
||||||
|
scope: str | None = None,
|
||||||
|
) -> Any | None:
|
||||||
|
"""
|
||||||
|
Get a memory from hot cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
memory_type: Type of memory
|
||||||
|
memory_id: Memory ID
|
||||||
|
scope: Optional scope
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cached memory or None
|
||||||
|
"""
|
||||||
|
if not self._enabled:
|
||||||
|
return None
|
||||||
|
return self._hot_cache.get_by_id(memory_type, memory_id, scope)
|
||||||
|
|
||||||
|
def cache_memory(
|
||||||
|
self,
|
||||||
|
memory_type: str,
|
||||||
|
memory_id: UUID | str,
|
||||||
|
memory: Any,
|
||||||
|
scope: str | None = None,
|
||||||
|
ttl_seconds: float | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Cache a memory in hot cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
memory_type: Type of memory
|
||||||
|
memory_id: Memory ID
|
||||||
|
memory: Memory object
|
||||||
|
scope: Optional scope
|
||||||
|
ttl_seconds: Optional TTL override
|
||||||
|
"""
|
||||||
|
if not self._enabled:
|
||||||
|
return
|
||||||
|
self._hot_cache.put_by_id(memory_type, memory_id, memory, scope, ttl_seconds)
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Embedding Cache Operations
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
async def get_embedding(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
model: str = "default",
|
||||||
|
) -> list[float] | None:
|
||||||
|
"""
|
||||||
|
Get a cached embedding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Content text
|
||||||
|
model: Model name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cached embedding or None
|
||||||
|
"""
|
||||||
|
if not self._enabled:
|
||||||
|
return None
|
||||||
|
return await self._embedding_cache.get(content, model)
|
||||||
|
|
||||||
|
async def cache_embedding(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
embedding: list[float],
|
||||||
|
model: str = "default",
|
||||||
|
ttl_seconds: float | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Cache an embedding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Content text
|
||||||
|
embedding: Embedding vector
|
||||||
|
model: Model name
|
||||||
|
ttl_seconds: Optional TTL override
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Content hash
|
||||||
|
"""
|
||||||
|
if not self._enabled:
|
||||||
|
return EmbeddingCache.hash_content(content)
|
||||||
|
return await self._embedding_cache.put(content, embedding, model, ttl_seconds)
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Invalidation
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
async def invalidate_memory(
|
||||||
|
self,
|
||||||
|
memory_type: str,
|
||||||
|
memory_id: UUID | str,
|
||||||
|
scope: str | None = None,
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Invalidate a memory across all caches.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
memory_type: Type of memory
|
||||||
|
memory_id: Memory ID
|
||||||
|
scope: Optional scope
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of entries invalidated
|
||||||
|
"""
|
||||||
|
count = 0
|
||||||
|
|
||||||
|
# Invalidate hot cache
|
||||||
|
if self._hot_cache.invalidate_by_id(memory_type, memory_id, scope):
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
# Invalidate retrieval cache
|
||||||
|
if self._retrieval_cache:
|
||||||
|
uuid_id = (
|
||||||
|
UUID(str(memory_id)) if not isinstance(memory_id, UUID) else memory_id
|
||||||
|
)
|
||||||
|
count += self._retrieval_cache.invalidate_by_memory(uuid_id)
|
||||||
|
|
||||||
|
logger.debug(f"Invalidated {count} cache entries for {memory_type}:{memory_id}")
|
||||||
|
return count
|
||||||
|
|
||||||
|
async def invalidate_by_type(self, memory_type: str) -> int:
|
||||||
|
"""
|
||||||
|
Invalidate all entries of a memory type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
memory_type: Type of memory
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of entries invalidated
|
||||||
|
"""
|
||||||
|
count = self._hot_cache.invalidate_by_type(memory_type)
|
||||||
|
|
||||||
|
if self._retrieval_cache:
|
||||||
|
count += self._retrieval_cache.clear()
|
||||||
|
|
||||||
|
logger.info(f"Invalidated {count} cache entries for type {memory_type}")
|
||||||
|
return count
|
||||||
|
|
||||||
|
async def invalidate_by_scope(self, scope: str) -> int:
|
||||||
|
"""
|
||||||
|
Invalidate all entries in a scope.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scope: Scope to invalidate (e.g., project_id)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of entries invalidated
|
||||||
|
"""
|
||||||
|
count = self._hot_cache.invalidate_by_scope(scope)
|
||||||
|
|
||||||
|
# Retrieval cache doesn't support scope-based invalidation
|
||||||
|
# so we clear it entirely for safety
|
||||||
|
if self._retrieval_cache:
|
||||||
|
count += self._retrieval_cache.clear()
|
||||||
|
|
||||||
|
logger.info(f"Invalidated {count} cache entries for scope {scope}")
|
||||||
|
return count
|
||||||
|
|
||||||
|
async def invalidate_embedding(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
model: str = "default",
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Invalidate a cached embedding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Content text
|
||||||
|
model: Model name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if entry was found and removed
|
||||||
|
"""
|
||||||
|
return await self._embedding_cache.invalidate(content, model)
|
||||||
|
|
||||||
|
async def clear_all(self) -> int:
|
||||||
|
"""
|
||||||
|
Clear all caches.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Total number of entries cleared
|
||||||
|
"""
|
||||||
|
count = 0
|
||||||
|
|
||||||
|
count += self._hot_cache.clear()
|
||||||
|
count += await self._embedding_cache.clear()
|
||||||
|
|
||||||
|
if self._retrieval_cache:
|
||||||
|
count += self._retrieval_cache.clear()
|
||||||
|
|
||||||
|
logger.info(f"Cleared {count} entries from all caches")
|
||||||
|
return count
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Cleanup
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
async def cleanup_expired(self) -> int:
|
||||||
|
"""
|
||||||
|
Clean up expired entries from all caches.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of entries cleaned up
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
count = 0
|
||||||
|
|
||||||
|
count += self._hot_cache.cleanup_expired()
|
||||||
|
count += self._embedding_cache.cleanup_expired()
|
||||||
|
|
||||||
|
# Retrieval cache doesn't have a cleanup method,
|
||||||
|
# but entries expire on access
|
||||||
|
|
||||||
|
self._last_cleanup = _utcnow()
|
||||||
|
self._cleanup_count += 1
|
||||||
|
|
||||||
|
if count > 0:
|
||||||
|
logger.info(f"Cleaned up {count} expired cache entries")
|
||||||
|
|
||||||
|
return count
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Statistics
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def get_stats(self) -> CacheStats:
|
||||||
|
"""
|
||||||
|
Get aggregated cache statistics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CacheStats with all cache metrics
|
||||||
|
"""
|
||||||
|
hot_stats = self._hot_cache.get_stats().to_dict()
|
||||||
|
emb_stats = self._embedding_cache.get_stats().to_dict()
|
||||||
|
|
||||||
|
retrieval_stats: dict[str, Any] = {}
|
||||||
|
if self._retrieval_cache:
|
||||||
|
retrieval_stats = self._retrieval_cache.get_stats()
|
||||||
|
|
||||||
|
# Calculate overall hit rate
|
||||||
|
total_hits = hot_stats.get("hits", 0) + emb_stats.get("hits", 0)
|
||||||
|
total_misses = hot_stats.get("misses", 0) + emb_stats.get("misses", 0)
|
||||||
|
|
||||||
|
if retrieval_stats:
|
||||||
|
# Retrieval cache doesn't track hits/misses the same way
|
||||||
|
pass
|
||||||
|
|
||||||
|
total_requests = total_hits + total_misses
|
||||||
|
overall_hit_rate = total_hits / total_requests if total_requests > 0 else 0.0
|
||||||
|
|
||||||
|
return CacheStats(
|
||||||
|
hot_cache=hot_stats,
|
||||||
|
embedding_cache=emb_stats,
|
||||||
|
retrieval_cache=retrieval_stats,
|
||||||
|
overall_hit_rate=overall_hit_rate,
|
||||||
|
last_cleanup=self._last_cleanup,
|
||||||
|
cleanup_count=self._cleanup_count,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_hot_memories(self, limit: int = 10) -> list[tuple[CacheKey, int]]:
|
||||||
|
"""
|
||||||
|
Get the most frequently accessed memories.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
limit: Maximum number to return
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of (key, access_count) tuples
|
||||||
|
"""
|
||||||
|
return self._hot_cache.get_hot_memories(limit)
|
||||||
|
|
||||||
|
def reset_stats(self) -> None:
|
||||||
|
"""Reset all cache statistics."""
|
||||||
|
self._hot_cache.reset_stats()
|
||||||
|
self._embedding_cache.reset_stats()
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Warmup
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
async def warmup(
|
||||||
|
self,
|
||||||
|
memories: list[tuple[str, UUID | str, Any]],
|
||||||
|
scope: str | None = None,
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Warm up the hot cache with memories.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
memories: List of (memory_type, memory_id, memory) tuples
|
||||||
|
scope: Optional scope for all memories
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of memories cached
|
||||||
|
"""
|
||||||
|
if not self._enabled:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
for memory_type, memory_id, memory in memories:
|
||||||
|
self._hot_cache.put_by_id(memory_type, memory_id, memory, scope)
|
||||||
|
|
||||||
|
logger.info(f"Warmed up cache with {len(memories)} memories")
|
||||||
|
return len(memories)
|
||||||
|
|
||||||
|
|
||||||
|
# Singleton instance
|
||||||
|
_cache_manager: CacheManager | None = None
|
||||||
|
_cache_manager_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
def get_cache_manager(
|
||||||
|
redis: "Redis | None" = None,
|
||||||
|
reset: bool = False,
|
||||||
|
) -> CacheManager:
|
||||||
|
"""
|
||||||
|
Get the global CacheManager instance.
|
||||||
|
|
||||||
|
Thread-safe with double-checked locking pattern.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
redis: Optional Redis connection
|
||||||
|
reset: Force create a new instance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CacheManager instance
|
||||||
|
"""
|
||||||
|
global _cache_manager
|
||||||
|
|
||||||
|
if reset or _cache_manager is None:
|
||||||
|
with _cache_manager_lock:
|
||||||
|
if reset or _cache_manager is None:
|
||||||
|
_cache_manager = CacheManager(redis=redis)
|
||||||
|
|
||||||
|
return _cache_manager
|
||||||
|
|
||||||
|
|
||||||
|
def reset_cache_manager() -> None:
|
||||||
|
"""Reset the global cache manager instance."""
|
||||||
|
global _cache_manager
|
||||||
|
with _cache_manager_lock:
|
||||||
|
_cache_manager = None
|
||||||
623
backend/app/services/memory/cache/embedding_cache.py
vendored
Normal file
623
backend/app/services/memory/cache/embedding_cache.py
vendored
Normal file
@@ -0,0 +1,623 @@
|
|||||||
|
# app/services/memory/cache/embedding_cache.py
|
||||||
|
"""
|
||||||
|
Embedding Cache.
|
||||||
|
|
||||||
|
Caches embeddings by content hash to avoid recomputing.
|
||||||
|
Provides significant performance improvement for repeated content.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
from collections import OrderedDict
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from redis.asyncio import Redis
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _utcnow() -> datetime:
|
||||||
|
"""Get current UTC time as timezone-aware datetime."""
|
||||||
|
return datetime.now(UTC)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EmbeddingEntry:
|
||||||
|
"""A cached embedding entry."""
|
||||||
|
|
||||||
|
embedding: list[float]
|
||||||
|
content_hash: str
|
||||||
|
model: str
|
||||||
|
created_at: datetime
|
||||||
|
ttl_seconds: float = 3600.0 # 1 hour default
|
||||||
|
|
||||||
|
def is_expired(self) -> bool:
|
||||||
|
"""Check if this entry has expired."""
|
||||||
|
age = (_utcnow() - self.created_at).total_seconds()
|
||||||
|
return age > self.ttl_seconds
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EmbeddingCacheStats:
|
||||||
|
"""Statistics for the embedding cache."""
|
||||||
|
|
||||||
|
hits: int = 0
|
||||||
|
misses: int = 0
|
||||||
|
evictions: int = 0
|
||||||
|
expirations: int = 0
|
||||||
|
current_size: int = 0
|
||||||
|
max_size: int = 0
|
||||||
|
bytes_saved: int = 0 # Estimated bytes saved by caching
|
||||||
|
|
||||||
|
@property
|
||||||
|
def hit_rate(self) -> float:
|
||||||
|
"""Calculate cache hit rate."""
|
||||||
|
total = self.hits + self.misses
|
||||||
|
if total == 0:
|
||||||
|
return 0.0
|
||||||
|
return self.hits / total
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert to dictionary."""
|
||||||
|
return {
|
||||||
|
"hits": self.hits,
|
||||||
|
"misses": self.misses,
|
||||||
|
"evictions": self.evictions,
|
||||||
|
"expirations": self.expirations,
|
||||||
|
"current_size": self.current_size,
|
||||||
|
"max_size": self.max_size,
|
||||||
|
"hit_rate": self.hit_rate,
|
||||||
|
"bytes_saved": self.bytes_saved,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingCache:
|
||||||
|
"""
|
||||||
|
Cache for embeddings by content hash.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Content-hash based deduplication
|
||||||
|
- LRU eviction
|
||||||
|
- TTL-based expiration
|
||||||
|
- Optional Redis backing for persistence
|
||||||
|
- Thread-safe operations
|
||||||
|
|
||||||
|
Performance targets:
|
||||||
|
- Cache hit rate > 90% for repeated content
|
||||||
|
- Get/put operations < 1ms (memory), < 5ms (Redis)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_size: int = 50000,
|
||||||
|
default_ttl_seconds: float = 3600.0,
|
||||||
|
redis: "Redis | None" = None,
|
||||||
|
redis_prefix: str = "mem:emb",
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize the embedding cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_size: Maximum number of entries in memory cache
|
||||||
|
default_ttl_seconds: Default TTL for entries (1 hour)
|
||||||
|
redis: Optional Redis connection for persistence
|
||||||
|
redis_prefix: Prefix for Redis keys
|
||||||
|
"""
|
||||||
|
self._max_size = max_size
|
||||||
|
self._default_ttl = default_ttl_seconds
|
||||||
|
self._cache: OrderedDict[str, EmbeddingEntry] = OrderedDict()
|
||||||
|
self._lock = threading.RLock()
|
||||||
|
self._stats = EmbeddingCacheStats(max_size=max_size)
|
||||||
|
self._redis = redis
|
||||||
|
self._redis_prefix = redis_prefix
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Initialized EmbeddingCache with max_size={max_size}, "
|
||||||
|
f"ttl={default_ttl_seconds}s, redis={'enabled' if redis else 'disabled'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def set_redis(self, redis: "Redis") -> None:
|
||||||
|
"""Set Redis connection for persistence."""
|
||||||
|
self._redis = redis
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def hash_content(content: str) -> str:
|
||||||
|
"""
|
||||||
|
Compute hash of content for cache key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Content to hash
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
32-character hex hash
|
||||||
|
"""
|
||||||
|
return hashlib.sha256(content.encode()).hexdigest()[:32]
|
||||||
|
|
||||||
|
def _cache_key(self, content_hash: str, model: str) -> str:
|
||||||
|
"""Build cache key from content hash and model."""
|
||||||
|
return f"{content_hash}:{model}"
|
||||||
|
|
||||||
|
def _redis_key(self, content_hash: str, model: str) -> str:
|
||||||
|
"""Build Redis key from content hash and model."""
|
||||||
|
return f"{self._redis_prefix}:{content_hash}:{model}"
|
||||||
|
|
||||||
|
async def get(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
model: str = "default",
|
||||||
|
) -> list[float] | None:
|
||||||
|
"""
|
||||||
|
Get a cached embedding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Content text
|
||||||
|
model: Model name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cached embedding or None if not found/expired
|
||||||
|
"""
|
||||||
|
content_hash = self.hash_content(content)
|
||||||
|
cache_key = self._cache_key(content_hash, model)
|
||||||
|
|
||||||
|
# Check memory cache first
|
||||||
|
with self._lock:
|
||||||
|
if cache_key in self._cache:
|
||||||
|
entry = self._cache[cache_key]
|
||||||
|
if entry.is_expired():
|
||||||
|
del self._cache[cache_key]
|
||||||
|
self._stats.expirations += 1
|
||||||
|
self._stats.current_size = len(self._cache)
|
||||||
|
else:
|
||||||
|
# Move to end (most recently used)
|
||||||
|
self._cache.move_to_end(cache_key)
|
||||||
|
self._stats.hits += 1
|
||||||
|
return entry.embedding
|
||||||
|
|
||||||
|
# Check Redis if available
|
||||||
|
if self._redis:
|
||||||
|
try:
|
||||||
|
redis_key = self._redis_key(content_hash, model)
|
||||||
|
data = await self._redis.get(redis_key)
|
||||||
|
if data:
|
||||||
|
import json
|
||||||
|
|
||||||
|
embedding = json.loads(data)
|
||||||
|
# Store in memory cache for faster access
|
||||||
|
self._put_memory(content_hash, model, embedding)
|
||||||
|
self._stats.hits += 1
|
||||||
|
return embedding
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Redis get error: {e}")
|
||||||
|
|
||||||
|
self._stats.misses += 1
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_by_hash(
|
||||||
|
self,
|
||||||
|
content_hash: str,
|
||||||
|
model: str = "default",
|
||||||
|
) -> list[float] | None:
|
||||||
|
"""
|
||||||
|
Get a cached embedding by hash.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content_hash: Content hash
|
||||||
|
model: Model name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cached embedding or None if not found/expired
|
||||||
|
"""
|
||||||
|
cache_key = self._cache_key(content_hash, model)
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
if cache_key in self._cache:
|
||||||
|
entry = self._cache[cache_key]
|
||||||
|
if entry.is_expired():
|
||||||
|
del self._cache[cache_key]
|
||||||
|
self._stats.expirations += 1
|
||||||
|
self._stats.current_size = len(self._cache)
|
||||||
|
else:
|
||||||
|
self._cache.move_to_end(cache_key)
|
||||||
|
self._stats.hits += 1
|
||||||
|
return entry.embedding
|
||||||
|
|
||||||
|
# Check Redis
|
||||||
|
if self._redis:
|
||||||
|
try:
|
||||||
|
redis_key = self._redis_key(content_hash, model)
|
||||||
|
data = await self._redis.get(redis_key)
|
||||||
|
if data:
|
||||||
|
import json
|
||||||
|
|
||||||
|
embedding = json.loads(data)
|
||||||
|
self._put_memory(content_hash, model, embedding)
|
||||||
|
self._stats.hits += 1
|
||||||
|
return embedding
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Redis get error: {e}")
|
||||||
|
|
||||||
|
self._stats.misses += 1
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def put(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
embedding: list[float],
|
||||||
|
model: str = "default",
|
||||||
|
ttl_seconds: float | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Cache an embedding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Content text
|
||||||
|
embedding: Embedding vector
|
||||||
|
model: Model name
|
||||||
|
ttl_seconds: Optional TTL override
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Content hash
|
||||||
|
"""
|
||||||
|
content_hash = self.hash_content(content)
|
||||||
|
ttl = ttl_seconds or self._default_ttl
|
||||||
|
|
||||||
|
# Store in memory
|
||||||
|
self._put_memory(content_hash, model, embedding, ttl)
|
||||||
|
|
||||||
|
# Store in Redis if available
|
||||||
|
if self._redis:
|
||||||
|
try:
|
||||||
|
import json
|
||||||
|
|
||||||
|
redis_key = self._redis_key(content_hash, model)
|
||||||
|
await self._redis.setex(
|
||||||
|
redis_key,
|
||||||
|
int(ttl),
|
||||||
|
json.dumps(embedding),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Redis put error: {e}")
|
||||||
|
|
||||||
|
return content_hash
|
||||||
|
|
||||||
|
def _put_memory(
|
||||||
|
self,
|
||||||
|
content_hash: str,
|
||||||
|
model: str,
|
||||||
|
embedding: list[float],
|
||||||
|
ttl_seconds: float | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Store in memory cache."""
|
||||||
|
with self._lock:
|
||||||
|
# Evict if at capacity
|
||||||
|
self._evict_if_needed()
|
||||||
|
|
||||||
|
cache_key = self._cache_key(content_hash, model)
|
||||||
|
entry = EmbeddingEntry(
|
||||||
|
embedding=embedding,
|
||||||
|
content_hash=content_hash,
|
||||||
|
model=model,
|
||||||
|
created_at=_utcnow(),
|
||||||
|
ttl_seconds=ttl_seconds or self._default_ttl,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._cache[cache_key] = entry
|
||||||
|
self._cache.move_to_end(cache_key)
|
||||||
|
self._stats.current_size = len(self._cache)
|
||||||
|
|
||||||
|
def _evict_if_needed(self) -> None:
|
||||||
|
"""Evict entries if cache is at capacity."""
|
||||||
|
while len(self._cache) >= self._max_size:
|
||||||
|
if self._cache:
|
||||||
|
self._cache.popitem(last=False)
|
||||||
|
self._stats.evictions += 1
|
||||||
|
|
||||||
|
async def put_batch(
|
||||||
|
self,
|
||||||
|
items: list[tuple[str, list[float]]],
|
||||||
|
model: str = "default",
|
||||||
|
ttl_seconds: float | None = None,
|
||||||
|
) -> list[str]:
|
||||||
|
"""
|
||||||
|
Cache multiple embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
items: List of (content, embedding) tuples
|
||||||
|
model: Model name
|
||||||
|
ttl_seconds: Optional TTL override
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of content hashes
|
||||||
|
"""
|
||||||
|
hashes = []
|
||||||
|
for content, embedding in items:
|
||||||
|
content_hash = await self.put(content, embedding, model, ttl_seconds)
|
||||||
|
hashes.append(content_hash)
|
||||||
|
return hashes
|
||||||
|
|
||||||
|
async def invalidate(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
model: str = "default",
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Invalidate a cached embedding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Content text
|
||||||
|
model: Model name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if entry was found and removed
|
||||||
|
"""
|
||||||
|
content_hash = self.hash_content(content)
|
||||||
|
return await self.invalidate_by_hash(content_hash, model)
|
||||||
|
|
||||||
|
async def invalidate_by_hash(
|
||||||
|
self,
|
||||||
|
content_hash: str,
|
||||||
|
model: str = "default",
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Invalidate a cached embedding by hash.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content_hash: Content hash
|
||||||
|
model: Model name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if entry was found and removed
|
||||||
|
"""
|
||||||
|
cache_key = self._cache_key(content_hash, model)
|
||||||
|
removed = False
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
if cache_key in self._cache:
|
||||||
|
del self._cache[cache_key]
|
||||||
|
self._stats.current_size = len(self._cache)
|
||||||
|
removed = True
|
||||||
|
|
||||||
|
# Remove from Redis
|
||||||
|
if self._redis:
|
||||||
|
try:
|
||||||
|
redis_key = self._redis_key(content_hash, model)
|
||||||
|
await self._redis.delete(redis_key)
|
||||||
|
removed = True
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Redis delete error: {e}")
|
||||||
|
|
||||||
|
return removed
|
||||||
|
|
||||||
|
async def invalidate_by_model(self, model: str) -> int:
|
||||||
|
"""
|
||||||
|
Invalidate all embeddings for a model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of entries invalidated
|
||||||
|
"""
|
||||||
|
count = 0
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
keys_to_remove = [k for k, v in self._cache.items() if v.model == model]
|
||||||
|
for key in keys_to_remove:
|
||||||
|
del self._cache[key]
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
self._stats.current_size = len(self._cache)
|
||||||
|
|
||||||
|
# Note: Redis pattern deletion would require SCAN which is expensive
|
||||||
|
# For now, we only clear memory cache for model-based invalidation
|
||||||
|
|
||||||
|
return count
|
||||||
|
|
||||||
|
async def clear(self) -> int:
|
||||||
|
"""
|
||||||
|
Clear all cache entries.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of entries cleared
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
count = len(self._cache)
|
||||||
|
self._cache.clear()
|
||||||
|
self._stats.current_size = 0
|
||||||
|
|
||||||
|
# Clear Redis entries
|
||||||
|
if self._redis:
|
||||||
|
try:
|
||||||
|
pattern = f"{self._redis_prefix}:*"
|
||||||
|
deleted = 0
|
||||||
|
async for key in self._redis.scan_iter(match=pattern):
|
||||||
|
await self._redis.delete(key)
|
||||||
|
deleted += 1
|
||||||
|
count = max(count, deleted)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Redis clear error: {e}")
|
||||||
|
|
||||||
|
logger.info(f"Cleared {count} entries from embedding cache")
|
||||||
|
return count
|
||||||
|
|
||||||
|
def cleanup_expired(self) -> int:
|
||||||
|
"""
|
||||||
|
Remove all expired entries from memory cache.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of entries removed
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
keys_to_remove = [k for k, v in self._cache.items() if v.is_expired()]
|
||||||
|
for key in keys_to_remove:
|
||||||
|
del self._cache[key]
|
||||||
|
self._stats.expirations += 1
|
||||||
|
|
||||||
|
self._stats.current_size = len(self._cache)
|
||||||
|
|
||||||
|
if keys_to_remove:
|
||||||
|
logger.debug(f"Cleaned up {len(keys_to_remove)} expired embeddings")
|
||||||
|
|
||||||
|
return len(keys_to_remove)
|
||||||
|
|
||||||
|
def get_stats(self) -> EmbeddingCacheStats:
|
||||||
|
"""Get cache statistics."""
|
||||||
|
with self._lock:
|
||||||
|
self._stats.current_size = len(self._cache)
|
||||||
|
return self._stats
|
||||||
|
|
||||||
|
def reset_stats(self) -> None:
|
||||||
|
"""Reset cache statistics."""
|
||||||
|
with self._lock:
|
||||||
|
self._stats = EmbeddingCacheStats(
|
||||||
|
max_size=self._max_size,
|
||||||
|
current_size=len(self._cache),
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def size(self) -> int:
|
||||||
|
"""Get current cache size."""
|
||||||
|
return len(self._cache)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_size(self) -> int:
|
||||||
|
"""Get maximum cache size."""
|
||||||
|
return self._max_size
|
||||||
|
|
||||||
|
|
||||||
|
class CachedEmbeddingGenerator:
|
||||||
|
"""
|
||||||
|
Wrapper for embedding generators with caching.
|
||||||
|
|
||||||
|
Wraps an embedding generator to cache results.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
generator: Any,
|
||||||
|
cache: EmbeddingCache,
|
||||||
|
model: str = "default",
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize the cached embedding generator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
generator: Underlying embedding generator
|
||||||
|
cache: Embedding cache
|
||||||
|
model: Model name for cache keys
|
||||||
|
"""
|
||||||
|
self._generator = generator
|
||||||
|
self._cache = cache
|
||||||
|
self._model = model
|
||||||
|
self._call_count = 0
|
||||||
|
self._cache_hit_count = 0
|
||||||
|
|
||||||
|
async def generate(self, text: str) -> list[float]:
|
||||||
|
"""
|
||||||
|
Generate embedding with caching.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text to embed
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Embedding vector
|
||||||
|
"""
|
||||||
|
self._call_count += 1
|
||||||
|
|
||||||
|
# Check cache first
|
||||||
|
cached = await self._cache.get(text, self._model)
|
||||||
|
if cached is not None:
|
||||||
|
self._cache_hit_count += 1
|
||||||
|
return cached
|
||||||
|
|
||||||
|
# Generate and cache
|
||||||
|
embedding = await self._generator.generate(text)
|
||||||
|
await self._cache.put(text, embedding, self._model)
|
||||||
|
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
async def generate_batch(
|
||||||
|
self,
|
||||||
|
texts: list[str],
|
||||||
|
) -> list[list[float]]:
|
||||||
|
"""
|
||||||
|
Generate embeddings for multiple texts with caching.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: Texts to embed
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of embedding vectors
|
||||||
|
"""
|
||||||
|
results: list[list[float] | None] = [None] * len(texts)
|
||||||
|
to_generate: list[tuple[int, str]] = []
|
||||||
|
|
||||||
|
# Check cache for each text
|
||||||
|
for i, text in enumerate(texts):
|
||||||
|
cached = await self._cache.get(text, self._model)
|
||||||
|
if cached is not None:
|
||||||
|
results[i] = cached
|
||||||
|
self._cache_hit_count += 1
|
||||||
|
else:
|
||||||
|
to_generate.append((i, text))
|
||||||
|
|
||||||
|
self._call_count += len(texts)
|
||||||
|
|
||||||
|
# Generate missing embeddings
|
||||||
|
if to_generate:
|
||||||
|
if hasattr(self._generator, "generate_batch"):
|
||||||
|
texts_to_gen = [t for _, t in to_generate]
|
||||||
|
embeddings = await self._generator.generate_batch(texts_to_gen)
|
||||||
|
|
||||||
|
for (idx, text), embedding in zip(to_generate, embeddings, strict=True):
|
||||||
|
results[idx] = embedding
|
||||||
|
await self._cache.put(text, embedding, self._model)
|
||||||
|
else:
|
||||||
|
# Fallback to individual generation
|
||||||
|
for idx, text in to_generate:
|
||||||
|
embedding = await self._generator.generate(text)
|
||||||
|
results[idx] = embedding
|
||||||
|
await self._cache.put(text, embedding, self._model)
|
||||||
|
|
||||||
|
return results # type: ignore[return-value]
|
||||||
|
|
||||||
|
def get_stats(self) -> dict[str, Any]:
|
||||||
|
"""Get generator statistics."""
|
||||||
|
return {
|
||||||
|
"call_count": self._call_count,
|
||||||
|
"cache_hit_count": self._cache_hit_count,
|
||||||
|
"cache_hit_rate": (
|
||||||
|
self._cache_hit_count / self._call_count
|
||||||
|
if self._call_count > 0
|
||||||
|
else 0.0
|
||||||
|
),
|
||||||
|
"cache_stats": self._cache.get_stats().to_dict(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Factory function
|
||||||
|
def create_embedding_cache(
|
||||||
|
max_size: int = 50000,
|
||||||
|
default_ttl_seconds: float = 3600.0,
|
||||||
|
redis: "Redis | None" = None,
|
||||||
|
) -> EmbeddingCache:
|
||||||
|
"""
|
||||||
|
Create an embedding cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_size: Maximum number of entries
|
||||||
|
default_ttl_seconds: Default TTL for entries
|
||||||
|
redis: Optional Redis connection
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured EmbeddingCache instance
|
||||||
|
"""
|
||||||
|
return EmbeddingCache(
|
||||||
|
max_size=max_size,
|
||||||
|
default_ttl_seconds=default_ttl_seconds,
|
||||||
|
redis=redis,
|
||||||
|
)
|
||||||
461
backend/app/services/memory/cache/hot_cache.py
vendored
Normal file
461
backend/app/services/memory/cache/hot_cache.py
vendored
Normal file
@@ -0,0 +1,461 @@
|
|||||||
|
# app/services/memory/cache/hot_cache.py
|
||||||
|
"""
|
||||||
|
Hot Memory Cache.
|
||||||
|
|
||||||
|
LRU cache for frequently accessed memories.
|
||||||
|
Provides fast access to recently used memories without database queries.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
from collections import OrderedDict
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from typing import Any
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _utcnow() -> datetime:
|
||||||
|
"""Get current UTC time as timezone-aware datetime."""
|
||||||
|
return datetime.now(UTC)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CacheEntry[T]:
|
||||||
|
"""A cached memory entry with metadata."""
|
||||||
|
|
||||||
|
value: T
|
||||||
|
created_at: datetime
|
||||||
|
last_accessed_at: datetime
|
||||||
|
access_count: int = 1
|
||||||
|
ttl_seconds: float = 300.0
|
||||||
|
|
||||||
|
def is_expired(self) -> bool:
|
||||||
|
"""Check if this entry has expired."""
|
||||||
|
age = (_utcnow() - self.created_at).total_seconds()
|
||||||
|
return age > self.ttl_seconds
|
||||||
|
|
||||||
|
def touch(self) -> None:
|
||||||
|
"""Update access time and count."""
|
||||||
|
self.last_accessed_at = _utcnow()
|
||||||
|
self.access_count += 1
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CacheKey:
|
||||||
|
"""A structured cache key with components."""
|
||||||
|
|
||||||
|
memory_type: str
|
||||||
|
memory_id: str
|
||||||
|
scope: str | None = None
|
||||||
|
|
||||||
|
def __hash__(self) -> int:
|
||||||
|
return hash((self.memory_type, self.memory_id, self.scope))
|
||||||
|
|
||||||
|
def __eq__(self, other: object) -> bool:
|
||||||
|
if not isinstance(other, CacheKey):
|
||||||
|
return False
|
||||||
|
return (
|
||||||
|
self.memory_type == other.memory_type
|
||||||
|
and self.memory_id == other.memory_id
|
||||||
|
and self.scope == other.scope
|
||||||
|
)
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
if self.scope:
|
||||||
|
return f"{self.memory_type}:{self.scope}:{self.memory_id}"
|
||||||
|
return f"{self.memory_type}:{self.memory_id}"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HotCacheStats:
|
||||||
|
"""Statistics for the hot memory cache."""
|
||||||
|
|
||||||
|
hits: int = 0
|
||||||
|
misses: int = 0
|
||||||
|
evictions: int = 0
|
||||||
|
expirations: int = 0
|
||||||
|
current_size: int = 0
|
||||||
|
max_size: int = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def hit_rate(self) -> float:
|
||||||
|
"""Calculate cache hit rate."""
|
||||||
|
total = self.hits + self.misses
|
||||||
|
if total == 0:
|
||||||
|
return 0.0
|
||||||
|
return self.hits / total
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert to dictionary."""
|
||||||
|
return {
|
||||||
|
"hits": self.hits,
|
||||||
|
"misses": self.misses,
|
||||||
|
"evictions": self.evictions,
|
||||||
|
"expirations": self.expirations,
|
||||||
|
"current_size": self.current_size,
|
||||||
|
"max_size": self.max_size,
|
||||||
|
"hit_rate": self.hit_rate,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class HotMemoryCache[T]:
|
||||||
|
"""
|
||||||
|
LRU cache for frequently accessed memories.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- LRU eviction when capacity is reached
|
||||||
|
- TTL-based expiration
|
||||||
|
- Access count tracking for hot memory identification
|
||||||
|
- Thread-safe operations
|
||||||
|
- Scoped invalidation
|
||||||
|
|
||||||
|
Performance targets:
|
||||||
|
- Cache hit rate > 80% for hot memories
|
||||||
|
- Get/put operations < 1ms
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_size: int = 10000,
|
||||||
|
default_ttl_seconds: float = 300.0,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize the hot memory cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_size: Maximum number of entries
|
||||||
|
default_ttl_seconds: Default TTL for entries (5 minutes)
|
||||||
|
"""
|
||||||
|
self._max_size = max_size
|
||||||
|
self._default_ttl = default_ttl_seconds
|
||||||
|
self._cache: OrderedDict[CacheKey, CacheEntry[T]] = OrderedDict()
|
||||||
|
self._lock = threading.RLock()
|
||||||
|
self._stats = HotCacheStats(max_size=max_size)
|
||||||
|
logger.info(
|
||||||
|
f"Initialized HotMemoryCache with max_size={max_size}, "
|
||||||
|
f"ttl={default_ttl_seconds}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
def get(self, key: CacheKey) -> T | None:
|
||||||
|
"""
|
||||||
|
Get a memory from cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Cache key
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cached value or None if not found/expired
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
if key not in self._cache:
|
||||||
|
self._stats.misses += 1
|
||||||
|
return None
|
||||||
|
|
||||||
|
entry = self._cache[key]
|
||||||
|
|
||||||
|
# Check expiration
|
||||||
|
if entry.is_expired():
|
||||||
|
del self._cache[key]
|
||||||
|
self._stats.expirations += 1
|
||||||
|
self._stats.misses += 1
|
||||||
|
self._stats.current_size = len(self._cache)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Move to end (most recently used)
|
||||||
|
self._cache.move_to_end(key)
|
||||||
|
entry.touch()
|
||||||
|
|
||||||
|
self._stats.hits += 1
|
||||||
|
return entry.value
|
||||||
|
|
||||||
|
def get_by_id(
|
||||||
|
self,
|
||||||
|
memory_type: str,
|
||||||
|
memory_id: UUID | str,
|
||||||
|
scope: str | None = None,
|
||||||
|
) -> T | None:
|
||||||
|
"""
|
||||||
|
Get a memory by type and ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
memory_type: Type of memory (episodic, semantic, procedural)
|
||||||
|
memory_id: Memory ID
|
||||||
|
scope: Optional scope (project_id, agent_id)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cached value or None if not found/expired
|
||||||
|
"""
|
||||||
|
key = CacheKey(
|
||||||
|
memory_type=memory_type,
|
||||||
|
memory_id=str(memory_id),
|
||||||
|
scope=scope,
|
||||||
|
)
|
||||||
|
return self.get(key)
|
||||||
|
|
||||||
|
def put(
|
||||||
|
self,
|
||||||
|
key: CacheKey,
|
||||||
|
value: T,
|
||||||
|
ttl_seconds: float | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Put a memory into cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Cache key
|
||||||
|
value: Value to cache
|
||||||
|
ttl_seconds: Optional TTL override
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
# Evict if at capacity
|
||||||
|
self._evict_if_needed()
|
||||||
|
|
||||||
|
now = _utcnow()
|
||||||
|
entry = CacheEntry(
|
||||||
|
value=value,
|
||||||
|
created_at=now,
|
||||||
|
last_accessed_at=now,
|
||||||
|
access_count=1,
|
||||||
|
ttl_seconds=ttl_seconds or self._default_ttl,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._cache[key] = entry
|
||||||
|
self._cache.move_to_end(key)
|
||||||
|
self._stats.current_size = len(self._cache)
|
||||||
|
|
||||||
|
def put_by_id(
|
||||||
|
self,
|
||||||
|
memory_type: str,
|
||||||
|
memory_id: UUID | str,
|
||||||
|
value: T,
|
||||||
|
scope: str | None = None,
|
||||||
|
ttl_seconds: float | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Put a memory by type and ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
memory_type: Type of memory
|
||||||
|
memory_id: Memory ID
|
||||||
|
value: Value to cache
|
||||||
|
scope: Optional scope
|
||||||
|
ttl_seconds: Optional TTL override
|
||||||
|
"""
|
||||||
|
key = CacheKey(
|
||||||
|
memory_type=memory_type,
|
||||||
|
memory_id=str(memory_id),
|
||||||
|
scope=scope,
|
||||||
|
)
|
||||||
|
self.put(key, value, ttl_seconds)
|
||||||
|
|
||||||
|
def _evict_if_needed(self) -> None:
|
||||||
|
"""Evict entries if cache is at capacity."""
|
||||||
|
while len(self._cache) >= self._max_size:
|
||||||
|
# Remove least recently used (first item)
|
||||||
|
if self._cache:
|
||||||
|
self._cache.popitem(last=False)
|
||||||
|
self._stats.evictions += 1
|
||||||
|
|
||||||
|
def invalidate(self, key: CacheKey) -> bool:
|
||||||
|
"""
|
||||||
|
Invalidate a specific cache entry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Cache key to invalidate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if entry was found and removed
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
if key in self._cache:
|
||||||
|
del self._cache[key]
|
||||||
|
self._stats.current_size = len(self._cache)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def invalidate_by_id(
|
||||||
|
self,
|
||||||
|
memory_type: str,
|
||||||
|
memory_id: UUID | str,
|
||||||
|
scope: str | None = None,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Invalidate a memory by type and ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
memory_type: Type of memory
|
||||||
|
memory_id: Memory ID
|
||||||
|
scope: Optional scope
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if entry was found and removed
|
||||||
|
"""
|
||||||
|
key = CacheKey(
|
||||||
|
memory_type=memory_type,
|
||||||
|
memory_id=str(memory_id),
|
||||||
|
scope=scope,
|
||||||
|
)
|
||||||
|
return self.invalidate(key)
|
||||||
|
|
||||||
|
def invalidate_by_type(self, memory_type: str) -> int:
|
||||||
|
"""
|
||||||
|
Invalidate all entries of a memory type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
memory_type: Type of memory to invalidate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of entries invalidated
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
keys_to_remove = [
|
||||||
|
k for k in self._cache.keys() if k.memory_type == memory_type
|
||||||
|
]
|
||||||
|
for key in keys_to_remove:
|
||||||
|
del self._cache[key]
|
||||||
|
|
||||||
|
self._stats.current_size = len(self._cache)
|
||||||
|
return len(keys_to_remove)
|
||||||
|
|
||||||
|
def invalidate_by_scope(self, scope: str) -> int:
|
||||||
|
"""
|
||||||
|
Invalidate all entries in a scope.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scope: Scope to invalidate (e.g., project_id)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of entries invalidated
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
keys_to_remove = [k for k in self._cache.keys() if k.scope == scope]
|
||||||
|
for key in keys_to_remove:
|
||||||
|
del self._cache[key]
|
||||||
|
|
||||||
|
self._stats.current_size = len(self._cache)
|
||||||
|
return len(keys_to_remove)
|
||||||
|
|
||||||
|
def invalidate_pattern(self, pattern: str) -> int:
|
||||||
|
"""
|
||||||
|
Invalidate entries matching a pattern.
|
||||||
|
|
||||||
|
Pattern can include * as wildcard.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pattern: Pattern to match (e.g., "episodic:*")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of entries invalidated
|
||||||
|
"""
|
||||||
|
import fnmatch
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
keys_to_remove = [
|
||||||
|
k for k in self._cache.keys() if fnmatch.fnmatch(str(k), pattern)
|
||||||
|
]
|
||||||
|
for key in keys_to_remove:
|
||||||
|
del self._cache[key]
|
||||||
|
|
||||||
|
self._stats.current_size = len(self._cache)
|
||||||
|
return len(keys_to_remove)
|
||||||
|
|
||||||
|
def clear(self) -> int:
|
||||||
|
"""
|
||||||
|
Clear all cache entries.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of entries cleared
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
count = len(self._cache)
|
||||||
|
self._cache.clear()
|
||||||
|
self._stats.current_size = 0
|
||||||
|
logger.info(f"Cleared {count} entries from hot cache")
|
||||||
|
return count
|
||||||
|
|
||||||
|
def cleanup_expired(self) -> int:
|
||||||
|
"""
|
||||||
|
Remove all expired entries.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of entries removed
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
keys_to_remove = [k for k, v in self._cache.items() if v.is_expired()]
|
||||||
|
for key in keys_to_remove:
|
||||||
|
del self._cache[key]
|
||||||
|
self._stats.expirations += 1
|
||||||
|
|
||||||
|
self._stats.current_size = len(self._cache)
|
||||||
|
|
||||||
|
if keys_to_remove:
|
||||||
|
logger.debug(f"Cleaned up {len(keys_to_remove)} expired entries")
|
||||||
|
|
||||||
|
return len(keys_to_remove)
|
||||||
|
|
||||||
|
def get_hot_memories(self, limit: int = 10) -> list[tuple[CacheKey, int]]:
|
||||||
|
"""
|
||||||
|
Get the most frequently accessed memories.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
limit: Maximum number of memories to return
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of (key, access_count) tuples sorted by access count
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
entries = [
|
||||||
|
(k, v.access_count)
|
||||||
|
for k, v in self._cache.items()
|
||||||
|
if not v.is_expired()
|
||||||
|
]
|
||||||
|
entries.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
return entries[:limit]
|
||||||
|
|
||||||
|
def get_stats(self) -> HotCacheStats:
|
||||||
|
"""Get cache statistics."""
|
||||||
|
with self._lock:
|
||||||
|
self._stats.current_size = len(self._cache)
|
||||||
|
return self._stats
|
||||||
|
|
||||||
|
def reset_stats(self) -> None:
|
||||||
|
"""Reset cache statistics."""
|
||||||
|
with self._lock:
|
||||||
|
self._stats = HotCacheStats(
|
||||||
|
max_size=self._max_size,
|
||||||
|
current_size=len(self._cache),
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def size(self) -> int:
|
||||||
|
"""Get current cache size."""
|
||||||
|
return len(self._cache)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_size(self) -> int:
|
||||||
|
"""Get maximum cache size."""
|
||||||
|
return self._max_size
|
||||||
|
|
||||||
|
|
||||||
|
# Factory function for typed caches
|
||||||
|
def create_hot_cache(
|
||||||
|
max_size: int = 10000,
|
||||||
|
default_ttl_seconds: float = 300.0,
|
||||||
|
) -> HotMemoryCache[Any]:
|
||||||
|
"""
|
||||||
|
Create a hot memory cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_size: Maximum number of entries
|
||||||
|
default_ttl_seconds: Default TTL for entries
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured HotMemoryCache instance
|
||||||
|
"""
|
||||||
|
return HotMemoryCache(
|
||||||
|
max_size=max_size,
|
||||||
|
default_ttl_seconds=default_ttl_seconds,
|
||||||
|
)
|
||||||
410
backend/app/services/memory/config.py
Normal file
410
backend/app/services/memory/config.py
Normal file
@@ -0,0 +1,410 @@
|
|||||||
|
"""
|
||||||
|
Memory System Configuration.
|
||||||
|
|
||||||
|
Provides Pydantic settings for the Agent Memory System,
|
||||||
|
including storage backends, capacity limits, and consolidation policies.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import threading
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import Field, field_validator, model_validator
|
||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
|
|
||||||
|
class MemorySettings(BaseSettings):
|
||||||
|
"""
|
||||||
|
Configuration for the Agent Memory System.
|
||||||
|
|
||||||
|
All settings can be overridden via environment variables
|
||||||
|
with the MEM_ prefix.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Working Memory Settings
|
||||||
|
working_memory_backend: str = Field(
|
||||||
|
default="redis",
|
||||||
|
description="Backend for working memory: 'redis' or 'memory'",
|
||||||
|
)
|
||||||
|
working_memory_default_ttl_seconds: int = Field(
|
||||||
|
default=3600,
|
||||||
|
ge=60,
|
||||||
|
le=86400,
|
||||||
|
description="Default TTL for working memory items (1 hour default)",
|
||||||
|
)
|
||||||
|
working_memory_max_items_per_session: int = Field(
|
||||||
|
default=1000,
|
||||||
|
ge=100,
|
||||||
|
le=100000,
|
||||||
|
description="Maximum items per session in working memory",
|
||||||
|
)
|
||||||
|
working_memory_max_value_size_bytes: int = Field(
|
||||||
|
default=1048576, # 1MB
|
||||||
|
ge=1024,
|
||||||
|
le=104857600, # 100MB
|
||||||
|
description="Maximum size of a single value in working memory",
|
||||||
|
)
|
||||||
|
working_memory_checkpoint_enabled: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Enable checkpointing for working memory recovery",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Redis Settings (for working memory)
|
||||||
|
redis_url: str = Field(
|
||||||
|
default="redis://localhost:6379/0",
|
||||||
|
description="Redis connection URL",
|
||||||
|
)
|
||||||
|
redis_prefix: str = Field(
|
||||||
|
default="mem",
|
||||||
|
description="Redis key prefix for memory items",
|
||||||
|
)
|
||||||
|
redis_connection_timeout_seconds: int = Field(
|
||||||
|
default=5,
|
||||||
|
ge=1,
|
||||||
|
le=60,
|
||||||
|
description="Redis connection timeout",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Episodic Memory Settings
|
||||||
|
episodic_max_episodes_per_project: int = Field(
|
||||||
|
default=10000,
|
||||||
|
ge=100,
|
||||||
|
le=1000000,
|
||||||
|
description="Maximum episodes to retain per project",
|
||||||
|
)
|
||||||
|
episodic_default_importance: float = Field(
|
||||||
|
default=0.5,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
description="Default importance score for new episodes",
|
||||||
|
)
|
||||||
|
episodic_retention_days: int = Field(
|
||||||
|
default=365,
|
||||||
|
ge=7,
|
||||||
|
le=3650,
|
||||||
|
description="Days to retain episodes before archival",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Semantic Memory Settings
|
||||||
|
semantic_max_facts_per_project: int = Field(
|
||||||
|
default=50000,
|
||||||
|
ge=1000,
|
||||||
|
le=10000000,
|
||||||
|
description="Maximum facts to retain per project",
|
||||||
|
)
|
||||||
|
semantic_confidence_decay_days: int = Field(
|
||||||
|
default=90,
|
||||||
|
ge=7,
|
||||||
|
le=365,
|
||||||
|
description="Days until confidence decays by 50%",
|
||||||
|
)
|
||||||
|
semantic_min_confidence: float = Field(
|
||||||
|
default=0.1,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
description="Minimum confidence before fact is pruned",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Procedural Memory Settings
|
||||||
|
procedural_max_procedures_per_project: int = Field(
|
||||||
|
default=1000,
|
||||||
|
ge=10,
|
||||||
|
le=100000,
|
||||||
|
description="Maximum procedures per project",
|
||||||
|
)
|
||||||
|
procedural_min_success_rate: float = Field(
|
||||||
|
default=0.3,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
description="Minimum success rate before procedure is pruned",
|
||||||
|
)
|
||||||
|
procedural_min_uses_before_suggest: int = Field(
|
||||||
|
default=3,
|
||||||
|
ge=1,
|
||||||
|
le=100,
|
||||||
|
description="Minimum uses before procedure is suggested",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Embedding Settings
|
||||||
|
embedding_model: str = Field(
|
||||||
|
default="text-embedding-3-small",
|
||||||
|
description="Model to use for embeddings",
|
||||||
|
)
|
||||||
|
embedding_dimensions: int = Field(
|
||||||
|
default=1536,
|
||||||
|
ge=256,
|
||||||
|
le=4096,
|
||||||
|
description="Embedding vector dimensions",
|
||||||
|
)
|
||||||
|
embedding_batch_size: int = Field(
|
||||||
|
default=100,
|
||||||
|
ge=1,
|
||||||
|
le=1000,
|
||||||
|
description="Batch size for embedding generation",
|
||||||
|
)
|
||||||
|
embedding_cache_enabled: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Enable caching of embeddings",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Retrieval Settings
|
||||||
|
retrieval_default_limit: int = Field(
|
||||||
|
default=10,
|
||||||
|
ge=1,
|
||||||
|
le=100,
|
||||||
|
description="Default limit for retrieval queries",
|
||||||
|
)
|
||||||
|
retrieval_max_limit: int = Field(
|
||||||
|
default=100,
|
||||||
|
ge=10,
|
||||||
|
le=1000,
|
||||||
|
description="Maximum limit for retrieval queries",
|
||||||
|
)
|
||||||
|
retrieval_min_similarity: float = Field(
|
||||||
|
default=0.5,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
description="Minimum similarity score for retrieval",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Consolidation Settings
|
||||||
|
consolidation_enabled: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Enable automatic memory consolidation",
|
||||||
|
)
|
||||||
|
consolidation_batch_size: int = Field(
|
||||||
|
default=100,
|
||||||
|
ge=10,
|
||||||
|
le=1000,
|
||||||
|
description="Batch size for consolidation jobs",
|
||||||
|
)
|
||||||
|
consolidation_schedule_cron: str = Field(
|
||||||
|
default="0 3 * * *",
|
||||||
|
description="Cron expression for nightly consolidation (3 AM)",
|
||||||
|
)
|
||||||
|
consolidation_working_to_episodic_delay_minutes: int = Field(
|
||||||
|
default=30,
|
||||||
|
ge=5,
|
||||||
|
le=1440,
|
||||||
|
description="Minutes after session end before consolidating to episodic",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Pruning Settings
|
||||||
|
pruning_enabled: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Enable automatic memory pruning",
|
||||||
|
)
|
||||||
|
pruning_min_age_days: int = Field(
|
||||||
|
default=7,
|
||||||
|
ge=1,
|
||||||
|
le=365,
|
||||||
|
description="Minimum age before memory can be pruned",
|
||||||
|
)
|
||||||
|
pruning_importance_threshold: float = Field(
|
||||||
|
default=0.2,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
description="Importance threshold below which memory can be pruned",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Caching Settings
|
||||||
|
cache_enabled: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Enable caching for memory retrieval",
|
||||||
|
)
|
||||||
|
cache_ttl_seconds: int = Field(
|
||||||
|
default=300,
|
||||||
|
ge=10,
|
||||||
|
le=3600,
|
||||||
|
description="Cache TTL for retrieval results",
|
||||||
|
)
|
||||||
|
cache_max_items: int = Field(
|
||||||
|
default=10000,
|
||||||
|
ge=100,
|
||||||
|
le=1000000,
|
||||||
|
description="Maximum items in memory cache",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Performance Settings
|
||||||
|
max_retrieval_time_ms: int = Field(
|
||||||
|
default=100,
|
||||||
|
ge=10,
|
||||||
|
le=5000,
|
||||||
|
description="Target maximum retrieval time in milliseconds",
|
||||||
|
)
|
||||||
|
parallel_retrieval: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Enable parallel retrieval from multiple memory types",
|
||||||
|
)
|
||||||
|
max_parallel_retrievals: int = Field(
|
||||||
|
default=4,
|
||||||
|
ge=1,
|
||||||
|
le=10,
|
||||||
|
description="Maximum concurrent retrieval operations",
|
||||||
|
)
|
||||||
|
|
||||||
|
@field_validator("working_memory_backend")
|
||||||
|
@classmethod
|
||||||
|
def validate_backend(cls, v: str) -> str:
|
||||||
|
"""Validate working memory backend."""
|
||||||
|
valid_backends = {"redis", "memory"}
|
||||||
|
if v not in valid_backends:
|
||||||
|
raise ValueError(f"backend must be one of: {valid_backends}")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@field_validator("embedding_model")
|
||||||
|
@classmethod
|
||||||
|
def validate_embedding_model(cls, v: str) -> str:
|
||||||
|
"""Validate embedding model name."""
|
||||||
|
valid_models = {
|
||||||
|
"text-embedding-3-small",
|
||||||
|
"text-embedding-3-large",
|
||||||
|
"text-embedding-ada-002",
|
||||||
|
}
|
||||||
|
if v not in valid_models:
|
||||||
|
raise ValueError(f"embedding_model must be one of: {valid_models}")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_limits(self) -> "MemorySettings":
|
||||||
|
"""Validate that limits are consistent."""
|
||||||
|
if self.retrieval_default_limit > self.retrieval_max_limit:
|
||||||
|
raise ValueError(
|
||||||
|
f"retrieval_default_limit ({self.retrieval_default_limit}) "
|
||||||
|
f"cannot exceed retrieval_max_limit ({self.retrieval_max_limit})"
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def get_working_memory_config(self) -> dict[str, Any]:
|
||||||
|
"""Get working memory configuration as a dictionary."""
|
||||||
|
return {
|
||||||
|
"backend": self.working_memory_backend,
|
||||||
|
"default_ttl_seconds": self.working_memory_default_ttl_seconds,
|
||||||
|
"max_items_per_session": self.working_memory_max_items_per_session,
|
||||||
|
"max_value_size_bytes": self.working_memory_max_value_size_bytes,
|
||||||
|
"checkpoint_enabled": self.working_memory_checkpoint_enabled,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_redis_config(self) -> dict[str, Any]:
|
||||||
|
"""Get Redis configuration as a dictionary."""
|
||||||
|
return {
|
||||||
|
"url": self.redis_url,
|
||||||
|
"prefix": self.redis_prefix,
|
||||||
|
"connection_timeout_seconds": self.redis_connection_timeout_seconds,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_embedding_config(self) -> dict[str, Any]:
|
||||||
|
"""Get embedding configuration as a dictionary."""
|
||||||
|
return {
|
||||||
|
"model": self.embedding_model,
|
||||||
|
"dimensions": self.embedding_dimensions,
|
||||||
|
"batch_size": self.embedding_batch_size,
|
||||||
|
"cache_enabled": self.embedding_cache_enabled,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_consolidation_config(self) -> dict[str, Any]:
|
||||||
|
"""Get consolidation configuration as a dictionary."""
|
||||||
|
return {
|
||||||
|
"enabled": self.consolidation_enabled,
|
||||||
|
"batch_size": self.consolidation_batch_size,
|
||||||
|
"schedule_cron": self.consolidation_schedule_cron,
|
||||||
|
"working_to_episodic_delay_minutes": (
|
||||||
|
self.consolidation_working_to_episodic_delay_minutes
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert settings to dictionary for logging/debugging."""
|
||||||
|
return {
|
||||||
|
"working_memory": self.get_working_memory_config(),
|
||||||
|
"redis": self.get_redis_config(),
|
||||||
|
"episodic": {
|
||||||
|
"max_episodes_per_project": self.episodic_max_episodes_per_project,
|
||||||
|
"default_importance": self.episodic_default_importance,
|
||||||
|
"retention_days": self.episodic_retention_days,
|
||||||
|
},
|
||||||
|
"semantic": {
|
||||||
|
"max_facts_per_project": self.semantic_max_facts_per_project,
|
||||||
|
"confidence_decay_days": self.semantic_confidence_decay_days,
|
||||||
|
"min_confidence": self.semantic_min_confidence,
|
||||||
|
},
|
||||||
|
"procedural": {
|
||||||
|
"max_procedures_per_project": self.procedural_max_procedures_per_project,
|
||||||
|
"min_success_rate": self.procedural_min_success_rate,
|
||||||
|
"min_uses_before_suggest": self.procedural_min_uses_before_suggest,
|
||||||
|
},
|
||||||
|
"embedding": self.get_embedding_config(),
|
||||||
|
"retrieval": {
|
||||||
|
"default_limit": self.retrieval_default_limit,
|
||||||
|
"max_limit": self.retrieval_max_limit,
|
||||||
|
"min_similarity": self.retrieval_min_similarity,
|
||||||
|
},
|
||||||
|
"consolidation": self.get_consolidation_config(),
|
||||||
|
"pruning": {
|
||||||
|
"enabled": self.pruning_enabled,
|
||||||
|
"min_age_days": self.pruning_min_age_days,
|
||||||
|
"importance_threshold": self.pruning_importance_threshold,
|
||||||
|
},
|
||||||
|
"cache": {
|
||||||
|
"enabled": self.cache_enabled,
|
||||||
|
"ttl_seconds": self.cache_ttl_seconds,
|
||||||
|
"max_items": self.cache_max_items,
|
||||||
|
},
|
||||||
|
"performance": {
|
||||||
|
"max_retrieval_time_ms": self.max_retrieval_time_ms,
|
||||||
|
"parallel_retrieval": self.parallel_retrieval,
|
||||||
|
"max_parallel_retrievals": self.max_parallel_retrievals,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
model_config = {
|
||||||
|
"env_prefix": "MEM_",
|
||||||
|
"env_file": ".env",
|
||||||
|
"env_file_encoding": "utf-8",
|
||||||
|
"case_sensitive": False,
|
||||||
|
"extra": "ignore",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Thread-safe singleton pattern
|
||||||
|
_settings: MemorySettings | None = None
|
||||||
|
_settings_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
def get_memory_settings() -> MemorySettings:
|
||||||
|
"""
|
||||||
|
Get the global MemorySettings instance.
|
||||||
|
|
||||||
|
Thread-safe with double-checked locking pattern.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MemorySettings instance
|
||||||
|
"""
|
||||||
|
global _settings
|
||||||
|
if _settings is None:
|
||||||
|
with _settings_lock:
|
||||||
|
if _settings is None:
|
||||||
|
_settings = MemorySettings()
|
||||||
|
return _settings
|
||||||
|
|
||||||
|
|
||||||
|
def reset_memory_settings() -> None:
|
||||||
|
"""
|
||||||
|
Reset the global settings instance.
|
||||||
|
|
||||||
|
Primarily used for testing.
|
||||||
|
"""
|
||||||
|
global _settings
|
||||||
|
with _settings_lock:
|
||||||
|
_settings = None
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def get_default_settings() -> MemorySettings:
|
||||||
|
"""
|
||||||
|
Get default settings (cached).
|
||||||
|
|
||||||
|
Use this for read-only access to defaults.
|
||||||
|
For mutable access, use get_memory_settings().
|
||||||
|
"""
|
||||||
|
return MemorySettings()
|
||||||
29
backend/app/services/memory/consolidation/__init__.py
Normal file
29
backend/app/services/memory/consolidation/__init__.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
# app/services/memory/consolidation/__init__.py
|
||||||
|
"""
|
||||||
|
Memory Consolidation.
|
||||||
|
|
||||||
|
Transfers and extracts knowledge between memory tiers:
|
||||||
|
- Working -> Episodic (session end)
|
||||||
|
- Episodic -> Semantic (learn facts)
|
||||||
|
- Episodic -> Procedural (learn procedures)
|
||||||
|
|
||||||
|
Also handles memory pruning and importance-based retention.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .service import (
|
||||||
|
ConsolidationConfig,
|
||||||
|
ConsolidationResult,
|
||||||
|
MemoryConsolidationService,
|
||||||
|
NightlyConsolidationResult,
|
||||||
|
SessionConsolidationResult,
|
||||||
|
get_consolidation_service,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ConsolidationConfig",
|
||||||
|
"ConsolidationResult",
|
||||||
|
"MemoryConsolidationService",
|
||||||
|
"NightlyConsolidationResult",
|
||||||
|
"SessionConsolidationResult",
|
||||||
|
"get_consolidation_service",
|
||||||
|
]
|
||||||
913
backend/app/services/memory/consolidation/service.py
Normal file
913
backend/app/services/memory/consolidation/service.py
Normal file
@@ -0,0 +1,913 @@
|
|||||||
|
# app/services/memory/consolidation/service.py
|
||||||
|
"""
|
||||||
|
Memory Consolidation Service.
|
||||||
|
|
||||||
|
Transfers and extracts knowledge between memory tiers:
|
||||||
|
- Working -> Episodic (session end)
|
||||||
|
- Episodic -> Semantic (learn facts)
|
||||||
|
- Episodic -> Procedural (learn procedures)
|
||||||
|
|
||||||
|
Also handles memory pruning and importance-based retention.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
from typing import Any
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.services.memory.episodic.memory import EpisodicMemory
|
||||||
|
from app.services.memory.procedural.memory import ProceduralMemory
|
||||||
|
from app.services.memory.semantic.extraction import FactExtractor, get_fact_extractor
|
||||||
|
from app.services.memory.semantic.memory import SemanticMemory
|
||||||
|
from app.services.memory.types import (
|
||||||
|
Episode,
|
||||||
|
EpisodeCreate,
|
||||||
|
Outcome,
|
||||||
|
ProcedureCreate,
|
||||||
|
TaskState,
|
||||||
|
)
|
||||||
|
from app.services.memory.working.memory import WorkingMemory
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ConsolidationConfig:
|
||||||
|
"""Configuration for memory consolidation."""
|
||||||
|
|
||||||
|
# Working -> Episodic thresholds
|
||||||
|
min_steps_for_episode: int = 2
|
||||||
|
min_duration_seconds: float = 5.0
|
||||||
|
|
||||||
|
# Episodic -> Semantic thresholds
|
||||||
|
min_confidence_for_fact: float = 0.6
|
||||||
|
max_facts_per_episode: int = 10
|
||||||
|
reinforce_existing_facts: bool = True
|
||||||
|
|
||||||
|
# Episodic -> Procedural thresholds
|
||||||
|
min_episodes_for_procedure: int = 3
|
||||||
|
min_success_rate_for_procedure: float = 0.7
|
||||||
|
min_steps_for_procedure: int = 2
|
||||||
|
|
||||||
|
# Pruning thresholds
|
||||||
|
max_episode_age_days: int = 90
|
||||||
|
min_importance_to_keep: float = 0.2
|
||||||
|
keep_all_failures: bool = True
|
||||||
|
keep_all_with_lessons: bool = True
|
||||||
|
|
||||||
|
# Batch sizes
|
||||||
|
batch_size: int = 100
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ConsolidationResult:
|
||||||
|
"""Result of a consolidation operation."""
|
||||||
|
|
||||||
|
source_type: str
|
||||||
|
target_type: str
|
||||||
|
items_processed: int = 0
|
||||||
|
items_created: int = 0
|
||||||
|
items_updated: int = 0
|
||||||
|
items_skipped: int = 0
|
||||||
|
items_pruned: int = 0
|
||||||
|
errors: list[str] = field(default_factory=list)
|
||||||
|
duration_seconds: float = 0.0
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert to dictionary."""
|
||||||
|
return {
|
||||||
|
"source_type": self.source_type,
|
||||||
|
"target_type": self.target_type,
|
||||||
|
"items_processed": self.items_processed,
|
||||||
|
"items_created": self.items_created,
|
||||||
|
"items_updated": self.items_updated,
|
||||||
|
"items_skipped": self.items_skipped,
|
||||||
|
"items_pruned": self.items_pruned,
|
||||||
|
"errors": self.errors,
|
||||||
|
"duration_seconds": self.duration_seconds,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SessionConsolidationResult:
|
||||||
|
"""Result of consolidating a session's working memory to episodic."""
|
||||||
|
|
||||||
|
session_id: str
|
||||||
|
episode_created: bool = False
|
||||||
|
episode_id: UUID | None = None
|
||||||
|
scratchpad_entries: int = 0
|
||||||
|
variables_captured: int = 0
|
||||||
|
error: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class NightlyConsolidationResult:
|
||||||
|
"""Result of nightly consolidation run."""
|
||||||
|
|
||||||
|
started_at: datetime
|
||||||
|
completed_at: datetime | None = None
|
||||||
|
episodic_to_semantic: ConsolidationResult | None = None
|
||||||
|
episodic_to_procedural: ConsolidationResult | None = None
|
||||||
|
pruning: ConsolidationResult | None = None
|
||||||
|
total_episodes_processed: int = 0
|
||||||
|
total_facts_created: int = 0
|
||||||
|
total_procedures_created: int = 0
|
||||||
|
total_pruned: int = 0
|
||||||
|
errors: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert to dictionary."""
|
||||||
|
return {
|
||||||
|
"started_at": self.started_at.isoformat(),
|
||||||
|
"completed_at": self.completed_at.isoformat()
|
||||||
|
if self.completed_at
|
||||||
|
else None,
|
||||||
|
"episodic_to_semantic": (
|
||||||
|
self.episodic_to_semantic.to_dict()
|
||||||
|
if self.episodic_to_semantic
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
"episodic_to_procedural": (
|
||||||
|
self.episodic_to_procedural.to_dict()
|
||||||
|
if self.episodic_to_procedural
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
"pruning": self.pruning.to_dict() if self.pruning else None,
|
||||||
|
"total_episodes_processed": self.total_episodes_processed,
|
||||||
|
"total_facts_created": self.total_facts_created,
|
||||||
|
"total_procedures_created": self.total_procedures_created,
|
||||||
|
"total_pruned": self.total_pruned,
|
||||||
|
"errors": self.errors,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryConsolidationService:
|
||||||
|
"""
|
||||||
|
Service for consolidating memories between tiers.
|
||||||
|
|
||||||
|
Responsibilities:
|
||||||
|
- Transfer working memory to episodic at session end
|
||||||
|
- Extract facts from episodes to semantic memory
|
||||||
|
- Learn procedures from successful episode patterns
|
||||||
|
- Prune old/low-value memories
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
session: AsyncSession,
|
||||||
|
config: ConsolidationConfig | None = None,
|
||||||
|
embedding_generator: Any | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize consolidation service.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: Database session
|
||||||
|
config: Consolidation configuration
|
||||||
|
embedding_generator: Optional embedding generator
|
||||||
|
"""
|
||||||
|
self._session = session
|
||||||
|
self._config = config or ConsolidationConfig()
|
||||||
|
self._embedding_generator = embedding_generator
|
||||||
|
self._fact_extractor: FactExtractor = get_fact_extractor()
|
||||||
|
|
||||||
|
# Memory services (lazy initialized)
|
||||||
|
self._episodic: EpisodicMemory | None = None
|
||||||
|
self._semantic: SemanticMemory | None = None
|
||||||
|
self._procedural: ProceduralMemory | None = None
|
||||||
|
|
||||||
|
async def _get_episodic(self) -> EpisodicMemory:
|
||||||
|
"""Get or create episodic memory service."""
|
||||||
|
if self._episodic is None:
|
||||||
|
self._episodic = await EpisodicMemory.create(
|
||||||
|
self._session, self._embedding_generator
|
||||||
|
)
|
||||||
|
return self._episodic
|
||||||
|
|
||||||
|
async def _get_semantic(self) -> SemanticMemory:
|
||||||
|
"""Get or create semantic memory service."""
|
||||||
|
if self._semantic is None:
|
||||||
|
self._semantic = await SemanticMemory.create(
|
||||||
|
self._session, self._embedding_generator
|
||||||
|
)
|
||||||
|
return self._semantic
|
||||||
|
|
||||||
|
async def _get_procedural(self) -> ProceduralMemory:
|
||||||
|
"""Get or create procedural memory service."""
|
||||||
|
if self._procedural is None:
|
||||||
|
self._procedural = await ProceduralMemory.create(
|
||||||
|
self._session, self._embedding_generator
|
||||||
|
)
|
||||||
|
return self._procedural
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Working -> Episodic Consolidation
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
async def consolidate_session(
|
||||||
|
self,
|
||||||
|
working_memory: WorkingMemory,
|
||||||
|
project_id: UUID,
|
||||||
|
session_id: str,
|
||||||
|
task_type: str = "session_task",
|
||||||
|
agent_instance_id: UUID | None = None,
|
||||||
|
agent_type_id: UUID | None = None,
|
||||||
|
) -> SessionConsolidationResult:
|
||||||
|
"""
|
||||||
|
Consolidate a session's working memory to episodic memory.
|
||||||
|
|
||||||
|
Called at session end to transfer relevant session data
|
||||||
|
into a persistent episode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
working_memory: The session's working memory
|
||||||
|
project_id: Project ID
|
||||||
|
session_id: Session ID
|
||||||
|
task_type: Type of task performed
|
||||||
|
agent_instance_id: Optional agent instance
|
||||||
|
agent_type_id: Optional agent type
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SessionConsolidationResult with outcome details
|
||||||
|
"""
|
||||||
|
result = SessionConsolidationResult(session_id=session_id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get task state
|
||||||
|
task_state = await working_memory.get_task_state()
|
||||||
|
|
||||||
|
# Check if there's enough content to consolidate
|
||||||
|
if not self._should_consolidate_session(task_state):
|
||||||
|
logger.debug(
|
||||||
|
f"Skipping consolidation for session {session_id}: insufficient content"
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Gather scratchpad entries
|
||||||
|
scratchpad = await working_memory.get_scratchpad()
|
||||||
|
result.scratchpad_entries = len(scratchpad)
|
||||||
|
|
||||||
|
# Gather user variables
|
||||||
|
all_data = await working_memory.get_all()
|
||||||
|
result.variables_captured = len(all_data)
|
||||||
|
|
||||||
|
# Determine outcome
|
||||||
|
outcome = self._determine_session_outcome(task_state)
|
||||||
|
|
||||||
|
# Build actions from scratchpad and variables
|
||||||
|
actions = self._build_actions_from_session(scratchpad, all_data, task_state)
|
||||||
|
|
||||||
|
# Build context summary
|
||||||
|
context_summary = self._build_context_summary(task_state, all_data)
|
||||||
|
|
||||||
|
# Extract lessons learned
|
||||||
|
lessons = self._extract_session_lessons(task_state, outcome)
|
||||||
|
|
||||||
|
# Calculate importance
|
||||||
|
importance = self._calculate_session_importance(
|
||||||
|
task_state, outcome, actions
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create episode
|
||||||
|
episode_data = EpisodeCreate(
|
||||||
|
project_id=project_id,
|
||||||
|
session_id=session_id,
|
||||||
|
task_type=task_type,
|
||||||
|
task_description=task_state.description
|
||||||
|
if task_state
|
||||||
|
else "Session task",
|
||||||
|
actions=actions,
|
||||||
|
context_summary=context_summary,
|
||||||
|
outcome=outcome,
|
||||||
|
outcome_details=task_state.status if task_state else "",
|
||||||
|
duration_seconds=self._calculate_duration(task_state),
|
||||||
|
tokens_used=0, # Would need to track this in working memory
|
||||||
|
lessons_learned=lessons,
|
||||||
|
importance_score=importance,
|
||||||
|
agent_instance_id=agent_instance_id,
|
||||||
|
agent_type_id=agent_type_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
episodic = await self._get_episodic()
|
||||||
|
episode = await episodic.record_episode(episode_data)
|
||||||
|
|
||||||
|
result.episode_created = True
|
||||||
|
result.episode_id = episode.id
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Consolidated session {session_id} to episode {episode.id} "
|
||||||
|
f"({len(actions)} actions, outcome={outcome.value})"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
result.error = str(e)
|
||||||
|
logger.exception(f"Failed to consolidate session {session_id}")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _should_consolidate_session(self, task_state: TaskState | None) -> bool:
|
||||||
|
"""Check if session has enough content to consolidate."""
|
||||||
|
if task_state is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check minimum steps
|
||||||
|
if task_state.current_step < self._config.min_steps_for_episode:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _determine_session_outcome(self, task_state: TaskState | None) -> Outcome:
|
||||||
|
"""Determine outcome from task state."""
|
||||||
|
if task_state is None:
|
||||||
|
return Outcome.PARTIAL
|
||||||
|
|
||||||
|
status = task_state.status.lower() if task_state.status else ""
|
||||||
|
progress = task_state.progress_percent
|
||||||
|
|
||||||
|
if "success" in status or "complete" in status or progress >= 100:
|
||||||
|
return Outcome.SUCCESS
|
||||||
|
if "fail" in status or "error" in status:
|
||||||
|
return Outcome.FAILURE
|
||||||
|
if progress >= 50:
|
||||||
|
return Outcome.PARTIAL
|
||||||
|
return Outcome.FAILURE
|
||||||
|
|
||||||
|
def _build_actions_from_session(
|
||||||
|
self,
|
||||||
|
scratchpad: list[str],
|
||||||
|
variables: dict[str, Any],
|
||||||
|
task_state: TaskState | None,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Build action list from session data."""
|
||||||
|
actions: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
# Add scratchpad entries as actions
|
||||||
|
for i, entry in enumerate(scratchpad):
|
||||||
|
actions.append(
|
||||||
|
{
|
||||||
|
"step": i + 1,
|
||||||
|
"type": "reasoning",
|
||||||
|
"content": entry[:500], # Truncate long entries
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add final state
|
||||||
|
if task_state:
|
||||||
|
actions.append(
|
||||||
|
{
|
||||||
|
"step": len(scratchpad) + 1,
|
||||||
|
"type": "final_state",
|
||||||
|
"current_step": task_state.current_step,
|
||||||
|
"total_steps": task_state.total_steps,
|
||||||
|
"progress": task_state.progress_percent,
|
||||||
|
"status": task_state.status,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return actions
|
||||||
|
|
||||||
|
def _build_context_summary(
|
||||||
|
self,
|
||||||
|
task_state: TaskState | None,
|
||||||
|
variables: dict[str, Any],
|
||||||
|
) -> str:
|
||||||
|
"""Build context summary from session data."""
|
||||||
|
parts = []
|
||||||
|
|
||||||
|
if task_state:
|
||||||
|
parts.append(f"Task: {task_state.description}")
|
||||||
|
parts.append(f"Progress: {task_state.progress_percent:.1f}%")
|
||||||
|
parts.append(f"Steps: {task_state.current_step}/{task_state.total_steps}")
|
||||||
|
|
||||||
|
# Include key variables
|
||||||
|
key_vars = {k: v for k, v in variables.items() if len(str(v)) < 100}
|
||||||
|
if key_vars:
|
||||||
|
var_str = ", ".join(f"{k}={v}" for k, v in list(key_vars.items())[:5])
|
||||||
|
parts.append(f"Variables: {var_str}")
|
||||||
|
|
||||||
|
return "; ".join(parts) if parts else "Session completed"
|
||||||
|
|
||||||
|
def _extract_session_lessons(
|
||||||
|
self,
|
||||||
|
task_state: TaskState | None,
|
||||||
|
outcome: Outcome,
|
||||||
|
) -> list[str]:
|
||||||
|
"""Extract lessons from session."""
|
||||||
|
lessons: list[str] = []
|
||||||
|
|
||||||
|
if task_state and task_state.status:
|
||||||
|
if outcome == Outcome.FAILURE:
|
||||||
|
lessons.append(
|
||||||
|
f"Task failed at step {task_state.current_step}: {task_state.status}"
|
||||||
|
)
|
||||||
|
elif outcome == Outcome.SUCCESS:
|
||||||
|
lessons.append(
|
||||||
|
f"Successfully completed in {task_state.current_step} steps"
|
||||||
|
)
|
||||||
|
|
||||||
|
return lessons
|
||||||
|
|
||||||
|
def _calculate_session_importance(
|
||||||
|
self,
|
||||||
|
task_state: TaskState | None,
|
||||||
|
outcome: Outcome,
|
||||||
|
actions: list[dict[str, Any]],
|
||||||
|
) -> float:
|
||||||
|
"""Calculate importance score for session."""
|
||||||
|
score = 0.5 # Base score
|
||||||
|
|
||||||
|
# Failures are important to learn from
|
||||||
|
if outcome == Outcome.FAILURE:
|
||||||
|
score += 0.3
|
||||||
|
|
||||||
|
# Many steps means complex task
|
||||||
|
if task_state and task_state.total_steps >= 5:
|
||||||
|
score += 0.1
|
||||||
|
|
||||||
|
# Many actions means detailed reasoning
|
||||||
|
if len(actions) >= 5:
|
||||||
|
score += 0.1
|
||||||
|
|
||||||
|
return min(1.0, score)
|
||||||
|
|
||||||
|
def _calculate_duration(self, task_state: TaskState | None) -> float:
|
||||||
|
"""Calculate session duration."""
|
||||||
|
if task_state is None:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
if task_state.started_at and task_state.updated_at:
|
||||||
|
delta = task_state.updated_at - task_state.started_at
|
||||||
|
return delta.total_seconds()
|
||||||
|
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Episodic -> Semantic Consolidation
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
async def consolidate_episodes_to_facts(
|
||||||
|
self,
|
||||||
|
project_id: UUID,
|
||||||
|
since: datetime | None = None,
|
||||||
|
limit: int | None = None,
|
||||||
|
) -> ConsolidationResult:
|
||||||
|
"""
|
||||||
|
Extract facts from episodic memories to semantic memory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: Project to consolidate
|
||||||
|
since: Only process episodes since this time
|
||||||
|
limit: Maximum episodes to process
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ConsolidationResult with extraction statistics
|
||||||
|
"""
|
||||||
|
start_time = datetime.now(UTC)
|
||||||
|
result = ConsolidationResult(
|
||||||
|
source_type="episodic",
|
||||||
|
target_type="semantic",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
episodic = await self._get_episodic()
|
||||||
|
semantic = await self._get_semantic()
|
||||||
|
|
||||||
|
# Get episodes to process
|
||||||
|
since_time = since or datetime.now(UTC) - timedelta(days=1)
|
||||||
|
episodes = await episodic.get_recent(
|
||||||
|
project_id,
|
||||||
|
limit=limit or self._config.batch_size,
|
||||||
|
since=since_time,
|
||||||
|
)
|
||||||
|
|
||||||
|
for episode in episodes:
|
||||||
|
result.items_processed += 1
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Extract facts using the extractor
|
||||||
|
extracted_facts = self._fact_extractor.extract_from_episode(episode)
|
||||||
|
|
||||||
|
for extracted_fact in extracted_facts:
|
||||||
|
if (
|
||||||
|
extracted_fact.confidence
|
||||||
|
< self._config.min_confidence_for_fact
|
||||||
|
):
|
||||||
|
result.items_skipped += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Create fact (store_fact handles deduplication/reinforcement)
|
||||||
|
fact_create = extracted_fact.to_fact_create(
|
||||||
|
project_id=project_id,
|
||||||
|
source_episode_ids=[episode.id],
|
||||||
|
)
|
||||||
|
|
||||||
|
# store_fact automatically reinforces if fact already exists
|
||||||
|
fact = await semantic.store_fact(fact_create)
|
||||||
|
|
||||||
|
# Check if this was a new fact or reinforced existing
|
||||||
|
if fact.reinforcement_count == 1:
|
||||||
|
result.items_created += 1
|
||||||
|
else:
|
||||||
|
result.items_updated += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
result.errors.append(f"Episode {episode.id}: {e}")
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to extract facts from episode {episode.id}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
result.errors.append(f"Consolidation failed: {e}")
|
||||||
|
logger.exception("Failed episodic -> semantic consolidation")
|
||||||
|
|
||||||
|
result.duration_seconds = (datetime.now(UTC) - start_time).total_seconds()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Episodic -> Semantic consolidation: "
|
||||||
|
f"{result.items_processed} processed, "
|
||||||
|
f"{result.items_created} created, "
|
||||||
|
f"{result.items_updated} reinforced"
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Episodic -> Procedural Consolidation
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
async def consolidate_episodes_to_procedures(
|
||||||
|
self,
|
||||||
|
project_id: UUID,
|
||||||
|
agent_type_id: UUID | None = None,
|
||||||
|
since: datetime | None = None,
|
||||||
|
) -> ConsolidationResult:
|
||||||
|
"""
|
||||||
|
Learn procedures from patterns in episodic memories.
|
||||||
|
|
||||||
|
Identifies recurring successful patterns and creates/updates
|
||||||
|
procedures to capture them.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: Project to consolidate
|
||||||
|
agent_type_id: Optional filter by agent type
|
||||||
|
since: Only process episodes since this time
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ConsolidationResult with procedure statistics
|
||||||
|
"""
|
||||||
|
start_time = datetime.now(UTC)
|
||||||
|
result = ConsolidationResult(
|
||||||
|
source_type="episodic",
|
||||||
|
target_type="procedural",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
episodic = await self._get_episodic()
|
||||||
|
procedural = await self._get_procedural()
|
||||||
|
|
||||||
|
# Get successful episodes
|
||||||
|
since_time = since or datetime.now(UTC) - timedelta(days=7)
|
||||||
|
episodes = await episodic.get_by_outcome(
|
||||||
|
project_id,
|
||||||
|
outcome=Outcome.SUCCESS,
|
||||||
|
limit=self._config.batch_size,
|
||||||
|
agent_instance_id=None, # Get all agent instances
|
||||||
|
)
|
||||||
|
|
||||||
|
# Group by task type
|
||||||
|
task_groups: dict[str, list[Episode]] = {}
|
||||||
|
for episode in episodes:
|
||||||
|
if episode.occurred_at >= since_time:
|
||||||
|
if episode.task_type not in task_groups:
|
||||||
|
task_groups[episode.task_type] = []
|
||||||
|
task_groups[episode.task_type].append(episode)
|
||||||
|
|
||||||
|
result.items_processed = len(episodes)
|
||||||
|
|
||||||
|
# Process each task type group
|
||||||
|
for task_type, group in task_groups.items():
|
||||||
|
if len(group) < self._config.min_episodes_for_procedure:
|
||||||
|
result.items_skipped += len(group)
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
procedure_result = await self._learn_procedure_from_episodes(
|
||||||
|
procedural,
|
||||||
|
project_id,
|
||||||
|
agent_type_id,
|
||||||
|
task_type,
|
||||||
|
group,
|
||||||
|
)
|
||||||
|
|
||||||
|
if procedure_result == "created":
|
||||||
|
result.items_created += 1
|
||||||
|
elif procedure_result == "updated":
|
||||||
|
result.items_updated += 1
|
||||||
|
else:
|
||||||
|
result.items_skipped += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
result.errors.append(f"Task type '{task_type}': {e}")
|
||||||
|
logger.warning(f"Failed to learn procedure for '{task_type}': {e}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
result.errors.append(f"Consolidation failed: {e}")
|
||||||
|
logger.exception("Failed episodic -> procedural consolidation")
|
||||||
|
|
||||||
|
result.duration_seconds = (datetime.now(UTC) - start_time).total_seconds()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Episodic -> Procedural consolidation: "
|
||||||
|
f"{result.items_processed} processed, "
|
||||||
|
f"{result.items_created} created, "
|
||||||
|
f"{result.items_updated} updated"
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def _learn_procedure_from_episodes(
|
||||||
|
self,
|
||||||
|
procedural: ProceduralMemory,
|
||||||
|
project_id: UUID,
|
||||||
|
agent_type_id: UUID | None,
|
||||||
|
task_type: str,
|
||||||
|
episodes: list[Episode],
|
||||||
|
) -> str:
|
||||||
|
"""Learn or update a procedure from a set of episodes."""
|
||||||
|
# Calculate success rate for this pattern
|
||||||
|
success_count = sum(1 for e in episodes if e.outcome == Outcome.SUCCESS)
|
||||||
|
total_count = len(episodes)
|
||||||
|
success_rate = success_count / total_count if total_count > 0 else 0
|
||||||
|
|
||||||
|
if success_rate < self._config.min_success_rate_for_procedure:
|
||||||
|
return "skipped"
|
||||||
|
|
||||||
|
# Extract common steps from episodes
|
||||||
|
steps = self._extract_common_steps(episodes)
|
||||||
|
|
||||||
|
if len(steps) < self._config.min_steps_for_procedure:
|
||||||
|
return "skipped"
|
||||||
|
|
||||||
|
# Check for existing procedure
|
||||||
|
matching = await procedural.find_matching(
|
||||||
|
context=task_type,
|
||||||
|
project_id=project_id,
|
||||||
|
agent_type_id=agent_type_id,
|
||||||
|
limit=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
if matching:
|
||||||
|
# Update existing procedure with new success
|
||||||
|
await procedural.record_outcome(
|
||||||
|
matching[0].id,
|
||||||
|
success=True,
|
||||||
|
)
|
||||||
|
return "updated"
|
||||||
|
else:
|
||||||
|
# Create new procedure
|
||||||
|
# Note: success_count starts at 1 in record_procedure
|
||||||
|
procedure_data = ProcedureCreate(
|
||||||
|
project_id=project_id,
|
||||||
|
agent_type_id=agent_type_id,
|
||||||
|
name=f"Procedure for {task_type}",
|
||||||
|
trigger_pattern=task_type,
|
||||||
|
steps=steps,
|
||||||
|
)
|
||||||
|
await procedural.record_procedure(procedure_data)
|
||||||
|
return "created"
|
||||||
|
|
||||||
|
def _extract_common_steps(self, episodes: list[Episode]) -> list[dict[str, Any]]:
|
||||||
|
"""Extract common action steps from multiple episodes."""
|
||||||
|
# Simple heuristic: take the steps from the most successful episode
|
||||||
|
# with the most detailed actions
|
||||||
|
|
||||||
|
best_episode = max(
|
||||||
|
episodes,
|
||||||
|
key=lambda e: (
|
||||||
|
e.outcome == Outcome.SUCCESS,
|
||||||
|
len(e.actions),
|
||||||
|
e.importance_score,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
steps: list[dict[str, Any]] = []
|
||||||
|
for i, action in enumerate(best_episode.actions):
|
||||||
|
step = {
|
||||||
|
"order": i + 1,
|
||||||
|
"action": action.get("type", "action"),
|
||||||
|
"description": action.get("content", str(action))[:500],
|
||||||
|
"parameters": action,
|
||||||
|
}
|
||||||
|
steps.append(step)
|
||||||
|
|
||||||
|
return steps
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Memory Pruning
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
async def prune_old_episodes(
|
||||||
|
self,
|
||||||
|
project_id: UUID,
|
||||||
|
max_age_days: int | None = None,
|
||||||
|
min_importance: float | None = None,
|
||||||
|
) -> ConsolidationResult:
|
||||||
|
"""
|
||||||
|
Prune old, low-value episodes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: Project to prune
|
||||||
|
max_age_days: Maximum age in days (default from config)
|
||||||
|
min_importance: Minimum importance to keep (default from config)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ConsolidationResult with pruning statistics
|
||||||
|
"""
|
||||||
|
start_time = datetime.now(UTC)
|
||||||
|
result = ConsolidationResult(
|
||||||
|
source_type="episodic",
|
||||||
|
target_type="pruned",
|
||||||
|
)
|
||||||
|
|
||||||
|
max_age = max_age_days or self._config.max_episode_age_days
|
||||||
|
min_imp = min_importance or self._config.min_importance_to_keep
|
||||||
|
cutoff_date = datetime.now(UTC) - timedelta(days=max_age)
|
||||||
|
|
||||||
|
try:
|
||||||
|
episodic = await self._get_episodic()
|
||||||
|
|
||||||
|
# Get old episodes
|
||||||
|
# Note: In production, this would use a more efficient query
|
||||||
|
all_episodes = await episodic.get_recent(
|
||||||
|
project_id,
|
||||||
|
limit=self._config.batch_size * 10,
|
||||||
|
since=cutoff_date - timedelta(days=365), # Search past year
|
||||||
|
)
|
||||||
|
|
||||||
|
for episode in all_episodes:
|
||||||
|
result.items_processed += 1
|
||||||
|
|
||||||
|
# Check if should be pruned
|
||||||
|
if not self._should_prune_episode(episode, cutoff_date, min_imp):
|
||||||
|
result.items_skipped += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
deleted = await episodic.delete(episode.id)
|
||||||
|
if deleted:
|
||||||
|
result.items_pruned += 1
|
||||||
|
else:
|
||||||
|
result.items_skipped += 1
|
||||||
|
except Exception as e:
|
||||||
|
result.errors.append(f"Episode {episode.id}: {e}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
result.errors.append(f"Pruning failed: {e}")
|
||||||
|
logger.exception("Failed episode pruning")
|
||||||
|
|
||||||
|
result.duration_seconds = (datetime.now(UTC) - start_time).total_seconds()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Episode pruning: {result.items_processed} processed, "
|
||||||
|
f"{result.items_pruned} pruned"
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _should_prune_episode(
|
||||||
|
self,
|
||||||
|
episode: Episode,
|
||||||
|
cutoff_date: datetime,
|
||||||
|
min_importance: float,
|
||||||
|
) -> bool:
|
||||||
|
"""Determine if an episode should be pruned."""
|
||||||
|
# Keep recent episodes
|
||||||
|
if episode.occurred_at >= cutoff_date:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Keep failures if configured
|
||||||
|
if self._config.keep_all_failures and episode.outcome == Outcome.FAILURE:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Keep episodes with lessons if configured
|
||||||
|
if self._config.keep_all_with_lessons and episode.lessons_learned:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Keep high-importance episodes
|
||||||
|
if episode.importance_score >= min_importance:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Nightly Consolidation
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
async def run_nightly_consolidation(
|
||||||
|
self,
|
||||||
|
project_id: UUID,
|
||||||
|
agent_type_id: UUID | None = None,
|
||||||
|
) -> NightlyConsolidationResult:
|
||||||
|
"""
|
||||||
|
Run full nightly consolidation workflow.
|
||||||
|
|
||||||
|
This includes:
|
||||||
|
1. Extract facts from recent episodes
|
||||||
|
2. Learn procedures from successful patterns
|
||||||
|
3. Prune old, low-value memories
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: Project to consolidate
|
||||||
|
agent_type_id: Optional agent type filter
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
NightlyConsolidationResult with all outcomes
|
||||||
|
"""
|
||||||
|
result = NightlyConsolidationResult(started_at=datetime.now(UTC))
|
||||||
|
|
||||||
|
logger.info(f"Starting nightly consolidation for project {project_id}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Step 1: Episodic -> Semantic (last 24 hours)
|
||||||
|
since_yesterday = datetime.now(UTC) - timedelta(days=1)
|
||||||
|
result.episodic_to_semantic = await self.consolidate_episodes_to_facts(
|
||||||
|
project_id=project_id,
|
||||||
|
since=since_yesterday,
|
||||||
|
)
|
||||||
|
result.total_facts_created = result.episodic_to_semantic.items_created
|
||||||
|
|
||||||
|
# Step 2: Episodic -> Procedural (last 7 days)
|
||||||
|
since_week = datetime.now(UTC) - timedelta(days=7)
|
||||||
|
result.episodic_to_procedural = (
|
||||||
|
await self.consolidate_episodes_to_procedures(
|
||||||
|
project_id=project_id,
|
||||||
|
agent_type_id=agent_type_id,
|
||||||
|
since=since_week,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result.total_procedures_created = (
|
||||||
|
result.episodic_to_procedural.items_created
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 3: Prune old memories
|
||||||
|
result.pruning = await self.prune_old_episodes(project_id=project_id)
|
||||||
|
result.total_pruned = result.pruning.items_pruned
|
||||||
|
|
||||||
|
# Calculate totals
|
||||||
|
result.total_episodes_processed = (
|
||||||
|
result.episodic_to_semantic.items_processed
|
||||||
|
if result.episodic_to_semantic
|
||||||
|
else 0
|
||||||
|
) + (
|
||||||
|
result.episodic_to_procedural.items_processed
|
||||||
|
if result.episodic_to_procedural
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Collect all errors
|
||||||
|
if result.episodic_to_semantic and result.episodic_to_semantic.errors:
|
||||||
|
result.errors.extend(result.episodic_to_semantic.errors)
|
||||||
|
if result.episodic_to_procedural and result.episodic_to_procedural.errors:
|
||||||
|
result.errors.extend(result.episodic_to_procedural.errors)
|
||||||
|
if result.pruning and result.pruning.errors:
|
||||||
|
result.errors.extend(result.pruning.errors)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
result.errors.append(f"Nightly consolidation failed: {e}")
|
||||||
|
logger.exception("Nightly consolidation failed")
|
||||||
|
|
||||||
|
result.completed_at = datetime.now(UTC)
|
||||||
|
|
||||||
|
duration = (result.completed_at - result.started_at).total_seconds()
|
||||||
|
logger.info(
|
||||||
|
f"Nightly consolidation completed in {duration:.1f}s: "
|
||||||
|
f"{result.total_facts_created} facts, "
|
||||||
|
f"{result.total_procedures_created} procedures, "
|
||||||
|
f"{result.total_pruned} pruned"
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# Factory function - no singleton to avoid stale session issues
|
||||||
|
async def get_consolidation_service(
|
||||||
|
session: AsyncSession,
|
||||||
|
config: ConsolidationConfig | None = None,
|
||||||
|
) -> MemoryConsolidationService:
|
||||||
|
"""
|
||||||
|
Create a memory consolidation service for the given session.
|
||||||
|
|
||||||
|
Note: This creates a new instance each time to avoid stale session issues.
|
||||||
|
The service is lightweight and safe to recreate per-request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: Database session (must be active)
|
||||||
|
config: Optional configuration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MemoryConsolidationService instance
|
||||||
|
"""
|
||||||
|
return MemoryConsolidationService(session=session, config=config)
|
||||||
17
backend/app/services/memory/episodic/__init__.py
Normal file
17
backend/app/services/memory/episodic/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
# app/services/memory/episodic/__init__.py
|
||||||
|
"""
|
||||||
|
Episodic Memory Package.
|
||||||
|
|
||||||
|
Provides experiential memory storage and retrieval for agent learning.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .memory import EpisodicMemory
|
||||||
|
from .recorder import EpisodeRecorder
|
||||||
|
from .retrieval import EpisodeRetriever, RetrievalStrategy
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"EpisodeRecorder",
|
||||||
|
"EpisodeRetriever",
|
||||||
|
"EpisodicMemory",
|
||||||
|
"RetrievalStrategy",
|
||||||
|
]
|
||||||
490
backend/app/services/memory/episodic/memory.py
Normal file
490
backend/app/services/memory/episodic/memory.py
Normal file
@@ -0,0 +1,490 @@
|
|||||||
|
# app/services/memory/episodic/memory.py
|
||||||
|
"""
|
||||||
|
Episodic Memory Implementation.
|
||||||
|
|
||||||
|
Provides experiential memory storage and retrieval for agent learning.
|
||||||
|
Combines episode recording and retrieval into a unified interface.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.services.memory.types import Episode, EpisodeCreate, Outcome, RetrievalResult
|
||||||
|
|
||||||
|
from .recorder import EpisodeRecorder
|
||||||
|
from .retrieval import EpisodeRetriever, RetrievalStrategy
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class EpisodicMemory:
|
||||||
|
"""
|
||||||
|
Episodic Memory Service.
|
||||||
|
|
||||||
|
Provides experiential memory for agent learning:
|
||||||
|
- Record task completions with context
|
||||||
|
- Store failures with error context
|
||||||
|
- Retrieve by semantic similarity
|
||||||
|
- Retrieve by recency, outcome, task type
|
||||||
|
- Track importance scores
|
||||||
|
- Extract lessons learned
|
||||||
|
|
||||||
|
Performance target: <100ms P95 for retrieval
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
session: AsyncSession,
|
||||||
|
embedding_generator: Any | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize episodic memory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: Database session
|
||||||
|
embedding_generator: Optional embedding generator for semantic search
|
||||||
|
"""
|
||||||
|
self._session = session
|
||||||
|
self._embedding_generator = embedding_generator
|
||||||
|
self._recorder = EpisodeRecorder(session, embedding_generator)
|
||||||
|
self._retriever = EpisodeRetriever(session, embedding_generator)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def create(
|
||||||
|
cls,
|
||||||
|
session: AsyncSession,
|
||||||
|
embedding_generator: Any | None = None,
|
||||||
|
) -> "EpisodicMemory":
|
||||||
|
"""
|
||||||
|
Factory method to create EpisodicMemory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: Database session
|
||||||
|
embedding_generator: Optional embedding generator
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured EpisodicMemory instance
|
||||||
|
"""
|
||||||
|
return cls(session=session, embedding_generator=embedding_generator)
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Recording Operations
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
async def record_episode(self, episode: EpisodeCreate) -> Episode:
|
||||||
|
"""
|
||||||
|
Record a new episode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
episode: Episode data to record
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created episode with assigned ID
|
||||||
|
"""
|
||||||
|
return await self._recorder.record(episode)
|
||||||
|
|
||||||
|
async def record_success(
|
||||||
|
self,
|
||||||
|
project_id: UUID,
|
||||||
|
session_id: str,
|
||||||
|
task_type: str,
|
||||||
|
task_description: str,
|
||||||
|
actions: list[dict[str, Any]],
|
||||||
|
context_summary: str,
|
||||||
|
outcome_details: str = "",
|
||||||
|
duration_seconds: float = 0.0,
|
||||||
|
tokens_used: int = 0,
|
||||||
|
lessons_learned: list[str] | None = None,
|
||||||
|
agent_instance_id: UUID | None = None,
|
||||||
|
agent_type_id: UUID | None = None,
|
||||||
|
) -> Episode:
|
||||||
|
"""
|
||||||
|
Convenience method to record a successful episode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: Project ID
|
||||||
|
session_id: Session ID
|
||||||
|
task_type: Type of task
|
||||||
|
task_description: Task description
|
||||||
|
actions: Actions taken
|
||||||
|
context_summary: Context summary
|
||||||
|
outcome_details: Optional outcome details
|
||||||
|
duration_seconds: Task duration
|
||||||
|
tokens_used: Tokens consumed
|
||||||
|
lessons_learned: Optional lessons
|
||||||
|
agent_instance_id: Optional agent instance
|
||||||
|
agent_type_id: Optional agent type
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created episode
|
||||||
|
"""
|
||||||
|
episode_data = EpisodeCreate(
|
||||||
|
project_id=project_id,
|
||||||
|
session_id=session_id,
|
||||||
|
task_type=task_type,
|
||||||
|
task_description=task_description,
|
||||||
|
actions=actions,
|
||||||
|
context_summary=context_summary,
|
||||||
|
outcome=Outcome.SUCCESS,
|
||||||
|
outcome_details=outcome_details,
|
||||||
|
duration_seconds=duration_seconds,
|
||||||
|
tokens_used=tokens_used,
|
||||||
|
lessons_learned=lessons_learned or [],
|
||||||
|
agent_instance_id=agent_instance_id,
|
||||||
|
agent_type_id=agent_type_id,
|
||||||
|
)
|
||||||
|
return await self.record_episode(episode_data)
|
||||||
|
|
||||||
|
async def record_failure(
|
||||||
|
self,
|
||||||
|
project_id: UUID,
|
||||||
|
session_id: str,
|
||||||
|
task_type: str,
|
||||||
|
task_description: str,
|
||||||
|
actions: list[dict[str, Any]],
|
||||||
|
context_summary: str,
|
||||||
|
error_details: str,
|
||||||
|
duration_seconds: float = 0.0,
|
||||||
|
tokens_used: int = 0,
|
||||||
|
lessons_learned: list[str] | None = None,
|
||||||
|
agent_instance_id: UUID | None = None,
|
||||||
|
agent_type_id: UUID | None = None,
|
||||||
|
) -> Episode:
|
||||||
|
"""
|
||||||
|
Convenience method to record a failed episode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: Project ID
|
||||||
|
session_id: Session ID
|
||||||
|
task_type: Type of task
|
||||||
|
task_description: Task description
|
||||||
|
actions: Actions taken before failure
|
||||||
|
context_summary: Context summary
|
||||||
|
error_details: Error details
|
||||||
|
duration_seconds: Task duration
|
||||||
|
tokens_used: Tokens consumed
|
||||||
|
lessons_learned: Optional lessons from failure
|
||||||
|
agent_instance_id: Optional agent instance
|
||||||
|
agent_type_id: Optional agent type
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created episode
|
||||||
|
"""
|
||||||
|
episode_data = EpisodeCreate(
|
||||||
|
project_id=project_id,
|
||||||
|
session_id=session_id,
|
||||||
|
task_type=task_type,
|
||||||
|
task_description=task_description,
|
||||||
|
actions=actions,
|
||||||
|
context_summary=context_summary,
|
||||||
|
outcome=Outcome.FAILURE,
|
||||||
|
outcome_details=error_details,
|
||||||
|
duration_seconds=duration_seconds,
|
||||||
|
tokens_used=tokens_used,
|
||||||
|
lessons_learned=lessons_learned or [],
|
||||||
|
agent_instance_id=agent_instance_id,
|
||||||
|
agent_type_id=agent_type_id,
|
||||||
|
)
|
||||||
|
return await self.record_episode(episode_data)
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Retrieval Operations
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
async def search_similar(
|
||||||
|
self,
|
||||||
|
project_id: UUID,
|
||||||
|
query: str,
|
||||||
|
limit: int = 10,
|
||||||
|
agent_instance_id: UUID | None = None,
|
||||||
|
) -> list[Episode]:
|
||||||
|
"""
|
||||||
|
Search for semantically similar episodes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: Project to search within
|
||||||
|
query: Search query
|
||||||
|
limit: Maximum results
|
||||||
|
agent_instance_id: Optional filter by agent instance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of similar episodes
|
||||||
|
"""
|
||||||
|
result = await self._retriever.search_similar(
|
||||||
|
project_id, query, limit, agent_instance_id
|
||||||
|
)
|
||||||
|
return result.items
|
||||||
|
|
||||||
|
async def get_recent(
|
||||||
|
self,
|
||||||
|
project_id: UUID,
|
||||||
|
limit: int = 10,
|
||||||
|
since: datetime | None = None,
|
||||||
|
agent_instance_id: UUID | None = None,
|
||||||
|
) -> list[Episode]:
|
||||||
|
"""
|
||||||
|
Get recent episodes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: Project to search within
|
||||||
|
limit: Maximum results
|
||||||
|
since: Optional time filter
|
||||||
|
agent_instance_id: Optional filter by agent instance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of recent episodes
|
||||||
|
"""
|
||||||
|
result = await self._retriever.get_recent(
|
||||||
|
project_id, limit, since, agent_instance_id
|
||||||
|
)
|
||||||
|
return result.items
|
||||||
|
|
||||||
|
async def get_by_outcome(
|
||||||
|
self,
|
||||||
|
project_id: UUID,
|
||||||
|
outcome: Outcome,
|
||||||
|
limit: int = 10,
|
||||||
|
agent_instance_id: UUID | None = None,
|
||||||
|
) -> list[Episode]:
|
||||||
|
"""
|
||||||
|
Get episodes by outcome.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: Project to search within
|
||||||
|
outcome: Outcome to filter by
|
||||||
|
limit: Maximum results
|
||||||
|
agent_instance_id: Optional filter by agent instance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of episodes with specified outcome
|
||||||
|
"""
|
||||||
|
result = await self._retriever.get_by_outcome(
|
||||||
|
project_id, outcome, limit, agent_instance_id
|
||||||
|
)
|
||||||
|
return result.items
|
||||||
|
|
||||||
|
async def get_by_task_type(
|
||||||
|
self,
|
||||||
|
project_id: UUID,
|
||||||
|
task_type: str,
|
||||||
|
limit: int = 10,
|
||||||
|
agent_instance_id: UUID | None = None,
|
||||||
|
) -> list[Episode]:
|
||||||
|
"""
|
||||||
|
Get episodes by task type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: Project to search within
|
||||||
|
task_type: Task type to filter by
|
||||||
|
limit: Maximum results
|
||||||
|
agent_instance_id: Optional filter by agent instance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of episodes with specified task type
|
||||||
|
"""
|
||||||
|
result = await self._retriever.get_by_task_type(
|
||||||
|
project_id, task_type, limit, agent_instance_id
|
||||||
|
)
|
||||||
|
return result.items
|
||||||
|
|
||||||
|
async def get_important(
|
||||||
|
self,
|
||||||
|
project_id: UUID,
|
||||||
|
limit: int = 10,
|
||||||
|
min_importance: float = 0.7,
|
||||||
|
agent_instance_id: UUID | None = None,
|
||||||
|
) -> list[Episode]:
|
||||||
|
"""
|
||||||
|
Get high-importance episodes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: Project to search within
|
||||||
|
limit: Maximum results
|
||||||
|
min_importance: Minimum importance score
|
||||||
|
agent_instance_id: Optional filter by agent instance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of important episodes
|
||||||
|
"""
|
||||||
|
result = await self._retriever.get_important(
|
||||||
|
project_id, limit, min_importance, agent_instance_id
|
||||||
|
)
|
||||||
|
return result.items
|
||||||
|
|
||||||
|
async def retrieve(
|
||||||
|
self,
|
||||||
|
project_id: UUID,
|
||||||
|
strategy: RetrievalStrategy = RetrievalStrategy.RECENCY,
|
||||||
|
limit: int = 10,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> RetrievalResult[Episode]:
|
||||||
|
"""
|
||||||
|
Retrieve episodes with full result metadata.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: Project to search within
|
||||||
|
strategy: Retrieval strategy
|
||||||
|
limit: Maximum results
|
||||||
|
**kwargs: Strategy-specific parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RetrievalResult with episodes and metadata
|
||||||
|
"""
|
||||||
|
return await self._retriever.retrieve(project_id, strategy, limit, **kwargs)
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Modification Operations
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
async def get_by_id(self, episode_id: UUID) -> Episode | None:
|
||||||
|
"""Get an episode by ID."""
|
||||||
|
return await self._recorder.get_by_id(episode_id)
|
||||||
|
|
||||||
|
async def update_importance(
|
||||||
|
self,
|
||||||
|
episode_id: UUID,
|
||||||
|
importance_score: float,
|
||||||
|
) -> Episode | None:
|
||||||
|
"""
|
||||||
|
Update an episode's importance score.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
episode_id: Episode to update
|
||||||
|
importance_score: New importance score (0.0 to 1.0)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated episode or None if not found
|
||||||
|
"""
|
||||||
|
return await self._recorder.update_importance(episode_id, importance_score)
|
||||||
|
|
||||||
|
async def add_lessons(
|
||||||
|
self,
|
||||||
|
episode_id: UUID,
|
||||||
|
lessons: list[str],
|
||||||
|
) -> Episode | None:
|
||||||
|
"""
|
||||||
|
Add lessons learned to an episode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
episode_id: Episode to update
|
||||||
|
lessons: Lessons to add
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated episode or None if not found
|
||||||
|
"""
|
||||||
|
return await self._recorder.add_lessons(episode_id, lessons)
|
||||||
|
|
||||||
|
async def delete(self, episode_id: UUID) -> bool:
|
||||||
|
"""
|
||||||
|
Delete an episode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
episode_id: Episode to delete
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if deleted
|
||||||
|
"""
|
||||||
|
return await self._recorder.delete(episode_id)
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Summarization
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
async def summarize_episodes(
|
||||||
|
self,
|
||||||
|
episode_ids: list[UUID],
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Summarize multiple episodes into a consolidated view.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
episode_ids: Episodes to summarize
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Summary text
|
||||||
|
"""
|
||||||
|
if not episode_ids:
|
||||||
|
return "No episodes to summarize."
|
||||||
|
|
||||||
|
episodes: list[Episode] = []
|
||||||
|
for episode_id in episode_ids:
|
||||||
|
episode = await self.get_by_id(episode_id)
|
||||||
|
if episode:
|
||||||
|
episodes.append(episode)
|
||||||
|
|
||||||
|
if not episodes:
|
||||||
|
return "No episodes found."
|
||||||
|
|
||||||
|
# Build summary
|
||||||
|
lines = [f"Summary of {len(episodes)} episodes:", ""]
|
||||||
|
|
||||||
|
# Outcome breakdown
|
||||||
|
success = sum(1 for e in episodes if e.outcome == Outcome.SUCCESS)
|
||||||
|
failure = sum(1 for e in episodes if e.outcome == Outcome.FAILURE)
|
||||||
|
partial = sum(1 for e in episodes if e.outcome == Outcome.PARTIAL)
|
||||||
|
lines.append(
|
||||||
|
f"Outcomes: {success} success, {failure} failure, {partial} partial"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Task types
|
||||||
|
task_types = {e.task_type for e in episodes}
|
||||||
|
lines.append(f"Task types: {', '.join(sorted(task_types))}")
|
||||||
|
|
||||||
|
# Aggregate lessons
|
||||||
|
all_lessons: list[str] = []
|
||||||
|
for e in episodes:
|
||||||
|
all_lessons.extend(e.lessons_learned)
|
||||||
|
|
||||||
|
if all_lessons:
|
||||||
|
lines.append("")
|
||||||
|
lines.append("Key lessons learned:")
|
||||||
|
# Deduplicate lessons
|
||||||
|
unique_lessons = list(dict.fromkeys(all_lessons))
|
||||||
|
for lesson in unique_lessons[:10]: # Top 10
|
||||||
|
lines.append(f" - {lesson}")
|
||||||
|
|
||||||
|
# Duration and tokens
|
||||||
|
total_duration = sum(e.duration_seconds for e in episodes)
|
||||||
|
total_tokens = sum(e.tokens_used for e in episodes)
|
||||||
|
lines.append("")
|
||||||
|
lines.append(f"Total duration: {total_duration:.1f}s")
|
||||||
|
lines.append(f"Total tokens: {total_tokens:,}")
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Statistics
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
async def get_stats(self, project_id: UUID) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get episode statistics for a project.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: Project to get stats for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with episode statistics
|
||||||
|
"""
|
||||||
|
return await self._recorder.get_stats(project_id)
|
||||||
|
|
||||||
|
async def count(
|
||||||
|
self,
|
||||||
|
project_id: UUID,
|
||||||
|
since: datetime | None = None,
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Count episodes for a project.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: Project to count for
|
||||||
|
since: Optional time filter
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of episodes
|
||||||
|
"""
|
||||||
|
return await self._recorder.count_by_project(project_id, since)
|
||||||
357
backend/app/services/memory/episodic/recorder.py
Normal file
357
backend/app/services/memory/episodic/recorder.py
Normal file
@@ -0,0 +1,357 @@
|
|||||||
|
# app/services/memory/episodic/recorder.py
|
||||||
|
"""
|
||||||
|
Episode Recording.
|
||||||
|
|
||||||
|
Handles the creation and storage of episodic memories
|
||||||
|
during agent task execution.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from typing import Any
|
||||||
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
|
from sqlalchemy import select, update
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.models.memory.enums import EpisodeOutcome
|
||||||
|
from app.models.memory.episode import Episode as EpisodeModel
|
||||||
|
from app.services.memory.config import get_memory_settings
|
||||||
|
from app.services.memory.types import Episode, EpisodeCreate, Outcome
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _outcome_to_db(outcome: Outcome) -> EpisodeOutcome:
|
||||||
|
"""Convert service Outcome to database EpisodeOutcome."""
|
||||||
|
return EpisodeOutcome(outcome.value)
|
||||||
|
|
||||||
|
|
||||||
|
def _db_to_outcome(db_outcome: EpisodeOutcome) -> Outcome:
|
||||||
|
"""Convert database EpisodeOutcome to service Outcome."""
|
||||||
|
return Outcome(db_outcome.value)
|
||||||
|
|
||||||
|
|
||||||
|
def _model_to_episode(model: EpisodeModel) -> Episode:
|
||||||
|
"""Convert SQLAlchemy model to Episode dataclass."""
|
||||||
|
# SQLAlchemy Column types are inferred as Column[T] by mypy, but at runtime
|
||||||
|
# they return actual values. We use type: ignore to handle this mismatch.
|
||||||
|
return Episode(
|
||||||
|
id=model.id, # type: ignore[arg-type]
|
||||||
|
project_id=model.project_id, # type: ignore[arg-type]
|
||||||
|
agent_instance_id=model.agent_instance_id, # type: ignore[arg-type]
|
||||||
|
agent_type_id=model.agent_type_id, # type: ignore[arg-type]
|
||||||
|
session_id=model.session_id, # type: ignore[arg-type]
|
||||||
|
task_type=model.task_type, # type: ignore[arg-type]
|
||||||
|
task_description=model.task_description, # type: ignore[arg-type]
|
||||||
|
actions=model.actions or [], # type: ignore[arg-type]
|
||||||
|
context_summary=model.context_summary, # type: ignore[arg-type]
|
||||||
|
outcome=_db_to_outcome(model.outcome), # type: ignore[arg-type]
|
||||||
|
outcome_details=model.outcome_details or "", # type: ignore[arg-type]
|
||||||
|
duration_seconds=model.duration_seconds, # type: ignore[arg-type]
|
||||||
|
tokens_used=model.tokens_used, # type: ignore[arg-type]
|
||||||
|
lessons_learned=model.lessons_learned or [], # type: ignore[arg-type]
|
||||||
|
importance_score=model.importance_score, # type: ignore[arg-type]
|
||||||
|
embedding=None, # Don't expose raw embedding
|
||||||
|
occurred_at=model.occurred_at, # type: ignore[arg-type]
|
||||||
|
created_at=model.created_at, # type: ignore[arg-type]
|
||||||
|
updated_at=model.updated_at, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class EpisodeRecorder:
|
||||||
|
"""
|
||||||
|
Records episodes to the database.
|
||||||
|
|
||||||
|
Handles episode creation, importance scoring,
|
||||||
|
and lesson extraction.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
session: AsyncSession,
|
||||||
|
embedding_generator: Any | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize recorder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: Database session
|
||||||
|
embedding_generator: Optional embedding generator for semantic indexing
|
||||||
|
"""
|
||||||
|
self._session = session
|
||||||
|
self._embedding_generator = embedding_generator
|
||||||
|
self._settings = get_memory_settings()
|
||||||
|
|
||||||
|
async def record(self, episode_data: EpisodeCreate) -> Episode:
|
||||||
|
"""
|
||||||
|
Record a new episode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
episode_data: Episode data to record
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created episode
|
||||||
|
"""
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
|
||||||
|
# Calculate importance score if not provided
|
||||||
|
importance = episode_data.importance_score
|
||||||
|
if importance == 0.5: # Default value, calculate
|
||||||
|
importance = self._calculate_importance(episode_data)
|
||||||
|
|
||||||
|
# Create the model
|
||||||
|
model = EpisodeModel(
|
||||||
|
id=uuid4(),
|
||||||
|
project_id=episode_data.project_id,
|
||||||
|
agent_instance_id=episode_data.agent_instance_id,
|
||||||
|
agent_type_id=episode_data.agent_type_id,
|
||||||
|
session_id=episode_data.session_id,
|
||||||
|
task_type=episode_data.task_type,
|
||||||
|
task_description=episode_data.task_description,
|
||||||
|
actions=episode_data.actions,
|
||||||
|
context_summary=episode_data.context_summary,
|
||||||
|
outcome=_outcome_to_db(episode_data.outcome),
|
||||||
|
outcome_details=episode_data.outcome_details,
|
||||||
|
duration_seconds=episode_data.duration_seconds,
|
||||||
|
tokens_used=episode_data.tokens_used,
|
||||||
|
lessons_learned=episode_data.lessons_learned,
|
||||||
|
importance_score=importance,
|
||||||
|
occurred_at=now,
|
||||||
|
created_at=now,
|
||||||
|
updated_at=now,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate embedding if generator available
|
||||||
|
if self._embedding_generator is not None:
|
||||||
|
try:
|
||||||
|
text_for_embedding = self._create_embedding_text(episode_data)
|
||||||
|
embedding = await self._embedding_generator.generate(text_for_embedding)
|
||||||
|
model.embedding = embedding
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to generate embedding: {e}")
|
||||||
|
|
||||||
|
self._session.add(model)
|
||||||
|
await self._session.flush()
|
||||||
|
await self._session.refresh(model)
|
||||||
|
|
||||||
|
logger.debug(f"Recorded episode {model.id} for task {model.task_type}")
|
||||||
|
return _model_to_episode(model)
|
||||||
|
|
||||||
|
def _calculate_importance(self, episode_data: EpisodeCreate) -> float:
|
||||||
|
"""
|
||||||
|
Calculate importance score for an episode.
|
||||||
|
|
||||||
|
Factors:
|
||||||
|
- Outcome: Failures are more important to learn from
|
||||||
|
- Duration: Longer tasks may be more significant
|
||||||
|
- Token usage: Higher usage may indicate complexity
|
||||||
|
- Lessons learned: Episodes with lessons are more valuable
|
||||||
|
"""
|
||||||
|
score = 0.5 # Base score
|
||||||
|
|
||||||
|
# Outcome factor
|
||||||
|
if episode_data.outcome == Outcome.FAILURE:
|
||||||
|
score += 0.2 # Failures are important for learning
|
||||||
|
elif episode_data.outcome == Outcome.PARTIAL:
|
||||||
|
score += 0.1
|
||||||
|
# Success is default, no adjustment
|
||||||
|
|
||||||
|
# Lessons learned factor
|
||||||
|
if episode_data.lessons_learned:
|
||||||
|
score += min(0.15, len(episode_data.lessons_learned) * 0.05)
|
||||||
|
|
||||||
|
# Duration factor (longer tasks may be more significant)
|
||||||
|
if episode_data.duration_seconds > 60:
|
||||||
|
score += 0.05
|
||||||
|
if episode_data.duration_seconds > 300:
|
||||||
|
score += 0.05
|
||||||
|
|
||||||
|
# Token usage factor (complex tasks)
|
||||||
|
if episode_data.tokens_used > 1000:
|
||||||
|
score += 0.05
|
||||||
|
|
||||||
|
# Clamp to valid range
|
||||||
|
return min(1.0, max(0.0, score))
|
||||||
|
|
||||||
|
def _create_embedding_text(self, episode_data: EpisodeCreate) -> str:
|
||||||
|
"""Create text representation for embedding generation."""
|
||||||
|
parts = [
|
||||||
|
f"Task: {episode_data.task_type}",
|
||||||
|
f"Description: {episode_data.task_description}",
|
||||||
|
f"Context: {episode_data.context_summary}",
|
||||||
|
f"Outcome: {episode_data.outcome.value}",
|
||||||
|
]
|
||||||
|
|
||||||
|
if episode_data.outcome_details:
|
||||||
|
parts.append(f"Details: {episode_data.outcome_details}")
|
||||||
|
|
||||||
|
if episode_data.lessons_learned:
|
||||||
|
parts.append(f"Lessons: {', '.join(episode_data.lessons_learned)}")
|
||||||
|
|
||||||
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
async def get_by_id(self, episode_id: UUID) -> Episode | None:
|
||||||
|
"""Get an episode by ID."""
|
||||||
|
query = select(EpisodeModel).where(EpisodeModel.id == episode_id)
|
||||||
|
result = await self._session.execute(query)
|
||||||
|
model = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if model is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return _model_to_episode(model)
|
||||||
|
|
||||||
|
async def update_importance(
|
||||||
|
self,
|
||||||
|
episode_id: UUID,
|
||||||
|
importance_score: float,
|
||||||
|
) -> Episode | None:
|
||||||
|
"""
|
||||||
|
Update the importance score of an episode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
episode_id: Episode to update
|
||||||
|
importance_score: New importance score (0.0 to 1.0)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated episode or None if not found
|
||||||
|
"""
|
||||||
|
# Validate score
|
||||||
|
importance_score = min(1.0, max(0.0, importance_score))
|
||||||
|
|
||||||
|
stmt = (
|
||||||
|
update(EpisodeModel)
|
||||||
|
.where(EpisodeModel.id == episode_id)
|
||||||
|
.values(
|
||||||
|
importance_score=importance_score,
|
||||||
|
updated_at=datetime.now(UTC),
|
||||||
|
)
|
||||||
|
.returning(EpisodeModel)
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await self._session.execute(stmt)
|
||||||
|
model = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if model is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
await self._session.flush()
|
||||||
|
return _model_to_episode(model)
|
||||||
|
|
||||||
|
async def add_lessons(
|
||||||
|
self,
|
||||||
|
episode_id: UUID,
|
||||||
|
lessons: list[str],
|
||||||
|
) -> Episode | None:
|
||||||
|
"""
|
||||||
|
Add lessons learned to an episode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
episode_id: Episode to update
|
||||||
|
lessons: New lessons to add
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated episode or None if not found
|
||||||
|
"""
|
||||||
|
# Get current episode
|
||||||
|
query = select(EpisodeModel).where(EpisodeModel.id == episode_id)
|
||||||
|
result = await self._session.execute(query)
|
||||||
|
model = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if model is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Append lessons
|
||||||
|
current_lessons: list[str] = model.lessons_learned or [] # type: ignore[assignment]
|
||||||
|
updated_lessons = current_lessons + lessons
|
||||||
|
|
||||||
|
stmt = (
|
||||||
|
update(EpisodeModel)
|
||||||
|
.where(EpisodeModel.id == episode_id)
|
||||||
|
.values(
|
||||||
|
lessons_learned=updated_lessons,
|
||||||
|
updated_at=datetime.now(UTC),
|
||||||
|
)
|
||||||
|
.returning(EpisodeModel)
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await self._session.execute(stmt)
|
||||||
|
model = result.scalar_one_or_none()
|
||||||
|
await self._session.flush()
|
||||||
|
|
||||||
|
return _model_to_episode(model) if model else None
|
||||||
|
|
||||||
|
async def delete(self, episode_id: UUID) -> bool:
|
||||||
|
"""
|
||||||
|
Delete an episode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
episode_id: Episode to delete
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if deleted
|
||||||
|
"""
|
||||||
|
query = select(EpisodeModel).where(EpisodeModel.id == episode_id)
|
||||||
|
result = await self._session.execute(query)
|
||||||
|
model = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if model is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
await self._session.delete(model)
|
||||||
|
await self._session.flush()
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def count_by_project(
|
||||||
|
self,
|
||||||
|
project_id: UUID,
|
||||||
|
since: datetime | None = None,
|
||||||
|
) -> int:
|
||||||
|
"""Count episodes for a project."""
|
||||||
|
query = select(EpisodeModel).where(EpisodeModel.project_id == project_id)
|
||||||
|
if since is not None:
|
||||||
|
query = query.where(EpisodeModel.occurred_at >= since)
|
||||||
|
|
||||||
|
result = await self._session.execute(query)
|
||||||
|
return len(list(result.scalars().all()))
|
||||||
|
|
||||||
|
async def get_stats(self, project_id: UUID) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get statistics for a project's episodes.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with episode statistics
|
||||||
|
"""
|
||||||
|
query = select(EpisodeModel).where(EpisodeModel.project_id == project_id)
|
||||||
|
result = await self._session.execute(query)
|
||||||
|
episodes = list(result.scalars().all())
|
||||||
|
|
||||||
|
if not episodes:
|
||||||
|
return {
|
||||||
|
"total_count": 0,
|
||||||
|
"success_count": 0,
|
||||||
|
"failure_count": 0,
|
||||||
|
"partial_count": 0,
|
||||||
|
"avg_importance": 0.0,
|
||||||
|
"avg_duration": 0.0,
|
||||||
|
"total_tokens": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
success_count = sum(1 for e in episodes if e.outcome == EpisodeOutcome.SUCCESS)
|
||||||
|
failure_count = sum(1 for e in episodes if e.outcome == EpisodeOutcome.FAILURE)
|
||||||
|
partial_count = sum(1 for e in episodes if e.outcome == EpisodeOutcome.PARTIAL)
|
||||||
|
|
||||||
|
avg_importance = sum(e.importance_score for e in episodes) / len(episodes)
|
||||||
|
avg_duration = sum(e.duration_seconds for e in episodes) / len(episodes)
|
||||||
|
total_tokens = sum(e.tokens_used for e in episodes)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_count": len(episodes),
|
||||||
|
"success_count": success_count,
|
||||||
|
"failure_count": failure_count,
|
||||||
|
"partial_count": partial_count,
|
||||||
|
"avg_importance": avg_importance,
|
||||||
|
"avg_duration": avg_duration,
|
||||||
|
"total_tokens": total_tokens,
|
||||||
|
}
|
||||||
503
backend/app/services/memory/episodic/retrieval.py
Normal file
503
backend/app/services/memory/episodic/retrieval.py
Normal file
@@ -0,0 +1,503 @@
|
|||||||
|
# app/services/memory/episodic/retrieval.py
|
||||||
|
"""
|
||||||
|
Episode Retrieval Strategies.
|
||||||
|
|
||||||
|
Provides different retrieval strategies for finding relevant episodes:
|
||||||
|
- Semantic similarity (vector search)
|
||||||
|
- Recency-based
|
||||||
|
- Outcome-based filtering
|
||||||
|
- Importance-based ranking
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from sqlalchemy import and_, desc, select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.models.memory.enums import EpisodeOutcome
|
||||||
|
from app.models.memory.episode import Episode as EpisodeModel
|
||||||
|
from app.services.memory.types import Episode, Outcome, RetrievalResult
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class RetrievalStrategy(str, Enum):
|
||||||
|
"""Retrieval strategy types."""
|
||||||
|
|
||||||
|
SEMANTIC = "semantic"
|
||||||
|
RECENCY = "recency"
|
||||||
|
OUTCOME = "outcome"
|
||||||
|
IMPORTANCE = "importance"
|
||||||
|
HYBRID = "hybrid"
|
||||||
|
|
||||||
|
|
||||||
|
def _model_to_episode(model: EpisodeModel) -> Episode:
|
||||||
|
"""Convert SQLAlchemy model to Episode dataclass."""
|
||||||
|
# SQLAlchemy Column types are inferred as Column[T] by mypy, but at runtime
|
||||||
|
# they return actual values. We use type: ignore to handle this mismatch.
|
||||||
|
return Episode(
|
||||||
|
id=model.id, # type: ignore[arg-type]
|
||||||
|
project_id=model.project_id, # type: ignore[arg-type]
|
||||||
|
agent_instance_id=model.agent_instance_id, # type: ignore[arg-type]
|
||||||
|
agent_type_id=model.agent_type_id, # type: ignore[arg-type]
|
||||||
|
session_id=model.session_id, # type: ignore[arg-type]
|
||||||
|
task_type=model.task_type, # type: ignore[arg-type]
|
||||||
|
task_description=model.task_description, # type: ignore[arg-type]
|
||||||
|
actions=model.actions or [], # type: ignore[arg-type]
|
||||||
|
context_summary=model.context_summary, # type: ignore[arg-type]
|
||||||
|
outcome=Outcome(model.outcome.value),
|
||||||
|
outcome_details=model.outcome_details or "", # type: ignore[arg-type]
|
||||||
|
duration_seconds=model.duration_seconds, # type: ignore[arg-type]
|
||||||
|
tokens_used=model.tokens_used, # type: ignore[arg-type]
|
||||||
|
lessons_learned=model.lessons_learned or [], # type: ignore[arg-type]
|
||||||
|
importance_score=model.importance_score, # type: ignore[arg-type]
|
||||||
|
embedding=None, # Don't expose raw embedding
|
||||||
|
occurred_at=model.occurred_at, # type: ignore[arg-type]
|
||||||
|
created_at=model.created_at, # type: ignore[arg-type]
|
||||||
|
updated_at=model.updated_at, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseRetriever(ABC):
|
||||||
|
"""Abstract base class for episode retrieval strategies."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def retrieve(
|
||||||
|
self,
|
||||||
|
session: AsyncSession,
|
||||||
|
project_id: UUID,
|
||||||
|
limit: int = 10,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> RetrievalResult[Episode]:
|
||||||
|
"""Retrieve episodes based on the strategy."""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class RecencyRetriever(BaseRetriever):
|
||||||
|
"""Retrieves episodes by recency (most recent first)."""
|
||||||
|
|
||||||
|
async def retrieve(
|
||||||
|
self,
|
||||||
|
session: AsyncSession,
|
||||||
|
project_id: UUID,
|
||||||
|
limit: int = 10,
|
||||||
|
*,
|
||||||
|
since: datetime | None = None,
|
||||||
|
agent_instance_id: UUID | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> RetrievalResult[Episode]:
|
||||||
|
"""Retrieve most recent episodes."""
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
|
query = (
|
||||||
|
select(EpisodeModel)
|
||||||
|
.where(EpisodeModel.project_id == project_id)
|
||||||
|
.order_by(desc(EpisodeModel.occurred_at))
|
||||||
|
.limit(limit)
|
||||||
|
)
|
||||||
|
|
||||||
|
if since is not None:
|
||||||
|
query = query.where(EpisodeModel.occurred_at >= since)
|
||||||
|
|
||||||
|
if agent_instance_id is not None:
|
||||||
|
query = query.where(EpisodeModel.agent_instance_id == agent_instance_id)
|
||||||
|
|
||||||
|
result = await session.execute(query)
|
||||||
|
models = list(result.scalars().all())
|
||||||
|
|
||||||
|
# Get total count
|
||||||
|
count_query = select(EpisodeModel).where(EpisodeModel.project_id == project_id)
|
||||||
|
if since is not None:
|
||||||
|
count_query = count_query.where(EpisodeModel.occurred_at >= since)
|
||||||
|
count_result = await session.execute(count_query)
|
||||||
|
total_count = len(list(count_result.scalars().all()))
|
||||||
|
|
||||||
|
latency_ms = (time.perf_counter() - start_time) * 1000
|
||||||
|
|
||||||
|
return RetrievalResult(
|
||||||
|
items=[_model_to_episode(m) for m in models],
|
||||||
|
total_count=total_count,
|
||||||
|
query="recency",
|
||||||
|
retrieval_type=RetrievalStrategy.RECENCY.value,
|
||||||
|
latency_ms=latency_ms,
|
||||||
|
metadata={"since": since.isoformat() if since else None},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OutcomeRetriever(BaseRetriever):
|
||||||
|
"""Retrieves episodes filtered by outcome."""
|
||||||
|
|
||||||
|
async def retrieve(
|
||||||
|
self,
|
||||||
|
session: AsyncSession,
|
||||||
|
project_id: UUID,
|
||||||
|
limit: int = 10,
|
||||||
|
*,
|
||||||
|
outcome: Outcome | None = None,
|
||||||
|
agent_instance_id: UUID | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> RetrievalResult[Episode]:
|
||||||
|
"""Retrieve episodes by outcome."""
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
|
query = (
|
||||||
|
select(EpisodeModel)
|
||||||
|
.where(EpisodeModel.project_id == project_id)
|
||||||
|
.order_by(desc(EpisodeModel.occurred_at))
|
||||||
|
.limit(limit)
|
||||||
|
)
|
||||||
|
|
||||||
|
if outcome is not None:
|
||||||
|
db_outcome = EpisodeOutcome(outcome.value)
|
||||||
|
query = query.where(EpisodeModel.outcome == db_outcome)
|
||||||
|
|
||||||
|
if agent_instance_id is not None:
|
||||||
|
query = query.where(EpisodeModel.agent_instance_id == agent_instance_id)
|
||||||
|
|
||||||
|
result = await session.execute(query)
|
||||||
|
models = list(result.scalars().all())
|
||||||
|
|
||||||
|
# Get total count
|
||||||
|
count_query = select(EpisodeModel).where(EpisodeModel.project_id == project_id)
|
||||||
|
if outcome is not None:
|
||||||
|
count_query = count_query.where(
|
||||||
|
EpisodeModel.outcome == EpisodeOutcome(outcome.value)
|
||||||
|
)
|
||||||
|
count_result = await session.execute(count_query)
|
||||||
|
total_count = len(list(count_result.scalars().all()))
|
||||||
|
|
||||||
|
latency_ms = (time.perf_counter() - start_time) * 1000
|
||||||
|
|
||||||
|
return RetrievalResult(
|
||||||
|
items=[_model_to_episode(m) for m in models],
|
||||||
|
total_count=total_count,
|
||||||
|
query=f"outcome:{outcome.value if outcome else 'all'}",
|
||||||
|
retrieval_type=RetrievalStrategy.OUTCOME.value,
|
||||||
|
latency_ms=latency_ms,
|
||||||
|
metadata={"outcome": outcome.value if outcome else None},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TaskTypeRetriever(BaseRetriever):
|
||||||
|
"""Retrieves episodes filtered by task type."""
|
||||||
|
|
||||||
|
async def retrieve(
|
||||||
|
self,
|
||||||
|
session: AsyncSession,
|
||||||
|
project_id: UUID,
|
||||||
|
limit: int = 10,
|
||||||
|
*,
|
||||||
|
task_type: str | None = None,
|
||||||
|
agent_instance_id: UUID | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> RetrievalResult[Episode]:
|
||||||
|
"""Retrieve episodes by task type."""
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
|
query = (
|
||||||
|
select(EpisodeModel)
|
||||||
|
.where(EpisodeModel.project_id == project_id)
|
||||||
|
.order_by(desc(EpisodeModel.occurred_at))
|
||||||
|
.limit(limit)
|
||||||
|
)
|
||||||
|
|
||||||
|
if task_type is not None:
|
||||||
|
query = query.where(EpisodeModel.task_type == task_type)
|
||||||
|
|
||||||
|
if agent_instance_id is not None:
|
||||||
|
query = query.where(EpisodeModel.agent_instance_id == agent_instance_id)
|
||||||
|
|
||||||
|
result = await session.execute(query)
|
||||||
|
models = list(result.scalars().all())
|
||||||
|
|
||||||
|
# Get total count
|
||||||
|
count_query = select(EpisodeModel).where(EpisodeModel.project_id == project_id)
|
||||||
|
if task_type is not None:
|
||||||
|
count_query = count_query.where(EpisodeModel.task_type == task_type)
|
||||||
|
count_result = await session.execute(count_query)
|
||||||
|
total_count = len(list(count_result.scalars().all()))
|
||||||
|
|
||||||
|
latency_ms = (time.perf_counter() - start_time) * 1000
|
||||||
|
|
||||||
|
return RetrievalResult(
|
||||||
|
items=[_model_to_episode(m) for m in models],
|
||||||
|
total_count=total_count,
|
||||||
|
query=f"task_type:{task_type or 'all'}",
|
||||||
|
retrieval_type="task_type",
|
||||||
|
latency_ms=latency_ms,
|
||||||
|
metadata={"task_type": task_type},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ImportanceRetriever(BaseRetriever):
|
||||||
|
"""Retrieves episodes ranked by importance score."""
|
||||||
|
|
||||||
|
async def retrieve(
|
||||||
|
self,
|
||||||
|
session: AsyncSession,
|
||||||
|
project_id: UUID,
|
||||||
|
limit: int = 10,
|
||||||
|
*,
|
||||||
|
min_importance: float = 0.0,
|
||||||
|
agent_instance_id: UUID | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> RetrievalResult[Episode]:
|
||||||
|
"""Retrieve episodes by importance."""
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
|
query = (
|
||||||
|
select(EpisodeModel)
|
||||||
|
.where(
|
||||||
|
and_(
|
||||||
|
EpisodeModel.project_id == project_id,
|
||||||
|
EpisodeModel.importance_score >= min_importance,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.order_by(desc(EpisodeModel.importance_score))
|
||||||
|
.limit(limit)
|
||||||
|
)
|
||||||
|
|
||||||
|
if agent_instance_id is not None:
|
||||||
|
query = query.where(EpisodeModel.agent_instance_id == agent_instance_id)
|
||||||
|
|
||||||
|
result = await session.execute(query)
|
||||||
|
models = list(result.scalars().all())
|
||||||
|
|
||||||
|
# Get total count
|
||||||
|
count_query = select(EpisodeModel).where(
|
||||||
|
and_(
|
||||||
|
EpisodeModel.project_id == project_id,
|
||||||
|
EpisodeModel.importance_score >= min_importance,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
count_result = await session.execute(count_query)
|
||||||
|
total_count = len(list(count_result.scalars().all()))
|
||||||
|
|
||||||
|
latency_ms = (time.perf_counter() - start_time) * 1000
|
||||||
|
|
||||||
|
return RetrievalResult(
|
||||||
|
items=[_model_to_episode(m) for m in models],
|
||||||
|
total_count=total_count,
|
||||||
|
query=f"importance>={min_importance}",
|
||||||
|
retrieval_type=RetrievalStrategy.IMPORTANCE.value,
|
||||||
|
latency_ms=latency_ms,
|
||||||
|
metadata={"min_importance": min_importance},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SemanticRetriever(BaseRetriever):
|
||||||
|
"""Retrieves episodes by semantic similarity using vector search."""
|
||||||
|
|
||||||
|
def __init__(self, embedding_generator: Any | None = None) -> None:
|
||||||
|
"""Initialize with optional embedding generator."""
|
||||||
|
self._embedding_generator = embedding_generator
|
||||||
|
|
||||||
|
async def retrieve(
|
||||||
|
self,
|
||||||
|
session: AsyncSession,
|
||||||
|
project_id: UUID,
|
||||||
|
limit: int = 10,
|
||||||
|
*,
|
||||||
|
query_text: str | None = None,
|
||||||
|
query_embedding: list[float] | None = None,
|
||||||
|
agent_instance_id: UUID | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> RetrievalResult[Episode]:
|
||||||
|
"""Retrieve episodes by semantic similarity."""
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
|
# If no embedding provided, fall back to recency
|
||||||
|
if query_embedding is None and query_text is None:
|
||||||
|
logger.warning(
|
||||||
|
"No query provided for semantic search, falling back to recency"
|
||||||
|
)
|
||||||
|
recency = RecencyRetriever()
|
||||||
|
fallback_result = await recency.retrieve(
|
||||||
|
session, project_id, limit, agent_instance_id=agent_instance_id
|
||||||
|
)
|
||||||
|
latency_ms = (time.perf_counter() - start_time) * 1000
|
||||||
|
return RetrievalResult(
|
||||||
|
items=fallback_result.items,
|
||||||
|
total_count=fallback_result.total_count,
|
||||||
|
query="no_query",
|
||||||
|
retrieval_type=RetrievalStrategy.SEMANTIC.value,
|
||||||
|
latency_ms=latency_ms,
|
||||||
|
metadata={"fallback": "recency", "reason": "no_query"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate embedding if needed
|
||||||
|
embedding = query_embedding
|
||||||
|
if embedding is None and query_text is not None:
|
||||||
|
if self._embedding_generator is not None:
|
||||||
|
embedding = await self._embedding_generator.generate(query_text)
|
||||||
|
else:
|
||||||
|
logger.warning("No embedding generator, falling back to recency")
|
||||||
|
recency = RecencyRetriever()
|
||||||
|
fallback_result = await recency.retrieve(
|
||||||
|
session, project_id, limit, agent_instance_id=agent_instance_id
|
||||||
|
)
|
||||||
|
latency_ms = (time.perf_counter() - start_time) * 1000
|
||||||
|
return RetrievalResult(
|
||||||
|
items=fallback_result.items,
|
||||||
|
total_count=fallback_result.total_count,
|
||||||
|
query=query_text,
|
||||||
|
retrieval_type=RetrievalStrategy.SEMANTIC.value,
|
||||||
|
latency_ms=latency_ms,
|
||||||
|
metadata={
|
||||||
|
"fallback": "recency",
|
||||||
|
"reason": "no_embedding_generator",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# For now, use recency if vector search not available
|
||||||
|
# TODO: Implement proper pgvector similarity search when integrated
|
||||||
|
logger.debug("Vector search not yet implemented, using recency fallback")
|
||||||
|
recency = RecencyRetriever()
|
||||||
|
result = await recency.retrieve(
|
||||||
|
session, project_id, limit, agent_instance_id=agent_instance_id
|
||||||
|
)
|
||||||
|
|
||||||
|
latency_ms = (time.perf_counter() - start_time) * 1000
|
||||||
|
|
||||||
|
return RetrievalResult(
|
||||||
|
items=result.items,
|
||||||
|
total_count=result.total_count,
|
||||||
|
query=query_text or "embedding",
|
||||||
|
retrieval_type=RetrievalStrategy.SEMANTIC.value,
|
||||||
|
latency_ms=latency_ms,
|
||||||
|
metadata={"fallback": "recency"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class EpisodeRetriever:
|
||||||
|
"""
|
||||||
|
Unified episode retrieval service.
|
||||||
|
|
||||||
|
Provides a single interface for all retrieval strategies.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
session: AsyncSession,
|
||||||
|
embedding_generator: Any | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize retriever with database session."""
|
||||||
|
self._session = session
|
||||||
|
self._retrievers: dict[RetrievalStrategy, BaseRetriever] = {
|
||||||
|
RetrievalStrategy.RECENCY: RecencyRetriever(),
|
||||||
|
RetrievalStrategy.OUTCOME: OutcomeRetriever(),
|
||||||
|
RetrievalStrategy.IMPORTANCE: ImportanceRetriever(),
|
||||||
|
RetrievalStrategy.SEMANTIC: SemanticRetriever(embedding_generator),
|
||||||
|
}
|
||||||
|
|
||||||
|
async def retrieve(
|
||||||
|
self,
|
||||||
|
project_id: UUID,
|
||||||
|
strategy: RetrievalStrategy = RetrievalStrategy.RECENCY,
|
||||||
|
limit: int = 10,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> RetrievalResult[Episode]:
|
||||||
|
"""
|
||||||
|
Retrieve episodes using the specified strategy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: Project to search within
|
||||||
|
strategy: Retrieval strategy to use
|
||||||
|
limit: Maximum number of episodes to return
|
||||||
|
**kwargs: Strategy-specific parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RetrievalResult containing matching episodes
|
||||||
|
"""
|
||||||
|
retriever = self._retrievers.get(strategy)
|
||||||
|
if retriever is None:
|
||||||
|
raise ValueError(f"Unknown retrieval strategy: {strategy}")
|
||||||
|
|
||||||
|
return await retriever.retrieve(self._session, project_id, limit, **kwargs)
|
||||||
|
|
||||||
|
async def get_recent(
|
||||||
|
self,
|
||||||
|
project_id: UUID,
|
||||||
|
limit: int = 10,
|
||||||
|
since: datetime | None = None,
|
||||||
|
agent_instance_id: UUID | None = None,
|
||||||
|
) -> RetrievalResult[Episode]:
|
||||||
|
"""Get recent episodes."""
|
||||||
|
return await self.retrieve(
|
||||||
|
project_id,
|
||||||
|
RetrievalStrategy.RECENCY,
|
||||||
|
limit,
|
||||||
|
since=since,
|
||||||
|
agent_instance_id=agent_instance_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_by_outcome(
|
||||||
|
self,
|
||||||
|
project_id: UUID,
|
||||||
|
outcome: Outcome,
|
||||||
|
limit: int = 10,
|
||||||
|
agent_instance_id: UUID | None = None,
|
||||||
|
) -> RetrievalResult[Episode]:
|
||||||
|
"""Get episodes by outcome."""
|
||||||
|
return await self.retrieve(
|
||||||
|
project_id,
|
||||||
|
RetrievalStrategy.OUTCOME,
|
||||||
|
limit,
|
||||||
|
outcome=outcome,
|
||||||
|
agent_instance_id=agent_instance_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_by_task_type(
|
||||||
|
self,
|
||||||
|
project_id: UUID,
|
||||||
|
task_type: str,
|
||||||
|
limit: int = 10,
|
||||||
|
agent_instance_id: UUID | None = None,
|
||||||
|
) -> RetrievalResult[Episode]:
|
||||||
|
"""Get episodes by task type."""
|
||||||
|
retriever = TaskTypeRetriever()
|
||||||
|
return await retriever.retrieve(
|
||||||
|
self._session,
|
||||||
|
project_id,
|
||||||
|
limit,
|
||||||
|
task_type=task_type,
|
||||||
|
agent_instance_id=agent_instance_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_important(
|
||||||
|
self,
|
||||||
|
project_id: UUID,
|
||||||
|
limit: int = 10,
|
||||||
|
min_importance: float = 0.7,
|
||||||
|
agent_instance_id: UUID | None = None,
|
||||||
|
) -> RetrievalResult[Episode]:
|
||||||
|
"""Get high-importance episodes."""
|
||||||
|
return await self.retrieve(
|
||||||
|
project_id,
|
||||||
|
RetrievalStrategy.IMPORTANCE,
|
||||||
|
limit,
|
||||||
|
min_importance=min_importance,
|
||||||
|
agent_instance_id=agent_instance_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def search_similar(
|
||||||
|
self,
|
||||||
|
project_id: UUID,
|
||||||
|
query: str,
|
||||||
|
limit: int = 10,
|
||||||
|
agent_instance_id: UUID | None = None,
|
||||||
|
) -> RetrievalResult[Episode]:
|
||||||
|
"""Search for semantically similar episodes."""
|
||||||
|
return await self.retrieve(
|
||||||
|
project_id,
|
||||||
|
RetrievalStrategy.SEMANTIC,
|
||||||
|
limit,
|
||||||
|
query_text=query,
|
||||||
|
agent_instance_id=agent_instance_id,
|
||||||
|
)
|
||||||
222
backend/app/services/memory/exceptions.py
Normal file
222
backend/app/services/memory/exceptions.py
Normal file
@@ -0,0 +1,222 @@
|
|||||||
|
"""
|
||||||
|
Memory System Exceptions
|
||||||
|
|
||||||
|
Custom exception classes for the Agent Memory System.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryError(Exception):
|
||||||
|
"""Base exception for all memory-related errors."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
*,
|
||||||
|
memory_type: str | None = None,
|
||||||
|
scope_type: str | None = None,
|
||||||
|
scope_id: str | None = None,
|
||||||
|
details: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(message)
|
||||||
|
self.message = message
|
||||||
|
self.memory_type = memory_type
|
||||||
|
self.scope_type = scope_type
|
||||||
|
self.scope_id = scope_id
|
||||||
|
self.details = details or {}
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryNotFoundError(MemoryError):
|
||||||
|
"""Raised when a memory item is not found."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str = "Memory not found",
|
||||||
|
*,
|
||||||
|
memory_id: UUID | str | None = None,
|
||||||
|
key: str | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(message, **kwargs)
|
||||||
|
self.memory_id = memory_id
|
||||||
|
self.key = key
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryCapacityError(MemoryError):
|
||||||
|
"""Raised when memory capacity limits are exceeded."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str = "Memory capacity exceeded",
|
||||||
|
*,
|
||||||
|
current_size: int = 0,
|
||||||
|
max_size: int = 0,
|
||||||
|
item_count: int = 0,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(message, **kwargs)
|
||||||
|
self.current_size = current_size
|
||||||
|
self.max_size = max_size
|
||||||
|
self.item_count = item_count
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryExpiredError(MemoryError):
|
||||||
|
"""Raised when attempting to access expired memory."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str = "Memory has expired",
|
||||||
|
*,
|
||||||
|
key: str | None = None,
|
||||||
|
expired_at: str | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(message, **kwargs)
|
||||||
|
self.key = key
|
||||||
|
self.expired_at = expired_at
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryStorageError(MemoryError):
|
||||||
|
"""Raised when memory storage operations fail."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str = "Memory storage operation failed",
|
||||||
|
*,
|
||||||
|
operation: str | None = None,
|
||||||
|
backend: str | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(message, **kwargs)
|
||||||
|
self.operation = operation
|
||||||
|
self.backend = backend
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryConnectionError(MemoryError):
|
||||||
|
"""Raised when memory storage connection fails."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str = "Memory connection failed",
|
||||||
|
*,
|
||||||
|
backend: str | None = None,
|
||||||
|
host: str | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(message, **kwargs)
|
||||||
|
self.backend = backend
|
||||||
|
self.host = host
|
||||||
|
|
||||||
|
|
||||||
|
class MemorySerializationError(MemoryError):
|
||||||
|
"""Raised when memory serialization/deserialization fails."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str = "Memory serialization failed",
|
||||||
|
*,
|
||||||
|
content_type: str | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(message, **kwargs)
|
||||||
|
self.content_type = content_type
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryScopeError(MemoryError):
|
||||||
|
"""Raised when memory scope operations fail."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str = "Memory scope error",
|
||||||
|
*,
|
||||||
|
requested_scope: str | None = None,
|
||||||
|
allowed_scopes: list[str] | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(message, **kwargs)
|
||||||
|
self.requested_scope = requested_scope
|
||||||
|
self.allowed_scopes = allowed_scopes or []
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryConsolidationError(MemoryError):
|
||||||
|
"""Raised when memory consolidation fails."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str = "Memory consolidation failed",
|
||||||
|
*,
|
||||||
|
source_type: str | None = None,
|
||||||
|
target_type: str | None = None,
|
||||||
|
items_processed: int = 0,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(message, **kwargs)
|
||||||
|
self.source_type = source_type
|
||||||
|
self.target_type = target_type
|
||||||
|
self.items_processed = items_processed
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryRetrievalError(MemoryError):
|
||||||
|
"""Raised when memory retrieval fails."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str = "Memory retrieval failed",
|
||||||
|
*,
|
||||||
|
query: str | None = None,
|
||||||
|
retrieval_type: str | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(message, **kwargs)
|
||||||
|
self.query = query
|
||||||
|
self.retrieval_type = retrieval_type
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingError(MemoryError):
|
||||||
|
"""Raised when embedding generation fails."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str = "Embedding generation failed",
|
||||||
|
*,
|
||||||
|
content_length: int = 0,
|
||||||
|
model: str | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(message, **kwargs)
|
||||||
|
self.content_length = content_length
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
|
||||||
|
class CheckpointError(MemoryError):
|
||||||
|
"""Raised when checkpoint operations fail."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str = "Checkpoint operation failed",
|
||||||
|
*,
|
||||||
|
checkpoint_id: str | None = None,
|
||||||
|
operation: str | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(message, **kwargs)
|
||||||
|
self.checkpoint_id = checkpoint_id
|
||||||
|
self.operation = operation
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryConflictError(MemoryError):
|
||||||
|
"""Raised when there's a conflict in memory (e.g., contradictory facts)."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str = "Memory conflict detected",
|
||||||
|
*,
|
||||||
|
conflicting_ids: list[str | UUID] | None = None,
|
||||||
|
conflict_type: str | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(message, **kwargs)
|
||||||
|
self.conflicting_ids = conflicting_ids or []
|
||||||
|
self.conflict_type = conflict_type
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user