forked from cardosofelipe/pragma-stack
Compare commits
108 Commits
664415111a
...
dev
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4ad3d20cf2 | ||
|
|
8623eb56f5 | ||
|
|
3cb6c8d13b | ||
|
|
8e16e2645e | ||
|
|
82c3a6ba47 | ||
|
|
b6c38cac88 | ||
|
|
51404216ae | ||
|
|
3f23bc3db3 | ||
|
|
a0ec5fa2cc | ||
|
|
f262d08be2 | ||
|
|
b3f371e0a3 | ||
|
|
93cc37224c | ||
|
|
5717bffd63 | ||
|
|
9339ea30a1 | ||
|
|
79cb6bfd7b | ||
|
|
45025bb2f1 | ||
|
|
3c6b14d2bf | ||
|
|
6b21a6fadd | ||
|
|
600657adc4 | ||
|
|
c9d0d079b3 | ||
|
|
4c8f81368c | ||
|
|
efbe91ce14 | ||
|
|
5d646779c9 | ||
|
|
5a4d93df26 | ||
|
|
7ef217be39 | ||
|
|
20159c5865 | ||
|
|
f9a72fcb34 | ||
|
|
fcb0a5f86a | ||
|
|
92782bcb05 | ||
|
|
1dcf99ee38 | ||
|
|
70009676a3 | ||
|
|
192237e69b | ||
|
|
3edce9cd26 | ||
|
|
35aea2d73a | ||
|
|
d0f32d04f7 | ||
|
|
da85a8aba8 | ||
|
|
f8bd1011e9 | ||
|
|
f057c2f0b6 | ||
|
|
33ec889fc4 | ||
|
|
74b8c65741 | ||
|
|
b232298c61 | ||
|
|
cf6291ac8e | ||
|
|
e3fe0439fd | ||
|
|
57680c3772 | ||
|
|
997cfaa03a | ||
|
|
6954774e36 | ||
|
|
30e5c68304 | ||
|
|
0b24d4c6cc | ||
|
|
1670e05e0d | ||
|
|
999b7ac03f | ||
|
|
48ecb40f18 | ||
|
|
b818f17418 | ||
|
|
e946787a61 | ||
|
|
3554efe66a | ||
|
|
bd988f76b0 | ||
|
|
4974233169 | ||
|
|
c9d8c0835c | ||
|
|
085a748929 | ||
|
|
4b149b8a52 | ||
|
|
ad0c06851d | ||
|
|
49359b1416 | ||
|
|
911d950c15 | ||
|
|
b2a3ac60e0 | ||
|
|
dea092e1bb | ||
|
|
4154dd5268 | ||
|
|
db12937495 | ||
|
|
81e1456631 | ||
|
|
58e78d8700 | ||
|
|
5e80139afa | ||
|
|
60ebeaa582 | ||
|
|
758052dcff | ||
|
|
1628eacf2b | ||
|
|
2bea057fb1 | ||
|
|
9e54f16e56 | ||
|
|
96e6400bd8 | ||
|
|
6c7b72f130 | ||
|
|
027ebfc332 | ||
|
|
c2466ab401 | ||
|
|
7828d35e06 | ||
|
|
6b07e62f00 | ||
|
|
0d2005ddcb | ||
|
|
dfa75e682e | ||
|
|
22ecb5e989 | ||
|
|
2ab69f8561 | ||
|
|
95342cc94d | ||
|
|
f6194b3e19 | ||
|
|
6bb376a336 | ||
|
|
cd7a9ccbdf | ||
|
|
953af52d0e | ||
|
|
e6e98d4ed1 | ||
|
|
ca5f5e3383 | ||
|
|
d0fc7f37ff | ||
|
|
18d717e996 | ||
|
|
f482559e15 | ||
|
|
6e8b0b022a | ||
|
|
746fb7b181 | ||
|
|
caf283bed2 | ||
|
|
520c06175e | ||
|
|
065e43c5a9 | ||
|
|
c8b88dadc3 | ||
|
|
015f2de6c6 | ||
|
|
f36bfb3781 | ||
|
|
ef659cd72d | ||
|
|
728edd1453 | ||
|
|
498c0a0e94 | ||
|
|
e5975fa5d0 | ||
|
|
731a188a76 | ||
|
|
fe2104822e |
61
.githooks/pre-commit
Executable file
61
.githooks/pre-commit
Executable file
@@ -0,0 +1,61 @@
|
||||
#!/bin/bash
|
||||
# Pre-commit hook to enforce validation before commits on protected branches
|
||||
# Install: git config core.hooksPath .githooks
|
||||
|
||||
set -e
|
||||
|
||||
# Get the current branch name
|
||||
BRANCH=$(git rev-parse --abbrev-ref HEAD)
|
||||
|
||||
# Protected branches that require validation
|
||||
PROTECTED_BRANCHES="main dev"
|
||||
|
||||
# Check if we're on a protected branch
|
||||
is_protected() {
|
||||
for branch in $PROTECTED_BRANCHES; do
|
||||
if [ "$BRANCH" = "$branch" ]; then
|
||||
return 0
|
||||
fi
|
||||
done
|
||||
return 1
|
||||
}
|
||||
|
||||
if is_protected; then
|
||||
echo "🔒 Committing to protected branch '$BRANCH' - running validation..."
|
||||
|
||||
# Check if we have backend changes
|
||||
if git diff --cached --name-only | grep -q "^backend/"; then
|
||||
echo "📦 Backend changes detected - running make validate..."
|
||||
cd backend
|
||||
if ! make validate; then
|
||||
echo ""
|
||||
echo "❌ Backend validation failed!"
|
||||
echo " Please fix the issues and try again."
|
||||
echo " Run 'cd backend && make validate' to see errors."
|
||||
exit 1
|
||||
fi
|
||||
cd ..
|
||||
echo "✅ Backend validation passed!"
|
||||
fi
|
||||
|
||||
# Check if we have frontend changes
|
||||
if git diff --cached --name-only | grep -q "^frontend/"; then
|
||||
echo "🎨 Frontend changes detected - running npm run validate..."
|
||||
cd frontend
|
||||
if ! npm run validate 2>/dev/null; then
|
||||
echo ""
|
||||
echo "❌ Frontend validation failed!"
|
||||
echo " Please fix the issues and try again."
|
||||
echo " Run 'cd frontend && npm run validate' to see errors."
|
||||
exit 1
|
||||
fi
|
||||
cd ..
|
||||
echo "✅ Frontend validation passed!"
|
||||
fi
|
||||
|
||||
echo "🎉 All validations passed! Proceeding with commit..."
|
||||
else
|
||||
echo "📝 Committing to feature branch '$BRANCH' - skipping validation (run manually if needed)"
|
||||
fi
|
||||
|
||||
exit 0
|
||||
31
CLAUDE.md
31
CLAUDE.md
@@ -83,6 +83,37 @@ docs/
|
||||
3. **Testing Required**: All code must be tested, aim for >90% coverage
|
||||
4. **Code Review**: Must pass multi-agent review before merge
|
||||
5. **No Direct Commits**: Never commit directly to `main` or `dev`
|
||||
6. **Stack Verification**: ALWAYS run the full stack before considering work done (see below)
|
||||
|
||||
### CRITICAL: Stack Verification Before Merge
|
||||
|
||||
**This is NON-NEGOTIABLE. A feature with 100% test coverage that crashes on startup is WORTHLESS.**
|
||||
|
||||
Before considering ANY issue complete:
|
||||
|
||||
```bash
|
||||
# 1. Start the dev stack
|
||||
make dev
|
||||
|
||||
# 2. Wait for backend to be healthy, check logs
|
||||
docker compose -f docker-compose.dev.yml logs backend --tail=100
|
||||
|
||||
# 3. Start frontend
|
||||
cd frontend && npm run dev
|
||||
|
||||
# 4. Verify both are running without errors
|
||||
```
|
||||
|
||||
**The issue is NOT done if:**
|
||||
- Backend crashes on startup (import errors, missing dependencies)
|
||||
- Frontend fails to compile or render
|
||||
- Health checks fail
|
||||
- Any error appears in logs
|
||||
|
||||
**Why this matters:**
|
||||
- Tests run in isolation and may pass despite broken imports
|
||||
- Docker builds cache layers and may hide dependency issues
|
||||
- A single `ModuleNotFoundError` renders all test coverage meaningless
|
||||
|
||||
### Common Commands
|
||||
|
||||
|
||||
110
Makefile
110
Makefile
@@ -1,18 +1,34 @@
|
||||
.PHONY: help dev dev-full prod down logs logs-dev clean clean-slate drop-db reset-db push-images deploy
|
||||
.PHONY: test test-backend test-mcp test-frontend test-all test-cov test-integration validate validate-all format-all
|
||||
|
||||
VERSION ?= latest
|
||||
REGISTRY ?= ghcr.io/cardosofelipe/pragma-stack
|
||||
|
||||
# Default target
|
||||
help:
|
||||
@echo "FastAPI + Next.js Full-Stack Template"
|
||||
@echo "Syndarix - AI-Powered Software Consulting Agency"
|
||||
@echo ""
|
||||
@echo "Development:"
|
||||
@echo " make dev - Start backend + db (frontend runs separately)"
|
||||
@echo " make dev - Start backend + db + MCP servers (frontend runs separately)"
|
||||
@echo " make dev-full - Start all services including frontend"
|
||||
@echo " make down - Stop all services"
|
||||
@echo " make logs-dev - Follow dev container logs"
|
||||
@echo ""
|
||||
@echo "Testing:"
|
||||
@echo " make test - Run all tests (backend + MCP servers)"
|
||||
@echo " make test-backend - Run backend tests only"
|
||||
@echo " make test-mcp - Run MCP server tests only"
|
||||
@echo " make test-frontend - Run frontend tests only"
|
||||
@echo " make test-cov - Run all tests with coverage reports"
|
||||
@echo " make test-integration - Run MCP integration tests (requires running stack)"
|
||||
@echo ""
|
||||
@echo "Formatting:"
|
||||
@echo " make format-all - Format code in backend + MCP servers + frontend"
|
||||
@echo ""
|
||||
@echo "Validation:"
|
||||
@echo " make validate - Validate backend + MCP servers (lint, type-check, test)"
|
||||
@echo " make validate-all - Validate everything including frontend"
|
||||
@echo ""
|
||||
@echo "Database:"
|
||||
@echo " make drop-db - Drop and recreate empty database"
|
||||
@echo " make reset-db - Drop database and apply all migrations"
|
||||
@@ -29,6 +45,8 @@ help:
|
||||
@echo ""
|
||||
@echo "Subdirectory commands:"
|
||||
@echo " cd backend && make help - Backend-specific commands"
|
||||
@echo " cd mcp-servers/llm-gateway && make - LLM Gateway commands"
|
||||
@echo " cd mcp-servers/knowledge-base && make - Knowledge Base commands"
|
||||
@echo " cd frontend && npm run - Frontend-specific commands"
|
||||
|
||||
# ============================================================================
|
||||
@@ -99,3 +117,91 @@ clean:
|
||||
# WARNING! THIS REMOVES CONTAINERS AND VOLUMES AS WELL - DO NOT USE THIS UNLESS YOU WANT TO START OVER WITH DATA AND ALL
|
||||
clean-slate:
|
||||
docker compose -f docker-compose.dev.yml down -v --remove-orphans
|
||||
|
||||
# ============================================================================
|
||||
# Testing
|
||||
# ============================================================================
|
||||
|
||||
test: test-backend test-mcp
|
||||
@echo ""
|
||||
@echo "All tests passed!"
|
||||
|
||||
test-backend:
|
||||
@echo "Running backend tests..."
|
||||
@cd backend && IS_TEST=True uv run pytest tests/ -v
|
||||
|
||||
test-mcp:
|
||||
@echo "Running MCP server tests..."
|
||||
@echo ""
|
||||
@echo "=== LLM Gateway ==="
|
||||
@cd mcp-servers/llm-gateway && uv run pytest tests/ -v
|
||||
@echo ""
|
||||
@echo "=== Knowledge Base ==="
|
||||
@cd mcp-servers/knowledge-base && uv run pytest tests/ -v
|
||||
|
||||
test-frontend:
|
||||
@echo "Running frontend tests..."
|
||||
@cd frontend && npm test
|
||||
|
||||
test-all: test test-frontend
|
||||
@echo ""
|
||||
@echo "All tests (backend + MCP + frontend) passed!"
|
||||
|
||||
test-cov:
|
||||
@echo "Running all tests with coverage..."
|
||||
@echo ""
|
||||
@echo "=== Backend Coverage ==="
|
||||
@cd backend && IS_TEST=True uv run pytest tests/ -v --cov=app --cov-report=term-missing
|
||||
@echo ""
|
||||
@echo "=== LLM Gateway Coverage ==="
|
||||
@cd mcp-servers/llm-gateway && uv run pytest tests/ -v --cov=. --cov-report=term-missing
|
||||
@echo ""
|
||||
@echo "=== Knowledge Base Coverage ==="
|
||||
@cd mcp-servers/knowledge-base && uv run pytest tests/ -v --cov=. --cov-report=term-missing
|
||||
|
||||
test-integration:
|
||||
@echo "Running MCP integration tests..."
|
||||
@echo "Note: Requires running stack (make dev first)"
|
||||
@cd backend && RUN_INTEGRATION_TESTS=true IS_TEST=True uv run pytest tests/integration/ -v
|
||||
|
||||
# ============================================================================
|
||||
# Formatting
|
||||
# ============================================================================
|
||||
|
||||
format-all:
|
||||
@echo "Formatting backend..."
|
||||
@cd backend && make format
|
||||
@echo ""
|
||||
@echo "Formatting LLM Gateway..."
|
||||
@cd mcp-servers/llm-gateway && make format
|
||||
@echo ""
|
||||
@echo "Formatting Knowledge Base..."
|
||||
@cd mcp-servers/knowledge-base && make format
|
||||
@echo ""
|
||||
@echo "Formatting frontend..."
|
||||
@cd frontend && npm run format
|
||||
@echo ""
|
||||
@echo "All code formatted!"
|
||||
|
||||
# ============================================================================
|
||||
# Validation (lint + type-check + test)
|
||||
# ============================================================================
|
||||
|
||||
validate:
|
||||
@echo "Validating backend..."
|
||||
@cd backend && make validate
|
||||
@echo ""
|
||||
@echo "Validating LLM Gateway..."
|
||||
@cd mcp-servers/llm-gateway && make validate
|
||||
@echo ""
|
||||
@echo "Validating Knowledge Base..."
|
||||
@cd mcp-servers/knowledge-base && make validate
|
||||
@echo ""
|
||||
@echo "All validations passed!"
|
||||
|
||||
validate-all: validate
|
||||
@echo ""
|
||||
@echo "Validating frontend..."
|
||||
@cd frontend && npm run validate
|
||||
@echo ""
|
||||
@echo "Full validation passed!"
|
||||
|
||||
@@ -7,7 +7,10 @@ ENV PYTHONDONTWRITEBYTECODE=1 \
|
||||
PYTHONPATH=/app \
|
||||
UV_COMPILE_BYTECODE=1 \
|
||||
UV_LINK_MODE=copy \
|
||||
UV_NO_CACHE=1
|
||||
UV_NO_CACHE=1 \
|
||||
UV_PROJECT_ENVIRONMENT=/opt/venv \
|
||||
VIRTUAL_ENV=/opt/venv \
|
||||
PATH="/opt/venv/bin:$PATH"
|
||||
|
||||
# Install system dependencies and uv
|
||||
RUN apt-get update && \
|
||||
@@ -20,7 +23,7 @@ RUN apt-get update && \
|
||||
# Copy dependency files
|
||||
COPY pyproject.toml uv.lock ./
|
||||
|
||||
# Install dependencies using uv (development mode with dev dependencies)
|
||||
# Install dependencies using uv into /opt/venv (outside /app to survive bind mounts)
|
||||
RUN uv sync --extra dev --frozen
|
||||
|
||||
# Copy application code
|
||||
@@ -45,7 +48,10 @@ ENV PYTHONDONTWRITEBYTECODE=1 \
|
||||
PYTHONPATH=/app \
|
||||
UV_COMPILE_BYTECODE=1 \
|
||||
UV_LINK_MODE=copy \
|
||||
UV_NO_CACHE=1
|
||||
UV_NO_CACHE=1 \
|
||||
UV_PROJECT_ENVIRONMENT=/opt/venv \
|
||||
VIRTUAL_ENV=/opt/venv \
|
||||
PATH="/opt/venv/bin:$PATH"
|
||||
|
||||
# Install system dependencies and uv
|
||||
RUN apt-get update && \
|
||||
@@ -58,7 +64,7 @@ RUN apt-get update && \
|
||||
# Copy dependency files
|
||||
COPY pyproject.toml uv.lock ./
|
||||
|
||||
# Install only production dependencies using uv (no dev dependencies)
|
||||
# Install only production dependencies using uv into /opt/venv
|
||||
RUN uv sync --frozen --no-dev
|
||||
|
||||
# Copy application code
|
||||
@@ -67,7 +73,7 @@ COPY entrypoint.sh /usr/local/bin/
|
||||
RUN chmod +x /usr/local/bin/entrypoint.sh
|
||||
|
||||
# Set ownership to non-root user
|
||||
RUN chown -R appuser:appuser /app
|
||||
RUN chown -R appuser:appuser /app /opt/venv
|
||||
|
||||
# Switch to non-root user
|
||||
USER appuser
|
||||
@@ -77,4 +83,4 @@ HEALTHCHECK --interval=30s --timeout=10s --start-period=40s --retries=3 \
|
||||
CMD curl -f http://localhost:8000/health || exit 1
|
||||
|
||||
ENTRYPOINT ["/usr/local/bin/entrypoint.sh"]
|
||||
CMD ["uv", "run", "uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
.PHONY: help lint lint-fix format format-check type-check test test-cov validate clean install-dev sync check-docker install-e2e test-e2e test-e2e-schema test-all
|
||||
.PHONY: help lint lint-fix format format-check type-check test test-cov validate clean install-dev sync check-docker install-e2e test-e2e test-e2e-schema test-all test-integration
|
||||
|
||||
# Default target
|
||||
help:
|
||||
@@ -22,6 +22,7 @@ help:
|
||||
@echo " make test-cov - Run pytest with coverage report"
|
||||
@echo " make test-e2e - Run E2E tests (PostgreSQL, requires Docker)"
|
||||
@echo " make test-e2e-schema - Run Schemathesis API schema tests"
|
||||
@echo " make test-integration - Run MCP integration tests (requires running stack)"
|
||||
@echo " make test-all - Run all tests (unit + E2E)"
|
||||
@echo " make check-docker - Check if Docker is available"
|
||||
@echo ""
|
||||
@@ -79,9 +80,18 @@ test:
|
||||
|
||||
test-cov:
|
||||
@echo "🧪 Running tests with coverage..."
|
||||
@IS_TEST=True PYTHONPATH=. uv run pytest --cov=app --cov-report=term-missing --cov-report=html -n 16
|
||||
@IS_TEST=True PYTHONPATH=. uv run pytest --cov=app --cov-report=term-missing --cov-report=html -n 20
|
||||
@echo "📊 Coverage report generated in htmlcov/index.html"
|
||||
|
||||
# ============================================================================
|
||||
# Integration Testing (requires running stack: make dev)
|
||||
# ============================================================================
|
||||
|
||||
test-integration:
|
||||
@echo "🧪 Running MCP integration tests..."
|
||||
@echo "Note: Requires running stack (make dev from project root)"
|
||||
@RUN_INTEGRATION_TESTS=true IS_TEST=True PYTHONPATH=. uv run pytest tests/integration/ -v
|
||||
|
||||
# ============================================================================
|
||||
# E2E Testing (requires Docker)
|
||||
# ============================================================================
|
||||
|
||||
512
backend/app/alembic/versions/0005_add_memory_system_tables.py
Normal file
512
backend/app/alembic/versions/0005_add_memory_system_tables.py
Normal file
@@ -0,0 +1,512 @@
|
||||
"""Add Agent Memory System tables
|
||||
|
||||
Revision ID: 0005
|
||||
Revises: 0004
|
||||
Create Date: 2025-01-05
|
||||
|
||||
This migration creates the Agent Memory System tables:
|
||||
- working_memory: Key-value storage with TTL for active sessions
|
||||
- episodes: Experiential memories from task executions
|
||||
- facts: Semantic knowledge triples with confidence scores
|
||||
- procedures: Learned skills and procedures
|
||||
- memory_consolidation_log: Tracks consolidation jobs
|
||||
|
||||
See Issue #88: Database Schema & Storage Layer
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "0005"
|
||||
down_revision: str | None = "0004"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Create Agent Memory System tables."""
|
||||
|
||||
# =========================================================================
|
||||
# Create ENUM types for memory system
|
||||
# =========================================================================
|
||||
|
||||
# Scope type enum
|
||||
scope_type_enum = postgresql.ENUM(
|
||||
"global",
|
||||
"project",
|
||||
"agent_type",
|
||||
"agent_instance",
|
||||
"session",
|
||||
name="scope_type",
|
||||
create_type=False,
|
||||
)
|
||||
scope_type_enum.create(op.get_bind(), checkfirst=True)
|
||||
|
||||
# Episode outcome enum
|
||||
episode_outcome_enum = postgresql.ENUM(
|
||||
"success",
|
||||
"failure",
|
||||
"partial",
|
||||
name="episode_outcome",
|
||||
create_type=False,
|
||||
)
|
||||
episode_outcome_enum.create(op.get_bind(), checkfirst=True)
|
||||
|
||||
# Consolidation type enum
|
||||
consolidation_type_enum = postgresql.ENUM(
|
||||
"working_to_episodic",
|
||||
"episodic_to_semantic",
|
||||
"episodic_to_procedural",
|
||||
"pruning",
|
||||
name="consolidation_type",
|
||||
create_type=False,
|
||||
)
|
||||
consolidation_type_enum.create(op.get_bind(), checkfirst=True)
|
||||
|
||||
# Consolidation status enum
|
||||
consolidation_status_enum = postgresql.ENUM(
|
||||
"pending",
|
||||
"running",
|
||||
"completed",
|
||||
"failed",
|
||||
name="consolidation_status",
|
||||
create_type=False,
|
||||
)
|
||||
consolidation_status_enum.create(op.get_bind(), checkfirst=True)
|
||||
|
||||
# =========================================================================
|
||||
# Create working_memory table
|
||||
# Key-value storage with TTL for active sessions
|
||||
# =========================================================================
|
||||
op.create_table(
|
||||
"working_memory",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column(
|
||||
"scope_type",
|
||||
scope_type_enum,
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("scope_id", sa.String(255), nullable=False),
|
||||
sa.Column("key", sa.String(255), nullable=False),
|
||||
sa.Column("value", postgresql.JSONB(astext_type=sa.Text()), nullable=False),
|
||||
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
# Working memory indexes
|
||||
op.create_index(
|
||||
"ix_working_memory_scope_type",
|
||||
"working_memory",
|
||||
["scope_type"],
|
||||
)
|
||||
op.create_index(
|
||||
"ix_working_memory_scope_id",
|
||||
"working_memory",
|
||||
["scope_id"],
|
||||
)
|
||||
op.create_index(
|
||||
"ix_working_memory_scope_key",
|
||||
"working_memory",
|
||||
["scope_type", "scope_id", "key"],
|
||||
unique=True,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_working_memory_expires",
|
||||
"working_memory",
|
||||
["expires_at"],
|
||||
)
|
||||
op.create_index(
|
||||
"ix_working_memory_scope_list",
|
||||
"working_memory",
|
||||
["scope_type", "scope_id"],
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Create episodes table
|
||||
# Experiential memories from task executions
|
||||
# =========================================================================
|
||||
op.create_table(
|
||||
"episodes",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("project_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("agent_instance_id", postgresql.UUID(as_uuid=True), nullable=True),
|
||||
sa.Column("agent_type_id", postgresql.UUID(as_uuid=True), nullable=True),
|
||||
sa.Column("session_id", sa.String(255), nullable=False),
|
||||
sa.Column("task_type", sa.String(100), nullable=False),
|
||||
sa.Column("task_description", sa.Text(), nullable=False),
|
||||
sa.Column(
|
||||
"actions",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=False,
|
||||
server_default="[]",
|
||||
),
|
||||
sa.Column("context_summary", sa.Text(), nullable=False),
|
||||
sa.Column(
|
||||
"outcome",
|
||||
episode_outcome_enum,
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("outcome_details", sa.Text(), nullable=True),
|
||||
sa.Column("duration_seconds", sa.Float(), nullable=False, server_default="0.0"),
|
||||
sa.Column("tokens_used", sa.BigInteger(), nullable=False, server_default="0"),
|
||||
sa.Column(
|
||||
"lessons_learned",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=False,
|
||||
server_default="[]",
|
||||
),
|
||||
sa.Column("importance_score", sa.Float(), nullable=False, server_default="0.5"),
|
||||
# Vector embedding - using TEXT as fallback, will be VECTOR(1536) when pgvector is available
|
||||
sa.Column("embedding", sa.Text(), nullable=True),
|
||||
sa.Column("occurred_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["project_id"],
|
||||
["projects.id"],
|
||||
name="fk_episodes_project",
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["agent_instance_id"],
|
||||
["agent_instances.id"],
|
||||
name="fk_episodes_agent_instance",
|
||||
ondelete="SET NULL",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["agent_type_id"],
|
||||
["agent_types.id"],
|
||||
name="fk_episodes_agent_type",
|
||||
ondelete="SET NULL",
|
||||
),
|
||||
)
|
||||
|
||||
# Episode indexes
|
||||
op.create_index("ix_episodes_project_id", "episodes", ["project_id"])
|
||||
op.create_index("ix_episodes_agent_instance_id", "episodes", ["agent_instance_id"])
|
||||
op.create_index("ix_episodes_agent_type_id", "episodes", ["agent_type_id"])
|
||||
op.create_index("ix_episodes_session_id", "episodes", ["session_id"])
|
||||
op.create_index("ix_episodes_task_type", "episodes", ["task_type"])
|
||||
op.create_index("ix_episodes_outcome", "episodes", ["outcome"])
|
||||
op.create_index("ix_episodes_importance_score", "episodes", ["importance_score"])
|
||||
op.create_index("ix_episodes_occurred_at", "episodes", ["occurred_at"])
|
||||
op.create_index("ix_episodes_project_task", "episodes", ["project_id", "task_type"])
|
||||
op.create_index(
|
||||
"ix_episodes_project_outcome", "episodes", ["project_id", "outcome"]
|
||||
)
|
||||
op.create_index(
|
||||
"ix_episodes_agent_task", "episodes", ["agent_instance_id", "task_type"]
|
||||
)
|
||||
op.create_index(
|
||||
"ix_episodes_project_time", "episodes", ["project_id", "occurred_at"]
|
||||
)
|
||||
op.create_index(
|
||||
"ix_episodes_importance_time",
|
||||
"episodes",
|
||||
["importance_score", "occurred_at"],
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Create facts table
|
||||
# Semantic knowledge triples with confidence scores
|
||||
# =========================================================================
|
||||
op.create_table(
|
||||
"facts",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column(
|
||||
"project_id", postgresql.UUID(as_uuid=True), nullable=True
|
||||
), # NULL for global facts
|
||||
sa.Column("subject", sa.String(500), nullable=False),
|
||||
sa.Column("predicate", sa.String(255), nullable=False),
|
||||
sa.Column("object", sa.Text(), nullable=False),
|
||||
sa.Column("confidence", sa.Float(), nullable=False, server_default="0.8"),
|
||||
# Source episode IDs stored as JSON array of UUID strings for cross-db compatibility
|
||||
sa.Column(
|
||||
"source_episode_ids",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=False,
|
||||
server_default="[]",
|
||||
),
|
||||
sa.Column("first_learned", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("last_reinforced", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column(
|
||||
"reinforcement_count", sa.Integer(), nullable=False, server_default="1"
|
||||
),
|
||||
# Vector embedding
|
||||
sa.Column("embedding", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["project_id"],
|
||||
["projects.id"],
|
||||
name="fk_facts_project",
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
)
|
||||
|
||||
# Fact indexes
|
||||
op.create_index("ix_facts_project_id", "facts", ["project_id"])
|
||||
op.create_index("ix_facts_subject", "facts", ["subject"])
|
||||
op.create_index("ix_facts_predicate", "facts", ["predicate"])
|
||||
op.create_index("ix_facts_confidence", "facts", ["confidence"])
|
||||
op.create_index("ix_facts_subject_predicate", "facts", ["subject", "predicate"])
|
||||
op.create_index("ix_facts_project_subject", "facts", ["project_id", "subject"])
|
||||
op.create_index(
|
||||
"ix_facts_confidence_time", "facts", ["confidence", "last_reinforced"]
|
||||
)
|
||||
# Unique constraint for triples within project scope
|
||||
op.create_index(
|
||||
"ix_facts_unique_triple",
|
||||
"facts",
|
||||
["project_id", "subject", "predicate", "object"],
|
||||
unique=True,
|
||||
postgresql_where=sa.text("project_id IS NOT NULL"),
|
||||
)
|
||||
# Unique constraint for global facts (project_id IS NULL)
|
||||
op.create_index(
|
||||
"ix_facts_unique_triple_global",
|
||||
"facts",
|
||||
["subject", "predicate", "object"],
|
||||
unique=True,
|
||||
postgresql_where=sa.text("project_id IS NULL"),
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Create procedures table
|
||||
# Learned skills and procedures
|
||||
# =========================================================================
|
||||
op.create_table(
|
||||
"procedures",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("project_id", postgresql.UUID(as_uuid=True), nullable=True),
|
||||
sa.Column("agent_type_id", postgresql.UUID(as_uuid=True), nullable=True),
|
||||
sa.Column("name", sa.String(255), nullable=False),
|
||||
sa.Column("trigger_pattern", sa.Text(), nullable=False),
|
||||
sa.Column(
|
||||
"steps",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=False,
|
||||
server_default="[]",
|
||||
),
|
||||
sa.Column("success_count", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.Column("failure_count", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.Column("last_used", sa.DateTime(timezone=True), nullable=True),
|
||||
# Vector embedding
|
||||
sa.Column("embedding", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["project_id"],
|
||||
["projects.id"],
|
||||
name="fk_procedures_project",
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["agent_type_id"],
|
||||
["agent_types.id"],
|
||||
name="fk_procedures_agent_type",
|
||||
ondelete="SET NULL",
|
||||
),
|
||||
)
|
||||
|
||||
# Procedure indexes
|
||||
op.create_index("ix_procedures_project_id", "procedures", ["project_id"])
|
||||
op.create_index("ix_procedures_agent_type_id", "procedures", ["agent_type_id"])
|
||||
op.create_index("ix_procedures_name", "procedures", ["name"])
|
||||
op.create_index("ix_procedures_last_used", "procedures", ["last_used"])
|
||||
op.create_index(
|
||||
"ix_procedures_unique_name",
|
||||
"procedures",
|
||||
["project_id", "agent_type_id", "name"],
|
||||
unique=True,
|
||||
)
|
||||
op.create_index("ix_procedures_project_name", "procedures", ["project_id", "name"])
|
||||
# Note: agent_type_id already indexed via ix_procedures_agent_type_id (line 354)
|
||||
op.create_index(
|
||||
"ix_procedures_success_rate",
|
||||
"procedures",
|
||||
["success_count", "failure_count"],
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Add check constraints for data integrity
|
||||
# =========================================================================
|
||||
|
||||
# Episode constraints
|
||||
op.create_check_constraint(
|
||||
"ck_episodes_importance_range",
|
||||
"episodes",
|
||||
"importance_score >= 0.0 AND importance_score <= 1.0",
|
||||
)
|
||||
op.create_check_constraint(
|
||||
"ck_episodes_duration_positive",
|
||||
"episodes",
|
||||
"duration_seconds >= 0.0",
|
||||
)
|
||||
op.create_check_constraint(
|
||||
"ck_episodes_tokens_positive",
|
||||
"episodes",
|
||||
"tokens_used >= 0",
|
||||
)
|
||||
|
||||
# Fact constraints
|
||||
op.create_check_constraint(
|
||||
"ck_facts_confidence_range",
|
||||
"facts",
|
||||
"confidence >= 0.0 AND confidence <= 1.0",
|
||||
)
|
||||
op.create_check_constraint(
|
||||
"ck_facts_reinforcement_positive",
|
||||
"facts",
|
||||
"reinforcement_count >= 1",
|
||||
)
|
||||
|
||||
# Procedure constraints
|
||||
op.create_check_constraint(
|
||||
"ck_procedures_success_positive",
|
||||
"procedures",
|
||||
"success_count >= 0",
|
||||
)
|
||||
op.create_check_constraint(
|
||||
"ck_procedures_failure_positive",
|
||||
"procedures",
|
||||
"failure_count >= 0",
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Create memory_consolidation_log table
|
||||
# Tracks consolidation jobs
|
||||
# =========================================================================
|
||||
op.create_table(
|
||||
"memory_consolidation_log",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column(
|
||||
"consolidation_type",
|
||||
consolidation_type_enum,
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("source_count", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.Column("result_count", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.Column("started_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
"status",
|
||||
consolidation_status_enum,
|
||||
nullable=False,
|
||||
server_default="pending",
|
||||
),
|
||||
sa.Column("error", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
# Consolidation log indexes
|
||||
op.create_index(
|
||||
"ix_consolidation_type",
|
||||
"memory_consolidation_log",
|
||||
["consolidation_type"],
|
||||
)
|
||||
op.create_index(
|
||||
"ix_consolidation_status",
|
||||
"memory_consolidation_log",
|
||||
["status"],
|
||||
)
|
||||
op.create_index(
|
||||
"ix_consolidation_type_status",
|
||||
"memory_consolidation_log",
|
||||
["consolidation_type", "status"],
|
||||
)
|
||||
op.create_index(
|
||||
"ix_consolidation_started",
|
||||
"memory_consolidation_log",
|
||||
["started_at"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Drop Agent Memory System tables."""
|
||||
|
||||
# Drop check constraints first
|
||||
op.drop_constraint("ck_procedures_failure_positive", "procedures", type_="check")
|
||||
op.drop_constraint("ck_procedures_success_positive", "procedures", type_="check")
|
||||
op.drop_constraint("ck_facts_reinforcement_positive", "facts", type_="check")
|
||||
op.drop_constraint("ck_facts_confidence_range", "facts", type_="check")
|
||||
op.drop_constraint("ck_episodes_tokens_positive", "episodes", type_="check")
|
||||
op.drop_constraint("ck_episodes_duration_positive", "episodes", type_="check")
|
||||
op.drop_constraint("ck_episodes_importance_range", "episodes", type_="check")
|
||||
|
||||
# Drop unique indexes for global facts
|
||||
op.drop_index("ix_facts_unique_triple_global", "facts")
|
||||
|
||||
# Drop tables in reverse order (dependencies first)
|
||||
op.drop_table("memory_consolidation_log")
|
||||
op.drop_table("procedures")
|
||||
op.drop_table("facts")
|
||||
op.drop_table("episodes")
|
||||
op.drop_table("working_memory")
|
||||
|
||||
# Drop ENUM types
|
||||
op.execute("DROP TYPE IF EXISTS consolidation_status")
|
||||
op.execute("DROP TYPE IF EXISTS consolidation_type")
|
||||
op.execute("DROP TYPE IF EXISTS episode_outcome")
|
||||
op.execute("DROP TYPE IF EXISTS scope_type")
|
||||
52
backend/app/alembic/versions/0006_add_abandoned_outcome.py
Normal file
52
backend/app/alembic/versions/0006_add_abandoned_outcome.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""Add ABANDONED to episode_outcome enum
|
||||
|
||||
Revision ID: 0006
|
||||
Revises: 0005
|
||||
Create Date: 2025-01-06
|
||||
|
||||
This migration adds the 'abandoned' value to the episode_outcome enum type.
|
||||
This allows episodes to track when a task was abandoned (not completed,
|
||||
but not necessarily a failure either - e.g., user cancelled, session timeout).
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "0006"
|
||||
down_revision: str | None = "0005"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Add 'abandoned' value to episode_outcome enum."""
|
||||
# PostgreSQL ALTER TYPE ADD VALUE is safe and non-blocking
|
||||
op.execute("ALTER TYPE episode_outcome ADD VALUE IF NOT EXISTS 'abandoned'")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Remove 'abandoned' from episode_outcome enum.
|
||||
|
||||
Note: PostgreSQL doesn't support removing values from enums directly.
|
||||
This downgrade converts any 'abandoned' episodes to 'failure' and
|
||||
recreates the enum without 'abandoned'.
|
||||
"""
|
||||
# Convert any abandoned episodes to failure first
|
||||
op.execute("""
|
||||
UPDATE episodes
|
||||
SET outcome = 'failure'
|
||||
WHERE outcome = 'abandoned'
|
||||
""")
|
||||
|
||||
# Recreate the enum without abandoned
|
||||
# This is complex in PostgreSQL - requires creating new type, updating columns, dropping old
|
||||
op.execute("ALTER TYPE episode_outcome RENAME TO episode_outcome_old")
|
||||
op.execute("CREATE TYPE episode_outcome AS ENUM ('success', 'failure', 'partial')")
|
||||
op.execute("""
|
||||
ALTER TABLE episodes
|
||||
ALTER COLUMN outcome TYPE episode_outcome
|
||||
USING outcome::text::episode_outcome
|
||||
""")
|
||||
op.execute("DROP TYPE episode_outcome_old")
|
||||
@@ -0,0 +1,90 @@
|
||||
"""Add category and display fields to agent_types table
|
||||
|
||||
Revision ID: 0007
|
||||
Revises: 0006
|
||||
Create Date: 2026-01-06
|
||||
|
||||
This migration adds:
|
||||
- category: String(50) for grouping agents by role type
|
||||
- icon: String(50) for Lucide icon identifier
|
||||
- color: String(7) for hex color code
|
||||
- sort_order: Integer for display ordering within categories
|
||||
- typical_tasks: JSONB list of tasks this agent excels at
|
||||
- collaboration_hints: JSONB list of agent slugs that work well together
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "0007"
|
||||
down_revision: str | None = "0006"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Add category and display fields to agent_types table."""
|
||||
# Add new columns
|
||||
op.add_column(
|
||||
"agent_types",
|
||||
sa.Column("category", sa.String(length=50), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"agent_types",
|
||||
sa.Column("icon", sa.String(length=50), nullable=True, server_default="bot"),
|
||||
)
|
||||
op.add_column(
|
||||
"agent_types",
|
||||
sa.Column(
|
||||
"color", sa.String(length=7), nullable=True, server_default="#3B82F6"
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"agent_types",
|
||||
sa.Column("sort_order", sa.Integer(), nullable=False, server_default="0"),
|
||||
)
|
||||
op.add_column(
|
||||
"agent_types",
|
||||
sa.Column(
|
||||
"typical_tasks",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=False,
|
||||
server_default="[]",
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"agent_types",
|
||||
sa.Column(
|
||||
"collaboration_hints",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=False,
|
||||
server_default="[]",
|
||||
),
|
||||
)
|
||||
|
||||
# Add indexes for category and sort_order
|
||||
op.create_index("ix_agent_types_category", "agent_types", ["category"])
|
||||
op.create_index("ix_agent_types_sort_order", "agent_types", ["sort_order"])
|
||||
op.create_index(
|
||||
"ix_agent_types_category_sort", "agent_types", ["category", "sort_order"]
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Remove category and display fields from agent_types table."""
|
||||
# Drop indexes
|
||||
op.drop_index("ix_agent_types_category_sort", table_name="agent_types")
|
||||
op.drop_index("ix_agent_types_sort_order", table_name="agent_types")
|
||||
op.drop_index("ix_agent_types_category", table_name="agent_types")
|
||||
|
||||
# Drop columns
|
||||
op.drop_column("agent_types", "collaboration_hints")
|
||||
op.drop_column("agent_types", "typical_tasks")
|
||||
op.drop_column("agent_types", "sort_order")
|
||||
op.drop_column("agent_types", "color")
|
||||
op.drop_column("agent_types", "icon")
|
||||
op.drop_column("agent_types", "category")
|
||||
@@ -5,8 +5,10 @@ from app.api.routes import (
|
||||
agent_types,
|
||||
agents,
|
||||
auth,
|
||||
context,
|
||||
events,
|
||||
issues,
|
||||
mcp,
|
||||
oauth,
|
||||
oauth_provider,
|
||||
organizations,
|
||||
@@ -31,6 +33,12 @@ api_router.include_router(
|
||||
# SSE events router - no prefix, routes define full paths
|
||||
api_router.include_router(events.router, tags=["Events"])
|
||||
|
||||
# MCP (Model Context Protocol) router
|
||||
api_router.include_router(mcp.router, prefix="/mcp", tags=["MCP"])
|
||||
|
||||
# Context Management Engine router
|
||||
api_router.include_router(context.router, prefix="/context", tags=["Context"])
|
||||
|
||||
# Syndarix domain routers
|
||||
api_router.include_router(projects.router, prefix="/projects", tags=["Projects"])
|
||||
api_router.include_router(
|
||||
|
||||
@@ -81,6 +81,13 @@ def _build_agent_type_response(
|
||||
mcp_servers=agent_type.mcp_servers,
|
||||
tool_permissions=agent_type.tool_permissions,
|
||||
is_active=agent_type.is_active,
|
||||
# Category and display fields
|
||||
category=agent_type.category,
|
||||
icon=agent_type.icon,
|
||||
color=agent_type.color,
|
||||
sort_order=agent_type.sort_order,
|
||||
typical_tasks=agent_type.typical_tasks or [],
|
||||
collaboration_hints=agent_type.collaboration_hints or [],
|
||||
created_at=agent_type.created_at,
|
||||
updated_at=agent_type.updated_at,
|
||||
instance_count=instance_count,
|
||||
@@ -300,6 +307,7 @@ async def list_agent_types(
|
||||
request: Request,
|
||||
pagination: PaginationParams = Depends(),
|
||||
is_active: bool = Query(True, description="Filter by active status"),
|
||||
category: str | None = Query(None, description="Filter by category"),
|
||||
search: str | None = Query(None, description="Search by name, slug, description"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
@@ -314,6 +322,7 @@ async def list_agent_types(
|
||||
request: FastAPI request object
|
||||
pagination: Pagination parameters (page, limit)
|
||||
is_active: Filter by active status (default: True)
|
||||
category: Filter by category (e.g., "development", "design")
|
||||
search: Optional search term for name, slug, description
|
||||
current_user: Authenticated user
|
||||
db: Database session
|
||||
@@ -328,6 +337,7 @@ async def list_agent_types(
|
||||
skip=pagination.offset,
|
||||
limit=pagination.limit,
|
||||
is_active=is_active,
|
||||
category=category,
|
||||
search=search,
|
||||
)
|
||||
|
||||
@@ -354,6 +364,51 @@ async def list_agent_types(
|
||||
raise
|
||||
|
||||
|
||||
@router.get(
|
||||
"/grouped",
|
||||
response_model=dict[str, list[AgentTypeResponse]],
|
||||
summary="List Agent Types Grouped by Category",
|
||||
description="Get all agent types organized by category",
|
||||
operation_id="list_agent_types_grouped",
|
||||
)
|
||||
@limiter.limit(f"{60 * RATE_MULTIPLIER}/minute")
|
||||
async def list_agent_types_grouped(
|
||||
request: Request,
|
||||
is_active: bool = Query(True, description="Filter by active status"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Get agent types grouped by category.
|
||||
|
||||
Returns a dictionary where keys are category names and values
|
||||
are lists of agent types, sorted by sort_order within each category.
|
||||
|
||||
Args:
|
||||
request: FastAPI request object
|
||||
is_active: Filter by active status (default: True)
|
||||
current_user: Authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Dictionary mapping category to list of agent types
|
||||
"""
|
||||
try:
|
||||
grouped = await agent_type_crud.get_grouped_by_category(db, is_active=is_active)
|
||||
|
||||
# Transform to response objects
|
||||
result: dict[str, list[AgentTypeResponse]] = {}
|
||||
for category, types in grouped.items():
|
||||
result[category] = [
|
||||
_build_agent_type_response(t, instance_count=0) for t in types
|
||||
]
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting grouped agent types: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{agent_type_id}",
|
||||
response_model=AgentTypeResponse,
|
||||
|
||||
411
backend/app/api/routes/context.py
Normal file
411
backend/app/api/routes/context.py
Normal file
@@ -0,0 +1,411 @@
|
||||
"""
|
||||
Context Management API Endpoints.
|
||||
|
||||
Provides REST endpoints for context assembly and optimization
|
||||
for LLM requests using the ContextEngine.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Annotated, Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.api.dependencies.permissions import require_superuser
|
||||
from app.models.user import User
|
||||
from app.services.context import (
|
||||
AssemblyTimeoutError,
|
||||
BudgetExceededError,
|
||||
ContextEngine,
|
||||
ContextSettings,
|
||||
create_context_engine,
|
||||
get_context_settings,
|
||||
)
|
||||
from app.services.mcp import MCPClientManager, get_mcp_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Singleton Engine Management
|
||||
# ============================================================================
|
||||
|
||||
_context_engine: ContextEngine | None = None
|
||||
|
||||
|
||||
def _get_or_create_engine(
|
||||
mcp: MCPClientManager,
|
||||
settings: ContextSettings | None = None,
|
||||
) -> ContextEngine:
|
||||
"""Get or create the singleton ContextEngine."""
|
||||
global _context_engine
|
||||
if _context_engine is None:
|
||||
_context_engine = create_context_engine(
|
||||
mcp_manager=mcp,
|
||||
redis=None, # Optional: add Redis caching later
|
||||
settings=settings or get_context_settings(),
|
||||
)
|
||||
logger.info("ContextEngine initialized")
|
||||
else:
|
||||
# Ensure MCP manager is up to date
|
||||
_context_engine.set_mcp_manager(mcp)
|
||||
return _context_engine
|
||||
|
||||
|
||||
async def get_context_engine(
|
||||
mcp: MCPClientManager = Depends(get_mcp_client),
|
||||
) -> ContextEngine:
|
||||
"""FastAPI dependency to get the ContextEngine."""
|
||||
return _get_or_create_engine(mcp)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Request/Response Schemas
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class ConversationTurn(BaseModel):
|
||||
"""A single conversation turn."""
|
||||
|
||||
role: str = Field(..., description="Role: 'user' or 'assistant'")
|
||||
content: str = Field(..., description="Message content")
|
||||
|
||||
|
||||
class ToolResult(BaseModel):
|
||||
"""A tool execution result."""
|
||||
|
||||
tool_name: str = Field(..., description="Name of the tool")
|
||||
content: str | dict[str, Any] = Field(..., description="Tool result content")
|
||||
status: str = Field(default="success", description="Execution status")
|
||||
|
||||
|
||||
class AssembleContextRequest(BaseModel):
|
||||
"""Request to assemble context for an LLM request."""
|
||||
|
||||
project_id: str = Field(..., description="Project identifier")
|
||||
agent_id: str = Field(..., description="Agent identifier")
|
||||
query: str = Field(..., description="User's query or current request")
|
||||
model: str = Field(
|
||||
default="claude-3-sonnet",
|
||||
description="Target model name",
|
||||
)
|
||||
max_tokens: int | None = Field(
|
||||
None,
|
||||
description="Maximum context tokens (uses model default if None)",
|
||||
)
|
||||
system_prompt: str | None = Field(
|
||||
None,
|
||||
description="System prompt/instructions",
|
||||
)
|
||||
task_description: str | None = Field(
|
||||
None,
|
||||
description="Current task description",
|
||||
)
|
||||
knowledge_query: str | None = Field(
|
||||
None,
|
||||
description="Query for knowledge base search",
|
||||
)
|
||||
knowledge_limit: int = Field(
|
||||
default=10,
|
||||
ge=1,
|
||||
le=50,
|
||||
description="Max number of knowledge results",
|
||||
)
|
||||
conversation_history: list[ConversationTurn] | None = Field(
|
||||
None,
|
||||
description="Previous conversation turns",
|
||||
)
|
||||
tool_results: list[ToolResult] | None = Field(
|
||||
None,
|
||||
description="Tool execution results to include",
|
||||
)
|
||||
compress: bool = Field(
|
||||
default=True,
|
||||
description="Whether to apply compression",
|
||||
)
|
||||
use_cache: bool = Field(
|
||||
default=True,
|
||||
description="Whether to use caching",
|
||||
)
|
||||
|
||||
|
||||
class AssembledContextResponse(BaseModel):
|
||||
"""Response containing assembled context."""
|
||||
|
||||
content: str = Field(..., description="Assembled context content")
|
||||
total_tokens: int = Field(..., description="Total token count")
|
||||
context_count: int = Field(..., description="Number of context items included")
|
||||
compressed: bool = Field(..., description="Whether compression was applied")
|
||||
budget_used_percent: float = Field(
|
||||
...,
|
||||
description="Percentage of token budget used",
|
||||
)
|
||||
metadata: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Additional metadata",
|
||||
)
|
||||
|
||||
|
||||
class TokenCountRequest(BaseModel):
|
||||
"""Request to count tokens in content."""
|
||||
|
||||
content: str = Field(..., description="Content to count tokens in")
|
||||
model: str | None = Field(
|
||||
None,
|
||||
description="Model for model-specific tokenization",
|
||||
)
|
||||
|
||||
|
||||
class TokenCountResponse(BaseModel):
|
||||
"""Response containing token count."""
|
||||
|
||||
token_count: int = Field(..., description="Number of tokens")
|
||||
model: str | None = Field(None, description="Model used for counting")
|
||||
|
||||
|
||||
class BudgetInfoResponse(BaseModel):
|
||||
"""Response containing budget information for a model."""
|
||||
|
||||
model: str = Field(..., description="Model name")
|
||||
total_tokens: int = Field(..., description="Total token budget")
|
||||
system_tokens: int = Field(..., description="Tokens reserved for system")
|
||||
knowledge_tokens: int = Field(..., description="Tokens for knowledge")
|
||||
conversation_tokens: int = Field(..., description="Tokens for conversation")
|
||||
tool_tokens: int = Field(..., description="Tokens for tool results")
|
||||
response_reserve: int = Field(..., description="Tokens reserved for response")
|
||||
|
||||
|
||||
class ContextEngineStatsResponse(BaseModel):
|
||||
"""Response containing engine statistics."""
|
||||
|
||||
cache: dict[str, Any] = Field(..., description="Cache statistics")
|
||||
settings: dict[str, Any] = Field(..., description="Current settings")
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
"""Health check response."""
|
||||
|
||||
status: str = Field(..., description="Health status")
|
||||
mcp_connected: bool = Field(..., description="Whether MCP is connected")
|
||||
cache_enabled: bool = Field(..., description="Whether caching is enabled")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Endpoints
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.get(
|
||||
"/health",
|
||||
response_model=HealthResponse,
|
||||
summary="Context Engine Health",
|
||||
description="Check health status of the context engine.",
|
||||
)
|
||||
async def health_check(
|
||||
engine: ContextEngine = Depends(get_context_engine),
|
||||
) -> HealthResponse:
|
||||
"""Check context engine health."""
|
||||
stats = await engine.get_stats()
|
||||
return HealthResponse(
|
||||
status="healthy",
|
||||
mcp_connected=engine._mcp is not None,
|
||||
cache_enabled=stats.get("settings", {}).get("cache_enabled", False),
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/assemble",
|
||||
response_model=AssembledContextResponse,
|
||||
summary="Assemble Context",
|
||||
description="Assemble optimized context for an LLM request.",
|
||||
)
|
||||
async def assemble_context(
|
||||
request: AssembleContextRequest,
|
||||
current_user: User = Depends(require_superuser),
|
||||
engine: ContextEngine = Depends(get_context_engine),
|
||||
) -> AssembledContextResponse:
|
||||
"""
|
||||
Assemble optimized context for an LLM request.
|
||||
|
||||
This endpoint gathers context from various sources, scores and ranks them,
|
||||
compresses if needed, and formats for the target model.
|
||||
"""
|
||||
logger.info(
|
||||
"Context assembly for project=%s agent=%s by user=%s",
|
||||
request.project_id,
|
||||
request.agent_id,
|
||||
current_user.id,
|
||||
)
|
||||
|
||||
# Convert conversation history to dict format
|
||||
conversation_history = None
|
||||
if request.conversation_history:
|
||||
conversation_history = [
|
||||
{"role": turn.role, "content": turn.content}
|
||||
for turn in request.conversation_history
|
||||
]
|
||||
|
||||
# Convert tool results to dict format
|
||||
tool_results = None
|
||||
if request.tool_results:
|
||||
tool_results = [
|
||||
{
|
||||
"tool_name": tr.tool_name,
|
||||
"content": tr.content,
|
||||
"status": tr.status,
|
||||
}
|
||||
for tr in request.tool_results
|
||||
]
|
||||
|
||||
try:
|
||||
result = await engine.assemble_context(
|
||||
project_id=request.project_id,
|
||||
agent_id=request.agent_id,
|
||||
query=request.query,
|
||||
model=request.model,
|
||||
max_tokens=request.max_tokens,
|
||||
system_prompt=request.system_prompt,
|
||||
task_description=request.task_description,
|
||||
knowledge_query=request.knowledge_query,
|
||||
knowledge_limit=request.knowledge_limit,
|
||||
conversation_history=conversation_history,
|
||||
tool_results=tool_results,
|
||||
compress=request.compress,
|
||||
use_cache=request.use_cache,
|
||||
)
|
||||
|
||||
# Calculate budget usage percentage
|
||||
budget = await engine.get_budget_for_model(request.model, request.max_tokens)
|
||||
budget_used_percent = (result.total_tokens / budget.total) * 100
|
||||
|
||||
# Check if compression was applied (from metadata if available)
|
||||
was_compressed = result.metadata.get("compressed_contexts", 0) > 0
|
||||
|
||||
return AssembledContextResponse(
|
||||
content=result.content,
|
||||
total_tokens=result.total_tokens,
|
||||
context_count=result.context_count,
|
||||
compressed=was_compressed,
|
||||
budget_used_percent=round(budget_used_percent, 2),
|
||||
metadata={
|
||||
"model": request.model,
|
||||
"query": request.query,
|
||||
"knowledge_included": bool(request.knowledge_query),
|
||||
"conversation_turns": len(request.conversation_history or []),
|
||||
"excluded_count": result.excluded_count,
|
||||
"assembly_time_ms": result.assembly_time_ms,
|
||||
},
|
||||
)
|
||||
|
||||
except AssemblyTimeoutError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_504_GATEWAY_TIMEOUT,
|
||||
detail=f"Context assembly timed out: {e}",
|
||||
) from e
|
||||
except BudgetExceededError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
|
||||
detail=f"Token budget exceeded: {e}",
|
||||
) from e
|
||||
except Exception as e:
|
||||
logger.exception("Context assembly failed")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Context assembly failed: {e}",
|
||||
) from e
|
||||
|
||||
|
||||
@router.post(
|
||||
"/count-tokens",
|
||||
response_model=TokenCountResponse,
|
||||
summary="Count Tokens",
|
||||
description="Count tokens in content using the LLM Gateway.",
|
||||
)
|
||||
async def count_tokens(
|
||||
request: TokenCountRequest,
|
||||
engine: ContextEngine = Depends(get_context_engine),
|
||||
) -> TokenCountResponse:
|
||||
"""Count tokens in content."""
|
||||
try:
|
||||
count = await engine.count_tokens(
|
||||
content=request.content,
|
||||
model=request.model,
|
||||
)
|
||||
return TokenCountResponse(
|
||||
token_count=count,
|
||||
model=request.model,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Token counting failed: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Token counting failed: {e}",
|
||||
) from e
|
||||
|
||||
|
||||
@router.get(
|
||||
"/budget/{model}",
|
||||
response_model=BudgetInfoResponse,
|
||||
summary="Get Token Budget",
|
||||
description="Get token budget allocation for a specific model.",
|
||||
)
|
||||
async def get_budget(
|
||||
model: str,
|
||||
max_tokens: Annotated[int | None, Query(description="Custom max tokens")] = None,
|
||||
engine: ContextEngine = Depends(get_context_engine),
|
||||
) -> BudgetInfoResponse:
|
||||
"""Get token budget information for a model."""
|
||||
budget = await engine.get_budget_for_model(model, max_tokens)
|
||||
return BudgetInfoResponse(
|
||||
model=model,
|
||||
total_tokens=budget.total,
|
||||
system_tokens=budget.system,
|
||||
knowledge_tokens=budget.knowledge,
|
||||
conversation_tokens=budget.conversation,
|
||||
tool_tokens=budget.tools,
|
||||
response_reserve=budget.response_reserve,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/stats",
|
||||
response_model=ContextEngineStatsResponse,
|
||||
summary="Engine Statistics",
|
||||
description="Get context engine statistics and configuration.",
|
||||
)
|
||||
async def get_stats(
|
||||
current_user: User = Depends(require_superuser),
|
||||
engine: ContextEngine = Depends(get_context_engine),
|
||||
) -> ContextEngineStatsResponse:
|
||||
"""Get engine statistics."""
|
||||
stats = await engine.get_stats()
|
||||
return ContextEngineStatsResponse(
|
||||
cache=stats.get("cache", {}),
|
||||
settings=stats.get("settings", {}),
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/cache/invalidate",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
summary="Invalidate Cache (Admin Only)",
|
||||
description="Invalidate context cache entries.",
|
||||
)
|
||||
async def invalidate_cache(
|
||||
project_id: Annotated[
|
||||
str | None, Query(description="Project to invalidate")
|
||||
] = None,
|
||||
pattern: Annotated[str | None, Query(description="Pattern to match")] = None,
|
||||
current_user: User = Depends(require_superuser),
|
||||
engine: ContextEngine = Depends(get_context_engine),
|
||||
) -> None:
|
||||
"""Invalidate cache entries."""
|
||||
logger.info(
|
||||
"Cache invalidation by user %s: project=%s pattern=%s",
|
||||
current_user.id,
|
||||
project_id,
|
||||
pattern,
|
||||
)
|
||||
await engine.invalidate_cache(project_id=project_id, pattern=pattern)
|
||||
446
backend/app/api/routes/mcp.py
Normal file
446
backend/app/api/routes/mcp.py
Normal file
@@ -0,0 +1,446 @@
|
||||
"""
|
||||
MCP (Model Context Protocol) API Endpoints
|
||||
|
||||
Provides REST endpoints for managing MCP server connections
|
||||
and executing tool calls.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Annotated, Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Path, status
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.api.dependencies.permissions import require_superuser
|
||||
from app.models.user import User
|
||||
from app.services.mcp import (
|
||||
MCPCircuitOpenError,
|
||||
MCPClientManager,
|
||||
MCPConnectionError,
|
||||
MCPError,
|
||||
MCPServerNotFoundError,
|
||||
MCPTimeoutError,
|
||||
MCPToolError,
|
||||
MCPToolNotFoundError,
|
||||
get_mcp_client,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# Server name validation pattern: alphanumeric, hyphens, underscores, 1-64 chars
|
||||
SERVER_NAME_PATTERN = re.compile(r"^[a-zA-Z0-9_-]{1,64}$")
|
||||
|
||||
# Type alias for validated server name path parameter
|
||||
ServerNamePath = Annotated[
|
||||
str,
|
||||
Path(
|
||||
description="MCP server name",
|
||||
min_length=1,
|
||||
max_length=64,
|
||||
pattern=r"^[a-zA-Z0-9_-]+$",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Request/Response Schemas
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class ServerInfo(BaseModel):
|
||||
"""Information about an MCP server."""
|
||||
|
||||
name: str = Field(..., description="Server name")
|
||||
url: str = Field(..., description="Server URL")
|
||||
enabled: bool = Field(..., description="Whether server is enabled")
|
||||
timeout: int = Field(..., description="Request timeout in seconds")
|
||||
transport: str = Field(..., description="Transport type (http, stdio, sse)")
|
||||
description: str | None = Field(None, description="Server description")
|
||||
|
||||
|
||||
class ServerListResponse(BaseModel):
|
||||
"""Response containing list of MCP servers."""
|
||||
|
||||
servers: list[ServerInfo]
|
||||
total: int
|
||||
|
||||
|
||||
class ToolInfoResponse(BaseModel):
|
||||
"""Information about an MCP tool."""
|
||||
|
||||
name: str = Field(..., description="Tool name")
|
||||
description: str | None = Field(None, description="Tool description")
|
||||
server_name: str | None = Field(None, description="Server providing the tool")
|
||||
input_schema: dict[str, Any] | None = Field(
|
||||
None, description="JSON schema for input"
|
||||
)
|
||||
|
||||
|
||||
class ToolListResponse(BaseModel):
|
||||
"""Response containing list of tools."""
|
||||
|
||||
tools: list[ToolInfoResponse]
|
||||
total: int
|
||||
|
||||
|
||||
class ServerHealthStatus(BaseModel):
|
||||
"""Health status for a server."""
|
||||
|
||||
name: str
|
||||
healthy: bool
|
||||
state: str
|
||||
url: str
|
||||
error: str | None = None
|
||||
tools_count: int = 0
|
||||
|
||||
|
||||
class HealthCheckResponse(BaseModel):
|
||||
"""Response containing health status of all servers."""
|
||||
|
||||
servers: dict[str, ServerHealthStatus]
|
||||
healthy_count: int
|
||||
unhealthy_count: int
|
||||
total: int
|
||||
|
||||
|
||||
class ToolCallRequest(BaseModel):
|
||||
"""Request to execute a tool."""
|
||||
|
||||
server: str = Field(..., description="MCP server name")
|
||||
tool: str = Field(..., description="Tool name to execute")
|
||||
arguments: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Tool arguments",
|
||||
)
|
||||
timeout: float | None = Field(
|
||||
None,
|
||||
description="Optional timeout override in seconds",
|
||||
)
|
||||
|
||||
|
||||
class ToolCallResponse(BaseModel):
|
||||
"""Response from tool execution."""
|
||||
|
||||
success: bool
|
||||
data: Any | None = None
|
||||
error: str | None = None
|
||||
error_code: str | None = None
|
||||
tool_name: str | None = None
|
||||
server_name: str | None = None
|
||||
execution_time_ms: float = 0.0
|
||||
request_id: str | None = None
|
||||
|
||||
|
||||
class CircuitBreakerStatus(BaseModel):
|
||||
"""Status of a circuit breaker."""
|
||||
|
||||
server_name: str
|
||||
state: str
|
||||
failure_count: int
|
||||
|
||||
|
||||
class CircuitBreakerListResponse(BaseModel):
|
||||
"""Response containing circuit breaker statuses."""
|
||||
|
||||
circuit_breakers: list[CircuitBreakerStatus]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Endpoints
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.get(
|
||||
"/servers",
|
||||
response_model=ServerListResponse,
|
||||
summary="List MCP Servers",
|
||||
description="Get list of all registered MCP servers with their configurations.",
|
||||
)
|
||||
async def list_servers(
|
||||
mcp: MCPClientManager = Depends(get_mcp_client),
|
||||
) -> ServerListResponse:
|
||||
"""List all registered MCP servers."""
|
||||
servers = []
|
||||
|
||||
for name in mcp.list_servers():
|
||||
try:
|
||||
config = mcp.get_server_config(name)
|
||||
servers.append(
|
||||
ServerInfo(
|
||||
name=name,
|
||||
url=config.url,
|
||||
enabled=config.enabled,
|
||||
timeout=config.timeout,
|
||||
transport=config.transport.value,
|
||||
description=config.description,
|
||||
)
|
||||
)
|
||||
except MCPServerNotFoundError:
|
||||
continue
|
||||
|
||||
return ServerListResponse(
|
||||
servers=servers,
|
||||
total=len(servers),
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/servers/{server_name}/tools",
|
||||
response_model=ToolListResponse,
|
||||
summary="List Server Tools",
|
||||
description="Get list of tools available on a specific MCP server.",
|
||||
)
|
||||
async def list_server_tools(
|
||||
server_name: ServerNamePath,
|
||||
mcp: MCPClientManager = Depends(get_mcp_client),
|
||||
) -> ToolListResponse:
|
||||
"""List all tools available on a specific server."""
|
||||
try:
|
||||
tools = await mcp.list_tools(server_name)
|
||||
return ToolListResponse(
|
||||
tools=[
|
||||
ToolInfoResponse(
|
||||
name=t.name,
|
||||
description=t.description,
|
||||
server_name=t.server_name,
|
||||
input_schema=t.input_schema,
|
||||
)
|
||||
for t in tools
|
||||
],
|
||||
total=len(tools),
|
||||
)
|
||||
except MCPServerNotFoundError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Server not found: {server_name}",
|
||||
) from e
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tools",
|
||||
response_model=ToolListResponse,
|
||||
summary="List All Tools",
|
||||
description="Get list of all tools from all MCP servers.",
|
||||
)
|
||||
async def list_all_tools(
|
||||
mcp: MCPClientManager = Depends(get_mcp_client),
|
||||
) -> ToolListResponse:
|
||||
"""List all tools from all servers."""
|
||||
tools = await mcp.list_all_tools()
|
||||
return ToolListResponse(
|
||||
tools=[
|
||||
ToolInfoResponse(
|
||||
name=t.name,
|
||||
description=t.description,
|
||||
server_name=t.server_name,
|
||||
input_schema=t.input_schema,
|
||||
)
|
||||
for t in tools
|
||||
],
|
||||
total=len(tools),
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/health",
|
||||
response_model=HealthCheckResponse,
|
||||
summary="Health Check",
|
||||
description="Check health status of all MCP servers.",
|
||||
)
|
||||
async def health_check(
|
||||
mcp: MCPClientManager = Depends(get_mcp_client),
|
||||
) -> HealthCheckResponse:
|
||||
"""Perform health check on all MCP servers."""
|
||||
health_results = await mcp.health_check()
|
||||
|
||||
servers = {
|
||||
name: ServerHealthStatus(
|
||||
name=status.name,
|
||||
healthy=status.healthy,
|
||||
state=status.state,
|
||||
url=status.url,
|
||||
error=status.error,
|
||||
tools_count=status.tools_count,
|
||||
)
|
||||
for name, status in health_results.items()
|
||||
}
|
||||
|
||||
healthy_count = sum(1 for s in servers.values() if s.healthy)
|
||||
unhealthy_count = len(servers) - healthy_count
|
||||
|
||||
return HealthCheckResponse(
|
||||
servers=servers,
|
||||
healthy_count=healthy_count,
|
||||
unhealthy_count=unhealthy_count,
|
||||
total=len(servers),
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/call",
|
||||
response_model=ToolCallResponse,
|
||||
summary="Execute Tool (Admin Only)",
|
||||
description="Execute a tool on an MCP server. Requires superuser privileges.",
|
||||
)
|
||||
async def call_tool(
|
||||
request: ToolCallRequest,
|
||||
current_user: User = Depends(require_superuser),
|
||||
mcp: MCPClientManager = Depends(get_mcp_client),
|
||||
) -> ToolCallResponse:
|
||||
"""
|
||||
Execute a tool on an MCP server.
|
||||
|
||||
This endpoint is restricted to superusers for direct tool execution.
|
||||
Normal tool execution should go through agent workflows.
|
||||
"""
|
||||
logger.info(
|
||||
"Tool call by user %s: %s.%s",
|
||||
current_user.id,
|
||||
request.server,
|
||||
request.tool,
|
||||
)
|
||||
|
||||
try:
|
||||
result = await mcp.call_tool(
|
||||
server=request.server,
|
||||
tool=request.tool,
|
||||
args=request.arguments,
|
||||
timeout=request.timeout,
|
||||
)
|
||||
|
||||
return ToolCallResponse(
|
||||
success=result.success,
|
||||
data=result.data,
|
||||
error=result.error,
|
||||
error_code=result.error_code,
|
||||
tool_name=result.tool_name,
|
||||
server_name=result.server_name,
|
||||
execution_time_ms=result.execution_time_ms,
|
||||
request_id=result.request_id,
|
||||
)
|
||||
|
||||
except MCPCircuitOpenError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=f"Server temporarily unavailable: {e.server_name}",
|
||||
) from e
|
||||
except MCPToolNotFoundError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Tool not found: {e.tool_name}",
|
||||
) from e
|
||||
except MCPServerNotFoundError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Server not found: {e.server_name}",
|
||||
) from e
|
||||
except MCPTimeoutError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_504_GATEWAY_TIMEOUT,
|
||||
detail=str(e),
|
||||
) from e
|
||||
except MCPConnectionError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail=str(e),
|
||||
) from e
|
||||
except MCPToolError as e:
|
||||
# Tool errors are returned in the response, not as HTTP errors
|
||||
return ToolCallResponse(
|
||||
success=False,
|
||||
error=str(e),
|
||||
error_code=e.error_code,
|
||||
tool_name=e.tool_name,
|
||||
server_name=e.server_name,
|
||||
)
|
||||
except MCPError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=str(e),
|
||||
) from e
|
||||
|
||||
|
||||
@router.get(
|
||||
"/circuit-breakers",
|
||||
response_model=CircuitBreakerListResponse,
|
||||
summary="List Circuit Breakers",
|
||||
description="Get status of all circuit breakers.",
|
||||
)
|
||||
async def list_circuit_breakers(
|
||||
mcp: MCPClientManager = Depends(get_mcp_client),
|
||||
) -> CircuitBreakerListResponse:
|
||||
"""Get status of all circuit breakers."""
|
||||
status_dict = mcp.get_circuit_breaker_status()
|
||||
|
||||
return CircuitBreakerListResponse(
|
||||
circuit_breakers=[
|
||||
CircuitBreakerStatus(
|
||||
server_name=name,
|
||||
state=info.get("state", "unknown"),
|
||||
failure_count=info.get("failure_count", 0),
|
||||
)
|
||||
for name, info in status_dict.items()
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/circuit-breakers/{server_name}/reset",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
summary="Reset Circuit Breaker (Admin Only)",
|
||||
description="Manually reset a circuit breaker for a server.",
|
||||
)
|
||||
async def reset_circuit_breaker(
|
||||
server_name: ServerNamePath,
|
||||
current_user: User = Depends(require_superuser),
|
||||
mcp: MCPClientManager = Depends(get_mcp_client),
|
||||
) -> None:
|
||||
"""Manually reset a circuit breaker."""
|
||||
logger.info(
|
||||
"Circuit breaker reset by user %s for server %s",
|
||||
current_user.id,
|
||||
server_name,
|
||||
)
|
||||
|
||||
success = await mcp.reset_circuit_breaker(server_name)
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"No circuit breaker found for server: {server_name}",
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/servers/{server_name}/reconnect",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
summary="Reconnect to Server (Admin Only)",
|
||||
description="Force reconnection to an MCP server.",
|
||||
)
|
||||
async def reconnect_server(
|
||||
server_name: ServerNamePath,
|
||||
current_user: User = Depends(require_superuser),
|
||||
mcp: MCPClientManager = Depends(get_mcp_client),
|
||||
) -> None:
|
||||
"""Force reconnection to an MCP server."""
|
||||
logger.info(
|
||||
"Reconnect requested by user %s for server %s",
|
||||
current_user.id,
|
||||
server_name,
|
||||
)
|
||||
|
||||
try:
|
||||
await mcp.disconnect(server_name)
|
||||
await mcp.connect(server_name)
|
||||
except MCPServerNotFoundError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Server not found: {server_name}",
|
||||
) from e
|
||||
except MCPConnectionError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail=f"Failed to reconnect: {e}",
|
||||
) from e
|
||||
@@ -1,366 +0,0 @@
|
||||
{
|
||||
"organizations": [
|
||||
{
|
||||
"name": "Acme Corp",
|
||||
"slug": "acme-corp",
|
||||
"description": "A leading provider of coyote-catching equipment."
|
||||
},
|
||||
{
|
||||
"name": "Globex Corporation",
|
||||
"slug": "globex",
|
||||
"description": "We own the East Coast."
|
||||
},
|
||||
{
|
||||
"name": "Soylent Corp",
|
||||
"slug": "soylent",
|
||||
"description": "Making food for the future."
|
||||
},
|
||||
{
|
||||
"name": "Initech",
|
||||
"slug": "initech",
|
||||
"description": "Software for the soul."
|
||||
},
|
||||
{
|
||||
"name": "Umbrella Corporation",
|
||||
"slug": "umbrella",
|
||||
"description": "Our business is life itself."
|
||||
},
|
||||
{
|
||||
"name": "Massive Dynamic",
|
||||
"slug": "massive-dynamic",
|
||||
"description": "What don't we do?"
|
||||
}
|
||||
],
|
||||
"users": [
|
||||
{
|
||||
"email": "demo@example.com",
|
||||
"password": "DemoPass1234!",
|
||||
"first_name": "Demo",
|
||||
"last_name": "User",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "acme-corp",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "alice@acme.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Alice",
|
||||
"last_name": "Smith",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "acme-corp",
|
||||
"role": "admin",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "bob@acme.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Bob",
|
||||
"last_name": "Jones",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "acme-corp",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "charlie@acme.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Charlie",
|
||||
"last_name": "Brown",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "acme-corp",
|
||||
"role": "member",
|
||||
"is_active": false
|
||||
},
|
||||
{
|
||||
"email": "diana@acme.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Diana",
|
||||
"last_name": "Prince",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "acme-corp",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "carol@globex.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Carol",
|
||||
"last_name": "Williams",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "globex",
|
||||
"role": "owner",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "dan@globex.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Dan",
|
||||
"last_name": "Miller",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "globex",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "ellen@globex.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Ellen",
|
||||
"last_name": "Ripley",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "globex",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "fred@globex.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Fred",
|
||||
"last_name": "Flintstone",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "globex",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "dave@soylent.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Dave",
|
||||
"last_name": "Brown",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "soylent",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "gina@soylent.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Gina",
|
||||
"last_name": "Torres",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "soylent",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "harry@soylent.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Harry",
|
||||
"last_name": "Potter",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "soylent",
|
||||
"role": "admin",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "eve@initech.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Eve",
|
||||
"last_name": "Davis",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "initech",
|
||||
"role": "admin",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "iris@initech.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Iris",
|
||||
"last_name": "West",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "initech",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "jack@initech.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Jack",
|
||||
"last_name": "Sparrow",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "initech",
|
||||
"role": "member",
|
||||
"is_active": false
|
||||
},
|
||||
{
|
||||
"email": "frank@umbrella.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Frank",
|
||||
"last_name": "Miller",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "umbrella",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "george@umbrella.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "George",
|
||||
"last_name": "Costanza",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "umbrella",
|
||||
"role": "member",
|
||||
"is_active": false
|
||||
},
|
||||
{
|
||||
"email": "kate@umbrella.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Kate",
|
||||
"last_name": "Bishop",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "umbrella",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "leo@massive.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Leo",
|
||||
"last_name": "Messi",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "massive-dynamic",
|
||||
"role": "owner",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "mary@massive.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Mary",
|
||||
"last_name": "Jane",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "massive-dynamic",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "nathan@massive.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Nathan",
|
||||
"last_name": "Drake",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "massive-dynamic",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "olivia@massive.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Olivia",
|
||||
"last_name": "Dunham",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "massive-dynamic",
|
||||
"role": "admin",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "peter@massive.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Peter",
|
||||
"last_name": "Parker",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "massive-dynamic",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "quinn@massive.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Quinn",
|
||||
"last_name": "Mallory",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "massive-dynamic",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "grace@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Grace",
|
||||
"last_name": "Hopper",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "heidi@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Heidi",
|
||||
"last_name": "Klum",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "ivan@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Ivan",
|
||||
"last_name": "Drago",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": false
|
||||
},
|
||||
{
|
||||
"email": "rachel@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Rachel",
|
||||
"last_name": "Green",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "sam@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Sam",
|
||||
"last_name": "Wilson",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "tony@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Tony",
|
||||
"last_name": "Stark",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "una@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Una",
|
||||
"last_name": "Chin-Riley",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": false
|
||||
},
|
||||
{
|
||||
"email": "victor@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Victor",
|
||||
"last_name": "Von Doom",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "wanda@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Wanda",
|
||||
"last_name": "Maximoff",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": true
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -43,6 +43,13 @@ class CRUDAgentType(CRUDBase[AgentType, AgentTypeCreate, AgentTypeUpdate]):
|
||||
mcp_servers=obj_in.mcp_servers,
|
||||
tool_permissions=obj_in.tool_permissions,
|
||||
is_active=obj_in.is_active,
|
||||
# Category and display fields
|
||||
category=obj_in.category.value if obj_in.category else None,
|
||||
icon=obj_in.icon,
|
||||
color=obj_in.color,
|
||||
sort_order=obj_in.sort_order,
|
||||
typical_tasks=obj_in.typical_tasks,
|
||||
collaboration_hints=obj_in.collaboration_hints,
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
@@ -68,6 +75,7 @@ class CRUDAgentType(CRUDBase[AgentType, AgentTypeCreate, AgentTypeUpdate]):
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
is_active: bool | None = None,
|
||||
category: str | None = None,
|
||||
search: str | None = None,
|
||||
sort_by: str = "created_at",
|
||||
sort_order: str = "desc",
|
||||
@@ -85,6 +93,9 @@ class CRUDAgentType(CRUDBase[AgentType, AgentTypeCreate, AgentTypeUpdate]):
|
||||
if is_active is not None:
|
||||
query = query.where(AgentType.is_active == is_active)
|
||||
|
||||
if category:
|
||||
query = query.where(AgentType.category == category)
|
||||
|
||||
if search:
|
||||
search_filter = or_(
|
||||
AgentType.name.ilike(f"%{search}%"),
|
||||
@@ -162,6 +173,7 @@ class CRUDAgentType(CRUDBase[AgentType, AgentTypeCreate, AgentTypeUpdate]):
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
is_active: bool | None = None,
|
||||
category: str | None = None,
|
||||
search: str | None = None,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""
|
||||
@@ -177,6 +189,7 @@ class CRUDAgentType(CRUDBase[AgentType, AgentTypeCreate, AgentTypeUpdate]):
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
is_active=is_active,
|
||||
category=category,
|
||||
search=search,
|
||||
)
|
||||
|
||||
@@ -260,6 +273,44 @@ class CRUDAgentType(CRUDBase[AgentType, AgentTypeCreate, AgentTypeUpdate]):
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_grouped_by_category(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
is_active: bool = True,
|
||||
) -> dict[str, list[AgentType]]:
|
||||
"""
|
||||
Get agent types grouped by category, sorted by sort_order within each group.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
is_active: Filter by active status (default: True)
|
||||
|
||||
Returns:
|
||||
Dictionary mapping category to list of agent types
|
||||
"""
|
||||
try:
|
||||
query = (
|
||||
select(AgentType)
|
||||
.where(AgentType.is_active == is_active)
|
||||
.order_by(AgentType.category, AgentType.sort_order, AgentType.name)
|
||||
)
|
||||
result = await db.execute(query)
|
||||
agent_types = list(result.scalars().all())
|
||||
|
||||
# Group by category
|
||||
grouped: dict[str, list[AgentType]] = {}
|
||||
for at in agent_types:
|
||||
cat: str = str(at.category) if at.category else "uncategorized"
|
||||
if cat not in grouped:
|
||||
grouped[cat] = []
|
||||
grouped[cat].append(at)
|
||||
|
||||
return grouped
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting grouped agent types: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
# Create a singleton instance for use across the application
|
||||
agent_type = CRUDAgentType(AgentType)
|
||||
|
||||
@@ -3,27 +3,48 @@
|
||||
Async database initialization script.
|
||||
|
||||
Creates the first superuser if configured and doesn't already exist.
|
||||
Seeds default agent types (production data) and demo data (when DEMO_MODE is enabled).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from datetime import UTC, date, datetime, timedelta
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy import select, text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.database import SessionLocal, engine
|
||||
from app.crud.syndarix.agent_type import agent_type as agent_type_crud
|
||||
from app.crud.user import user as user_crud
|
||||
from app.models.organization import Organization
|
||||
from app.models.syndarix import AgentInstance, AgentType, Issue, Project, Sprint
|
||||
from app.models.syndarix.enums import (
|
||||
AgentStatus,
|
||||
AutonomyLevel,
|
||||
ClientMode,
|
||||
IssuePriority,
|
||||
IssueStatus,
|
||||
IssueType,
|
||||
ProjectComplexity,
|
||||
ProjectStatus,
|
||||
SprintStatus,
|
||||
)
|
||||
from app.models.user import User
|
||||
from app.models.user_organization import UserOrganization
|
||||
from app.schemas.syndarix import AgentTypeCreate
|
||||
from app.schemas.users import UserCreate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Data file paths
|
||||
DATA_DIR = Path(__file__).parent.parent / "data"
|
||||
DEFAULT_AGENT_TYPES_PATH = DATA_DIR / "default_agent_types.json"
|
||||
DEMO_DATA_PATH = DATA_DIR / "demo_data.json"
|
||||
|
||||
|
||||
async def init_db() -> User | None:
|
||||
"""
|
||||
@@ -54,8 +75,7 @@ async def init_db() -> User | None:
|
||||
|
||||
if existing_user:
|
||||
logger.info(f"Superuser already exists: {existing_user.email}")
|
||||
return existing_user
|
||||
|
||||
else:
|
||||
# Create superuser if doesn't exist
|
||||
user_in = UserCreate(
|
||||
email=superuser_email,
|
||||
@@ -65,17 +85,19 @@ async def init_db() -> User | None:
|
||||
is_superuser=True,
|
||||
)
|
||||
|
||||
user = await user_crud.create(session, obj_in=user_in)
|
||||
existing_user = await user_crud.create(session, obj_in=user_in)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
await session.refresh(existing_user)
|
||||
logger.info(f"Created first superuser: {existing_user.email}")
|
||||
|
||||
logger.info(f"Created first superuser: {user.email}")
|
||||
# ALWAYS load default agent types (production data)
|
||||
await load_default_agent_types(session)
|
||||
|
||||
# Create demo data if in demo mode
|
||||
# Only load demo data if in demo mode
|
||||
if settings.DEMO_MODE:
|
||||
await load_demo_data(session)
|
||||
|
||||
return user
|
||||
return existing_user
|
||||
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
@@ -88,26 +110,96 @@ def _load_json_file(path: Path):
|
||||
return json.load(f)
|
||||
|
||||
|
||||
async def load_demo_data(session):
|
||||
"""Load demo data from JSON file."""
|
||||
demo_data_path = Path(__file__).parent / "core" / "demo_data.json"
|
||||
if not demo_data_path.exists():
|
||||
logger.warning(f"Demo data file not found: {demo_data_path}")
|
||||
async def load_default_agent_types(session: AsyncSession) -> None:
|
||||
"""
|
||||
Load default agent types from JSON file.
|
||||
|
||||
These are production defaults - created only if they don't exist, never overwritten.
|
||||
This allows users to customize agent types without worrying about server restarts.
|
||||
"""
|
||||
if not DEFAULT_AGENT_TYPES_PATH.exists():
|
||||
logger.warning(
|
||||
f"Default agent types file not found: {DEFAULT_AGENT_TYPES_PATH}"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
# Use asyncio.to_thread to avoid blocking the event loop
|
||||
data = await asyncio.to_thread(_load_json_file, demo_data_path)
|
||||
data = await asyncio.to_thread(_load_json_file, DEFAULT_AGENT_TYPES_PATH)
|
||||
|
||||
# Create Organizations
|
||||
org_map = {}
|
||||
for org_data in data.get("organizations", []):
|
||||
# Check if org exists
|
||||
result = await session.execute(
|
||||
text("SELECT * FROM organizations WHERE slug = :slug"),
|
||||
{"slug": org_data["slug"]},
|
||||
for agent_type_data in data:
|
||||
slug = agent_type_data["slug"]
|
||||
|
||||
# Check if agent type already exists
|
||||
existing = await agent_type_crud.get_by_slug(session, slug=slug)
|
||||
|
||||
if existing:
|
||||
logger.debug(f"Agent type already exists: {agent_type_data['name']}")
|
||||
continue
|
||||
|
||||
# Create the agent type
|
||||
agent_type_in = AgentTypeCreate(
|
||||
name=agent_type_data["name"],
|
||||
slug=slug,
|
||||
description=agent_type_data.get("description"),
|
||||
expertise=agent_type_data.get("expertise", []),
|
||||
personality_prompt=agent_type_data["personality_prompt"],
|
||||
primary_model=agent_type_data["primary_model"],
|
||||
fallback_models=agent_type_data.get("fallback_models", []),
|
||||
model_params=agent_type_data.get("model_params", {}),
|
||||
mcp_servers=agent_type_data.get("mcp_servers", []),
|
||||
tool_permissions=agent_type_data.get("tool_permissions", {}),
|
||||
is_active=agent_type_data.get("is_active", True),
|
||||
# Category and display fields
|
||||
category=agent_type_data.get("category"),
|
||||
icon=agent_type_data.get("icon", "bot"),
|
||||
color=agent_type_data.get("color", "#3B82F6"),
|
||||
sort_order=agent_type_data.get("sort_order", 0),
|
||||
typical_tasks=agent_type_data.get("typical_tasks", []),
|
||||
collaboration_hints=agent_type_data.get("collaboration_hints", []),
|
||||
)
|
||||
existing_org = result.first()
|
||||
|
||||
await agent_type_crud.create(session, obj_in=agent_type_in)
|
||||
logger.info(f"Created default agent type: {agent_type_data['name']}")
|
||||
|
||||
logger.info("Default agent types loaded successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading default agent types: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def load_demo_data(session: AsyncSession) -> None:
|
||||
"""
|
||||
Load demo data from JSON file.
|
||||
|
||||
Only runs when DEMO_MODE is enabled. Creates demo organizations, users,
|
||||
projects, sprints, agent instances, and issues.
|
||||
"""
|
||||
if not DEMO_DATA_PATH.exists():
|
||||
logger.warning(f"Demo data file not found: {DEMO_DATA_PATH}")
|
||||
return
|
||||
|
||||
try:
|
||||
data = await asyncio.to_thread(_load_json_file, DEMO_DATA_PATH)
|
||||
|
||||
# Build lookup maps for FK resolution
|
||||
org_map: dict[str, Organization] = {}
|
||||
user_map: dict[str, User] = {}
|
||||
project_map: dict[str, Project] = {}
|
||||
sprint_map: dict[str, Sprint] = {} # key: "project_slug:sprint_number"
|
||||
agent_type_map: dict[str, AgentType] = {}
|
||||
agent_instance_map: dict[
|
||||
str, AgentInstance
|
||||
] = {} # key: "project_slug:agent_name"
|
||||
|
||||
# ========================
|
||||
# 1. Create Organizations
|
||||
# ========================
|
||||
for org_data in data.get("organizations", []):
|
||||
org_result = await session.execute(
|
||||
select(Organization).where(Organization.slug == org_data["slug"])
|
||||
)
|
||||
existing_org = org_result.scalar_one_or_none()
|
||||
|
||||
if not existing_org:
|
||||
org = Organization(
|
||||
@@ -117,29 +209,20 @@ async def load_demo_data(session):
|
||||
is_active=True,
|
||||
)
|
||||
session.add(org)
|
||||
await session.flush() # Flush to get ID
|
||||
org_map[org.slug] = org
|
||||
await session.flush()
|
||||
org_map[str(org.slug)] = org
|
||||
logger.info(f"Created demo organization: {org.name}")
|
||||
else:
|
||||
# We can't easily get the ORM object from raw SQL result for map without querying again or mapping
|
||||
# So let's just query it properly if we need it for relationships
|
||||
# But for simplicity in this script, let's just assume we created it or it exists.
|
||||
# To properly map for users, we need the ID.
|
||||
# Let's use a simpler approach: just try to create, if slug conflict, skip.
|
||||
pass
|
||||
org_map[str(existing_org.slug)] = existing_org
|
||||
|
||||
# Re-query all orgs to build map for users
|
||||
result = await session.execute(select(Organization))
|
||||
orgs = result.scalars().all()
|
||||
org_map = {org.slug: org for org in orgs}
|
||||
|
||||
# Create Users
|
||||
# ========================
|
||||
# 2. Create Users
|
||||
# ========================
|
||||
for user_data in data.get("users", []):
|
||||
existing_user = await user_crud.get_by_email(
|
||||
session, email=user_data["email"]
|
||||
)
|
||||
if not existing_user:
|
||||
# Create user
|
||||
user_in = UserCreate(
|
||||
email=user_data["email"],
|
||||
password=user_data["password"],
|
||||
@@ -151,17 +234,13 @@ async def load_demo_data(session):
|
||||
user = await user_crud.create(session, obj_in=user_in)
|
||||
|
||||
# Randomize created_at for demo data (last 30 days)
|
||||
# This makes the charts look more realistic
|
||||
days_ago = random.randint(0, 30) # noqa: S311
|
||||
random_time = datetime.now(UTC) - timedelta(days=days_ago)
|
||||
# Add some random hours/minutes variation
|
||||
random_time = random_time.replace(
|
||||
hour=random.randint(0, 23), # noqa: S311
|
||||
minute=random.randint(0, 59), # noqa: S311
|
||||
)
|
||||
|
||||
# Update the timestamp and is_active directly in the database
|
||||
# We do this to ensure the values are persisted correctly
|
||||
await session.execute(
|
||||
text(
|
||||
"UPDATE users SET created_at = :created_at, is_active = :is_active WHERE id = :user_id"
|
||||
@@ -174,7 +253,7 @@ async def load_demo_data(session):
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Created demo user: {user.email} (created {days_ago} days ago, active={user_data.get('is_active', True)})"
|
||||
f"Created demo user: {user.email} (created {days_ago} days ago)"
|
||||
)
|
||||
|
||||
# Add to organization if specified
|
||||
@@ -182,19 +261,228 @@ async def load_demo_data(session):
|
||||
role = user_data.get("role")
|
||||
if org_slug and org_slug in org_map and role:
|
||||
org = org_map[org_slug]
|
||||
# Check if membership exists (it shouldn't for new user)
|
||||
member = UserOrganization(
|
||||
user_id=user.id, organization_id=org.id, role=role
|
||||
)
|
||||
session.add(member)
|
||||
logger.info(f"Added {user.email} to {org.name} as {role}")
|
||||
|
||||
user_map[str(user.email)] = user
|
||||
else:
|
||||
logger.info(f"Demo user already exists: {existing_user.email}")
|
||||
user_map[str(existing_user.email)] = existing_user
|
||||
logger.debug(f"Demo user already exists: {existing_user.email}")
|
||||
|
||||
await session.flush()
|
||||
|
||||
# Add admin user to map with special "__admin__" key
|
||||
# This allows demo data to reference the admin user as owner
|
||||
superuser_email = settings.FIRST_SUPERUSER_EMAIL or "admin@example.com"
|
||||
admin_user = await user_crud.get_by_email(session, email=superuser_email)
|
||||
if admin_user:
|
||||
user_map["__admin__"] = admin_user
|
||||
user_map[str(admin_user.email)] = admin_user
|
||||
logger.debug(f"Added admin user to map: {admin_user.email}")
|
||||
|
||||
# ========================
|
||||
# 3. Load Agent Types Map (for FK resolution)
|
||||
# ========================
|
||||
agent_types_result = await session.execute(select(AgentType))
|
||||
for at in agent_types_result.scalars().all():
|
||||
agent_type_map[str(at.slug)] = at
|
||||
|
||||
# ========================
|
||||
# 4. Create Projects
|
||||
# ========================
|
||||
for project_data in data.get("projects", []):
|
||||
project_result = await session.execute(
|
||||
select(Project).where(Project.slug == project_data["slug"])
|
||||
)
|
||||
existing_project = project_result.scalar_one_or_none()
|
||||
|
||||
if not existing_project:
|
||||
# Resolve owner email to user ID
|
||||
owner_id = None
|
||||
owner_email = project_data.get("owner_email")
|
||||
if owner_email and owner_email in user_map:
|
||||
owner_id = user_map[owner_email].id
|
||||
|
||||
project = Project(
|
||||
name=project_data["name"],
|
||||
slug=project_data["slug"],
|
||||
description=project_data.get("description"),
|
||||
owner_id=owner_id,
|
||||
autonomy_level=AutonomyLevel(
|
||||
project_data.get("autonomy_level", "milestone")
|
||||
),
|
||||
status=ProjectStatus(project_data.get("status", "active")),
|
||||
complexity=ProjectComplexity(
|
||||
project_data.get("complexity", "medium")
|
||||
),
|
||||
client_mode=ClientMode(project_data.get("client_mode", "auto")),
|
||||
settings=project_data.get("settings", {}),
|
||||
)
|
||||
session.add(project)
|
||||
await session.flush()
|
||||
project_map[str(project.slug)] = project
|
||||
logger.info(f"Created demo project: {project.name}")
|
||||
else:
|
||||
project_map[str(existing_project.slug)] = existing_project
|
||||
logger.debug(f"Demo project already exists: {existing_project.name}")
|
||||
|
||||
# ========================
|
||||
# 5. Create Sprints
|
||||
# ========================
|
||||
for sprint_data in data.get("sprints", []):
|
||||
project_slug = sprint_data["project_slug"]
|
||||
sprint_number = sprint_data["number"]
|
||||
sprint_key = f"{project_slug}:{sprint_number}"
|
||||
|
||||
if project_slug not in project_map:
|
||||
logger.warning(f"Project not found for sprint: {project_slug}")
|
||||
continue
|
||||
|
||||
sprint_project = project_map[project_slug]
|
||||
|
||||
# Check if sprint exists
|
||||
sprint_result = await session.execute(
|
||||
select(Sprint).where(
|
||||
Sprint.project_id == sprint_project.id,
|
||||
Sprint.number == sprint_number,
|
||||
)
|
||||
)
|
||||
existing_sprint = sprint_result.scalar_one_or_none()
|
||||
|
||||
if not existing_sprint:
|
||||
sprint = Sprint(
|
||||
project_id=sprint_project.id,
|
||||
name=sprint_data["name"],
|
||||
number=sprint_number,
|
||||
goal=sprint_data.get("goal"),
|
||||
start_date=date.fromisoformat(sprint_data["start_date"]),
|
||||
end_date=date.fromisoformat(sprint_data["end_date"]),
|
||||
status=SprintStatus(sprint_data.get("status", "planned")),
|
||||
planned_points=sprint_data.get("planned_points"),
|
||||
)
|
||||
session.add(sprint)
|
||||
await session.flush()
|
||||
sprint_map[sprint_key] = sprint
|
||||
logger.info(
|
||||
f"Created demo sprint: {sprint.name} for {sprint_project.name}"
|
||||
)
|
||||
else:
|
||||
sprint_map[sprint_key] = existing_sprint
|
||||
logger.debug(f"Demo sprint already exists: {existing_sprint.name}")
|
||||
|
||||
# ========================
|
||||
# 6. Create Agent Instances
|
||||
# ========================
|
||||
for agent_data in data.get("agent_instances", []):
|
||||
project_slug = agent_data["project_slug"]
|
||||
agent_type_slug = agent_data["agent_type_slug"]
|
||||
agent_name = agent_data["name"]
|
||||
agent_key = f"{project_slug}:{agent_name}"
|
||||
|
||||
if project_slug not in project_map:
|
||||
logger.warning(f"Project not found for agent: {project_slug}")
|
||||
continue
|
||||
|
||||
if agent_type_slug not in agent_type_map:
|
||||
logger.warning(f"Agent type not found: {agent_type_slug}")
|
||||
continue
|
||||
|
||||
agent_project = project_map[project_slug]
|
||||
agent_type = agent_type_map[agent_type_slug]
|
||||
|
||||
# Check if agent instance exists (by name within project)
|
||||
agent_result = await session.execute(
|
||||
select(AgentInstance).where(
|
||||
AgentInstance.project_id == agent_project.id,
|
||||
AgentInstance.name == agent_name,
|
||||
)
|
||||
)
|
||||
existing_agent = agent_result.scalar_one_or_none()
|
||||
|
||||
if not existing_agent:
|
||||
agent_instance = AgentInstance(
|
||||
project_id=agent_project.id,
|
||||
agent_type_id=agent_type.id,
|
||||
name=agent_name,
|
||||
status=AgentStatus(agent_data.get("status", "idle")),
|
||||
current_task=agent_data.get("current_task"),
|
||||
)
|
||||
session.add(agent_instance)
|
||||
await session.flush()
|
||||
agent_instance_map[agent_key] = agent_instance
|
||||
logger.info(
|
||||
f"Created demo agent: {agent_name} ({agent_type.name}) "
|
||||
f"for {agent_project.name}"
|
||||
)
|
||||
else:
|
||||
agent_instance_map[agent_key] = existing_agent
|
||||
logger.debug(f"Demo agent already exists: {existing_agent.name}")
|
||||
|
||||
# ========================
|
||||
# 7. Create Issues
|
||||
# ========================
|
||||
for issue_data in data.get("issues", []):
|
||||
project_slug = issue_data["project_slug"]
|
||||
|
||||
if project_slug not in project_map:
|
||||
logger.warning(f"Project not found for issue: {project_slug}")
|
||||
continue
|
||||
|
||||
issue_project = project_map[project_slug]
|
||||
|
||||
# Check if issue exists (by title within project - simple heuristic)
|
||||
issue_result = await session.execute(
|
||||
select(Issue).where(
|
||||
Issue.project_id == issue_project.id,
|
||||
Issue.title == issue_data["title"],
|
||||
)
|
||||
)
|
||||
existing_issue = issue_result.scalar_one_or_none()
|
||||
|
||||
if not existing_issue:
|
||||
# Resolve sprint
|
||||
sprint_id = None
|
||||
sprint_number = issue_data.get("sprint_number")
|
||||
if sprint_number:
|
||||
sprint_key = f"{project_slug}:{sprint_number}"
|
||||
if sprint_key in sprint_map:
|
||||
sprint_id = sprint_map[sprint_key].id
|
||||
|
||||
# Resolve assigned agent
|
||||
assigned_agent_id = None
|
||||
assigned_agent_name = issue_data.get("assigned_agent_name")
|
||||
if assigned_agent_name:
|
||||
agent_key = f"{project_slug}:{assigned_agent_name}"
|
||||
if agent_key in agent_instance_map:
|
||||
assigned_agent_id = agent_instance_map[agent_key].id
|
||||
|
||||
issue = Issue(
|
||||
project_id=issue_project.id,
|
||||
sprint_id=sprint_id,
|
||||
type=IssueType(issue_data.get("type", "task")),
|
||||
title=issue_data["title"],
|
||||
body=issue_data.get("body", ""),
|
||||
status=IssueStatus(issue_data.get("status", "open")),
|
||||
priority=IssuePriority(issue_data.get("priority", "medium")),
|
||||
labels=issue_data.get("labels", []),
|
||||
story_points=issue_data.get("story_points"),
|
||||
assigned_agent_id=assigned_agent_id,
|
||||
)
|
||||
session.add(issue)
|
||||
logger.info(f"Created demo issue: {issue.title[:50]}...")
|
||||
else:
|
||||
logger.debug(
|
||||
f"Demo issue already exists: {existing_issue.title[:50]}..."
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
logger.info("Demo data loaded successfully")
|
||||
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.error(f"Error loading demo data: {e}")
|
||||
raise
|
||||
|
||||
@@ -210,12 +498,12 @@ async def main():
|
||||
try:
|
||||
user = await init_db()
|
||||
if user:
|
||||
print("✓ Database initialized successfully")
|
||||
print(f"✓ Superuser: {user.email}")
|
||||
print("Database initialized successfully")
|
||||
print(f"Superuser: {user.email}")
|
||||
else:
|
||||
print("✗ Failed to initialize database")
|
||||
print("Failed to initialize database")
|
||||
except Exception as e:
|
||||
print(f"✗ Error initializing database: {e}")
|
||||
print(f"Error initializing database: {e}")
|
||||
raise
|
||||
finally:
|
||||
# Close the engine
|
||||
|
||||
@@ -8,6 +8,19 @@ from app.core.database import Base
|
||||
|
||||
from .base import TimestampMixin, UUIDMixin
|
||||
|
||||
# Memory system models
|
||||
from .memory import (
|
||||
ConsolidationStatus,
|
||||
ConsolidationType,
|
||||
Episode,
|
||||
EpisodeOutcome,
|
||||
Fact,
|
||||
MemoryConsolidationLog,
|
||||
Procedure,
|
||||
ScopeType,
|
||||
WorkingMemory,
|
||||
)
|
||||
|
||||
# OAuth models (client mode - authenticate via Google/GitHub)
|
||||
from .oauth_account import OAuthAccount
|
||||
|
||||
@@ -37,7 +50,14 @@ __all__ = [
|
||||
"AgentInstance",
|
||||
"AgentType",
|
||||
"Base",
|
||||
# Memory models
|
||||
"ConsolidationStatus",
|
||||
"ConsolidationType",
|
||||
"Episode",
|
||||
"EpisodeOutcome",
|
||||
"Fact",
|
||||
"Issue",
|
||||
"MemoryConsolidationLog",
|
||||
"OAuthAccount",
|
||||
"OAuthAuthorizationCode",
|
||||
"OAuthClient",
|
||||
@@ -46,11 +66,14 @@ __all__ = [
|
||||
"OAuthState",
|
||||
"Organization",
|
||||
"OrganizationRole",
|
||||
"Procedure",
|
||||
"Project",
|
||||
"ScopeType",
|
||||
"Sprint",
|
||||
"TimestampMixin",
|
||||
"UUIDMixin",
|
||||
"User",
|
||||
"UserOrganization",
|
||||
"UserSession",
|
||||
"WorkingMemory",
|
||||
]
|
||||
|
||||
32
backend/app/models/memory/__init__.py
Normal file
32
backend/app/models/memory/__init__.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# app/models/memory/__init__.py
|
||||
"""
|
||||
Memory System Database Models.
|
||||
|
||||
Provides SQLAlchemy models for the Agent Memory System:
|
||||
- WorkingMemory: Key-value storage with TTL
|
||||
- Episode: Experiential memories
|
||||
- Fact: Semantic knowledge triples
|
||||
- Procedure: Learned skills
|
||||
- MemoryConsolidationLog: Consolidation job tracking
|
||||
"""
|
||||
|
||||
from .consolidation import MemoryConsolidationLog
|
||||
from .enums import ConsolidationStatus, ConsolidationType, EpisodeOutcome, ScopeType
|
||||
from .episode import Episode
|
||||
from .fact import Fact
|
||||
from .procedure import Procedure
|
||||
from .working_memory import WorkingMemory
|
||||
|
||||
__all__ = [
|
||||
# Enums
|
||||
"ConsolidationStatus",
|
||||
"ConsolidationType",
|
||||
# Models
|
||||
"Episode",
|
||||
"EpisodeOutcome",
|
||||
"Fact",
|
||||
"MemoryConsolidationLog",
|
||||
"Procedure",
|
||||
"ScopeType",
|
||||
"WorkingMemory",
|
||||
]
|
||||
72
backend/app/models/memory/consolidation.py
Normal file
72
backend/app/models/memory/consolidation.py
Normal file
@@ -0,0 +1,72 @@
|
||||
# app/models/memory/consolidation.py
|
||||
"""
|
||||
Memory Consolidation Log database model.
|
||||
|
||||
Tracks memory consolidation jobs that transfer knowledge
|
||||
between memory tiers.
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, DateTime, Enum, Index, Integer, Text
|
||||
|
||||
from app.models.base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
from .enums import ConsolidationStatus, ConsolidationType
|
||||
|
||||
|
||||
class MemoryConsolidationLog(Base, UUIDMixin, TimestampMixin):
|
||||
"""
|
||||
Memory consolidation job log.
|
||||
|
||||
Tracks consolidation operations:
|
||||
- Working -> Episodic (session end)
|
||||
- Episodic -> Semantic (fact extraction)
|
||||
- Episodic -> Procedural (procedure learning)
|
||||
- Pruning (removing low-value memories)
|
||||
"""
|
||||
|
||||
__tablename__ = "memory_consolidation_log"
|
||||
|
||||
# Consolidation type
|
||||
consolidation_type: Column[ConsolidationType] = Column(
|
||||
Enum(ConsolidationType),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Counts
|
||||
source_count = Column(Integer, nullable=False, default=0)
|
||||
result_count = Column(Integer, nullable=False, default=0)
|
||||
|
||||
# Timing
|
||||
started_at = Column(DateTime(timezone=True), nullable=False)
|
||||
completed_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Status
|
||||
status: Column[ConsolidationStatus] = Column(
|
||||
Enum(ConsolidationStatus),
|
||||
nullable=False,
|
||||
default=ConsolidationStatus.PENDING,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Error details if failed
|
||||
error = Column(Text, nullable=True)
|
||||
|
||||
__table_args__ = (
|
||||
# Query patterns
|
||||
Index("ix_consolidation_type_status", "consolidation_type", "status"),
|
||||
Index("ix_consolidation_started", "started_at"),
|
||||
)
|
||||
|
||||
@property
|
||||
def duration_seconds(self) -> float | None:
|
||||
"""Calculate duration of the consolidation job."""
|
||||
if self.completed_at is None or self.started_at is None:
|
||||
return None
|
||||
return (self.completed_at - self.started_at).total_seconds()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<MemoryConsolidationLog {self.id} "
|
||||
f"type={self.consolidation_type.value} status={self.status.value}>"
|
||||
)
|
||||
73
backend/app/models/memory/enums.py
Normal file
73
backend/app/models/memory/enums.py
Normal file
@@ -0,0 +1,73 @@
|
||||
# app/models/memory/enums.py
|
||||
"""
|
||||
Enums for Memory System database models.
|
||||
|
||||
These enums define the database-level constraints for memory types
|
||||
and scoping levels.
|
||||
"""
|
||||
|
||||
from enum import Enum as PyEnum
|
||||
|
||||
|
||||
class ScopeType(str, PyEnum):
|
||||
"""
|
||||
Memory scope levels matching the memory service types.
|
||||
|
||||
GLOBAL: System-wide memories accessible by all
|
||||
PROJECT: Project-scoped memories
|
||||
AGENT_TYPE: Type-specific memories (shared by instances of same type)
|
||||
AGENT_INSTANCE: Instance-specific memories
|
||||
SESSION: Session-scoped ephemeral memories
|
||||
"""
|
||||
|
||||
GLOBAL = "global"
|
||||
PROJECT = "project"
|
||||
AGENT_TYPE = "agent_type"
|
||||
AGENT_INSTANCE = "agent_instance"
|
||||
SESSION = "session"
|
||||
|
||||
|
||||
class EpisodeOutcome(str, PyEnum):
|
||||
"""
|
||||
Outcome of an episode (task execution).
|
||||
|
||||
SUCCESS: Task completed successfully
|
||||
FAILURE: Task failed
|
||||
PARTIAL: Task partially completed
|
||||
"""
|
||||
|
||||
SUCCESS = "success"
|
||||
FAILURE = "failure"
|
||||
PARTIAL = "partial"
|
||||
|
||||
|
||||
class ConsolidationType(str, PyEnum):
|
||||
"""
|
||||
Types of memory consolidation operations.
|
||||
|
||||
WORKING_TO_EPISODIC: Transfer session state to episodic
|
||||
EPISODIC_TO_SEMANTIC: Extract facts from episodes
|
||||
EPISODIC_TO_PROCEDURAL: Extract procedures from episodes
|
||||
PRUNING: Remove low-value memories
|
||||
"""
|
||||
|
||||
WORKING_TO_EPISODIC = "working_to_episodic"
|
||||
EPISODIC_TO_SEMANTIC = "episodic_to_semantic"
|
||||
EPISODIC_TO_PROCEDURAL = "episodic_to_procedural"
|
||||
PRUNING = "pruning"
|
||||
|
||||
|
||||
class ConsolidationStatus(str, PyEnum):
|
||||
"""
|
||||
Status of a consolidation job.
|
||||
|
||||
PENDING: Job is queued
|
||||
RUNNING: Job is currently executing
|
||||
COMPLETED: Job finished successfully
|
||||
FAILED: Job failed with errors
|
||||
"""
|
||||
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
139
backend/app/models/memory/episode.py
Normal file
139
backend/app/models/memory/episode.py
Normal file
@@ -0,0 +1,139 @@
|
||||
# app/models/memory/episode.py
|
||||
"""
|
||||
Episode database model.
|
||||
|
||||
Stores experiential memories - records of past task executions
|
||||
with context, actions, outcomes, and lessons learned.
|
||||
"""
|
||||
|
||||
from sqlalchemy import (
|
||||
BigInteger,
|
||||
CheckConstraint,
|
||||
Column,
|
||||
DateTime,
|
||||
Enum,
|
||||
Float,
|
||||
ForeignKey,
|
||||
Index,
|
||||
String,
|
||||
Text,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import (
|
||||
JSONB,
|
||||
UUID as PGUUID,
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.models.base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
from .enums import EpisodeOutcome
|
||||
|
||||
# Import pgvector type - will be available after migration enables extension
|
||||
try:
|
||||
from pgvector.sqlalchemy import Vector # type: ignore[import-not-found]
|
||||
except ImportError:
|
||||
# Fallback for environments without pgvector
|
||||
Vector = None
|
||||
|
||||
|
||||
class Episode(Base, UUIDMixin, TimestampMixin):
|
||||
"""
|
||||
Episodic memory model.
|
||||
|
||||
Records experiential memories from agent task execution:
|
||||
- What task was performed
|
||||
- What actions were taken
|
||||
- What was the outcome
|
||||
- What lessons were learned
|
||||
"""
|
||||
|
||||
__tablename__ = "episodes"
|
||||
|
||||
# Foreign keys
|
||||
project_id = Column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("projects.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
agent_instance_id = Column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("agent_instances.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
|
||||
agent_type_id = Column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("agent_types.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Session reference
|
||||
session_id = Column(String(255), nullable=False, index=True)
|
||||
|
||||
# Task information
|
||||
task_type = Column(String(100), nullable=False, index=True)
|
||||
task_description = Column(Text, nullable=False)
|
||||
|
||||
# Actions taken (list of action dictionaries)
|
||||
actions = Column(JSONB, default=list, nullable=False)
|
||||
|
||||
# Context summary
|
||||
context_summary = Column(Text, nullable=False)
|
||||
|
||||
# Outcome
|
||||
outcome: Column[EpisodeOutcome] = Column(
|
||||
Enum(EpisodeOutcome),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
outcome_details = Column(Text, nullable=True)
|
||||
|
||||
# Metrics
|
||||
duration_seconds = Column(Float, nullable=False, default=0.0)
|
||||
tokens_used = Column(BigInteger, nullable=False, default=0)
|
||||
|
||||
# Learning
|
||||
lessons_learned = Column(JSONB, default=list, nullable=False)
|
||||
importance_score = Column(Float, nullable=False, default=0.5, index=True)
|
||||
|
||||
# Vector embedding for semantic search
|
||||
# Using 1536 dimensions for OpenAI text-embedding-3-small
|
||||
embedding = Column(Vector(1536) if Vector else Text, nullable=True)
|
||||
|
||||
# When the episode occurred
|
||||
occurred_at = Column(DateTime(timezone=True), nullable=False, index=True)
|
||||
|
||||
# Relationships
|
||||
project = relationship("Project", foreign_keys=[project_id])
|
||||
agent_instance = relationship("AgentInstance", foreign_keys=[agent_instance_id])
|
||||
agent_type = relationship("AgentType", foreign_keys=[agent_type_id])
|
||||
|
||||
__table_args__ = (
|
||||
# Primary query patterns
|
||||
Index("ix_episodes_project_task", "project_id", "task_type"),
|
||||
Index("ix_episodes_project_outcome", "project_id", "outcome"),
|
||||
Index("ix_episodes_agent_task", "agent_instance_id", "task_type"),
|
||||
Index("ix_episodes_project_time", "project_id", "occurred_at"),
|
||||
# For importance-based pruning
|
||||
Index("ix_episodes_importance_time", "importance_score", "occurred_at"),
|
||||
# Data integrity constraints
|
||||
CheckConstraint(
|
||||
"importance_score >= 0.0 AND importance_score <= 1.0",
|
||||
name="ck_episodes_importance_range",
|
||||
),
|
||||
CheckConstraint(
|
||||
"duration_seconds >= 0.0",
|
||||
name="ck_episodes_duration_positive",
|
||||
),
|
||||
CheckConstraint(
|
||||
"tokens_used >= 0",
|
||||
name="ck_episodes_tokens_positive",
|
||||
),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Episode {self.id} task={self.task_type} outcome={self.outcome.value}>"
|
||||
120
backend/app/models/memory/fact.py
Normal file
120
backend/app/models/memory/fact.py
Normal file
@@ -0,0 +1,120 @@
|
||||
# app/models/memory/fact.py
|
||||
"""
|
||||
Fact database model.
|
||||
|
||||
Stores semantic memories - learned facts in subject-predicate-object
|
||||
triple format with confidence scores and source tracking.
|
||||
"""
|
||||
|
||||
from sqlalchemy import (
|
||||
CheckConstraint,
|
||||
Column,
|
||||
DateTime,
|
||||
Float,
|
||||
ForeignKey,
|
||||
Index,
|
||||
Integer,
|
||||
String,
|
||||
Text,
|
||||
text,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import (
|
||||
JSONB,
|
||||
UUID as PGUUID,
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.models.base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
# Import pgvector type
|
||||
try:
|
||||
from pgvector.sqlalchemy import Vector # type: ignore[import-not-found]
|
||||
except ImportError:
|
||||
Vector = None
|
||||
|
||||
|
||||
class Fact(Base, UUIDMixin, TimestampMixin):
|
||||
"""
|
||||
Semantic memory model.
|
||||
|
||||
Stores learned facts as subject-predicate-object triples:
|
||||
- "FastAPI" - "uses" - "Starlette framework"
|
||||
- "Project Alpha" - "requires" - "OAuth authentication"
|
||||
|
||||
Facts have confidence scores that decay over time and can be
|
||||
reinforced when the same fact is learned again.
|
||||
"""
|
||||
|
||||
__tablename__ = "facts"
|
||||
|
||||
# Scoping: project_id is NULL for global facts
|
||||
project_id = Column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("projects.id", ondelete="CASCADE"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Triple format
|
||||
subject = Column(String(500), nullable=False, index=True)
|
||||
predicate = Column(String(255), nullable=False, index=True)
|
||||
object = Column(Text, nullable=False)
|
||||
|
||||
# Confidence score (0.0 to 1.0)
|
||||
confidence = Column(Float, nullable=False, default=0.8, index=True)
|
||||
|
||||
# Source tracking: which episodes contributed to this fact (stored as JSONB array of UUID strings)
|
||||
source_episode_ids: Column[list] = Column(JSONB, default=list, nullable=False)
|
||||
|
||||
# Learning history
|
||||
first_learned = Column(DateTime(timezone=True), nullable=False)
|
||||
last_reinforced = Column(DateTime(timezone=True), nullable=False)
|
||||
reinforcement_count = Column(Integer, nullable=False, default=1)
|
||||
|
||||
# Vector embedding for semantic search
|
||||
embedding = Column(Vector(1536) if Vector else Text, nullable=True)
|
||||
|
||||
# Relationships
|
||||
project = relationship("Project", foreign_keys=[project_id])
|
||||
|
||||
__table_args__ = (
|
||||
# Unique constraint on triple within project scope
|
||||
Index(
|
||||
"ix_facts_unique_triple",
|
||||
"project_id",
|
||||
"subject",
|
||||
"predicate",
|
||||
"object",
|
||||
unique=True,
|
||||
postgresql_where=text("project_id IS NOT NULL"),
|
||||
),
|
||||
# Unique constraint on triple for global facts (project_id IS NULL)
|
||||
Index(
|
||||
"ix_facts_unique_triple_global",
|
||||
"subject",
|
||||
"predicate",
|
||||
"object",
|
||||
unique=True,
|
||||
postgresql_where=text("project_id IS NULL"),
|
||||
),
|
||||
# Query patterns
|
||||
Index("ix_facts_subject_predicate", "subject", "predicate"),
|
||||
Index("ix_facts_project_subject", "project_id", "subject"),
|
||||
Index("ix_facts_confidence_time", "confidence", "last_reinforced"),
|
||||
# Note: subject already has index=True on Column definition, no need for explicit index
|
||||
# Data integrity constraints
|
||||
CheckConstraint(
|
||||
"confidence >= 0.0 AND confidence <= 1.0",
|
||||
name="ck_facts_confidence_range",
|
||||
),
|
||||
CheckConstraint(
|
||||
"reinforcement_count >= 1",
|
||||
name="ck_facts_reinforcement_positive",
|
||||
),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<Fact {self.id} '{self.subject}' - '{self.predicate}' - "
|
||||
f"'{self.object[:50]}...' conf={self.confidence:.2f}>"
|
||||
)
|
||||
129
backend/app/models/memory/procedure.py
Normal file
129
backend/app/models/memory/procedure.py
Normal file
@@ -0,0 +1,129 @@
|
||||
# app/models/memory/procedure.py
|
||||
"""
|
||||
Procedure database model.
|
||||
|
||||
Stores procedural memories - learned skills and procedures
|
||||
derived from successful task execution patterns.
|
||||
"""
|
||||
|
||||
from sqlalchemy import (
|
||||
CheckConstraint,
|
||||
Column,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Index,
|
||||
Integer,
|
||||
String,
|
||||
Text,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import (
|
||||
JSONB,
|
||||
UUID as PGUUID,
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.models.base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
# Import pgvector type
|
||||
try:
|
||||
from pgvector.sqlalchemy import Vector # type: ignore[import-not-found]
|
||||
except ImportError:
|
||||
Vector = None
|
||||
|
||||
|
||||
class Procedure(Base, UUIDMixin, TimestampMixin):
|
||||
"""
|
||||
Procedural memory model.
|
||||
|
||||
Stores learned procedures (skills) extracted from successful
|
||||
task execution patterns:
|
||||
- Name and trigger pattern for matching
|
||||
- Step-by-step actions
|
||||
- Success/failure tracking
|
||||
"""
|
||||
|
||||
__tablename__ = "procedures"
|
||||
|
||||
# Scoping
|
||||
project_id = Column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("projects.id", ondelete="CASCADE"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
|
||||
agent_type_id = Column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("agent_types.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Procedure identification
|
||||
name = Column(String(255), nullable=False, index=True)
|
||||
trigger_pattern = Column(Text, nullable=False)
|
||||
|
||||
# Steps as JSON array of step objects
|
||||
# Each step: {order, action, parameters, expected_outcome, fallback_action}
|
||||
steps = Column(JSONB, default=list, nullable=False)
|
||||
|
||||
# Success tracking
|
||||
success_count = Column(Integer, nullable=False, default=0)
|
||||
failure_count = Column(Integer, nullable=False, default=0)
|
||||
|
||||
# Usage tracking
|
||||
last_used = Column(DateTime(timezone=True), nullable=True, index=True)
|
||||
|
||||
# Vector embedding for semantic matching
|
||||
embedding = Column(Vector(1536) if Vector else Text, nullable=True)
|
||||
|
||||
# Relationships
|
||||
project = relationship("Project", foreign_keys=[project_id])
|
||||
agent_type = relationship("AgentType", foreign_keys=[agent_type_id])
|
||||
|
||||
__table_args__ = (
|
||||
# Unique procedure name within scope
|
||||
Index(
|
||||
"ix_procedures_unique_name",
|
||||
"project_id",
|
||||
"agent_type_id",
|
||||
"name",
|
||||
unique=True,
|
||||
),
|
||||
# Query patterns
|
||||
Index("ix_procedures_project_name", "project_id", "name"),
|
||||
# Note: agent_type_id already has index=True on Column definition
|
||||
# For finding best procedures
|
||||
Index("ix_procedures_success_rate", "success_count", "failure_count"),
|
||||
# Data integrity constraints
|
||||
CheckConstraint(
|
||||
"success_count >= 0",
|
||||
name="ck_procedures_success_positive",
|
||||
),
|
||||
CheckConstraint(
|
||||
"failure_count >= 0",
|
||||
name="ck_procedures_failure_positive",
|
||||
),
|
||||
)
|
||||
|
||||
@property
|
||||
def success_rate(self) -> float:
|
||||
"""Calculate the success rate of this procedure."""
|
||||
# Snapshot values to avoid race conditions in concurrent access
|
||||
success = self.success_count
|
||||
failure = self.failure_count
|
||||
total = success + failure
|
||||
if total == 0:
|
||||
return 0.0
|
||||
return success / total
|
||||
|
||||
@property
|
||||
def total_uses(self) -> int:
|
||||
"""Get total number of times this procedure was used."""
|
||||
# Snapshot values for consistency
|
||||
return self.success_count + self.failure_count
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<Procedure {self.name} ({self.id}) success_rate={self.success_rate:.2%}>"
|
||||
)
|
||||
58
backend/app/models/memory/working_memory.py
Normal file
58
backend/app/models/memory/working_memory.py
Normal file
@@ -0,0 +1,58 @@
|
||||
# app/models/memory/working_memory.py
|
||||
"""
|
||||
Working Memory database model.
|
||||
|
||||
Stores ephemeral key-value data for active sessions with TTL support.
|
||||
Used as database backup when Redis is unavailable.
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, DateTime, Enum, Index, String
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
|
||||
from app.models.base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
from .enums import ScopeType
|
||||
|
||||
|
||||
class WorkingMemory(Base, UUIDMixin, TimestampMixin):
|
||||
"""
|
||||
Working memory storage table.
|
||||
|
||||
Provides database-backed working memory as fallback when
|
||||
Redis is unavailable. Supports TTL-based expiration.
|
||||
"""
|
||||
|
||||
__tablename__ = "working_memory"
|
||||
|
||||
# Scoping
|
||||
scope_type: Column[ScopeType] = Column(
|
||||
Enum(ScopeType),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
scope_id = Column(String(255), nullable=False, index=True)
|
||||
|
||||
# Key-value storage
|
||||
key = Column(String(255), nullable=False)
|
||||
value = Column(JSONB, nullable=False)
|
||||
|
||||
# TTL support
|
||||
expires_at = Column(DateTime(timezone=True), nullable=True, index=True)
|
||||
|
||||
__table_args__ = (
|
||||
# Primary lookup: scope + key
|
||||
Index(
|
||||
"ix_working_memory_scope_key",
|
||||
"scope_type",
|
||||
"scope_id",
|
||||
"key",
|
||||
unique=True,
|
||||
),
|
||||
# For cleanup of expired entries
|
||||
Index("ix_working_memory_expires", "expires_at"),
|
||||
# For listing all keys in a scope
|
||||
Index("ix_working_memory_scope_list", "scope_type", "scope_id"),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<WorkingMemory {self.scope_type.value}:{self.scope_id}:{self.key}>"
|
||||
@@ -62,7 +62,11 @@ class AgentInstance(Base, UUIDMixin, TimestampMixin):
|
||||
|
||||
# Status tracking
|
||||
status: Column[AgentStatus] = Column(
|
||||
Enum(AgentStatus),
|
||||
Enum(
|
||||
AgentStatus,
|
||||
name="agent_status",
|
||||
values_callable=lambda x: [e.value for e in x],
|
||||
),
|
||||
default=AgentStatus.IDLE,
|
||||
nullable=False,
|
||||
index=True,
|
||||
|
||||
@@ -6,7 +6,7 @@ An AgentType is a template that defines the capabilities, personality,
|
||||
and model configuration for agent instances.
|
||||
"""
|
||||
|
||||
from sqlalchemy import Boolean, Column, Index, String, Text
|
||||
from sqlalchemy import Boolean, Column, Index, Integer, String, Text
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
@@ -56,6 +56,24 @@ class AgentType(Base, UUIDMixin, TimestampMixin):
|
||||
# Whether this agent type is available for new instances
|
||||
is_active = Column(Boolean, default=True, nullable=False, index=True)
|
||||
|
||||
# Category for grouping agents (development, design, quality, etc.)
|
||||
category = Column(String(50), nullable=True, index=True)
|
||||
|
||||
# Lucide icon identifier for UI display (e.g., "code", "palette", "shield")
|
||||
icon = Column(String(50), nullable=True, default="bot")
|
||||
|
||||
# Hex color code for visual distinction (e.g., "#3B82F6")
|
||||
color = Column(String(7), nullable=True, default="#3B82F6")
|
||||
|
||||
# Display ordering within category (lower = first)
|
||||
sort_order = Column(Integer, nullable=False, default=0, index=True)
|
||||
|
||||
# List of typical tasks this agent excels at
|
||||
typical_tasks = Column(JSONB, default=list, nullable=False)
|
||||
|
||||
# List of agent slugs that collaborate well with this type
|
||||
collaboration_hints = Column(JSONB, default=list, nullable=False)
|
||||
|
||||
# Relationships
|
||||
instances = relationship(
|
||||
"AgentInstance",
|
||||
@@ -66,6 +84,7 @@ class AgentType(Base, UUIDMixin, TimestampMixin):
|
||||
__table_args__ = (
|
||||
Index("ix_agent_types_slug_active", "slug", "is_active"),
|
||||
Index("ix_agent_types_name_active", "name", "is_active"),
|
||||
Index("ix_agent_types_category_sort", "category", "sort_order"),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
|
||||
@@ -167,3 +167,29 @@ class SprintStatus(str, PyEnum):
|
||||
IN_REVIEW = "in_review"
|
||||
COMPLETED = "completed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class AgentTypeCategory(str, PyEnum):
|
||||
"""
|
||||
Category classification for agent types.
|
||||
|
||||
Used for grouping and filtering agents in the UI.
|
||||
|
||||
DEVELOPMENT: Product, project, and engineering roles
|
||||
DESIGN: UI/UX and design research roles
|
||||
QUALITY: QA and security engineering
|
||||
OPERATIONS: DevOps and MLOps
|
||||
AI_ML: Machine learning and AI specialists
|
||||
DATA: Data science and engineering
|
||||
LEADERSHIP: Technical leadership roles
|
||||
DOMAIN_EXPERT: Industry and domain specialists
|
||||
"""
|
||||
|
||||
DEVELOPMENT = "development"
|
||||
DESIGN = "design"
|
||||
QUALITY = "quality"
|
||||
OPERATIONS = "operations"
|
||||
AI_ML = "ai_ml"
|
||||
DATA = "data"
|
||||
LEADERSHIP = "leadership"
|
||||
DOMAIN_EXPERT = "domain_expert"
|
||||
|
||||
@@ -59,7 +59,9 @@ class Issue(Base, UUIDMixin, TimestampMixin):
|
||||
|
||||
# Issue type (Epic, Story, Task, Bug)
|
||||
type: Column[IssueType] = Column(
|
||||
Enum(IssueType),
|
||||
Enum(
|
||||
IssueType, name="issue_type", values_callable=lambda x: [e.value for e in x]
|
||||
),
|
||||
default=IssueType.TASK,
|
||||
nullable=False,
|
||||
index=True,
|
||||
@@ -78,14 +80,22 @@ class Issue(Base, UUIDMixin, TimestampMixin):
|
||||
|
||||
# Status and priority
|
||||
status: Column[IssueStatus] = Column(
|
||||
Enum(IssueStatus),
|
||||
Enum(
|
||||
IssueStatus,
|
||||
name="issue_status",
|
||||
values_callable=lambda x: [e.value for e in x],
|
||||
),
|
||||
default=IssueStatus.OPEN,
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
priority: Column[IssuePriority] = Column(
|
||||
Enum(IssuePriority),
|
||||
Enum(
|
||||
IssuePriority,
|
||||
name="issue_priority",
|
||||
values_callable=lambda x: [e.value for e in x],
|
||||
),
|
||||
default=IssuePriority.MEDIUM,
|
||||
nullable=False,
|
||||
index=True,
|
||||
@@ -132,7 +142,11 @@ class Issue(Base, UUIDMixin, TimestampMixin):
|
||||
|
||||
# Sync status with external tracker
|
||||
sync_status: Column[SyncStatus] = Column(
|
||||
Enum(SyncStatus),
|
||||
Enum(
|
||||
SyncStatus,
|
||||
name="sync_status",
|
||||
values_callable=lambda x: [e.value for e in x],
|
||||
),
|
||||
default=SyncStatus.SYNCED,
|
||||
nullable=False,
|
||||
# Note: Index defined in __table_args__ as ix_issues_sync_status
|
||||
|
||||
@@ -35,28 +35,44 @@ class Project(Base, UUIDMixin, TimestampMixin):
|
||||
description = Column(Text, nullable=True)
|
||||
|
||||
autonomy_level: Column[AutonomyLevel] = Column(
|
||||
Enum(AutonomyLevel),
|
||||
Enum(
|
||||
AutonomyLevel,
|
||||
name="autonomy_level",
|
||||
values_callable=lambda x: [e.value for e in x],
|
||||
),
|
||||
default=AutonomyLevel.MILESTONE,
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
status: Column[ProjectStatus] = Column(
|
||||
Enum(ProjectStatus),
|
||||
Enum(
|
||||
ProjectStatus,
|
||||
name="project_status",
|
||||
values_callable=lambda x: [e.value for e in x],
|
||||
),
|
||||
default=ProjectStatus.ACTIVE,
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
complexity: Column[ProjectComplexity] = Column(
|
||||
Enum(ProjectComplexity),
|
||||
Enum(
|
||||
ProjectComplexity,
|
||||
name="project_complexity",
|
||||
values_callable=lambda x: [e.value for e in x],
|
||||
),
|
||||
default=ProjectComplexity.MEDIUM,
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
client_mode: Column[ClientMode] = Column(
|
||||
Enum(ClientMode),
|
||||
Enum(
|
||||
ClientMode,
|
||||
name="client_mode",
|
||||
values_callable=lambda x: [e.value for e in x],
|
||||
),
|
||||
default=ClientMode.AUTO,
|
||||
nullable=False,
|
||||
index=True,
|
||||
|
||||
@@ -57,7 +57,11 @@ class Sprint(Base, UUIDMixin, TimestampMixin):
|
||||
|
||||
# Status
|
||||
status: Column[SprintStatus] = Column(
|
||||
Enum(SprintStatus),
|
||||
Enum(
|
||||
SprintStatus,
|
||||
name="sprint_status",
|
||||
values_callable=lambda x: [e.value for e in x],
|
||||
),
|
||||
default=SprintStatus.PLANNED,
|
||||
nullable=False,
|
||||
index=True,
|
||||
|
||||
@@ -10,6 +10,8 @@ from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from app.models.syndarix.enums import AgentTypeCategory
|
||||
|
||||
|
||||
class AgentTypeBase(BaseModel):
|
||||
"""Base agent type schema with common fields."""
|
||||
@@ -26,6 +28,14 @@ class AgentTypeBase(BaseModel):
|
||||
tool_permissions: dict[str, Any] = Field(default_factory=dict)
|
||||
is_active: bool = True
|
||||
|
||||
# Category and display fields
|
||||
category: AgentTypeCategory | None = None
|
||||
icon: str | None = Field(None, max_length=50)
|
||||
color: str | None = Field(None, pattern=r"^#[0-9A-Fa-f]{6}$")
|
||||
sort_order: int = Field(default=0, ge=0, le=1000)
|
||||
typical_tasks: list[str] = Field(default_factory=list)
|
||||
collaboration_hints: list[str] = Field(default_factory=list)
|
||||
|
||||
@field_validator("slug")
|
||||
@classmethod
|
||||
def validate_slug(cls, v: str | None) -> str | None:
|
||||
@@ -62,6 +72,18 @@ class AgentTypeBase(BaseModel):
|
||||
"""Validate MCP server list."""
|
||||
return [s.strip() for s in v if s.strip()]
|
||||
|
||||
@field_validator("typical_tasks")
|
||||
@classmethod
|
||||
def validate_typical_tasks(cls, v: list[str]) -> list[str]:
|
||||
"""Validate and normalize typical tasks list."""
|
||||
return [t.strip() for t in v if t.strip()]
|
||||
|
||||
@field_validator("collaboration_hints")
|
||||
@classmethod
|
||||
def validate_collaboration_hints(cls, v: list[str]) -> list[str]:
|
||||
"""Validate and normalize collaboration hints (agent slugs)."""
|
||||
return [h.strip().lower() for h in v if h.strip()]
|
||||
|
||||
|
||||
class AgentTypeCreate(AgentTypeBase):
|
||||
"""Schema for creating a new agent type."""
|
||||
@@ -87,6 +109,14 @@ class AgentTypeUpdate(BaseModel):
|
||||
tool_permissions: dict[str, Any] | None = None
|
||||
is_active: bool | None = None
|
||||
|
||||
# Category and display fields (all optional for updates)
|
||||
category: AgentTypeCategory | None = None
|
||||
icon: str | None = Field(None, max_length=50)
|
||||
color: str | None = Field(None, pattern=r"^#[0-9A-Fa-f]{6}$")
|
||||
sort_order: int | None = Field(None, ge=0, le=1000)
|
||||
typical_tasks: list[str] | None = None
|
||||
collaboration_hints: list[str] | None = None
|
||||
|
||||
@field_validator("slug")
|
||||
@classmethod
|
||||
def validate_slug(cls, v: str | None) -> str | None:
|
||||
@@ -119,6 +149,22 @@ class AgentTypeUpdate(BaseModel):
|
||||
return v
|
||||
return [e.strip().lower() for e in v if e.strip()]
|
||||
|
||||
@field_validator("typical_tasks")
|
||||
@classmethod
|
||||
def validate_typical_tasks(cls, v: list[str] | None) -> list[str] | None:
|
||||
"""Validate and normalize typical tasks list."""
|
||||
if v is None:
|
||||
return v
|
||||
return [t.strip() for t in v if t.strip()]
|
||||
|
||||
@field_validator("collaboration_hints")
|
||||
@classmethod
|
||||
def validate_collaboration_hints(cls, v: list[str] | None) -> list[str] | None:
|
||||
"""Validate and normalize collaboration hints (agent slugs)."""
|
||||
if v is None:
|
||||
return v
|
||||
return [h.strip().lower() for h in v if h.strip()]
|
||||
|
||||
|
||||
class AgentTypeInDB(AgentTypeBase):
|
||||
"""Schema for agent type in database."""
|
||||
|
||||
182
backend/app/services/context/__init__.py
Normal file
182
backend/app/services/context/__init__.py
Normal file
@@ -0,0 +1,182 @@
|
||||
"""
|
||||
Context Management Engine
|
||||
|
||||
Sophisticated context assembly and optimization for LLM requests.
|
||||
Provides intelligent context selection, token budget management,
|
||||
and model-specific formatting.
|
||||
|
||||
Usage:
|
||||
from app.services.context import (
|
||||
ContextSettings,
|
||||
get_context_settings,
|
||||
SystemContext,
|
||||
KnowledgeContext,
|
||||
ConversationContext,
|
||||
TaskContext,
|
||||
ToolContext,
|
||||
TokenBudget,
|
||||
BudgetAllocator,
|
||||
TokenCalculator,
|
||||
)
|
||||
|
||||
# Get settings
|
||||
settings = get_context_settings()
|
||||
|
||||
# Create budget for a model
|
||||
allocator = BudgetAllocator(settings)
|
||||
budget = allocator.create_budget_for_model("claude-3-sonnet")
|
||||
|
||||
# Create context instances
|
||||
system_ctx = SystemContext.create_persona(
|
||||
name="Code Assistant",
|
||||
description="You are a helpful code assistant.",
|
||||
capabilities=["Write code", "Debug issues"],
|
||||
)
|
||||
"""
|
||||
|
||||
# Budget Management
|
||||
# Adapters
|
||||
from .adapters import (
|
||||
ClaudeAdapter,
|
||||
DefaultAdapter,
|
||||
ModelAdapter,
|
||||
OpenAIAdapter,
|
||||
get_adapter,
|
||||
)
|
||||
|
||||
# Assembly
|
||||
from .assembly import (
|
||||
ContextPipeline,
|
||||
PipelineMetrics,
|
||||
)
|
||||
from .budget import (
|
||||
BudgetAllocator,
|
||||
TokenBudget,
|
||||
TokenCalculator,
|
||||
)
|
||||
|
||||
# Cache
|
||||
from .cache import ContextCache
|
||||
|
||||
# Compression
|
||||
from .compression import (
|
||||
ContextCompressor,
|
||||
TruncationResult,
|
||||
TruncationStrategy,
|
||||
)
|
||||
|
||||
# Configuration
|
||||
from .config import (
|
||||
ContextSettings,
|
||||
get_context_settings,
|
||||
get_default_settings,
|
||||
reset_context_settings,
|
||||
)
|
||||
|
||||
# Engine
|
||||
from .engine import ContextEngine, create_context_engine
|
||||
|
||||
# Exceptions
|
||||
from .exceptions import (
|
||||
AssemblyTimeoutError,
|
||||
BudgetExceededError,
|
||||
CacheError,
|
||||
CompressionError,
|
||||
ContextError,
|
||||
ContextNotFoundError,
|
||||
FormattingError,
|
||||
InvalidContextError,
|
||||
ScoringError,
|
||||
TokenCountError,
|
||||
)
|
||||
|
||||
# Prioritization
|
||||
from .prioritization import (
|
||||
ContextRanker,
|
||||
RankingResult,
|
||||
)
|
||||
|
||||
# Scoring
|
||||
from .scoring import (
|
||||
BaseScorer,
|
||||
CompositeScorer,
|
||||
PriorityScorer,
|
||||
RecencyScorer,
|
||||
RelevanceScorer,
|
||||
ScoredContext,
|
||||
)
|
||||
|
||||
# Types
|
||||
from .types import (
|
||||
AssembledContext,
|
||||
BaseContext,
|
||||
ContextPriority,
|
||||
ContextType,
|
||||
ConversationContext,
|
||||
KnowledgeContext,
|
||||
MemoryContext,
|
||||
MemorySubtype,
|
||||
MessageRole,
|
||||
SystemContext,
|
||||
TaskComplexity,
|
||||
TaskContext,
|
||||
TaskStatus,
|
||||
ToolContext,
|
||||
ToolResultStatus,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AssembledContext",
|
||||
"AssemblyTimeoutError",
|
||||
"BaseContext",
|
||||
"BaseScorer",
|
||||
"BudgetAllocator",
|
||||
"BudgetExceededError",
|
||||
"CacheError",
|
||||
"ClaudeAdapter",
|
||||
"CompositeScorer",
|
||||
"CompressionError",
|
||||
"ContextCache",
|
||||
"ContextCompressor",
|
||||
"ContextEngine",
|
||||
"ContextError",
|
||||
"ContextNotFoundError",
|
||||
"ContextPipeline",
|
||||
"ContextPriority",
|
||||
"ContextRanker",
|
||||
"ContextSettings",
|
||||
"ContextType",
|
||||
"ConversationContext",
|
||||
"DefaultAdapter",
|
||||
"FormattingError",
|
||||
"InvalidContextError",
|
||||
"KnowledgeContext",
|
||||
"MemoryContext",
|
||||
"MemorySubtype",
|
||||
"MessageRole",
|
||||
"ModelAdapter",
|
||||
"OpenAIAdapter",
|
||||
"PipelineMetrics",
|
||||
"PriorityScorer",
|
||||
"RankingResult",
|
||||
"RecencyScorer",
|
||||
"RelevanceScorer",
|
||||
"ScoredContext",
|
||||
"ScoringError",
|
||||
"SystemContext",
|
||||
"TaskComplexity",
|
||||
"TaskContext",
|
||||
"TaskStatus",
|
||||
"TokenBudget",
|
||||
"TokenCalculator",
|
||||
"TokenCountError",
|
||||
"ToolContext",
|
||||
"ToolResultStatus",
|
||||
"TruncationResult",
|
||||
"TruncationStrategy",
|
||||
"create_context_engine",
|
||||
"get_adapter",
|
||||
"get_context_settings",
|
||||
"get_default_settings",
|
||||
"reset_context_settings",
|
||||
]
|
||||
35
backend/app/services/context/adapters/__init__.py
Normal file
35
backend/app/services/context/adapters/__init__.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""
|
||||
Model Adapters Module.
|
||||
|
||||
Provides model-specific context formatting adapters.
|
||||
"""
|
||||
|
||||
from .base import DefaultAdapter, ModelAdapter
|
||||
from .claude import ClaudeAdapter
|
||||
from .openai import OpenAIAdapter
|
||||
|
||||
|
||||
def get_adapter(model: str) -> ModelAdapter:
|
||||
"""
|
||||
Get the appropriate adapter for a model.
|
||||
|
||||
Args:
|
||||
model: Model name
|
||||
|
||||
Returns:
|
||||
Adapter instance for the model
|
||||
"""
|
||||
if ClaudeAdapter.matches_model(model):
|
||||
return ClaudeAdapter()
|
||||
elif OpenAIAdapter.matches_model(model):
|
||||
return OpenAIAdapter()
|
||||
return DefaultAdapter()
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ClaudeAdapter",
|
||||
"DefaultAdapter",
|
||||
"ModelAdapter",
|
||||
"OpenAIAdapter",
|
||||
"get_adapter",
|
||||
]
|
||||
178
backend/app/services/context/adapters/base.py
Normal file
178
backend/app/services/context/adapters/base.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""
|
||||
Base Model Adapter.
|
||||
|
||||
Abstract base class for model-specific context formatting.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from ..types import BaseContext, ContextType
|
||||
|
||||
|
||||
class ModelAdapter(ABC):
|
||||
"""
|
||||
Abstract base adapter for model-specific context formatting.
|
||||
|
||||
Each adapter knows how to format contexts for optimal
|
||||
understanding by a specific LLM family (Claude, OpenAI, etc.).
|
||||
"""
|
||||
|
||||
# Model name patterns this adapter handles
|
||||
MODEL_PATTERNS: ClassVar[list[str]] = []
|
||||
|
||||
@classmethod
|
||||
def matches_model(cls, model: str) -> bool:
|
||||
"""
|
||||
Check if this adapter handles the given model.
|
||||
|
||||
Args:
|
||||
model: Model name to check
|
||||
|
||||
Returns:
|
||||
True if this adapter handles the model
|
||||
"""
|
||||
model_lower = model.lower()
|
||||
return any(pattern in model_lower for pattern in cls.MODEL_PATTERNS)
|
||||
|
||||
@abstractmethod
|
||||
def format(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""
|
||||
Format contexts for the target model.
|
||||
|
||||
Args:
|
||||
contexts: List of contexts to format
|
||||
**kwargs: Additional formatting options
|
||||
|
||||
Returns:
|
||||
Formatted context string
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def format_type(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
context_type: ContextType,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""
|
||||
Format contexts of a specific type.
|
||||
|
||||
Args:
|
||||
contexts: List of contexts of the same type
|
||||
context_type: The type of contexts
|
||||
**kwargs: Additional formatting options
|
||||
|
||||
Returns:
|
||||
Formatted string for this context type
|
||||
"""
|
||||
...
|
||||
|
||||
def get_type_order(self) -> list[ContextType]:
|
||||
"""
|
||||
Get the preferred order of context types.
|
||||
|
||||
Returns:
|
||||
List of context types in preferred order
|
||||
"""
|
||||
return [
|
||||
ContextType.SYSTEM,
|
||||
ContextType.TASK,
|
||||
ContextType.KNOWLEDGE,
|
||||
ContextType.CONVERSATION,
|
||||
ContextType.TOOL,
|
||||
]
|
||||
|
||||
def group_by_type(
|
||||
self, contexts: list[BaseContext]
|
||||
) -> dict[ContextType, list[BaseContext]]:
|
||||
"""
|
||||
Group contexts by their type.
|
||||
|
||||
Args:
|
||||
contexts: List of contexts to group
|
||||
|
||||
Returns:
|
||||
Dictionary mapping context type to list of contexts
|
||||
"""
|
||||
by_type: dict[ContextType, list[BaseContext]] = {}
|
||||
for context in contexts:
|
||||
ct = context.get_type()
|
||||
if ct not in by_type:
|
||||
by_type[ct] = []
|
||||
by_type[ct].append(context)
|
||||
return by_type
|
||||
|
||||
def get_separator(self) -> str:
|
||||
"""
|
||||
Get the separator between context sections.
|
||||
|
||||
Returns:
|
||||
Separator string
|
||||
"""
|
||||
return "\n\n"
|
||||
|
||||
|
||||
class DefaultAdapter(ModelAdapter):
|
||||
"""
|
||||
Default adapter for unknown models.
|
||||
|
||||
Uses simple plain-text formatting with minimal structure.
|
||||
"""
|
||||
|
||||
MODEL_PATTERNS: ClassVar[list[str]] = [] # Fallback adapter
|
||||
|
||||
@classmethod
|
||||
def matches_model(cls, model: str) -> bool:
|
||||
"""Always returns True as fallback."""
|
||||
return True
|
||||
|
||||
def format(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Format contexts as plain text."""
|
||||
if not contexts:
|
||||
return ""
|
||||
|
||||
by_type = self.group_by_type(contexts)
|
||||
parts: list[str] = []
|
||||
|
||||
for ct in self.get_type_order():
|
||||
if ct in by_type:
|
||||
formatted = self.format_type(by_type[ct], ct, **kwargs)
|
||||
if formatted:
|
||||
parts.append(formatted)
|
||||
|
||||
return self.get_separator().join(parts)
|
||||
|
||||
def format_type(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
context_type: ContextType,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Format contexts of a type as plain text."""
|
||||
if not contexts:
|
||||
return ""
|
||||
|
||||
content = "\n\n".join(c.content for c in contexts)
|
||||
|
||||
if context_type == ContextType.SYSTEM:
|
||||
return content
|
||||
elif context_type == ContextType.TASK:
|
||||
return f"Task:\n{content}"
|
||||
elif context_type == ContextType.KNOWLEDGE:
|
||||
return f"Reference Information:\n{content}"
|
||||
elif context_type == ContextType.CONVERSATION:
|
||||
return f"Previous Conversation:\n{content}"
|
||||
elif context_type == ContextType.TOOL:
|
||||
return f"Tool Results:\n{content}"
|
||||
|
||||
return content
|
||||
212
backend/app/services/context/adapters/claude.py
Normal file
212
backend/app/services/context/adapters/claude.py
Normal file
@@ -0,0 +1,212 @@
|
||||
"""
|
||||
Claude Model Adapter.
|
||||
|
||||
Provides Claude-specific context formatting using XML tags
|
||||
which Claude models understand natively.
|
||||
"""
|
||||
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from ..types import BaseContext, ContextType
|
||||
from .base import ModelAdapter
|
||||
|
||||
|
||||
class ClaudeAdapter(ModelAdapter):
|
||||
"""
|
||||
Claude-specific context formatting adapter.
|
||||
|
||||
Claude models have native understanding of XML structure,
|
||||
so we use XML tags for clear delineation of context types.
|
||||
|
||||
Features:
|
||||
- XML tags for each context type
|
||||
- Document structure for knowledge contexts
|
||||
- Role-based message formatting for conversations
|
||||
- Tool result wrapping with tool names
|
||||
"""
|
||||
|
||||
MODEL_PATTERNS: ClassVar[list[str]] = ["claude", "anthropic"]
|
||||
|
||||
def format(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""
|
||||
Format contexts for Claude models.
|
||||
|
||||
Uses XML tags for structured content that Claude
|
||||
understands natively.
|
||||
|
||||
Args:
|
||||
contexts: List of contexts to format
|
||||
**kwargs: Additional formatting options
|
||||
|
||||
Returns:
|
||||
XML-structured context string
|
||||
"""
|
||||
if not contexts:
|
||||
return ""
|
||||
|
||||
by_type = self.group_by_type(contexts)
|
||||
parts: list[str] = []
|
||||
|
||||
for ct in self.get_type_order():
|
||||
if ct in by_type:
|
||||
formatted = self.format_type(by_type[ct], ct, **kwargs)
|
||||
if formatted:
|
||||
parts.append(formatted)
|
||||
|
||||
return self.get_separator().join(parts)
|
||||
|
||||
def format_type(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
context_type: ContextType,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""
|
||||
Format contexts of a specific type for Claude.
|
||||
|
||||
Args:
|
||||
contexts: List of contexts of the same type
|
||||
context_type: The type of contexts
|
||||
**kwargs: Additional formatting options
|
||||
|
||||
Returns:
|
||||
XML-formatted string for this context type
|
||||
"""
|
||||
if not contexts:
|
||||
return ""
|
||||
|
||||
if context_type == ContextType.SYSTEM:
|
||||
return self._format_system(contexts)
|
||||
elif context_type == ContextType.TASK:
|
||||
return self._format_task(contexts)
|
||||
elif context_type == ContextType.KNOWLEDGE:
|
||||
return self._format_knowledge(contexts)
|
||||
elif context_type == ContextType.CONVERSATION:
|
||||
return self._format_conversation(contexts)
|
||||
elif context_type == ContextType.TOOL:
|
||||
return self._format_tool(contexts)
|
||||
|
||||
# Fallback for any unhandled context types - still escape content
|
||||
# to prevent XML injection if new types are added without updating adapter
|
||||
return "\n".join(self._escape_xml_content(c.content) for c in contexts)
|
||||
|
||||
def _format_system(self, contexts: list[BaseContext]) -> str:
|
||||
"""Format system contexts."""
|
||||
# System prompts are typically admin-controlled, but escape for safety
|
||||
content = "\n\n".join(self._escape_xml_content(c.content) for c in contexts)
|
||||
return f"<system_instructions>\n{content}\n</system_instructions>"
|
||||
|
||||
def _format_task(self, contexts: list[BaseContext]) -> str:
|
||||
"""Format task contexts."""
|
||||
content = "\n\n".join(self._escape_xml_content(c.content) for c in contexts)
|
||||
return f"<current_task>\n{content}\n</current_task>"
|
||||
|
||||
def _format_knowledge(self, contexts: list[BaseContext]) -> str:
|
||||
"""
|
||||
Format knowledge contexts as structured documents.
|
||||
|
||||
Each knowledge context becomes a document with source attribution.
|
||||
All content is XML-escaped to prevent injection attacks.
|
||||
"""
|
||||
parts = ["<reference_documents>"]
|
||||
|
||||
for ctx in contexts:
|
||||
source = self._escape_xml(ctx.source)
|
||||
# Escape content to prevent XML injection
|
||||
content = self._escape_xml_content(ctx.content)
|
||||
score = ctx.metadata.get("score", ctx.metadata.get("relevance_score", ""))
|
||||
|
||||
if score:
|
||||
# Escape score to prevent XML injection via metadata
|
||||
escaped_score = self._escape_xml(str(score))
|
||||
parts.append(
|
||||
f'<document source="{source}" relevance="{escaped_score}">'
|
||||
)
|
||||
else:
|
||||
parts.append(f'<document source="{source}">')
|
||||
|
||||
parts.append(content)
|
||||
parts.append("</document>")
|
||||
|
||||
parts.append("</reference_documents>")
|
||||
return "\n".join(parts)
|
||||
|
||||
def _format_conversation(self, contexts: list[BaseContext]) -> str:
|
||||
"""
|
||||
Format conversation contexts as message history.
|
||||
|
||||
Uses role-based message tags for clear turn delineation.
|
||||
All content is XML-escaped to prevent prompt injection.
|
||||
"""
|
||||
parts = ["<conversation_history>"]
|
||||
|
||||
for ctx in contexts:
|
||||
role = self._escape_xml(ctx.metadata.get("role", "user"))
|
||||
# Escape content to prevent prompt injection via fake XML tags
|
||||
content = self._escape_xml_content(ctx.content)
|
||||
parts.append(f'<message role="{role}">')
|
||||
parts.append(content)
|
||||
parts.append("</message>")
|
||||
|
||||
parts.append("</conversation_history>")
|
||||
return "\n".join(parts)
|
||||
|
||||
def _format_tool(self, contexts: list[BaseContext]) -> str:
|
||||
"""
|
||||
Format tool contexts as tool results.
|
||||
|
||||
Each tool result is wrapped with the tool name.
|
||||
All content is XML-escaped to prevent injection.
|
||||
"""
|
||||
parts = ["<tool_results>"]
|
||||
|
||||
for ctx in contexts:
|
||||
tool_name = self._escape_xml(ctx.metadata.get("tool_name", "unknown"))
|
||||
status = ctx.metadata.get("status", "")
|
||||
|
||||
if status:
|
||||
parts.append(
|
||||
f'<tool_result name="{tool_name}" status="{self._escape_xml(status)}">'
|
||||
)
|
||||
else:
|
||||
parts.append(f'<tool_result name="{tool_name}">')
|
||||
|
||||
# Escape content to prevent injection
|
||||
parts.append(self._escape_xml_content(ctx.content))
|
||||
parts.append("</tool_result>")
|
||||
|
||||
parts.append("</tool_results>")
|
||||
return "\n".join(parts)
|
||||
|
||||
@staticmethod
|
||||
def _escape_xml(text: str) -> str:
|
||||
"""Escape XML special characters in attribute values."""
|
||||
return (
|
||||
text.replace("&", "&")
|
||||
.replace("<", "<")
|
||||
.replace(">", ">")
|
||||
.replace('"', """)
|
||||
.replace("'", "'")
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _escape_xml_content(text: str) -> str:
|
||||
"""
|
||||
Escape XML special characters in element content.
|
||||
|
||||
This prevents XML injection attacks where malicious content
|
||||
could break out of XML tags or inject fake tags for prompt injection.
|
||||
|
||||
Only escapes &, <, > since quotes don't need escaping in content.
|
||||
|
||||
Args:
|
||||
text: Content text to escape
|
||||
|
||||
Returns:
|
||||
XML-safe content string
|
||||
"""
|
||||
return text.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
160
backend/app/services/context/adapters/openai.py
Normal file
160
backend/app/services/context/adapters/openai.py
Normal file
@@ -0,0 +1,160 @@
|
||||
"""
|
||||
OpenAI Model Adapter.
|
||||
|
||||
Provides OpenAI-specific context formatting using markdown
|
||||
which GPT models understand well.
|
||||
"""
|
||||
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from ..types import BaseContext, ContextType
|
||||
from .base import ModelAdapter
|
||||
|
||||
|
||||
class OpenAIAdapter(ModelAdapter):
|
||||
"""
|
||||
OpenAI-specific context formatting adapter.
|
||||
|
||||
GPT models work well with markdown formatting,
|
||||
so we use headers and structured markdown for clarity.
|
||||
|
||||
Features:
|
||||
- Markdown headers for each context type
|
||||
- Bulleted lists for document sources
|
||||
- Bold role labels for conversations
|
||||
- Code blocks for tool outputs
|
||||
"""
|
||||
|
||||
MODEL_PATTERNS: ClassVar[list[str]] = ["gpt", "openai", "o1", "o3"]
|
||||
|
||||
def format(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""
|
||||
Format contexts for OpenAI models.
|
||||
|
||||
Uses markdown formatting for structured content.
|
||||
|
||||
Args:
|
||||
contexts: List of contexts to format
|
||||
**kwargs: Additional formatting options
|
||||
|
||||
Returns:
|
||||
Markdown-structured context string
|
||||
"""
|
||||
if not contexts:
|
||||
return ""
|
||||
|
||||
by_type = self.group_by_type(contexts)
|
||||
parts: list[str] = []
|
||||
|
||||
for ct in self.get_type_order():
|
||||
if ct in by_type:
|
||||
formatted = self.format_type(by_type[ct], ct, **kwargs)
|
||||
if formatted:
|
||||
parts.append(formatted)
|
||||
|
||||
return self.get_separator().join(parts)
|
||||
|
||||
def format_type(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
context_type: ContextType,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""
|
||||
Format contexts of a specific type for OpenAI.
|
||||
|
||||
Args:
|
||||
contexts: List of contexts of the same type
|
||||
context_type: The type of contexts
|
||||
**kwargs: Additional formatting options
|
||||
|
||||
Returns:
|
||||
Markdown-formatted string for this context type
|
||||
"""
|
||||
if not contexts:
|
||||
return ""
|
||||
|
||||
if context_type == ContextType.SYSTEM:
|
||||
return self._format_system(contexts)
|
||||
elif context_type == ContextType.TASK:
|
||||
return self._format_task(contexts)
|
||||
elif context_type == ContextType.KNOWLEDGE:
|
||||
return self._format_knowledge(contexts)
|
||||
elif context_type == ContextType.CONVERSATION:
|
||||
return self._format_conversation(contexts)
|
||||
elif context_type == ContextType.TOOL:
|
||||
return self._format_tool(contexts)
|
||||
|
||||
return "\n".join(c.content for c in contexts)
|
||||
|
||||
def _format_system(self, contexts: list[BaseContext]) -> str:
|
||||
"""Format system contexts."""
|
||||
content = "\n\n".join(c.content for c in contexts)
|
||||
return content
|
||||
|
||||
def _format_task(self, contexts: list[BaseContext]) -> str:
|
||||
"""Format task contexts."""
|
||||
content = "\n\n".join(c.content for c in contexts)
|
||||
return f"## Current Task\n\n{content}"
|
||||
|
||||
def _format_knowledge(self, contexts: list[BaseContext]) -> str:
|
||||
"""
|
||||
Format knowledge contexts as structured documents.
|
||||
|
||||
Each knowledge context becomes a section with source attribution.
|
||||
"""
|
||||
parts = ["## Reference Documents\n"]
|
||||
|
||||
for ctx in contexts:
|
||||
source = ctx.source
|
||||
score = ctx.metadata.get("score", ctx.metadata.get("relevance_score", ""))
|
||||
|
||||
if score:
|
||||
parts.append(f"### Source: {source} (relevance: {score})\n")
|
||||
else:
|
||||
parts.append(f"### Source: {source}\n")
|
||||
|
||||
parts.append(ctx.content)
|
||||
parts.append("")
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
def _format_conversation(self, contexts: list[BaseContext]) -> str:
|
||||
"""
|
||||
Format conversation contexts as message history.
|
||||
|
||||
Uses bold role labels for clear turn delineation.
|
||||
"""
|
||||
parts = []
|
||||
|
||||
for ctx in contexts:
|
||||
role = ctx.metadata.get("role", "user").upper()
|
||||
parts.append(f"**{role}**: {ctx.content}")
|
||||
|
||||
return "\n\n".join(parts)
|
||||
|
||||
def _format_tool(self, contexts: list[BaseContext]) -> str:
|
||||
"""
|
||||
Format tool contexts as tool results.
|
||||
|
||||
Each tool result is in a code block with the tool name.
|
||||
"""
|
||||
parts = ["## Recent Tool Results\n"]
|
||||
|
||||
for ctx in contexts:
|
||||
tool_name = ctx.metadata.get("tool_name", "unknown")
|
||||
status = ctx.metadata.get("status", "")
|
||||
|
||||
if status:
|
||||
parts.append(f"### Tool: {tool_name} ({status})\n")
|
||||
else:
|
||||
parts.append(f"### Tool: {tool_name}\n")
|
||||
|
||||
parts.append(f"```\n{ctx.content}\n```")
|
||||
parts.append("")
|
||||
|
||||
return "\n".join(parts)
|
||||
12
backend/app/services/context/assembly/__init__.py
Normal file
12
backend/app/services/context/assembly/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""
|
||||
Context Assembly Module.
|
||||
|
||||
Provides the assembly pipeline and formatting.
|
||||
"""
|
||||
|
||||
from .pipeline import ContextPipeline, PipelineMetrics
|
||||
|
||||
__all__ = [
|
||||
"ContextPipeline",
|
||||
"PipelineMetrics",
|
||||
]
|
||||
362
backend/app/services/context/assembly/pipeline.py
Normal file
362
backend/app/services/context/assembly/pipeline.py
Normal file
@@ -0,0 +1,362 @@
|
||||
"""
|
||||
Context Assembly Pipeline.
|
||||
|
||||
Orchestrates the full context assembly workflow:
|
||||
Gather → Count → Score → Rank → Compress → Format
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from ..adapters import get_adapter
|
||||
from ..budget import BudgetAllocator, TokenBudget, TokenCalculator
|
||||
from ..compression.truncation import ContextCompressor
|
||||
from ..config import ContextSettings, get_context_settings
|
||||
from ..exceptions import AssemblyTimeoutError
|
||||
from ..prioritization import ContextRanker
|
||||
from ..scoring import CompositeScorer
|
||||
from ..types import AssembledContext, BaseContext, ContextType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.services.mcp.client_manager import MCPClientManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineMetrics:
|
||||
"""Metrics from pipeline execution."""
|
||||
|
||||
start_time: datetime = field(default_factory=lambda: datetime.now(UTC))
|
||||
end_time: datetime | None = None
|
||||
total_contexts: int = 0
|
||||
selected_contexts: int = 0
|
||||
excluded_contexts: int = 0
|
||||
compressed_contexts: int = 0
|
||||
total_tokens: int = 0
|
||||
assembly_time_ms: float = 0.0
|
||||
scoring_time_ms: float = 0.0
|
||||
compression_time_ms: float = 0.0
|
||||
formatting_time_ms: float = 0.0
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"start_time": self.start_time.isoformat(),
|
||||
"end_time": self.end_time.isoformat() if self.end_time else None,
|
||||
"total_contexts": self.total_contexts,
|
||||
"selected_contexts": self.selected_contexts,
|
||||
"excluded_contexts": self.excluded_contexts,
|
||||
"compressed_contexts": self.compressed_contexts,
|
||||
"total_tokens": self.total_tokens,
|
||||
"assembly_time_ms": round(self.assembly_time_ms, 2),
|
||||
"scoring_time_ms": round(self.scoring_time_ms, 2),
|
||||
"compression_time_ms": round(self.compression_time_ms, 2),
|
||||
"formatting_time_ms": round(self.formatting_time_ms, 2),
|
||||
}
|
||||
|
||||
|
||||
class ContextPipeline:
|
||||
"""
|
||||
Context assembly pipeline.
|
||||
|
||||
Orchestrates the full workflow of context assembly:
|
||||
1. Validate and count tokens for all contexts
|
||||
2. Score contexts based on relevance, recency, and priority
|
||||
3. Rank and select contexts within budget
|
||||
4. Compress if needed to fit remaining budget
|
||||
5. Format for the target model
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mcp_manager: "MCPClientManager | None" = None,
|
||||
settings: ContextSettings | None = None,
|
||||
calculator: TokenCalculator | None = None,
|
||||
scorer: CompositeScorer | None = None,
|
||||
ranker: ContextRanker | None = None,
|
||||
compressor: ContextCompressor | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the context pipeline.
|
||||
|
||||
Args:
|
||||
mcp_manager: MCP client manager for LLM Gateway integration
|
||||
settings: Context settings
|
||||
calculator: Token calculator
|
||||
scorer: Context scorer
|
||||
ranker: Context ranker
|
||||
compressor: Context compressor
|
||||
"""
|
||||
self._settings = settings or get_context_settings()
|
||||
self._mcp = mcp_manager
|
||||
|
||||
# Initialize components
|
||||
self._calculator = calculator or TokenCalculator(mcp_manager=mcp_manager)
|
||||
self._scorer = scorer or CompositeScorer(
|
||||
mcp_manager=mcp_manager, settings=self._settings
|
||||
)
|
||||
self._ranker = ranker or ContextRanker(
|
||||
scorer=self._scorer, calculator=self._calculator
|
||||
)
|
||||
self._compressor = compressor or ContextCompressor(calculator=self._calculator)
|
||||
self._allocator = BudgetAllocator(self._settings)
|
||||
|
||||
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
|
||||
"""Set MCP manager for all components."""
|
||||
self._mcp = mcp_manager
|
||||
self._calculator.set_mcp_manager(mcp_manager)
|
||||
self._scorer.set_mcp_manager(mcp_manager)
|
||||
|
||||
async def assemble(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
query: str,
|
||||
model: str,
|
||||
max_tokens: int | None = None,
|
||||
custom_budget: TokenBudget | None = None,
|
||||
compress: bool = True,
|
||||
format_output: bool = True,
|
||||
timeout_ms: int | None = None,
|
||||
) -> AssembledContext:
|
||||
"""
|
||||
Assemble context for an LLM request.
|
||||
|
||||
This is the main entry point for context assembly.
|
||||
|
||||
Args:
|
||||
contexts: List of contexts to assemble
|
||||
query: Query to optimize for
|
||||
model: Target model name
|
||||
max_tokens: Maximum total tokens (uses model default if None)
|
||||
custom_budget: Optional pre-configured budget
|
||||
compress: Whether to compress oversized contexts
|
||||
format_output: Whether to format the final output
|
||||
timeout_ms: Maximum assembly time in milliseconds
|
||||
|
||||
Returns:
|
||||
AssembledContext with optimized content
|
||||
|
||||
Raises:
|
||||
AssemblyTimeoutError: If assembly exceeds timeout
|
||||
"""
|
||||
timeout = timeout_ms or self._settings.max_assembly_time_ms
|
||||
start = time.perf_counter()
|
||||
metrics = PipelineMetrics(total_contexts=len(contexts))
|
||||
|
||||
try:
|
||||
# Create or use budget
|
||||
if custom_budget:
|
||||
budget = custom_budget
|
||||
elif max_tokens:
|
||||
budget = self._allocator.create_budget(max_tokens)
|
||||
else:
|
||||
budget = self._allocator.create_budget_for_model(model)
|
||||
|
||||
# 1. Count tokens for all contexts (with timeout enforcement)
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self._ensure_token_counts(contexts, model),
|
||||
timeout=self._remaining_timeout(start, timeout),
|
||||
)
|
||||
except TimeoutError:
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
raise AssemblyTimeoutError(
|
||||
message="Context assembly timed out during token counting",
|
||||
elapsed_ms=elapsed_ms,
|
||||
timeout_ms=timeout,
|
||||
)
|
||||
|
||||
# Check timeout (handles edge case where operation finished just at limit)
|
||||
self._check_timeout(start, timeout, "token counting")
|
||||
|
||||
# 2. Score and rank contexts (with timeout enforcement)
|
||||
scoring_start = time.perf_counter()
|
||||
try:
|
||||
ranking_result = await asyncio.wait_for(
|
||||
self._ranker.rank(
|
||||
contexts=contexts,
|
||||
query=query,
|
||||
budget=budget,
|
||||
model=model,
|
||||
),
|
||||
timeout=self._remaining_timeout(start, timeout),
|
||||
)
|
||||
except TimeoutError:
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
raise AssemblyTimeoutError(
|
||||
message="Context assembly timed out during scoring/ranking",
|
||||
elapsed_ms=elapsed_ms,
|
||||
timeout_ms=timeout,
|
||||
)
|
||||
metrics.scoring_time_ms = (time.perf_counter() - scoring_start) * 1000
|
||||
|
||||
selected_contexts = ranking_result.selected_contexts
|
||||
metrics.selected_contexts = len(selected_contexts)
|
||||
metrics.excluded_contexts = len(ranking_result.excluded)
|
||||
|
||||
# Check timeout
|
||||
self._check_timeout(start, timeout, "scoring")
|
||||
|
||||
# 3. Compress if needed and enabled (with timeout enforcement)
|
||||
if compress and self._needs_compression(selected_contexts, budget):
|
||||
compression_start = time.perf_counter()
|
||||
try:
|
||||
selected_contexts = await asyncio.wait_for(
|
||||
self._compressor.compress_contexts(
|
||||
selected_contexts, budget, model
|
||||
),
|
||||
timeout=self._remaining_timeout(start, timeout),
|
||||
)
|
||||
except TimeoutError:
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
raise AssemblyTimeoutError(
|
||||
message="Context assembly timed out during compression",
|
||||
elapsed_ms=elapsed_ms,
|
||||
timeout_ms=timeout,
|
||||
)
|
||||
metrics.compression_time_ms = (
|
||||
time.perf_counter() - compression_start
|
||||
) * 1000
|
||||
metrics.compressed_contexts = sum(
|
||||
1 for c in selected_contexts if c.metadata.get("truncated", False)
|
||||
)
|
||||
|
||||
# Check timeout
|
||||
self._check_timeout(start, timeout, "compression")
|
||||
|
||||
# 4. Format output
|
||||
formatting_start = time.perf_counter()
|
||||
if format_output:
|
||||
formatted_content = self._format_contexts(selected_contexts, model)
|
||||
else:
|
||||
formatted_content = "\n\n".join(c.content for c in selected_contexts)
|
||||
metrics.formatting_time_ms = (time.perf_counter() - formatting_start) * 1000
|
||||
|
||||
# Calculate final metrics
|
||||
total_tokens = sum(c.token_count or 0 for c in selected_contexts)
|
||||
metrics.total_tokens = total_tokens
|
||||
metrics.assembly_time_ms = (time.perf_counter() - start) * 1000
|
||||
metrics.end_time = datetime.now(UTC)
|
||||
|
||||
return AssembledContext(
|
||||
content=formatted_content,
|
||||
total_tokens=total_tokens,
|
||||
context_count=len(selected_contexts),
|
||||
assembly_time_ms=metrics.assembly_time_ms,
|
||||
model=model,
|
||||
contexts=selected_contexts,
|
||||
excluded_count=metrics.excluded_contexts,
|
||||
metadata={
|
||||
"metrics": metrics.to_dict(),
|
||||
"query": query,
|
||||
"budget": budget.to_dict(),
|
||||
},
|
||||
)
|
||||
|
||||
except AssemblyTimeoutError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Context assembly failed: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def _ensure_token_counts(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
model: str | None = None,
|
||||
) -> None:
|
||||
"""Ensure all contexts have token counts."""
|
||||
tasks = []
|
||||
for context in contexts:
|
||||
if context.token_count is None:
|
||||
tasks.append(self._count_and_set(context, model))
|
||||
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def _count_and_set(
|
||||
self,
|
||||
context: BaseContext,
|
||||
model: str | None = None,
|
||||
) -> None:
|
||||
"""Count tokens and set on context."""
|
||||
count = await self._calculator.count_tokens(context.content, model)
|
||||
context.token_count = count
|
||||
|
||||
def _needs_compression(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
budget: TokenBudget,
|
||||
) -> bool:
|
||||
"""Check if any contexts exceed their type budget."""
|
||||
# Group by type and check totals
|
||||
by_type: dict[ContextType, int] = {}
|
||||
for context in contexts:
|
||||
ct = context.get_type()
|
||||
by_type[ct] = by_type.get(ct, 0) + (context.token_count or 0)
|
||||
|
||||
for ct, total in by_type.items():
|
||||
if total > budget.get_allocation(ct):
|
||||
return True
|
||||
|
||||
# Also check if utilization exceeds threshold
|
||||
return budget.utilization() > self._settings.compression_threshold
|
||||
|
||||
def _format_contexts(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
model: str,
|
||||
) -> str:
|
||||
"""
|
||||
Format contexts for the target model.
|
||||
|
||||
Uses model-specific adapters (ClaudeAdapter, OpenAIAdapter, etc.)
|
||||
to format contexts optimally for each model family.
|
||||
|
||||
Args:
|
||||
contexts: Contexts to format
|
||||
model: Target model name
|
||||
|
||||
Returns:
|
||||
Formatted context string
|
||||
"""
|
||||
adapter = get_adapter(model)
|
||||
return adapter.format(contexts)
|
||||
|
||||
def _check_timeout(
|
||||
self,
|
||||
start: float,
|
||||
timeout_ms: int,
|
||||
phase: str,
|
||||
) -> None:
|
||||
"""Check if timeout exceeded and raise if so."""
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
if elapsed_ms >= timeout_ms:
|
||||
raise AssemblyTimeoutError(
|
||||
message=f"Context assembly timed out during {phase}",
|
||||
elapsed_ms=elapsed_ms,
|
||||
timeout_ms=timeout_ms,
|
||||
)
|
||||
|
||||
def _remaining_timeout(self, start: float, timeout_ms: int) -> float:
|
||||
"""
|
||||
Calculate remaining timeout in seconds for asyncio.wait_for.
|
||||
|
||||
Returns at least a small positive value to avoid immediate timeout
|
||||
edge cases with wait_for.
|
||||
|
||||
Args:
|
||||
start: Start time from time.perf_counter()
|
||||
timeout_ms: Total timeout in milliseconds
|
||||
|
||||
Returns:
|
||||
Remaining timeout in seconds (minimum 0.001)
|
||||
"""
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
remaining_ms = timeout_ms - elapsed_ms
|
||||
# Return at least 1ms to avoid zero/negative timeout edge cases
|
||||
return max(remaining_ms / 1000.0, 0.001)
|
||||
14
backend/app/services/context/budget/__init__.py
Normal file
14
backend/app/services/context/budget/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""
|
||||
Token Budget Management Module.
|
||||
|
||||
Provides token counting and budget allocation.
|
||||
"""
|
||||
|
||||
from .allocator import BudgetAllocator, TokenBudget
|
||||
from .calculator import TokenCalculator
|
||||
|
||||
__all__ = [
|
||||
"BudgetAllocator",
|
||||
"TokenBudget",
|
||||
"TokenCalculator",
|
||||
]
|
||||
444
backend/app/services/context/budget/allocator.py
Normal file
444
backend/app/services/context/budget/allocator.py
Normal file
@@ -0,0 +1,444 @@
|
||||
"""
|
||||
Token Budget Allocator for Context Management.
|
||||
|
||||
Manages token budget allocation across context types.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from ..config import ContextSettings, get_context_settings
|
||||
from ..exceptions import BudgetExceededError
|
||||
from ..types import ContextType
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenBudget:
|
||||
"""
|
||||
Token budget allocation and tracking.
|
||||
|
||||
Tracks allocated tokens per context type and
|
||||
monitors usage to prevent overflows.
|
||||
"""
|
||||
|
||||
# Total budget
|
||||
total: int
|
||||
|
||||
# Allocated per type
|
||||
system: int = 0
|
||||
task: int = 0
|
||||
knowledge: int = 0
|
||||
conversation: int = 0
|
||||
tools: int = 0
|
||||
memory: int = 0 # Agent memory (working, episodic, semantic, procedural)
|
||||
response_reserve: int = 0
|
||||
buffer: int = 0
|
||||
|
||||
# Usage tracking
|
||||
used: dict[str, int] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Initialize usage tracking."""
|
||||
if not self.used:
|
||||
self.used = {ct.value: 0 for ct in ContextType}
|
||||
|
||||
def get_allocation(self, context_type: ContextType | str) -> int:
|
||||
"""
|
||||
Get allocated tokens for a context type.
|
||||
|
||||
Args:
|
||||
context_type: Context type to get allocation for
|
||||
|
||||
Returns:
|
||||
Allocated token count
|
||||
"""
|
||||
if isinstance(context_type, ContextType):
|
||||
context_type = context_type.value
|
||||
|
||||
allocation_map = {
|
||||
"system": self.system,
|
||||
"task": self.task,
|
||||
"knowledge": self.knowledge,
|
||||
"conversation": self.conversation,
|
||||
"tool": self.tools,
|
||||
"memory": self.memory,
|
||||
}
|
||||
return allocation_map.get(context_type, 0)
|
||||
|
||||
def get_used(self, context_type: ContextType | str) -> int:
|
||||
"""
|
||||
Get used tokens for a context type.
|
||||
|
||||
Args:
|
||||
context_type: Context type to check
|
||||
|
||||
Returns:
|
||||
Used token count
|
||||
"""
|
||||
if isinstance(context_type, ContextType):
|
||||
context_type = context_type.value
|
||||
return self.used.get(context_type, 0)
|
||||
|
||||
def remaining(self, context_type: ContextType | str) -> int:
|
||||
"""
|
||||
Get remaining tokens for a context type.
|
||||
|
||||
Args:
|
||||
context_type: Context type to check
|
||||
|
||||
Returns:
|
||||
Remaining token count
|
||||
"""
|
||||
allocated = self.get_allocation(context_type)
|
||||
used = self.get_used(context_type)
|
||||
return max(0, allocated - used)
|
||||
|
||||
def total_remaining(self) -> int:
|
||||
"""
|
||||
Get total remaining tokens across all types.
|
||||
|
||||
Returns:
|
||||
Total remaining tokens
|
||||
"""
|
||||
total_used = sum(self.used.values())
|
||||
usable = self.total - self.response_reserve - self.buffer
|
||||
return max(0, usable - total_used)
|
||||
|
||||
def total_used(self) -> int:
|
||||
"""
|
||||
Get total used tokens.
|
||||
|
||||
Returns:
|
||||
Total used tokens
|
||||
"""
|
||||
return sum(self.used.values())
|
||||
|
||||
def can_fit(self, context_type: ContextType | str, tokens: int) -> bool:
|
||||
"""
|
||||
Check if tokens fit within budget for a type.
|
||||
|
||||
Args:
|
||||
context_type: Context type to check
|
||||
tokens: Number of tokens to fit
|
||||
|
||||
Returns:
|
||||
True if tokens fit within remaining budget
|
||||
"""
|
||||
return tokens <= self.remaining(context_type)
|
||||
|
||||
def allocate(
|
||||
self,
|
||||
context_type: ContextType | str,
|
||||
tokens: int,
|
||||
force: bool = False,
|
||||
) -> bool:
|
||||
"""
|
||||
Allocate (use) tokens from a context type's budget.
|
||||
|
||||
Args:
|
||||
context_type: Context type to allocate from
|
||||
tokens: Number of tokens to allocate
|
||||
force: If True, allow exceeding budget
|
||||
|
||||
Returns:
|
||||
True if allocation succeeded
|
||||
|
||||
Raises:
|
||||
BudgetExceededError: If tokens exceed budget and force=False
|
||||
"""
|
||||
if isinstance(context_type, ContextType):
|
||||
context_type = context_type.value
|
||||
|
||||
if not force and not self.can_fit(context_type, tokens):
|
||||
raise BudgetExceededError(
|
||||
message=f"Token budget exceeded for {context_type}",
|
||||
allocated=self.get_allocation(context_type),
|
||||
requested=self.get_used(context_type) + tokens,
|
||||
context_type=context_type,
|
||||
)
|
||||
|
||||
self.used[context_type] = self.used.get(context_type, 0) + tokens
|
||||
return True
|
||||
|
||||
def deallocate(
|
||||
self,
|
||||
context_type: ContextType | str,
|
||||
tokens: int,
|
||||
) -> None:
|
||||
"""
|
||||
Deallocate (return) tokens to a context type's budget.
|
||||
|
||||
Args:
|
||||
context_type: Context type to return to
|
||||
tokens: Number of tokens to return
|
||||
"""
|
||||
if isinstance(context_type, ContextType):
|
||||
context_type = context_type.value
|
||||
|
||||
current = self.used.get(context_type, 0)
|
||||
self.used[context_type] = max(0, current - tokens)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset all usage tracking."""
|
||||
self.used = {ct.value: 0 for ct in ContextType}
|
||||
|
||||
def utilization(self, context_type: ContextType | str | None = None) -> float:
|
||||
"""
|
||||
Get budget utilization percentage.
|
||||
|
||||
Args:
|
||||
context_type: Specific type or None for total
|
||||
|
||||
Returns:
|
||||
Utilization as a fraction (0.0 to 1.0+)
|
||||
"""
|
||||
if context_type is None:
|
||||
usable = self.total - self.response_reserve - self.buffer
|
||||
if usable <= 0:
|
||||
return 0.0
|
||||
return self.total_used() / usable
|
||||
|
||||
allocated = self.get_allocation(context_type)
|
||||
if allocated <= 0:
|
||||
return 0.0
|
||||
return self.get_used(context_type) / allocated
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert budget to dictionary."""
|
||||
return {
|
||||
"total": self.total,
|
||||
"allocations": {
|
||||
"system": self.system,
|
||||
"task": self.task,
|
||||
"knowledge": self.knowledge,
|
||||
"conversation": self.conversation,
|
||||
"tools": self.tools,
|
||||
"memory": self.memory,
|
||||
"response_reserve": self.response_reserve,
|
||||
"buffer": self.buffer,
|
||||
},
|
||||
"used": dict(self.used),
|
||||
"remaining": {ct.value: self.remaining(ct) for ct in ContextType},
|
||||
"total_used": self.total_used(),
|
||||
"total_remaining": self.total_remaining(),
|
||||
"utilization": round(self.utilization(), 3),
|
||||
}
|
||||
|
||||
|
||||
class BudgetAllocator:
|
||||
"""
|
||||
Budget allocator for context management.
|
||||
|
||||
Creates token budgets based on configuration and
|
||||
model context window sizes.
|
||||
"""
|
||||
|
||||
def __init__(self, settings: ContextSettings | None = None) -> None:
|
||||
"""
|
||||
Initialize budget allocator.
|
||||
|
||||
Args:
|
||||
settings: Context settings (uses default if None)
|
||||
"""
|
||||
self._settings = settings or get_context_settings()
|
||||
|
||||
def create_budget(
|
||||
self,
|
||||
total_tokens: int,
|
||||
custom_allocations: dict[str, float] | None = None,
|
||||
) -> TokenBudget:
|
||||
"""
|
||||
Create a token budget with allocations.
|
||||
|
||||
Args:
|
||||
total_tokens: Total available tokens
|
||||
custom_allocations: Optional custom allocation percentages
|
||||
|
||||
Returns:
|
||||
TokenBudget with allocations set
|
||||
"""
|
||||
# Use custom or default allocations
|
||||
if custom_allocations:
|
||||
alloc = custom_allocations
|
||||
else:
|
||||
alloc = self._settings.get_budget_allocation()
|
||||
|
||||
return TokenBudget(
|
||||
total=total_tokens,
|
||||
system=int(total_tokens * alloc.get("system", 0.05)),
|
||||
task=int(total_tokens * alloc.get("task", 0.10)),
|
||||
knowledge=int(total_tokens * alloc.get("knowledge", 0.30)),
|
||||
conversation=int(total_tokens * alloc.get("conversation", 0.15)),
|
||||
tools=int(total_tokens * alloc.get("tools", 0.05)),
|
||||
memory=int(total_tokens * alloc.get("memory", 0.15)),
|
||||
response_reserve=int(total_tokens * alloc.get("response", 0.15)),
|
||||
buffer=int(total_tokens * alloc.get("buffer", 0.05)),
|
||||
)
|
||||
|
||||
def adjust_budget(
|
||||
self,
|
||||
budget: TokenBudget,
|
||||
context_type: ContextType | str,
|
||||
adjustment: int,
|
||||
) -> TokenBudget:
|
||||
"""
|
||||
Adjust a specific allocation in a budget.
|
||||
|
||||
Takes tokens from buffer and adds to specified type.
|
||||
|
||||
Args:
|
||||
budget: Budget to adjust
|
||||
context_type: Type to adjust
|
||||
adjustment: Positive to increase, negative to decrease
|
||||
|
||||
Returns:
|
||||
Adjusted budget
|
||||
"""
|
||||
if isinstance(context_type, ContextType):
|
||||
context_type = context_type.value
|
||||
|
||||
# Calculate adjustment (limited by buffer for increases, by current allocation for decreases)
|
||||
if adjustment > 0:
|
||||
# Taking from buffer - limited by available buffer
|
||||
actual_adjustment = min(adjustment, budget.buffer)
|
||||
budget.buffer -= actual_adjustment
|
||||
else:
|
||||
# Returning to buffer - limited by current allocation of target type
|
||||
current_allocation = budget.get_allocation(context_type)
|
||||
# Can't return more than current allocation
|
||||
actual_adjustment = max(adjustment, -current_allocation)
|
||||
# Add returned tokens back to buffer (adjustment is negative, so subtract)
|
||||
budget.buffer -= actual_adjustment
|
||||
|
||||
# Apply to target type
|
||||
if context_type == "system":
|
||||
budget.system = max(0, budget.system + actual_adjustment)
|
||||
elif context_type == "task":
|
||||
budget.task = max(0, budget.task + actual_adjustment)
|
||||
elif context_type == "knowledge":
|
||||
budget.knowledge = max(0, budget.knowledge + actual_adjustment)
|
||||
elif context_type == "conversation":
|
||||
budget.conversation = max(0, budget.conversation + actual_adjustment)
|
||||
elif context_type == "tool":
|
||||
budget.tools = max(0, budget.tools + actual_adjustment)
|
||||
elif context_type == "memory":
|
||||
budget.memory = max(0, budget.memory + actual_adjustment)
|
||||
|
||||
return budget
|
||||
|
||||
def rebalance_budget(
|
||||
self,
|
||||
budget: TokenBudget,
|
||||
prioritize: list[ContextType] | None = None,
|
||||
) -> TokenBudget:
|
||||
"""
|
||||
Rebalance budget based on actual usage.
|
||||
|
||||
Moves unused allocations to prioritized types.
|
||||
|
||||
Args:
|
||||
budget: Budget to rebalance
|
||||
prioritize: Types to prioritize (in order)
|
||||
|
||||
Returns:
|
||||
Rebalanced budget
|
||||
"""
|
||||
if prioritize is None:
|
||||
prioritize = [
|
||||
ContextType.KNOWLEDGE,
|
||||
ContextType.MEMORY,
|
||||
ContextType.TASK,
|
||||
ContextType.SYSTEM,
|
||||
]
|
||||
|
||||
# Calculate unused tokens per type
|
||||
unused: dict[str, int] = {}
|
||||
for ct in ContextType:
|
||||
remaining = budget.remaining(ct)
|
||||
if remaining > 0:
|
||||
unused[ct.value] = remaining
|
||||
|
||||
# Calculate total reclaimable (excluding prioritized types)
|
||||
prioritize_values = {ct.value for ct in prioritize}
|
||||
reclaimable = sum(
|
||||
tokens for ct, tokens in unused.items() if ct not in prioritize_values
|
||||
)
|
||||
|
||||
# Redistribute to prioritized types that are near capacity
|
||||
for ct in prioritize:
|
||||
utilization = budget.utilization(ct)
|
||||
|
||||
if utilization > 0.8: # Near capacity
|
||||
# Give more tokens from reclaimable pool
|
||||
bonus = min(reclaimable, budget.get_allocation(ct) // 2)
|
||||
self.adjust_budget(budget, ct, bonus)
|
||||
reclaimable -= bonus
|
||||
|
||||
if reclaimable <= 0:
|
||||
break
|
||||
|
||||
return budget
|
||||
|
||||
def get_model_context_size(self, model: str) -> int:
|
||||
"""
|
||||
Get context window size for a model.
|
||||
|
||||
Args:
|
||||
model: Model name
|
||||
|
||||
Returns:
|
||||
Context window size in tokens
|
||||
"""
|
||||
# Common model context sizes
|
||||
context_sizes = {
|
||||
"claude-3-opus": 200000,
|
||||
"claude-3-sonnet": 200000,
|
||||
"claude-3-haiku": 200000,
|
||||
"claude-3-5-sonnet": 200000,
|
||||
"claude-3-5-haiku": 200000,
|
||||
"claude-opus-4": 200000,
|
||||
"gpt-4-turbo": 128000,
|
||||
"gpt-4": 8192,
|
||||
"gpt-4-32k": 32768,
|
||||
"gpt-4o": 128000,
|
||||
"gpt-4o-mini": 128000,
|
||||
"gpt-3.5-turbo": 16385,
|
||||
"gemini-1.5-pro": 2000000,
|
||||
"gemini-1.5-flash": 1000000,
|
||||
"gemini-2.0-flash": 1000000,
|
||||
"qwen-plus": 32000,
|
||||
"qwen-turbo": 8000,
|
||||
"deepseek-chat": 64000,
|
||||
"deepseek-reasoner": 64000,
|
||||
}
|
||||
|
||||
# Check exact match first
|
||||
model_lower = model.lower()
|
||||
if model_lower in context_sizes:
|
||||
return context_sizes[model_lower]
|
||||
|
||||
# Check prefix match
|
||||
for model_name, size in context_sizes.items():
|
||||
if model_lower.startswith(model_name):
|
||||
return size
|
||||
|
||||
# Default fallback
|
||||
return 8192
|
||||
|
||||
def create_budget_for_model(
|
||||
self,
|
||||
model: str,
|
||||
custom_allocations: dict[str, float] | None = None,
|
||||
) -> TokenBudget:
|
||||
"""
|
||||
Create a budget based on model's context window.
|
||||
|
||||
Args:
|
||||
model: Model name
|
||||
custom_allocations: Optional custom allocation percentages
|
||||
|
||||
Returns:
|
||||
TokenBudget sized for the model
|
||||
"""
|
||||
context_size = self.get_model_context_size(model)
|
||||
return self.create_budget(context_size, custom_allocations)
|
||||
285
backend/app/services/context/budget/calculator.py
Normal file
285
backend/app/services/context/budget/calculator.py
Normal file
@@ -0,0 +1,285 @@
|
||||
"""
|
||||
Token Calculator for Context Management.
|
||||
|
||||
Provides token counting with caching and fallback estimation.
|
||||
Integrates with LLM Gateway for accurate counts.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Protocol
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.services.mcp.client_manager import MCPClientManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TokenCounterProtocol(Protocol):
|
||||
"""Protocol for token counting implementations."""
|
||||
|
||||
async def count_tokens(
|
||||
self,
|
||||
text: str,
|
||||
model: str | None = None,
|
||||
) -> int:
|
||||
"""Count tokens in text."""
|
||||
...
|
||||
|
||||
|
||||
class TokenCalculator:
|
||||
"""
|
||||
Token calculator with LLM Gateway integration.
|
||||
|
||||
Features:
|
||||
- In-memory caching for repeated text
|
||||
- Fallback to character-based estimation
|
||||
- Model-specific counting when possible
|
||||
|
||||
The calculator uses the LLM Gateway's count_tokens tool
|
||||
for accurate counting, with a local cache to avoid
|
||||
repeated calls for the same content.
|
||||
"""
|
||||
|
||||
# Default characters per token ratio for estimation
|
||||
DEFAULT_CHARS_PER_TOKEN: ClassVar[float] = 4.0
|
||||
|
||||
# Model-specific ratios (more accurate estimation)
|
||||
MODEL_CHAR_RATIOS: ClassVar[dict[str, float]] = {
|
||||
"claude": 3.5,
|
||||
"gpt-4": 4.0,
|
||||
"gpt-3.5": 4.0,
|
||||
"gemini": 4.0,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mcp_manager: "MCPClientManager | None" = None,
|
||||
project_id: str = "system",
|
||||
agent_id: str = "context-engine",
|
||||
cache_enabled: bool = True,
|
||||
cache_max_size: int = 10000,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize token calculator.
|
||||
|
||||
Args:
|
||||
mcp_manager: MCP client manager for LLM Gateway calls
|
||||
project_id: Project ID for LLM Gateway calls
|
||||
agent_id: Agent ID for LLM Gateway calls
|
||||
cache_enabled: Whether to enable in-memory caching
|
||||
cache_max_size: Maximum cache entries
|
||||
"""
|
||||
self._mcp = mcp_manager
|
||||
self._project_id = project_id
|
||||
self._agent_id = agent_id
|
||||
self._cache_enabled = cache_enabled
|
||||
self._cache_max_size = cache_max_size
|
||||
|
||||
# In-memory cache: hash(model:text) -> token_count
|
||||
self._cache: dict[str, int] = {}
|
||||
self._cache_hits = 0
|
||||
self._cache_misses = 0
|
||||
|
||||
def _get_cache_key(self, text: str, model: str | None) -> str:
|
||||
"""Generate cache key from text and model."""
|
||||
# Use hash for efficient storage
|
||||
content = f"{model or 'default'}:{text}"
|
||||
return hashlib.sha256(content.encode()).hexdigest()[:32]
|
||||
|
||||
def _check_cache(self, cache_key: str) -> int | None:
|
||||
"""Check cache for existing count."""
|
||||
if not self._cache_enabled:
|
||||
return None
|
||||
|
||||
if cache_key in self._cache:
|
||||
self._cache_hits += 1
|
||||
return self._cache[cache_key]
|
||||
|
||||
self._cache_misses += 1
|
||||
return None
|
||||
|
||||
def _store_cache(self, cache_key: str, count: int) -> None:
|
||||
"""Store count in cache."""
|
||||
if not self._cache_enabled:
|
||||
return
|
||||
|
||||
# Simple LRU-like eviction: remove oldest entries when full
|
||||
if len(self._cache) >= self._cache_max_size:
|
||||
# Remove first 10% of entries
|
||||
entries_to_remove = self._cache_max_size // 10
|
||||
keys_to_remove = list(self._cache.keys())[:entries_to_remove]
|
||||
for key in keys_to_remove:
|
||||
del self._cache[key]
|
||||
|
||||
self._cache[cache_key] = count
|
||||
|
||||
def estimate_tokens(self, text: str, model: str | None = None) -> int:
|
||||
"""
|
||||
Estimate token count based on character count.
|
||||
|
||||
This is a fast fallback when LLM Gateway is unavailable.
|
||||
|
||||
Args:
|
||||
text: Text to count
|
||||
model: Optional model for more accurate ratio
|
||||
|
||||
Returns:
|
||||
Estimated token count
|
||||
"""
|
||||
if not text:
|
||||
return 0
|
||||
|
||||
# Get model-specific ratio
|
||||
ratio = self.DEFAULT_CHARS_PER_TOKEN
|
||||
if model:
|
||||
model_lower = model.lower()
|
||||
for model_prefix, model_ratio in self.MODEL_CHAR_RATIOS.items():
|
||||
if model_prefix in model_lower:
|
||||
ratio = model_ratio
|
||||
break
|
||||
|
||||
return max(1, int(len(text) / ratio))
|
||||
|
||||
async def count_tokens(
|
||||
self,
|
||||
text: str,
|
||||
model: str | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Count tokens in text.
|
||||
|
||||
Uses LLM Gateway for accurate counts with fallback to estimation.
|
||||
|
||||
Args:
|
||||
text: Text to count
|
||||
model: Optional model for accurate counting
|
||||
|
||||
Returns:
|
||||
Token count
|
||||
"""
|
||||
if not text:
|
||||
return 0
|
||||
|
||||
# Check cache first
|
||||
cache_key = self._get_cache_key(text, model)
|
||||
cached = self._check_cache(cache_key)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
# Try LLM Gateway
|
||||
if self._mcp is not None:
|
||||
try:
|
||||
result = await self._mcp.call_tool(
|
||||
server="llm-gateway",
|
||||
tool="count_tokens",
|
||||
args={
|
||||
"project_id": self._project_id,
|
||||
"agent_id": self._agent_id,
|
||||
"text": text,
|
||||
"model": model,
|
||||
},
|
||||
)
|
||||
|
||||
# Parse result
|
||||
if result.success and result.data:
|
||||
count = self._parse_token_count(result.data)
|
||||
if count is not None:
|
||||
self._store_cache(cache_key, count)
|
||||
return count
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM Gateway token count failed, using estimation: {e}")
|
||||
|
||||
# Fallback to estimation
|
||||
count = self.estimate_tokens(text, model)
|
||||
self._store_cache(cache_key, count)
|
||||
return count
|
||||
|
||||
def _parse_token_count(self, data: Any) -> int | None:
|
||||
"""Parse token count from LLM Gateway response."""
|
||||
if isinstance(data, dict):
|
||||
if "token_count" in data:
|
||||
return int(data["token_count"])
|
||||
if "tokens" in data:
|
||||
return int(data["tokens"])
|
||||
if "count" in data:
|
||||
return int(data["count"])
|
||||
|
||||
if isinstance(data, int):
|
||||
return data
|
||||
|
||||
if isinstance(data, str):
|
||||
# Try to parse from text content
|
||||
try:
|
||||
# Handle {"token_count": 123} or just "123"
|
||||
import json
|
||||
|
||||
parsed = json.loads(data)
|
||||
if isinstance(parsed, dict) and "token_count" in parsed:
|
||||
return int(parsed["token_count"])
|
||||
if isinstance(parsed, int):
|
||||
return parsed
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
# Try direct int conversion
|
||||
try:
|
||||
return int(data)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
async def count_tokens_batch(
|
||||
self,
|
||||
texts: list[str],
|
||||
model: str | None = None,
|
||||
) -> list[int]:
|
||||
"""
|
||||
Count tokens for multiple texts.
|
||||
|
||||
Efficient batch counting with caching and parallel execution.
|
||||
|
||||
Args:
|
||||
texts: List of texts to count
|
||||
model: Optional model for accurate counting
|
||||
|
||||
Returns:
|
||||
List of token counts (same order as input)
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
# Execute all token counts in parallel for better performance
|
||||
tasks = [self.count_tokens(text, model) for text in texts]
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear the token count cache."""
|
||||
self._cache.clear()
|
||||
self._cache_hits = 0
|
||||
self._cache_misses = 0
|
||||
|
||||
def get_cache_stats(self) -> dict[str, Any]:
|
||||
"""Get cache statistics."""
|
||||
total = self._cache_hits + self._cache_misses
|
||||
hit_rate = self._cache_hits / total if total > 0 else 0.0
|
||||
|
||||
return {
|
||||
"enabled": self._cache_enabled,
|
||||
"size": len(self._cache),
|
||||
"max_size": self._cache_max_size,
|
||||
"hits": self._cache_hits,
|
||||
"misses": self._cache_misses,
|
||||
"hit_rate": round(hit_rate, 3),
|
||||
}
|
||||
|
||||
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
|
||||
"""
|
||||
Set the MCP manager (for lazy initialization).
|
||||
|
||||
Args:
|
||||
mcp_manager: MCP client manager instance
|
||||
"""
|
||||
self._mcp = mcp_manager
|
||||
11
backend/app/services/context/cache/__init__.py
vendored
Normal file
11
backend/app/services/context/cache/__init__.py
vendored
Normal file
@@ -0,0 +1,11 @@
|
||||
"""
|
||||
Context Cache Module.
|
||||
|
||||
Provides Redis-based caching for assembled contexts.
|
||||
"""
|
||||
|
||||
from .context_cache import ContextCache
|
||||
|
||||
__all__ = [
|
||||
"ContextCache",
|
||||
]
|
||||
434
backend/app/services/context/cache/context_cache.py
vendored
Normal file
434
backend/app/services/context/cache/context_cache.py
vendored
Normal file
@@ -0,0 +1,434 @@
|
||||
"""
|
||||
Context Cache Implementation.
|
||||
|
||||
Provides Redis-based caching for context operations including
|
||||
assembled contexts, token counts, and scoring results.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from ..config import ContextSettings, get_context_settings
|
||||
from ..exceptions import CacheError
|
||||
from ..types import AssembledContext, BaseContext
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis.asyncio import Redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ContextCache:
|
||||
"""
|
||||
Redis-based caching for context operations.
|
||||
|
||||
Provides caching for:
|
||||
- Assembled contexts (fingerprint-based)
|
||||
- Token counts (content hash-based)
|
||||
- Scoring results (context + query hash-based)
|
||||
|
||||
Cache keys use a hierarchical structure:
|
||||
- ctx:assembled:{fingerprint}
|
||||
- ctx:tokens:{model}:{content_hash}
|
||||
- ctx:score:{scorer}:{context_hash}:{query_hash}
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis: "Redis | None" = None,
|
||||
settings: ContextSettings | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the context cache.
|
||||
|
||||
Args:
|
||||
redis: Redis connection (optional for testing)
|
||||
settings: Cache settings
|
||||
"""
|
||||
self._redis = redis
|
||||
self._settings = settings or get_context_settings()
|
||||
self._prefix = self._settings.cache_prefix
|
||||
self._ttl = self._settings.cache_ttl_seconds
|
||||
|
||||
# In-memory fallback cache when Redis unavailable
|
||||
self._memory_cache: dict[str, tuple[str, float]] = {}
|
||||
self._max_memory_items = self._settings.cache_memory_max_items
|
||||
|
||||
def set_redis(self, redis: "Redis") -> None:
|
||||
"""Set Redis connection."""
|
||||
self._redis = redis
|
||||
|
||||
@property
|
||||
def is_enabled(self) -> bool:
|
||||
"""Check if caching is enabled and available."""
|
||||
return self._settings.cache_enabled and self._redis is not None
|
||||
|
||||
def _cache_key(self, *parts: str) -> str:
|
||||
"""
|
||||
Build a cache key from parts.
|
||||
|
||||
Args:
|
||||
*parts: Key components
|
||||
|
||||
Returns:
|
||||
Colon-separated cache key
|
||||
"""
|
||||
return f"{self._prefix}:{':'.join(parts)}"
|
||||
|
||||
@staticmethod
|
||||
def _hash_content(content: str) -> str:
|
||||
"""
|
||||
Compute hash of content for cache key.
|
||||
|
||||
Args:
|
||||
content: Content to hash
|
||||
|
||||
Returns:
|
||||
32-character hex hash
|
||||
"""
|
||||
return hashlib.sha256(content.encode()).hexdigest()[:32]
|
||||
|
||||
def compute_fingerprint(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
query: str,
|
||||
model: str,
|
||||
project_id: str | None = None,
|
||||
agent_id: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Compute a fingerprint for a context assembly request.
|
||||
|
||||
The fingerprint is based on:
|
||||
- Project and agent IDs (for tenant isolation)
|
||||
- Context content hash and metadata (not full content for performance)
|
||||
- Query string
|
||||
- Target model
|
||||
|
||||
SECURITY: project_id and agent_id MUST be included to prevent
|
||||
cross-tenant cache pollution. Without these, one tenant could
|
||||
receive cached contexts from another tenant with the same query.
|
||||
|
||||
Args:
|
||||
contexts: List of contexts
|
||||
query: Query string
|
||||
model: Model name
|
||||
project_id: Project ID for tenant isolation
|
||||
agent_id: Agent ID for tenant isolation
|
||||
|
||||
Returns:
|
||||
32-character hex fingerprint
|
||||
"""
|
||||
# Build a deterministic representation using content hashes for performance
|
||||
# This avoids JSON serializing potentially large content strings
|
||||
context_data = []
|
||||
for ctx in contexts:
|
||||
context_data.append(
|
||||
{
|
||||
"type": ctx.get_type().value,
|
||||
"content_hash": self._hash_content(
|
||||
ctx.content
|
||||
), # Hash instead of full content
|
||||
"source": ctx.source,
|
||||
"priority": ctx.priority, # Already an int
|
||||
}
|
||||
)
|
||||
|
||||
data = {
|
||||
# CRITICAL: Include tenant identifiers for cache isolation
|
||||
"project_id": project_id or "",
|
||||
"agent_id": agent_id or "",
|
||||
"contexts": context_data,
|
||||
"query": query,
|
||||
"model": model,
|
||||
}
|
||||
|
||||
content = json.dumps(data, sort_keys=True)
|
||||
return self._hash_content(content)
|
||||
|
||||
async def get_assembled(
|
||||
self,
|
||||
fingerprint: str,
|
||||
) -> AssembledContext | None:
|
||||
"""
|
||||
Get cached assembled context by fingerprint.
|
||||
|
||||
Args:
|
||||
fingerprint: Assembly fingerprint
|
||||
|
||||
Returns:
|
||||
Cached AssembledContext or None if not found
|
||||
"""
|
||||
if not self.is_enabled:
|
||||
return None
|
||||
|
||||
key = self._cache_key("assembled", fingerprint)
|
||||
|
||||
try:
|
||||
data = await self._redis.get(key) # type: ignore
|
||||
if data:
|
||||
logger.debug(f"Cache hit for assembled context: {fingerprint}")
|
||||
result = AssembledContext.from_json(data)
|
||||
result.cache_hit = True
|
||||
result.cache_key = fingerprint
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache get error: {e}")
|
||||
raise CacheError(f"Failed to get assembled context: {e}") from e
|
||||
|
||||
return None
|
||||
|
||||
async def set_assembled(
|
||||
self,
|
||||
fingerprint: str,
|
||||
context: AssembledContext,
|
||||
ttl: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Cache an assembled context.
|
||||
|
||||
Args:
|
||||
fingerprint: Assembly fingerprint
|
||||
context: Assembled context to cache
|
||||
ttl: Optional TTL override in seconds
|
||||
"""
|
||||
if not self.is_enabled:
|
||||
return
|
||||
|
||||
key = self._cache_key("assembled", fingerprint)
|
||||
expire = ttl or self._ttl
|
||||
|
||||
try:
|
||||
await self._redis.setex(key, expire, context.to_json()) # type: ignore
|
||||
logger.debug(f"Cached assembled context: {fingerprint}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache set error: {e}")
|
||||
raise CacheError(f"Failed to cache assembled context: {e}") from e
|
||||
|
||||
async def get_token_count(
|
||||
self,
|
||||
content: str,
|
||||
model: str | None = None,
|
||||
) -> int | None:
|
||||
"""
|
||||
Get cached token count.
|
||||
|
||||
Args:
|
||||
content: Content to look up
|
||||
model: Model name for model-specific tokenization
|
||||
|
||||
Returns:
|
||||
Cached token count or None if not found
|
||||
"""
|
||||
model_key = model or "default"
|
||||
content_hash = self._hash_content(content)
|
||||
key = self._cache_key("tokens", model_key, content_hash)
|
||||
|
||||
# Try in-memory first
|
||||
if key in self._memory_cache:
|
||||
return int(self._memory_cache[key][0])
|
||||
|
||||
if not self.is_enabled:
|
||||
return None
|
||||
|
||||
try:
|
||||
data = await self._redis.get(key) # type: ignore
|
||||
if data:
|
||||
count = int(data)
|
||||
# Store in memory for faster subsequent access
|
||||
self._set_memory(key, str(count))
|
||||
return count
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache get error for tokens: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def set_token_count(
|
||||
self,
|
||||
content: str,
|
||||
count: int,
|
||||
model: str | None = None,
|
||||
ttl: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Cache a token count.
|
||||
|
||||
Args:
|
||||
content: Content that was counted
|
||||
count: Token count
|
||||
model: Model name
|
||||
ttl: Optional TTL override in seconds
|
||||
"""
|
||||
model_key = model or "default"
|
||||
content_hash = self._hash_content(content)
|
||||
key = self._cache_key("tokens", model_key, content_hash)
|
||||
expire = ttl or self._ttl
|
||||
|
||||
# Always store in memory
|
||||
self._set_memory(key, str(count))
|
||||
|
||||
if not self.is_enabled:
|
||||
return
|
||||
|
||||
try:
|
||||
await self._redis.setex(key, expire, str(count)) # type: ignore
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache set error for tokens: {e}")
|
||||
|
||||
async def get_score(
|
||||
self,
|
||||
scorer_name: str,
|
||||
context_id: str,
|
||||
query: str,
|
||||
) -> float | None:
|
||||
"""
|
||||
Get cached score.
|
||||
|
||||
Args:
|
||||
scorer_name: Name of the scorer
|
||||
context_id: Context identifier
|
||||
query: Query string
|
||||
|
||||
Returns:
|
||||
Cached score or None if not found
|
||||
"""
|
||||
query_hash = self._hash_content(query)[:16]
|
||||
key = self._cache_key("score", scorer_name, context_id, query_hash)
|
||||
|
||||
# Try in-memory first
|
||||
if key in self._memory_cache:
|
||||
return float(self._memory_cache[key][0])
|
||||
|
||||
if not self.is_enabled:
|
||||
return None
|
||||
|
||||
try:
|
||||
data = await self._redis.get(key) # type: ignore
|
||||
if data:
|
||||
score = float(data)
|
||||
self._set_memory(key, str(score))
|
||||
return score
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache get error for score: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def set_score(
|
||||
self,
|
||||
scorer_name: str,
|
||||
context_id: str,
|
||||
query: str,
|
||||
score: float,
|
||||
ttl: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Cache a score.
|
||||
|
||||
Args:
|
||||
scorer_name: Name of the scorer
|
||||
context_id: Context identifier
|
||||
query: Query string
|
||||
score: Score value
|
||||
ttl: Optional TTL override in seconds
|
||||
"""
|
||||
query_hash = self._hash_content(query)[:16]
|
||||
key = self._cache_key("score", scorer_name, context_id, query_hash)
|
||||
expire = ttl or self._ttl
|
||||
|
||||
# Always store in memory
|
||||
self._set_memory(key, str(score))
|
||||
|
||||
if not self.is_enabled:
|
||||
return
|
||||
|
||||
try:
|
||||
await self._redis.setex(key, expire, str(score)) # type: ignore
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache set error for score: {e}")
|
||||
|
||||
async def invalidate(self, pattern: str) -> int:
|
||||
"""
|
||||
Invalidate cache entries matching a pattern.
|
||||
|
||||
Args:
|
||||
pattern: Key pattern (supports * wildcard)
|
||||
|
||||
Returns:
|
||||
Number of keys deleted
|
||||
"""
|
||||
if not self.is_enabled:
|
||||
return 0
|
||||
|
||||
full_pattern = self._cache_key(pattern)
|
||||
deleted = 0
|
||||
|
||||
try:
|
||||
async for key in self._redis.scan_iter(match=full_pattern): # type: ignore
|
||||
await self._redis.delete(key) # type: ignore
|
||||
deleted += 1
|
||||
|
||||
logger.info(f"Invalidated {deleted} cache entries matching {pattern}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache invalidation error: {e}")
|
||||
raise CacheError(f"Failed to invalidate cache: {e}") from e
|
||||
|
||||
return deleted
|
||||
|
||||
async def clear_all(self) -> int:
|
||||
"""
|
||||
Clear all context cache entries.
|
||||
|
||||
Returns:
|
||||
Number of keys deleted
|
||||
"""
|
||||
self._memory_cache.clear()
|
||||
return await self.invalidate("*")
|
||||
|
||||
def _set_memory(self, key: str, value: str) -> None:
|
||||
"""
|
||||
Set a value in the memory cache.
|
||||
|
||||
Uses LRU-style eviction when max items reached.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
value: Value to store
|
||||
"""
|
||||
import time
|
||||
|
||||
if len(self._memory_cache) >= self._max_memory_items:
|
||||
# Evict oldest entries
|
||||
sorted_keys = sorted(
|
||||
self._memory_cache.keys(),
|
||||
key=lambda k: self._memory_cache[k][1],
|
||||
)
|
||||
for k in sorted_keys[: len(sorted_keys) // 2]:
|
||||
del self._memory_cache[k]
|
||||
|
||||
self._memory_cache[key] = (value, time.time())
|
||||
|
||||
async def get_stats(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get cache statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with cache stats
|
||||
"""
|
||||
stats = {
|
||||
"enabled": self._settings.cache_enabled,
|
||||
"redis_available": self._redis is not None,
|
||||
"memory_items": len(self._memory_cache),
|
||||
"ttl_seconds": self._ttl,
|
||||
}
|
||||
|
||||
if self.is_enabled:
|
||||
try:
|
||||
# Get Redis info
|
||||
info = await self._redis.info("memory") # type: ignore
|
||||
stats["redis_memory_used"] = info.get("used_memory_human", "unknown")
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to get Redis stats: {e}")
|
||||
|
||||
return stats
|
||||
13
backend/app/services/context/compression/__init__.py
Normal file
13
backend/app/services/context/compression/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""
|
||||
Context Compression Module.
|
||||
|
||||
Provides truncation and compression strategies.
|
||||
"""
|
||||
|
||||
from .truncation import ContextCompressor, TruncationResult, TruncationStrategy
|
||||
|
||||
__all__ = [
|
||||
"ContextCompressor",
|
||||
"TruncationResult",
|
||||
"TruncationStrategy",
|
||||
]
|
||||
453
backend/app/services/context/compression/truncation.py
Normal file
453
backend/app/services/context/compression/truncation.py
Normal file
@@ -0,0 +1,453 @@
|
||||
"""
|
||||
Smart Truncation for Context Compression.
|
||||
|
||||
Provides intelligent truncation strategies to reduce context size
|
||||
while preserving the most important information.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ..config import ContextSettings, get_context_settings
|
||||
from ..types import BaseContext, ContextType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..budget import TokenBudget, TokenCalculator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _estimate_tokens(text: str, model: str | None = None) -> int:
|
||||
"""
|
||||
Estimate token count using model-specific character ratios.
|
||||
|
||||
Module-level function for reuse across classes. Uses the same ratios
|
||||
as TokenCalculator for consistency.
|
||||
|
||||
Args:
|
||||
text: Text to estimate tokens for
|
||||
model: Optional model name for model-specific ratios
|
||||
|
||||
Returns:
|
||||
Estimated token count (minimum 1)
|
||||
"""
|
||||
# Model-specific character ratios (chars per token)
|
||||
model_ratios = {
|
||||
"claude": 3.5,
|
||||
"gpt-4": 4.0,
|
||||
"gpt-3.5": 4.0,
|
||||
"gemini": 4.0,
|
||||
}
|
||||
default_ratio = 4.0
|
||||
|
||||
ratio = default_ratio
|
||||
if model:
|
||||
model_lower = model.lower()
|
||||
for model_prefix, model_ratio in model_ratios.items():
|
||||
if model_prefix in model_lower:
|
||||
ratio = model_ratio
|
||||
break
|
||||
|
||||
return max(1, int(len(text) / ratio))
|
||||
|
||||
|
||||
@dataclass
|
||||
class TruncationResult:
|
||||
"""Result of truncation operation."""
|
||||
|
||||
original_tokens: int
|
||||
truncated_tokens: int
|
||||
content: str
|
||||
truncated: bool
|
||||
truncation_ratio: float # 0.0 = no truncation, 1.0 = completely removed
|
||||
|
||||
@property
|
||||
def tokens_saved(self) -> int:
|
||||
"""Calculate tokens saved by truncation."""
|
||||
return self.original_tokens - self.truncated_tokens
|
||||
|
||||
|
||||
class TruncationStrategy:
|
||||
"""
|
||||
Smart truncation strategies for context compression.
|
||||
|
||||
Strategies:
|
||||
1. End truncation: Cut from end (for knowledge/docs)
|
||||
2. Middle truncation: Keep start and end (for code)
|
||||
3. Sentence-aware: Truncate at sentence boundaries
|
||||
4. Semantic chunking: Keep most relevant chunks
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
calculator: "TokenCalculator | None" = None,
|
||||
preserve_ratio_start: float | None = None,
|
||||
min_content_length: int | None = None,
|
||||
settings: ContextSettings | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize truncation strategy.
|
||||
|
||||
Args:
|
||||
calculator: Token calculator for accurate counting
|
||||
preserve_ratio_start: Ratio of content to keep from start (overrides settings)
|
||||
min_content_length: Minimum characters to preserve (overrides settings)
|
||||
settings: Context settings (uses global if None)
|
||||
"""
|
||||
self._settings = settings or get_context_settings()
|
||||
self._calculator = calculator
|
||||
|
||||
# Use provided values or fall back to settings
|
||||
self._preserve_ratio_start = (
|
||||
preserve_ratio_start
|
||||
if preserve_ratio_start is not None
|
||||
else self._settings.truncation_preserve_ratio
|
||||
)
|
||||
self._min_content_length = (
|
||||
min_content_length
|
||||
if min_content_length is not None
|
||||
else self._settings.truncation_min_content_length
|
||||
)
|
||||
|
||||
@property
|
||||
def truncation_marker(self) -> str:
|
||||
"""Get truncation marker from settings."""
|
||||
return self._settings.truncation_marker
|
||||
|
||||
def set_calculator(self, calculator: "TokenCalculator") -> None:
|
||||
"""Set token calculator."""
|
||||
self._calculator = calculator
|
||||
|
||||
async def truncate_to_tokens(
|
||||
self,
|
||||
content: str,
|
||||
max_tokens: int,
|
||||
strategy: str = "end",
|
||||
model: str | None = None,
|
||||
) -> TruncationResult:
|
||||
"""
|
||||
Truncate content to fit within token limit.
|
||||
|
||||
Args:
|
||||
content: Content to truncate
|
||||
max_tokens: Maximum tokens allowed
|
||||
strategy: Truncation strategy ('end', 'middle', 'sentence')
|
||||
model: Model for token counting
|
||||
|
||||
Returns:
|
||||
TruncationResult with truncated content
|
||||
"""
|
||||
if not content:
|
||||
return TruncationResult(
|
||||
original_tokens=0,
|
||||
truncated_tokens=0,
|
||||
content="",
|
||||
truncated=False,
|
||||
truncation_ratio=0.0,
|
||||
)
|
||||
|
||||
# Get original token count
|
||||
original_tokens = await self._count_tokens(content, model)
|
||||
|
||||
if original_tokens <= max_tokens:
|
||||
return TruncationResult(
|
||||
original_tokens=original_tokens,
|
||||
truncated_tokens=original_tokens,
|
||||
content=content,
|
||||
truncated=False,
|
||||
truncation_ratio=0.0,
|
||||
)
|
||||
|
||||
# Apply truncation strategy
|
||||
if strategy == "middle":
|
||||
truncated = await self._truncate_middle(content, max_tokens, model)
|
||||
elif strategy == "sentence":
|
||||
truncated = await self._truncate_sentence(content, max_tokens, model)
|
||||
else: # "end"
|
||||
truncated = await self._truncate_end(content, max_tokens, model)
|
||||
|
||||
truncated_tokens = await self._count_tokens(truncated, model)
|
||||
|
||||
return TruncationResult(
|
||||
original_tokens=original_tokens,
|
||||
truncated_tokens=truncated_tokens,
|
||||
content=truncated,
|
||||
truncated=True,
|
||||
truncation_ratio=0.0
|
||||
if original_tokens == 0
|
||||
else 1 - (truncated_tokens / original_tokens),
|
||||
)
|
||||
|
||||
async def _truncate_end(
|
||||
self,
|
||||
content: str,
|
||||
max_tokens: int,
|
||||
model: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Truncate from end of content.
|
||||
|
||||
Simple but effective for most content types.
|
||||
"""
|
||||
# Binary search for optimal truncation point
|
||||
marker_tokens = await self._count_tokens(self.truncation_marker, model)
|
||||
available_tokens = max(0, max_tokens - marker_tokens)
|
||||
|
||||
# Edge case: if no tokens available for content, return just the marker
|
||||
if available_tokens <= 0:
|
||||
return self.truncation_marker
|
||||
|
||||
# Estimate characters per token (guard against division by zero)
|
||||
content_tokens = await self._count_tokens(content, model)
|
||||
if content_tokens == 0:
|
||||
return content + self.truncation_marker
|
||||
chars_per_token = len(content) / content_tokens
|
||||
|
||||
# Start with estimated position
|
||||
estimated_chars = int(available_tokens * chars_per_token)
|
||||
truncated = content[:estimated_chars]
|
||||
|
||||
# Refine with binary search
|
||||
low, high = len(truncated) // 2, len(truncated)
|
||||
best = truncated
|
||||
|
||||
for _ in range(5): # Max 5 iterations
|
||||
mid = (low + high) // 2
|
||||
candidate = content[:mid]
|
||||
tokens = await self._count_tokens(candidate, model)
|
||||
|
||||
if tokens <= available_tokens:
|
||||
best = candidate
|
||||
low = mid + 1
|
||||
else:
|
||||
high = mid - 1
|
||||
|
||||
return best + self.truncation_marker
|
||||
|
||||
async def _truncate_middle(
|
||||
self,
|
||||
content: str,
|
||||
max_tokens: int,
|
||||
model: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Truncate from middle, keeping start and end.
|
||||
|
||||
Good for code or content where context at boundaries matters.
|
||||
"""
|
||||
marker_tokens = await self._count_tokens(self.truncation_marker, model)
|
||||
available_tokens = max_tokens - marker_tokens
|
||||
|
||||
# Split between start and end
|
||||
start_tokens = int(available_tokens * self._preserve_ratio_start)
|
||||
end_tokens = available_tokens - start_tokens
|
||||
|
||||
# Get start portion
|
||||
start_content = await self._get_content_for_tokens(
|
||||
content, start_tokens, from_start=True, model=model
|
||||
)
|
||||
|
||||
# Get end portion
|
||||
end_content = await self._get_content_for_tokens(
|
||||
content, end_tokens, from_start=False, model=model
|
||||
)
|
||||
|
||||
return start_content + self.truncation_marker + end_content
|
||||
|
||||
async def _truncate_sentence(
|
||||
self,
|
||||
content: str,
|
||||
max_tokens: int,
|
||||
model: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Truncate at sentence boundaries.
|
||||
|
||||
Produces cleaner output by not cutting mid-sentence.
|
||||
"""
|
||||
# Split into sentences
|
||||
sentences = re.split(r"(?<=[.!?])\s+", content)
|
||||
|
||||
result: list[str] = []
|
||||
total_tokens = 0
|
||||
marker_tokens = await self._count_tokens(self.truncation_marker, model)
|
||||
available = max_tokens - marker_tokens
|
||||
|
||||
for sentence in sentences:
|
||||
sentence_tokens = await self._count_tokens(sentence, model)
|
||||
if total_tokens + sentence_tokens <= available:
|
||||
result.append(sentence)
|
||||
total_tokens += sentence_tokens
|
||||
else:
|
||||
break
|
||||
|
||||
if len(result) < len(sentences):
|
||||
return " ".join(result) + self.truncation_marker
|
||||
return " ".join(result)
|
||||
|
||||
async def _get_content_for_tokens(
|
||||
self,
|
||||
content: str,
|
||||
target_tokens: int,
|
||||
from_start: bool = True,
|
||||
model: str | None = None,
|
||||
) -> str:
|
||||
"""Get portion of content fitting within token limit."""
|
||||
if target_tokens <= 0:
|
||||
return ""
|
||||
|
||||
current_tokens = await self._count_tokens(content, model)
|
||||
if current_tokens <= target_tokens:
|
||||
return content
|
||||
|
||||
# Estimate characters (guard against division by zero)
|
||||
if current_tokens == 0:
|
||||
return content
|
||||
chars_per_token = len(content) / current_tokens
|
||||
estimated_chars = int(target_tokens * chars_per_token)
|
||||
|
||||
if from_start:
|
||||
return content[:estimated_chars]
|
||||
else:
|
||||
return content[-estimated_chars:]
|
||||
|
||||
async def _count_tokens(self, text: str, model: str | None = None) -> int:
|
||||
"""Count tokens using calculator or estimation."""
|
||||
if self._calculator is not None:
|
||||
return await self._calculator.count_tokens(text, model)
|
||||
|
||||
# Fallback estimation with model-specific ratios
|
||||
return _estimate_tokens(text, model)
|
||||
|
||||
|
||||
class ContextCompressor:
|
||||
"""
|
||||
Compresses contexts to fit within budget constraints.
|
||||
|
||||
Uses truncation strategies to reduce context size while
|
||||
preserving the most important information.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
truncation: TruncationStrategy | None = None,
|
||||
calculator: "TokenCalculator | None" = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize context compressor.
|
||||
|
||||
Args:
|
||||
truncation: Truncation strategy to use
|
||||
calculator: Token calculator for counting
|
||||
"""
|
||||
self._truncation = truncation or TruncationStrategy(calculator)
|
||||
self._calculator = calculator
|
||||
|
||||
if calculator:
|
||||
self._truncation.set_calculator(calculator)
|
||||
|
||||
def set_calculator(self, calculator: "TokenCalculator") -> None:
|
||||
"""Set token calculator."""
|
||||
self._calculator = calculator
|
||||
self._truncation.set_calculator(calculator)
|
||||
|
||||
async def compress_context(
|
||||
self,
|
||||
context: BaseContext,
|
||||
max_tokens: int,
|
||||
model: str | None = None,
|
||||
) -> BaseContext:
|
||||
"""
|
||||
Compress a single context to fit token limit.
|
||||
|
||||
Args:
|
||||
context: Context to compress
|
||||
max_tokens: Maximum tokens allowed
|
||||
model: Model for token counting
|
||||
|
||||
Returns:
|
||||
Compressed context (may be same object if no compression needed)
|
||||
"""
|
||||
current_tokens = context.token_count or await self._count_tokens(
|
||||
context.content, model
|
||||
)
|
||||
|
||||
if current_tokens <= max_tokens:
|
||||
return context
|
||||
|
||||
# Choose strategy based on context type
|
||||
strategy = self._get_strategy_for_type(context.get_type())
|
||||
|
||||
result = await self._truncation.truncate_to_tokens(
|
||||
content=context.content,
|
||||
max_tokens=max_tokens,
|
||||
strategy=strategy,
|
||||
model=model,
|
||||
)
|
||||
|
||||
# Update context with truncated content
|
||||
context.content = result.content
|
||||
context.token_count = result.truncated_tokens
|
||||
context.metadata["truncated"] = True
|
||||
context.metadata["original_tokens"] = result.original_tokens
|
||||
|
||||
return context
|
||||
|
||||
async def compress_contexts(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
budget: "TokenBudget",
|
||||
model: str | None = None,
|
||||
) -> list[BaseContext]:
|
||||
"""
|
||||
Compress multiple contexts to fit within budget.
|
||||
|
||||
Args:
|
||||
contexts: Contexts to potentially compress
|
||||
budget: Token budget constraints
|
||||
model: Model for token counting
|
||||
|
||||
Returns:
|
||||
List of contexts (compressed as needed)
|
||||
"""
|
||||
result: list[BaseContext] = []
|
||||
|
||||
for context in contexts:
|
||||
context_type = context.get_type()
|
||||
remaining = budget.remaining(context_type)
|
||||
current_tokens = context.token_count or await self._count_tokens(
|
||||
context.content, model
|
||||
)
|
||||
|
||||
if current_tokens > remaining:
|
||||
# Need to compress
|
||||
compressed = await self.compress_context(context, remaining, model)
|
||||
result.append(compressed)
|
||||
logger.debug(
|
||||
f"Compressed {context_type.value} context from "
|
||||
f"{current_tokens} to {compressed.token_count} tokens"
|
||||
)
|
||||
else:
|
||||
result.append(context)
|
||||
|
||||
return result
|
||||
|
||||
def _get_strategy_for_type(self, context_type: ContextType) -> str:
|
||||
"""Get optimal truncation strategy for context type."""
|
||||
strategies = {
|
||||
ContextType.SYSTEM: "end", # Keep instructions at start
|
||||
ContextType.TASK: "end", # Keep task description start
|
||||
ContextType.KNOWLEDGE: "sentence", # Clean sentence boundaries
|
||||
ContextType.CONVERSATION: "end", # Keep recent conversation
|
||||
ContextType.TOOL: "middle", # Keep command and result summary
|
||||
}
|
||||
return strategies.get(context_type, "end")
|
||||
|
||||
async def _count_tokens(self, text: str, model: str | None = None) -> int:
|
||||
"""Count tokens using calculator or estimation."""
|
||||
if self._calculator is not None:
|
||||
return await self._calculator.count_tokens(text, model)
|
||||
# Use model-specific estimation for consistency
|
||||
return _estimate_tokens(text, model)
|
||||
380
backend/app/services/context/config.py
Normal file
380
backend/app/services/context/config.py
Normal file
@@ -0,0 +1,380 @@
|
||||
"""
|
||||
Context Management Engine Configuration.
|
||||
|
||||
Provides Pydantic settings for context assembly,
|
||||
token budget allocation, and caching.
|
||||
"""
|
||||
|
||||
import threading
|
||||
from functools import lru_cache
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, field_validator, model_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class ContextSettings(BaseSettings):
|
||||
"""
|
||||
Configuration for the Context Management Engine.
|
||||
|
||||
All settings can be overridden via environment variables
|
||||
with the CTX_ prefix.
|
||||
"""
|
||||
|
||||
# Budget allocation percentages (must sum to 1.0)
|
||||
budget_system: float = Field(
|
||||
default=0.05,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Percentage of budget for system prompts (5%)",
|
||||
)
|
||||
budget_task: float = Field(
|
||||
default=0.10,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Percentage of budget for task context (10%)",
|
||||
)
|
||||
budget_knowledge: float = Field(
|
||||
default=0.40,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Percentage of budget for RAG/knowledge (40%)",
|
||||
)
|
||||
budget_conversation: float = Field(
|
||||
default=0.20,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Percentage of budget for conversation history (20%)",
|
||||
)
|
||||
budget_tools: float = Field(
|
||||
default=0.05,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Percentage of budget for tool descriptions (5%)",
|
||||
)
|
||||
budget_response: float = Field(
|
||||
default=0.15,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Percentage reserved for response (15%)",
|
||||
)
|
||||
budget_buffer: float = Field(
|
||||
default=0.05,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Percentage buffer for safety margin (5%)",
|
||||
)
|
||||
|
||||
# Scoring weights
|
||||
scoring_relevance_weight: float = Field(
|
||||
default=0.5,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Weight for relevance scoring",
|
||||
)
|
||||
scoring_recency_weight: float = Field(
|
||||
default=0.3,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Weight for recency scoring",
|
||||
)
|
||||
scoring_priority_weight: float = Field(
|
||||
default=0.2,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Weight for priority scoring",
|
||||
)
|
||||
|
||||
# Recency decay settings
|
||||
recency_decay_hours: float = Field(
|
||||
default=24.0,
|
||||
gt=0.0,
|
||||
description="Hours until recency score decays to 50%",
|
||||
)
|
||||
recency_max_age_hours: float = Field(
|
||||
default=168.0,
|
||||
gt=0.0,
|
||||
description="Hours until context is considered stale (7 days)",
|
||||
)
|
||||
|
||||
# Compression settings
|
||||
compression_threshold: float = Field(
|
||||
default=0.8,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Compress when budget usage exceeds this percentage",
|
||||
)
|
||||
truncation_marker: str = Field(
|
||||
default="\n\n[...content truncated...]\n\n",
|
||||
description="Marker text to insert where content was truncated",
|
||||
)
|
||||
truncation_preserve_ratio: float = Field(
|
||||
default=0.7,
|
||||
ge=0.1,
|
||||
le=0.9,
|
||||
description="Ratio of content to preserve from start in middle truncation (0.7 = 70% start, 30% end)",
|
||||
)
|
||||
truncation_min_content_length: int = Field(
|
||||
default=100,
|
||||
ge=10,
|
||||
le=1000,
|
||||
description="Minimum content length in characters before truncation applies",
|
||||
)
|
||||
summary_model_group: str = Field(
|
||||
default="fast",
|
||||
description="Model group to use for summarization",
|
||||
)
|
||||
|
||||
# Caching settings
|
||||
cache_enabled: bool = Field(
|
||||
default=True,
|
||||
description="Enable Redis caching for assembled contexts",
|
||||
)
|
||||
cache_ttl_seconds: int = Field(
|
||||
default=3600,
|
||||
ge=60,
|
||||
le=86400,
|
||||
description="Cache TTL in seconds (1 hour default, max 24 hours)",
|
||||
)
|
||||
cache_prefix: str = Field(
|
||||
default="ctx",
|
||||
description="Redis key prefix for context cache",
|
||||
)
|
||||
cache_memory_max_items: int = Field(
|
||||
default=1000,
|
||||
ge=100,
|
||||
le=100000,
|
||||
description="Maximum items in memory fallback cache when Redis unavailable",
|
||||
)
|
||||
|
||||
# Performance settings
|
||||
max_assembly_time_ms: int = Field(
|
||||
default=2000,
|
||||
ge=10,
|
||||
le=30000,
|
||||
description="Maximum time for context assembly in milliseconds. "
|
||||
"Should be high enough to accommodate MCP calls for knowledge retrieval.",
|
||||
)
|
||||
parallel_scoring: bool = Field(
|
||||
default=True,
|
||||
description="Score contexts in parallel for better performance",
|
||||
)
|
||||
max_parallel_scores: int = Field(
|
||||
default=10,
|
||||
ge=1,
|
||||
le=50,
|
||||
description="Maximum number of contexts to score in parallel",
|
||||
)
|
||||
|
||||
# Knowledge retrieval settings
|
||||
knowledge_search_type: str = Field(
|
||||
default="hybrid",
|
||||
description="Default search type for knowledge retrieval",
|
||||
)
|
||||
knowledge_max_results: int = Field(
|
||||
default=10,
|
||||
ge=1,
|
||||
le=50,
|
||||
description="Maximum knowledge chunks to retrieve",
|
||||
)
|
||||
knowledge_min_score: float = Field(
|
||||
default=0.5,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Minimum relevance score for knowledge",
|
||||
)
|
||||
|
||||
# Relevance scoring settings
|
||||
relevance_keyword_fallback_weight: float = Field(
|
||||
default=0.5,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Maximum score for keyword-based fallback scoring (when semantic unavailable)",
|
||||
)
|
||||
relevance_semantic_max_chars: int = Field(
|
||||
default=2000,
|
||||
ge=100,
|
||||
le=10000,
|
||||
description="Maximum content length in chars for semantic similarity computation",
|
||||
)
|
||||
|
||||
# Diversity/ranking settings
|
||||
diversity_max_per_source: int = Field(
|
||||
default=3,
|
||||
ge=1,
|
||||
le=20,
|
||||
description="Maximum contexts from the same source in diversity reranking",
|
||||
)
|
||||
|
||||
# Conversation history settings
|
||||
conversation_max_turns: int = Field(
|
||||
default=20,
|
||||
ge=1,
|
||||
le=100,
|
||||
description="Maximum conversation turns to include",
|
||||
)
|
||||
conversation_recent_priority: bool = Field(
|
||||
default=True,
|
||||
description="Prioritize recent conversation turns",
|
||||
)
|
||||
|
||||
@field_validator("knowledge_search_type")
|
||||
@classmethod
|
||||
def validate_search_type(cls, v: str) -> str:
|
||||
"""Validate search type is valid."""
|
||||
valid_types = {"semantic", "keyword", "hybrid"}
|
||||
if v not in valid_types:
|
||||
raise ValueError(f"search_type must be one of: {valid_types}")
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_budget_allocation(self) -> "ContextSettings":
|
||||
"""Validate that budget percentages sum to 1.0."""
|
||||
total = (
|
||||
self.budget_system
|
||||
+ self.budget_task
|
||||
+ self.budget_knowledge
|
||||
+ self.budget_conversation
|
||||
+ self.budget_tools
|
||||
+ self.budget_response
|
||||
+ self.budget_buffer
|
||||
)
|
||||
# Allow small floating point error
|
||||
if abs(total - 1.0) > 0.001:
|
||||
raise ValueError(
|
||||
f"Budget percentages must sum to 1.0, got {total:.3f}. "
|
||||
f"Current allocation: system={self.budget_system}, task={self.budget_task}, "
|
||||
f"knowledge={self.budget_knowledge}, conversation={self.budget_conversation}, "
|
||||
f"tools={self.budget_tools}, response={self.budget_response}, buffer={self.budget_buffer}"
|
||||
)
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_scoring_weights(self) -> "ContextSettings":
|
||||
"""Validate that scoring weights sum to 1.0."""
|
||||
total = (
|
||||
self.scoring_relevance_weight
|
||||
+ self.scoring_recency_weight
|
||||
+ self.scoring_priority_weight
|
||||
)
|
||||
# Allow small floating point error
|
||||
if abs(total - 1.0) > 0.001:
|
||||
raise ValueError(
|
||||
f"Scoring weights must sum to 1.0, got {total:.3f}. "
|
||||
f"Current weights: relevance={self.scoring_relevance_weight}, "
|
||||
f"recency={self.scoring_recency_weight}, priority={self.scoring_priority_weight}"
|
||||
)
|
||||
return self
|
||||
|
||||
def get_budget_allocation(self) -> dict[str, float]:
|
||||
"""Get budget allocation as a dictionary."""
|
||||
return {
|
||||
"system": self.budget_system,
|
||||
"task": self.budget_task,
|
||||
"knowledge": self.budget_knowledge,
|
||||
"conversation": self.budget_conversation,
|
||||
"tools": self.budget_tools,
|
||||
"response": self.budget_response,
|
||||
"buffer": self.budget_buffer,
|
||||
}
|
||||
|
||||
def get_scoring_weights(self) -> dict[str, float]:
|
||||
"""Get scoring weights as a dictionary."""
|
||||
return {
|
||||
"relevance": self.scoring_relevance_weight,
|
||||
"recency": self.scoring_recency_weight,
|
||||
"priority": self.scoring_priority_weight,
|
||||
}
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert settings to dictionary for logging/debugging."""
|
||||
return {
|
||||
"budget": self.get_budget_allocation(),
|
||||
"scoring": self.get_scoring_weights(),
|
||||
"compression": {
|
||||
"threshold": self.compression_threshold,
|
||||
"summary_model_group": self.summary_model_group,
|
||||
"truncation_marker": self.truncation_marker,
|
||||
"truncation_preserve_ratio": self.truncation_preserve_ratio,
|
||||
"truncation_min_content_length": self.truncation_min_content_length,
|
||||
},
|
||||
"cache": {
|
||||
"enabled": self.cache_enabled,
|
||||
"ttl_seconds": self.cache_ttl_seconds,
|
||||
"prefix": self.cache_prefix,
|
||||
"memory_max_items": self.cache_memory_max_items,
|
||||
},
|
||||
"performance": {
|
||||
"max_assembly_time_ms": self.max_assembly_time_ms,
|
||||
"parallel_scoring": self.parallel_scoring,
|
||||
"max_parallel_scores": self.max_parallel_scores,
|
||||
},
|
||||
"knowledge": {
|
||||
"search_type": self.knowledge_search_type,
|
||||
"max_results": self.knowledge_max_results,
|
||||
"min_score": self.knowledge_min_score,
|
||||
},
|
||||
"relevance": {
|
||||
"keyword_fallback_weight": self.relevance_keyword_fallback_weight,
|
||||
"semantic_max_chars": self.relevance_semantic_max_chars,
|
||||
},
|
||||
"diversity": {
|
||||
"max_per_source": self.diversity_max_per_source,
|
||||
},
|
||||
"conversation": {
|
||||
"max_turns": self.conversation_max_turns,
|
||||
"recent_priority": self.conversation_recent_priority,
|
||||
},
|
||||
}
|
||||
|
||||
model_config = {
|
||||
"env_prefix": "CTX_",
|
||||
"env_file": "../.env",
|
||||
"env_file_encoding": "utf-8",
|
||||
"case_sensitive": False,
|
||||
"extra": "ignore",
|
||||
}
|
||||
|
||||
|
||||
# Thread-safe singleton pattern
|
||||
_settings: ContextSettings | None = None
|
||||
_settings_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_context_settings() -> ContextSettings:
|
||||
"""
|
||||
Get the global ContextSettings instance.
|
||||
|
||||
Thread-safe with double-checked locking pattern.
|
||||
|
||||
Returns:
|
||||
ContextSettings instance
|
||||
"""
|
||||
global _settings
|
||||
if _settings is None:
|
||||
with _settings_lock:
|
||||
if _settings is None:
|
||||
_settings = ContextSettings()
|
||||
return _settings
|
||||
|
||||
|
||||
def reset_context_settings() -> None:
|
||||
"""
|
||||
Reset the global settings instance.
|
||||
|
||||
Primarily used for testing.
|
||||
"""
|
||||
global _settings
|
||||
with _settings_lock:
|
||||
_settings = None
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_default_settings() -> ContextSettings:
|
||||
"""
|
||||
Get default settings (cached).
|
||||
|
||||
Use this for read-only access to defaults.
|
||||
For mutable access, use get_context_settings().
|
||||
"""
|
||||
return ContextSettings()
|
||||
582
backend/app/services/context/engine.py
Normal file
582
backend/app/services/context/engine.py
Normal file
@@ -0,0 +1,582 @@
|
||||
"""
|
||||
Context Management Engine.
|
||||
|
||||
Main orchestration layer for context assembly and optimization.
|
||||
Provides a high-level API for assembling optimized context for LLM requests.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from uuid import UUID
|
||||
|
||||
from .assembly import ContextPipeline
|
||||
from .budget import BudgetAllocator, TokenBudget, TokenCalculator
|
||||
from .cache import ContextCache
|
||||
from .compression import ContextCompressor
|
||||
from .config import ContextSettings, get_context_settings
|
||||
from .prioritization import ContextRanker
|
||||
from .scoring import CompositeScorer
|
||||
from .types import (
|
||||
AssembledContext,
|
||||
BaseContext,
|
||||
ConversationContext,
|
||||
KnowledgeContext,
|
||||
MemoryContext,
|
||||
MessageRole,
|
||||
SystemContext,
|
||||
TaskContext,
|
||||
ToolContext,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis.asyncio import Redis
|
||||
|
||||
from app.services.mcp.client_manager import MCPClientManager
|
||||
from app.services.memory.integration import MemoryContextSource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ContextEngine:
|
||||
"""
|
||||
Main context management engine.
|
||||
|
||||
Provides high-level API for context assembly and optimization.
|
||||
Integrates all components: scoring, ranking, compression, formatting, and caching.
|
||||
|
||||
Usage:
|
||||
engine = ContextEngine(mcp_manager=mcp, redis=redis)
|
||||
|
||||
# Assemble context for an LLM request
|
||||
result = await engine.assemble_context(
|
||||
project_id="proj-123",
|
||||
agent_id="agent-456",
|
||||
query="implement user authentication",
|
||||
model="claude-3-sonnet",
|
||||
system_prompt="You are an expert developer.",
|
||||
knowledge_query="authentication best practices",
|
||||
)
|
||||
|
||||
# Use the assembled context
|
||||
print(result.content)
|
||||
print(f"Tokens: {result.total_tokens}")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mcp_manager: "MCPClientManager | None" = None,
|
||||
redis: "Redis | None" = None,
|
||||
settings: ContextSettings | None = None,
|
||||
memory_source: "MemoryContextSource | None" = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the context engine.
|
||||
|
||||
Args:
|
||||
mcp_manager: MCP client manager for LLM Gateway/Knowledge Base
|
||||
redis: Redis connection for caching
|
||||
settings: Context settings
|
||||
memory_source: Optional memory context source for agent memory
|
||||
"""
|
||||
self._mcp = mcp_manager
|
||||
self._settings = settings or get_context_settings()
|
||||
self._memory_source = memory_source
|
||||
|
||||
# Initialize components
|
||||
self._calculator = TokenCalculator(mcp_manager=mcp_manager)
|
||||
self._scorer = CompositeScorer(mcp_manager=mcp_manager, settings=self._settings)
|
||||
self._ranker = ContextRanker(scorer=self._scorer, calculator=self._calculator)
|
||||
self._compressor = ContextCompressor(calculator=self._calculator)
|
||||
self._allocator = BudgetAllocator(self._settings)
|
||||
self._cache = ContextCache(redis=redis, settings=self._settings)
|
||||
|
||||
# Pipeline for assembly
|
||||
self._pipeline = ContextPipeline(
|
||||
mcp_manager=mcp_manager,
|
||||
settings=self._settings,
|
||||
calculator=self._calculator,
|
||||
scorer=self._scorer,
|
||||
ranker=self._ranker,
|
||||
compressor=self._compressor,
|
||||
)
|
||||
|
||||
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
|
||||
"""
|
||||
Set MCP manager for all components.
|
||||
|
||||
Args:
|
||||
mcp_manager: MCP client manager
|
||||
"""
|
||||
self._mcp = mcp_manager
|
||||
self._calculator.set_mcp_manager(mcp_manager)
|
||||
self._scorer.set_mcp_manager(mcp_manager)
|
||||
self._pipeline.set_mcp_manager(mcp_manager)
|
||||
|
||||
def set_redis(self, redis: "Redis") -> None:
|
||||
"""
|
||||
Set Redis connection for caching.
|
||||
|
||||
Args:
|
||||
redis: Redis connection
|
||||
"""
|
||||
self._cache.set_redis(redis)
|
||||
|
||||
def set_memory_source(self, memory_source: "MemoryContextSource") -> None:
|
||||
"""
|
||||
Set memory context source for agent memory integration.
|
||||
|
||||
Args:
|
||||
memory_source: Memory context source
|
||||
"""
|
||||
self._memory_source = memory_source
|
||||
|
||||
async def assemble_context(
|
||||
self,
|
||||
project_id: str,
|
||||
agent_id: str,
|
||||
query: str,
|
||||
model: str,
|
||||
max_tokens: int | None = None,
|
||||
system_prompt: str | None = None,
|
||||
task_description: str | None = None,
|
||||
knowledge_query: str | None = None,
|
||||
knowledge_limit: int = 10,
|
||||
memory_query: str | None = None,
|
||||
memory_limit: int = 20,
|
||||
session_id: str | None = None,
|
||||
agent_type_id: str | None = None,
|
||||
conversation_history: list[dict[str, str]] | None = None,
|
||||
tool_results: list[dict[str, Any]] | None = None,
|
||||
custom_contexts: list[BaseContext] | None = None,
|
||||
custom_budget: TokenBudget | None = None,
|
||||
compress: bool = True,
|
||||
format_output: bool = True,
|
||||
use_cache: bool = True,
|
||||
) -> AssembledContext:
|
||||
"""
|
||||
Assemble optimized context for an LLM request.
|
||||
|
||||
This is the main entry point for context management.
|
||||
It gathers context from various sources, scores and ranks them,
|
||||
compresses if needed, and formats for the target model.
|
||||
|
||||
Args:
|
||||
project_id: Project identifier
|
||||
agent_id: Agent identifier
|
||||
query: User's query or current request
|
||||
model: Target model name
|
||||
max_tokens: Maximum context tokens (uses model default if None)
|
||||
system_prompt: System prompt/instructions
|
||||
task_description: Current task description
|
||||
knowledge_query: Query for knowledge base search
|
||||
knowledge_limit: Max number of knowledge results
|
||||
memory_query: Query for agent memory search
|
||||
memory_limit: Max number of memory results
|
||||
session_id: Session ID for working memory access
|
||||
agent_type_id: Agent type ID for procedural memory
|
||||
conversation_history: List of {"role": str, "content": str}
|
||||
tool_results: List of tool results to include
|
||||
custom_contexts: Additional custom contexts
|
||||
custom_budget: Custom token budget
|
||||
compress: Whether to apply compression
|
||||
format_output: Whether to format for the model
|
||||
use_cache: Whether to use caching
|
||||
|
||||
Returns:
|
||||
AssembledContext with optimized content
|
||||
|
||||
Raises:
|
||||
AssemblyTimeoutError: If assembly exceeds timeout
|
||||
BudgetExceededError: If context exceeds budget
|
||||
"""
|
||||
# Gather all contexts
|
||||
contexts: list[BaseContext] = []
|
||||
|
||||
# 1. System context
|
||||
if system_prompt:
|
||||
contexts.append(
|
||||
SystemContext(
|
||||
content=system_prompt,
|
||||
source="system_prompt",
|
||||
)
|
||||
)
|
||||
|
||||
# 2. Task context
|
||||
if task_description:
|
||||
contexts.append(
|
||||
TaskContext(
|
||||
content=task_description,
|
||||
source=f"task:{project_id}:{agent_id}",
|
||||
)
|
||||
)
|
||||
|
||||
# 3. Knowledge context from Knowledge Base
|
||||
if knowledge_query and self._mcp:
|
||||
knowledge_contexts = await self._fetch_knowledge(
|
||||
project_id=project_id,
|
||||
agent_id=agent_id,
|
||||
query=knowledge_query,
|
||||
limit=knowledge_limit,
|
||||
)
|
||||
contexts.extend(knowledge_contexts)
|
||||
|
||||
# 4. Memory context from Agent Memory System
|
||||
if memory_query and self._memory_source:
|
||||
memory_contexts = await self._fetch_memory(
|
||||
project_id=project_id,
|
||||
agent_id=agent_id,
|
||||
query=memory_query,
|
||||
limit=memory_limit,
|
||||
session_id=session_id,
|
||||
agent_type_id=agent_type_id,
|
||||
)
|
||||
contexts.extend(memory_contexts)
|
||||
|
||||
# 5. Conversation history
|
||||
if conversation_history:
|
||||
contexts.extend(self._convert_conversation(conversation_history))
|
||||
|
||||
# 6. Tool results
|
||||
if tool_results:
|
||||
contexts.extend(self._convert_tool_results(tool_results))
|
||||
|
||||
# 7. Custom contexts
|
||||
if custom_contexts:
|
||||
contexts.extend(custom_contexts)
|
||||
|
||||
# Check cache if enabled
|
||||
fingerprint: str | None = None
|
||||
if use_cache and self._cache.is_enabled:
|
||||
# Include project_id and agent_id for tenant isolation
|
||||
fingerprint = self._cache.compute_fingerprint(
|
||||
contexts, query, model, project_id=project_id, agent_id=agent_id
|
||||
)
|
||||
cached = await self._cache.get_assembled(fingerprint)
|
||||
if cached:
|
||||
logger.debug(f"Cache hit for context assembly: {fingerprint}")
|
||||
return cached
|
||||
|
||||
# Run assembly pipeline
|
||||
result = await self._pipeline.assemble(
|
||||
contexts=contexts,
|
||||
query=query,
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
custom_budget=custom_budget,
|
||||
compress=compress,
|
||||
format_output=format_output,
|
||||
)
|
||||
|
||||
# Cache result if enabled (reuse fingerprint computed above)
|
||||
if use_cache and self._cache.is_enabled and fingerprint is not None:
|
||||
await self._cache.set_assembled(fingerprint, result)
|
||||
|
||||
return result
|
||||
|
||||
async def _fetch_knowledge(
|
||||
self,
|
||||
project_id: str,
|
||||
agent_id: str,
|
||||
query: str,
|
||||
limit: int = 10,
|
||||
) -> list[KnowledgeContext]:
|
||||
"""
|
||||
Fetch relevant knowledge from Knowledge Base via MCP.
|
||||
|
||||
Args:
|
||||
project_id: Project identifier
|
||||
agent_id: Agent identifier
|
||||
query: Search query
|
||||
limit: Maximum results
|
||||
|
||||
Returns:
|
||||
List of KnowledgeContext instances
|
||||
"""
|
||||
if not self._mcp:
|
||||
return []
|
||||
|
||||
try:
|
||||
result = await self._mcp.call_tool(
|
||||
"knowledge-base",
|
||||
"search_knowledge",
|
||||
{
|
||||
"project_id": project_id,
|
||||
"agent_id": agent_id,
|
||||
"query": query,
|
||||
"search_type": "hybrid",
|
||||
"limit": limit,
|
||||
},
|
||||
)
|
||||
|
||||
# Check both ToolResult.success AND response success
|
||||
if not result.success:
|
||||
logger.warning(f"Knowledge search failed: {result.error}")
|
||||
return []
|
||||
|
||||
if not isinstance(result.data, dict) or not result.data.get(
|
||||
"success", True
|
||||
):
|
||||
logger.warning("Knowledge search returned unsuccessful response")
|
||||
return []
|
||||
|
||||
contexts = []
|
||||
results = result.data.get("results", [])
|
||||
for chunk in results:
|
||||
contexts.append(
|
||||
KnowledgeContext(
|
||||
content=chunk.get("content", ""),
|
||||
source=chunk.get("source_path", "unknown"),
|
||||
relevance_score=chunk.get("score", 0.0),
|
||||
metadata={
|
||||
"chunk_id": chunk.get(
|
||||
"id"
|
||||
), # Server returns 'id' not 'chunk_id'
|
||||
"document_id": chunk.get("document_id"),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug(f"Fetched {len(contexts)} knowledge chunks for query: {query}")
|
||||
return contexts
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch knowledge: {e}")
|
||||
return []
|
||||
|
||||
async def _fetch_memory(
|
||||
self,
|
||||
project_id: str,
|
||||
agent_id: str,
|
||||
query: str,
|
||||
limit: int = 20,
|
||||
session_id: str | None = None,
|
||||
agent_type_id: str | None = None,
|
||||
) -> list[MemoryContext]:
|
||||
"""
|
||||
Fetch relevant memories from Agent Memory System.
|
||||
|
||||
Args:
|
||||
project_id: Project identifier
|
||||
agent_id: Agent identifier
|
||||
query: Search query
|
||||
limit: Maximum results
|
||||
session_id: Session ID for working memory
|
||||
agent_type_id: Agent type ID for procedural memory
|
||||
|
||||
Returns:
|
||||
List of MemoryContext instances
|
||||
"""
|
||||
if not self._memory_source:
|
||||
return []
|
||||
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
|
||||
# Configure fetch limits
|
||||
from app.services.memory.integration.context_source import MemoryFetchConfig
|
||||
|
||||
config = MemoryFetchConfig(
|
||||
working_limit=min(limit // 4, 5),
|
||||
episodic_limit=min(limit // 2, 10),
|
||||
semantic_limit=min(limit // 2, 10),
|
||||
procedural_limit=min(limit // 4, 5),
|
||||
include_working=session_id is not None,
|
||||
)
|
||||
|
||||
result = await self._memory_source.fetch_context(
|
||||
query=query,
|
||||
project_id=UUID(project_id),
|
||||
agent_instance_id=UUID(agent_id) if agent_id else None,
|
||||
agent_type_id=UUID(agent_type_id) if agent_type_id else None,
|
||||
session_id=session_id,
|
||||
config=config,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Fetched {len(result.contexts)} memory contexts for query: {query}, "
|
||||
f"by_type: {result.by_type}"
|
||||
)
|
||||
return result.contexts[:limit]
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch memory: {e}")
|
||||
return []
|
||||
|
||||
def _convert_conversation(
|
||||
self,
|
||||
history: list[dict[str, str]],
|
||||
) -> list[ConversationContext]:
|
||||
"""
|
||||
Convert conversation history to ConversationContext instances.
|
||||
|
||||
Args:
|
||||
history: List of {"role": str, "content": str}
|
||||
|
||||
Returns:
|
||||
List of ConversationContext instances
|
||||
"""
|
||||
contexts = []
|
||||
for i, turn in enumerate(history):
|
||||
role_str = turn.get("role", "user").lower()
|
||||
role = (
|
||||
MessageRole.ASSISTANT if role_str == "assistant" else MessageRole.USER
|
||||
)
|
||||
|
||||
contexts.append(
|
||||
ConversationContext(
|
||||
content=turn.get("content", ""),
|
||||
source=f"conversation:{i}",
|
||||
role=role,
|
||||
metadata={"role": role_str, "turn": i},
|
||||
)
|
||||
)
|
||||
|
||||
return contexts
|
||||
|
||||
def _convert_tool_results(
|
||||
self,
|
||||
results: list[dict[str, Any]],
|
||||
) -> list[ToolContext]:
|
||||
"""
|
||||
Convert tool results to ToolContext instances.
|
||||
|
||||
Args:
|
||||
results: List of tool result dictionaries
|
||||
|
||||
Returns:
|
||||
List of ToolContext instances
|
||||
"""
|
||||
contexts = []
|
||||
for result in results:
|
||||
tool_name = result.get("tool_name", "unknown")
|
||||
content = result.get("content", result.get("result", ""))
|
||||
|
||||
# Handle dict content
|
||||
if isinstance(content, dict):
|
||||
import json
|
||||
|
||||
content = json.dumps(content, indent=2)
|
||||
|
||||
contexts.append(
|
||||
ToolContext(
|
||||
content=str(content),
|
||||
source=f"tool:{tool_name}",
|
||||
metadata={
|
||||
"tool_name": tool_name,
|
||||
"status": result.get("status", "success"),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
return contexts
|
||||
|
||||
async def get_budget_for_model(
|
||||
self,
|
||||
model: str,
|
||||
max_tokens: int | None = None,
|
||||
) -> TokenBudget:
|
||||
"""
|
||||
Get the token budget for a specific model.
|
||||
|
||||
Args:
|
||||
model: Model name
|
||||
max_tokens: Optional max tokens override
|
||||
|
||||
Returns:
|
||||
TokenBudget instance
|
||||
"""
|
||||
if max_tokens:
|
||||
return self._allocator.create_budget(max_tokens)
|
||||
return self._allocator.create_budget_for_model(model)
|
||||
|
||||
async def count_tokens(
|
||||
self,
|
||||
content: str,
|
||||
model: str | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Count tokens in content.
|
||||
|
||||
Args:
|
||||
content: Content to count
|
||||
model: Model for model-specific tokenization
|
||||
|
||||
Returns:
|
||||
Token count
|
||||
"""
|
||||
# Check cache first
|
||||
cached = await self._cache.get_token_count(content, model)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
count = await self._calculator.count_tokens(content, model)
|
||||
|
||||
# Cache the result
|
||||
await self._cache.set_token_count(content, count, model)
|
||||
|
||||
return count
|
||||
|
||||
async def invalidate_cache(
|
||||
self,
|
||||
project_id: str | None = None,
|
||||
pattern: str | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Invalidate cache entries.
|
||||
|
||||
Args:
|
||||
project_id: Invalidate all cache for a project
|
||||
pattern: Custom pattern to match
|
||||
|
||||
Returns:
|
||||
Number of entries invalidated
|
||||
"""
|
||||
if pattern:
|
||||
return await self._cache.invalidate(pattern)
|
||||
elif project_id:
|
||||
return await self._cache.invalidate(f"*{project_id}*")
|
||||
else:
|
||||
return await self._cache.clear_all()
|
||||
|
||||
async def get_stats(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get engine statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with engine stats
|
||||
"""
|
||||
return {
|
||||
"cache": await self._cache.get_stats(),
|
||||
"settings": {
|
||||
"compression_threshold": self._settings.compression_threshold,
|
||||
"max_assembly_time_ms": self._settings.max_assembly_time_ms,
|
||||
"cache_enabled": self._settings.cache_enabled,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# Convenience factory function
|
||||
def create_context_engine(
|
||||
mcp_manager: "MCPClientManager | None" = None,
|
||||
redis: "Redis | None" = None,
|
||||
settings: ContextSettings | None = None,
|
||||
memory_source: "MemoryContextSource | None" = None,
|
||||
) -> ContextEngine:
|
||||
"""
|
||||
Create a context engine instance.
|
||||
|
||||
Args:
|
||||
mcp_manager: MCP client manager
|
||||
redis: Redis connection
|
||||
settings: Context settings
|
||||
memory_source: Optional memory context source
|
||||
|
||||
Returns:
|
||||
Configured ContextEngine instance
|
||||
"""
|
||||
return ContextEngine(
|
||||
mcp_manager=mcp_manager,
|
||||
redis=redis,
|
||||
settings=settings,
|
||||
memory_source=memory_source,
|
||||
)
|
||||
354
backend/app/services/context/exceptions.py
Normal file
354
backend/app/services/context/exceptions.py
Normal file
@@ -0,0 +1,354 @@
|
||||
"""
|
||||
Context Management Engine Exceptions.
|
||||
|
||||
Provides a hierarchy of exceptions for context assembly,
|
||||
token budget management, and related operations.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ContextError(Exception):
|
||||
"""
|
||||
Base exception for all context management errors.
|
||||
|
||||
All context-related exceptions should inherit from this class
|
||||
to allow for catch-all handling when needed.
|
||||
"""
|
||||
|
||||
def __init__(self, message: str, details: dict[str, Any] | None = None) -> None:
|
||||
"""
|
||||
Initialize context error.
|
||||
|
||||
Args:
|
||||
message: Human-readable error message
|
||||
details: Optional dict with additional error context
|
||||
"""
|
||||
self.message = message
|
||||
self.details = details or {}
|
||||
super().__init__(message)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert exception to dictionary for logging/serialization."""
|
||||
return {
|
||||
"error_type": self.__class__.__name__,
|
||||
"message": self.message,
|
||||
"details": self.details,
|
||||
}
|
||||
|
||||
|
||||
class BudgetExceededError(ContextError):
|
||||
"""
|
||||
Raised when token budget is exceeded.
|
||||
|
||||
This occurs when the assembled context would exceed the
|
||||
allocated token budget for a specific context type or total.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Token budget exceeded",
|
||||
allocated: int = 0,
|
||||
requested: int = 0,
|
||||
context_type: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize budget exceeded error.
|
||||
|
||||
Args:
|
||||
message: Error message
|
||||
allocated: Tokens allocated for this context type
|
||||
requested: Tokens requested
|
||||
context_type: Type of context that exceeded budget
|
||||
"""
|
||||
details: dict[str, Any] = {
|
||||
"allocated": allocated,
|
||||
"requested": requested,
|
||||
"overage": requested - allocated,
|
||||
}
|
||||
if context_type:
|
||||
details["context_type"] = context_type
|
||||
|
||||
super().__init__(message, details)
|
||||
self.allocated = allocated
|
||||
self.requested = requested
|
||||
self.context_type = context_type
|
||||
|
||||
|
||||
class TokenCountError(ContextError):
|
||||
"""
|
||||
Raised when token counting fails.
|
||||
|
||||
This typically occurs when the LLM Gateway token counting
|
||||
service is unavailable or returns an error.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Failed to count tokens",
|
||||
model: str | None = None,
|
||||
text_length: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize token count error.
|
||||
|
||||
Args:
|
||||
message: Error message
|
||||
model: Model for which counting was attempted
|
||||
text_length: Length of text that failed to count
|
||||
"""
|
||||
details: dict[str, Any] = {}
|
||||
if model:
|
||||
details["model"] = model
|
||||
if text_length is not None:
|
||||
details["text_length"] = text_length
|
||||
|
||||
super().__init__(message, details)
|
||||
self.model = model
|
||||
self.text_length = text_length
|
||||
|
||||
|
||||
class CompressionError(ContextError):
|
||||
"""
|
||||
Raised when context compression fails.
|
||||
|
||||
This can occur when summarization or truncation cannot
|
||||
reduce content to fit within the budget.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Failed to compress context",
|
||||
original_tokens: int | None = None,
|
||||
target_tokens: int | None = None,
|
||||
achieved_tokens: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize compression error.
|
||||
|
||||
Args:
|
||||
message: Error message
|
||||
original_tokens: Tokens before compression
|
||||
target_tokens: Target token count
|
||||
achieved_tokens: Tokens achieved after compression attempt
|
||||
"""
|
||||
details: dict[str, Any] = {}
|
||||
if original_tokens is not None:
|
||||
details["original_tokens"] = original_tokens
|
||||
if target_tokens is not None:
|
||||
details["target_tokens"] = target_tokens
|
||||
if achieved_tokens is not None:
|
||||
details["achieved_tokens"] = achieved_tokens
|
||||
|
||||
super().__init__(message, details)
|
||||
self.original_tokens = original_tokens
|
||||
self.target_tokens = target_tokens
|
||||
self.achieved_tokens = achieved_tokens
|
||||
|
||||
|
||||
class AssemblyTimeoutError(ContextError):
|
||||
"""
|
||||
Raised when context assembly exceeds time limit.
|
||||
|
||||
Context assembly must complete within a configurable
|
||||
time limit to maintain responsiveness.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Context assembly timed out",
|
||||
timeout_ms: int = 0,
|
||||
elapsed_ms: float = 0.0,
|
||||
stage: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize assembly timeout error.
|
||||
|
||||
Args:
|
||||
message: Error message
|
||||
timeout_ms: Configured timeout in milliseconds
|
||||
elapsed_ms: Actual elapsed time in milliseconds
|
||||
stage: Pipeline stage where timeout occurred
|
||||
"""
|
||||
details: dict[str, Any] = {
|
||||
"timeout_ms": timeout_ms,
|
||||
"elapsed_ms": round(elapsed_ms, 2),
|
||||
}
|
||||
if stage:
|
||||
details["stage"] = stage
|
||||
|
||||
super().__init__(message, details)
|
||||
self.timeout_ms = timeout_ms
|
||||
self.elapsed_ms = elapsed_ms
|
||||
self.stage = stage
|
||||
|
||||
|
||||
class ScoringError(ContextError):
|
||||
"""
|
||||
Raised when context scoring fails.
|
||||
|
||||
This occurs when relevance, recency, or priority scoring
|
||||
encounters an error.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Failed to score context",
|
||||
scorer_type: str | None = None,
|
||||
context_id: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize scoring error.
|
||||
|
||||
Args:
|
||||
message: Error message
|
||||
scorer_type: Type of scorer that failed
|
||||
context_id: ID of context being scored
|
||||
"""
|
||||
details: dict[str, Any] = {}
|
||||
if scorer_type:
|
||||
details["scorer_type"] = scorer_type
|
||||
if context_id:
|
||||
details["context_id"] = context_id
|
||||
|
||||
super().__init__(message, details)
|
||||
self.scorer_type = scorer_type
|
||||
self.context_id = context_id
|
||||
|
||||
|
||||
class FormattingError(ContextError):
|
||||
"""
|
||||
Raised when context formatting fails.
|
||||
|
||||
This occurs when converting assembled context to
|
||||
model-specific format fails.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Failed to format context",
|
||||
model: str | None = None,
|
||||
adapter: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize formatting error.
|
||||
|
||||
Args:
|
||||
message: Error message
|
||||
model: Target model
|
||||
adapter: Adapter that failed
|
||||
"""
|
||||
details: dict[str, Any] = {}
|
||||
if model:
|
||||
details["model"] = model
|
||||
if adapter:
|
||||
details["adapter"] = adapter
|
||||
|
||||
super().__init__(message, details)
|
||||
self.model = model
|
||||
self.adapter = adapter
|
||||
|
||||
|
||||
class CacheError(ContextError):
|
||||
"""
|
||||
Raised when cache operations fail.
|
||||
|
||||
This is typically non-fatal and should be handled
|
||||
gracefully by falling back to recomputation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Cache operation failed",
|
||||
operation: str | None = None,
|
||||
cache_key: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize cache error.
|
||||
|
||||
Args:
|
||||
message: Error message
|
||||
operation: Cache operation that failed (get, set, delete)
|
||||
cache_key: Key involved in the failed operation
|
||||
"""
|
||||
details: dict[str, Any] = {}
|
||||
if operation:
|
||||
details["operation"] = operation
|
||||
if cache_key:
|
||||
details["cache_key"] = cache_key
|
||||
|
||||
super().__init__(message, details)
|
||||
self.operation = operation
|
||||
self.cache_key = cache_key
|
||||
|
||||
|
||||
class ContextNotFoundError(ContextError):
|
||||
"""
|
||||
Raised when expected context is not found.
|
||||
|
||||
This occurs when required context sources return
|
||||
no results or are unavailable.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Required context not found",
|
||||
source: str | None = None,
|
||||
query: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize context not found error.
|
||||
|
||||
Args:
|
||||
message: Error message
|
||||
source: Source that returned no results
|
||||
query: Query used to search
|
||||
"""
|
||||
details: dict[str, Any] = {}
|
||||
if source:
|
||||
details["source"] = source
|
||||
if query:
|
||||
details["query"] = query
|
||||
|
||||
super().__init__(message, details)
|
||||
self.source = source
|
||||
self.query = query
|
||||
|
||||
|
||||
class InvalidContextError(ContextError):
|
||||
"""
|
||||
Raised when context data is invalid.
|
||||
|
||||
This occurs when context content or metadata
|
||||
fails validation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Invalid context data",
|
||||
field: str | None = None,
|
||||
value: Any | None = None,
|
||||
reason: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize invalid context error.
|
||||
|
||||
Args:
|
||||
message: Error message
|
||||
field: Field that is invalid
|
||||
value: Invalid value (may be redacted for security)
|
||||
reason: Reason for invalidity
|
||||
"""
|
||||
details: dict[str, Any] = {}
|
||||
if field:
|
||||
details["field"] = field
|
||||
if value is not None:
|
||||
# Avoid logging potentially sensitive values
|
||||
details["value_type"] = type(value).__name__
|
||||
if reason:
|
||||
details["reason"] = reason
|
||||
|
||||
super().__init__(message, details)
|
||||
self.field = field
|
||||
self.value = value
|
||||
self.reason = reason
|
||||
12
backend/app/services/context/prioritization/__init__.py
Normal file
12
backend/app/services/context/prioritization/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""
|
||||
Context Prioritization Module.
|
||||
|
||||
Provides context ranking and selection.
|
||||
"""
|
||||
|
||||
from .ranker import ContextRanker, RankingResult
|
||||
|
||||
__all__ = [
|
||||
"ContextRanker",
|
||||
"RankingResult",
|
||||
]
|
||||
374
backend/app/services/context/prioritization/ranker.py
Normal file
374
backend/app/services/context/prioritization/ranker.py
Normal file
@@ -0,0 +1,374 @@
|
||||
"""
|
||||
Context Ranker for Context Management.
|
||||
|
||||
Ranks and selects contexts based on scores and budget constraints.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from ..budget import TokenBudget, TokenCalculator
|
||||
from ..config import ContextSettings, get_context_settings
|
||||
from ..exceptions import BudgetExceededError
|
||||
from ..scoring.composite import CompositeScorer, ScoredContext
|
||||
from ..types import BaseContext, ContextPriority
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RankingResult:
|
||||
"""Result of context ranking and selection."""
|
||||
|
||||
selected: list[ScoredContext]
|
||||
excluded: list[ScoredContext]
|
||||
total_tokens: int
|
||||
selection_stats: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def selected_contexts(self) -> list[BaseContext]:
|
||||
"""Get just the context objects (not scored wrappers)."""
|
||||
return [s.context for s in self.selected]
|
||||
|
||||
|
||||
class ContextRanker:
|
||||
"""
|
||||
Ranks and selects contexts within budget constraints.
|
||||
|
||||
Uses greedy selection to maximize total score
|
||||
while respecting token budgets per context type.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scorer: CompositeScorer | None = None,
|
||||
calculator: TokenCalculator | None = None,
|
||||
settings: ContextSettings | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize context ranker.
|
||||
|
||||
Args:
|
||||
scorer: Composite scorer for scoring contexts
|
||||
calculator: Token calculator for counting tokens
|
||||
settings: Context settings (uses global if None)
|
||||
"""
|
||||
self._settings = settings or get_context_settings()
|
||||
self._scorer = scorer or CompositeScorer()
|
||||
self._calculator = calculator or TokenCalculator()
|
||||
|
||||
def set_scorer(self, scorer: CompositeScorer) -> None:
|
||||
"""Set the scorer."""
|
||||
self._scorer = scorer
|
||||
|
||||
def set_calculator(self, calculator: TokenCalculator) -> None:
|
||||
"""Set the token calculator."""
|
||||
self._calculator = calculator
|
||||
|
||||
async def rank(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
query: str,
|
||||
budget: TokenBudget,
|
||||
model: str | None = None,
|
||||
ensure_required: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> RankingResult:
|
||||
"""
|
||||
Rank and select contexts within budget.
|
||||
|
||||
Args:
|
||||
contexts: Contexts to rank
|
||||
query: Query to rank against
|
||||
budget: Token budget constraints
|
||||
model: Model for token counting
|
||||
ensure_required: If True, always include CRITICAL priority contexts
|
||||
**kwargs: Additional scoring parameters
|
||||
|
||||
Returns:
|
||||
RankingResult with selected and excluded contexts
|
||||
"""
|
||||
if not contexts:
|
||||
return RankingResult(
|
||||
selected=[],
|
||||
excluded=[],
|
||||
total_tokens=0,
|
||||
selection_stats={"total_contexts": 0},
|
||||
)
|
||||
|
||||
# 1. Ensure all contexts have token counts
|
||||
await self._ensure_token_counts(contexts, model)
|
||||
|
||||
# 2. Score all contexts
|
||||
scored_contexts = await self._scorer.score_batch(contexts, query, **kwargs)
|
||||
|
||||
# 3. Separate required (CRITICAL priority) from optional
|
||||
required: list[ScoredContext] = []
|
||||
optional: list[ScoredContext] = []
|
||||
|
||||
if ensure_required:
|
||||
for sc in scored_contexts:
|
||||
# CRITICAL priority (150) contexts are always included
|
||||
if sc.context.priority >= ContextPriority.CRITICAL.value:
|
||||
required.append(sc)
|
||||
else:
|
||||
optional.append(sc)
|
||||
else:
|
||||
optional = list(scored_contexts)
|
||||
|
||||
# 4. Sort optional by score (highest first)
|
||||
optional.sort(reverse=True)
|
||||
|
||||
# 5. Greedy selection
|
||||
selected: list[ScoredContext] = []
|
||||
excluded: list[ScoredContext] = []
|
||||
total_tokens = 0
|
||||
|
||||
# Calculate the usable budget (total minus reserved portions)
|
||||
usable_budget = budget.total - budget.response_reserve - budget.buffer
|
||||
|
||||
# Guard against invalid budget configuration
|
||||
if usable_budget <= 0:
|
||||
raise BudgetExceededError(
|
||||
message=(
|
||||
f"Invalid budget configuration: no usable tokens available. "
|
||||
f"total={budget.total}, response_reserve={budget.response_reserve}, "
|
||||
f"buffer={budget.buffer}"
|
||||
),
|
||||
allocated=budget.total,
|
||||
requested=0,
|
||||
context_type="CONFIGURATION_ERROR",
|
||||
)
|
||||
|
||||
# First, try to fit required contexts
|
||||
for sc in required:
|
||||
token_count = self._get_valid_token_count(sc.context)
|
||||
context_type = sc.context.get_type()
|
||||
|
||||
if budget.can_fit(context_type, token_count):
|
||||
budget.allocate(context_type, token_count)
|
||||
selected.append(sc)
|
||||
total_tokens += token_count
|
||||
else:
|
||||
# Force-fit CRITICAL contexts if needed, but check total budget first
|
||||
if total_tokens + token_count > usable_budget:
|
||||
# Even CRITICAL contexts cannot exceed total model context window
|
||||
raise BudgetExceededError(
|
||||
message=(
|
||||
f"CRITICAL contexts exceed total budget. "
|
||||
f"Context '{sc.context.source}' ({token_count} tokens) "
|
||||
f"would exceed usable budget of {usable_budget} tokens."
|
||||
),
|
||||
allocated=usable_budget,
|
||||
requested=total_tokens + token_count,
|
||||
context_type="CRITICAL_OVERFLOW",
|
||||
)
|
||||
|
||||
budget.allocate(context_type, token_count, force=True)
|
||||
selected.append(sc)
|
||||
total_tokens += token_count
|
||||
logger.warning(
|
||||
f"Force-fitted CRITICAL context: {sc.context.source} "
|
||||
f"({token_count} tokens)"
|
||||
)
|
||||
|
||||
# Then, greedily add optional contexts
|
||||
for sc in optional:
|
||||
token_count = self._get_valid_token_count(sc.context)
|
||||
context_type = sc.context.get_type()
|
||||
|
||||
if budget.can_fit(context_type, token_count):
|
||||
budget.allocate(context_type, token_count)
|
||||
selected.append(sc)
|
||||
total_tokens += token_count
|
||||
else:
|
||||
excluded.append(sc)
|
||||
|
||||
# Build stats
|
||||
stats = {
|
||||
"total_contexts": len(contexts),
|
||||
"required_count": len(required),
|
||||
"selected_count": len(selected),
|
||||
"excluded_count": len(excluded),
|
||||
"total_tokens": total_tokens,
|
||||
"by_type": self._count_by_type(selected),
|
||||
}
|
||||
|
||||
return RankingResult(
|
||||
selected=selected,
|
||||
excluded=excluded,
|
||||
total_tokens=total_tokens,
|
||||
selection_stats=stats,
|
||||
)
|
||||
|
||||
async def rank_simple(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
query: str,
|
||||
max_tokens: int,
|
||||
model: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> list[BaseContext]:
|
||||
"""
|
||||
Simple ranking without budget per type.
|
||||
|
||||
Selects top contexts by score until max tokens reached.
|
||||
|
||||
Args:
|
||||
contexts: Contexts to rank
|
||||
query: Query to rank against
|
||||
max_tokens: Maximum total tokens
|
||||
model: Model for token counting
|
||||
**kwargs: Additional scoring parameters
|
||||
|
||||
Returns:
|
||||
Selected contexts (in score order)
|
||||
"""
|
||||
if not contexts:
|
||||
return []
|
||||
|
||||
# Ensure token counts
|
||||
await self._ensure_token_counts(contexts, model)
|
||||
|
||||
# Score all contexts
|
||||
scored_contexts = await self._scorer.score_batch(contexts, query, **kwargs)
|
||||
|
||||
# Sort by score (highest first)
|
||||
scored_contexts.sort(reverse=True)
|
||||
|
||||
# Greedy selection
|
||||
selected: list[BaseContext] = []
|
||||
total_tokens = 0
|
||||
|
||||
for sc in scored_contexts:
|
||||
token_count = self._get_valid_token_count(sc.context)
|
||||
if total_tokens + token_count <= max_tokens:
|
||||
selected.append(sc.context)
|
||||
total_tokens += token_count
|
||||
|
||||
return selected
|
||||
|
||||
def _get_valid_token_count(self, context: BaseContext) -> int:
|
||||
"""
|
||||
Get validated token count from a context.
|
||||
|
||||
Ensures token_count is set (not None) and non-negative to prevent
|
||||
budget bypass attacks where:
|
||||
- None would be treated as 0 (allowing huge contexts to slip through)
|
||||
- Negative values would corrupt budget tracking
|
||||
|
||||
Args:
|
||||
context: Context to get token count from
|
||||
|
||||
Returns:
|
||||
Valid non-negative token count
|
||||
|
||||
Raises:
|
||||
ValueError: If token_count is None or negative
|
||||
"""
|
||||
if context.token_count is None:
|
||||
raise ValueError(
|
||||
f"Context '{context.source}' has no token count. "
|
||||
"Ensure _ensure_token_counts() is called before ranking."
|
||||
)
|
||||
if context.token_count < 0:
|
||||
raise ValueError(
|
||||
f"Context '{context.source}' has invalid negative token count: "
|
||||
f"{context.token_count}"
|
||||
)
|
||||
return context.token_count
|
||||
|
||||
async def _ensure_token_counts(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
model: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Ensure all contexts have token counts.
|
||||
|
||||
Counts tokens in parallel for contexts that don't have counts.
|
||||
|
||||
Args:
|
||||
contexts: Contexts to check
|
||||
model: Model for token counting
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
# Find contexts needing counts
|
||||
contexts_needing_counts = [ctx for ctx in contexts if ctx.token_count is None]
|
||||
|
||||
if not contexts_needing_counts:
|
||||
return
|
||||
|
||||
# Count all in parallel
|
||||
tasks = [
|
||||
self._calculator.count_tokens(ctx.content, model)
|
||||
for ctx in contexts_needing_counts
|
||||
]
|
||||
counts = await asyncio.gather(*tasks)
|
||||
|
||||
# Assign counts back
|
||||
for ctx, count in zip(contexts_needing_counts, counts, strict=True):
|
||||
ctx.token_count = count
|
||||
|
||||
def _count_by_type(
|
||||
self, scored_contexts: list[ScoredContext]
|
||||
) -> dict[str, dict[str, int]]:
|
||||
"""Count selected contexts by type."""
|
||||
by_type: dict[str, dict[str, int]] = {}
|
||||
|
||||
for sc in scored_contexts:
|
||||
type_name = sc.context.get_type().value
|
||||
if type_name not in by_type:
|
||||
by_type[type_name] = {"count": 0, "tokens": 0}
|
||||
by_type[type_name]["count"] += 1
|
||||
# Use validated token count (already validated during ranking)
|
||||
by_type[type_name]["tokens"] += sc.context.token_count or 0
|
||||
|
||||
return by_type
|
||||
|
||||
async def rerank_for_diversity(
|
||||
self,
|
||||
scored_contexts: list[ScoredContext],
|
||||
max_per_source: int | None = None,
|
||||
) -> list[ScoredContext]:
|
||||
"""
|
||||
Rerank to ensure source diversity.
|
||||
|
||||
Prevents too many items from the same source.
|
||||
|
||||
Args:
|
||||
scored_contexts: Already scored contexts
|
||||
max_per_source: Maximum items per source (uses settings if None)
|
||||
|
||||
Returns:
|
||||
Reranked contexts
|
||||
"""
|
||||
# Use provided value or fall back to settings
|
||||
effective_max = (
|
||||
max_per_source
|
||||
if max_per_source is not None
|
||||
else self._settings.diversity_max_per_source
|
||||
)
|
||||
|
||||
source_counts: dict[str, int] = {}
|
||||
result: list[ScoredContext] = []
|
||||
deferred: list[ScoredContext] = []
|
||||
|
||||
for sc in scored_contexts:
|
||||
source = sc.context.source
|
||||
current_count = source_counts.get(source, 0)
|
||||
|
||||
if current_count < effective_max:
|
||||
result.append(sc)
|
||||
source_counts[source] = current_count + 1
|
||||
else:
|
||||
deferred.append(sc)
|
||||
|
||||
# Add deferred items at the end
|
||||
result.extend(deferred)
|
||||
return result
|
||||
21
backend/app/services/context/scoring/__init__.py
Normal file
21
backend/app/services/context/scoring/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""
|
||||
Context Scoring Module.
|
||||
|
||||
Provides scoring strategies for context prioritization.
|
||||
"""
|
||||
|
||||
from .base import BaseScorer, ScorerProtocol
|
||||
from .composite import CompositeScorer, ScoredContext
|
||||
from .priority import PriorityScorer
|
||||
from .recency import RecencyScorer
|
||||
from .relevance import RelevanceScorer
|
||||
|
||||
__all__ = [
|
||||
"BaseScorer",
|
||||
"CompositeScorer",
|
||||
"PriorityScorer",
|
||||
"RecencyScorer",
|
||||
"RelevanceScorer",
|
||||
"ScoredContext",
|
||||
"ScorerProtocol",
|
||||
]
|
||||
99
backend/app/services/context/scoring/base.py
Normal file
99
backend/app/services/context/scoring/base.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""
|
||||
Base Scorer Protocol and Types.
|
||||
|
||||
Defines the interface for context scoring implementations.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
|
||||
|
||||
from ..types import BaseContext
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ScorerProtocol(Protocol):
|
||||
"""Protocol for context scorers."""
|
||||
|
||||
async def score(
|
||||
self,
|
||||
context: BaseContext,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> float:
|
||||
"""
|
||||
Score a context item.
|
||||
|
||||
Args:
|
||||
context: Context to score
|
||||
query: Query to score against
|
||||
**kwargs: Additional scoring parameters
|
||||
|
||||
Returns:
|
||||
Score between 0.0 and 1.0
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class BaseScorer(ABC):
|
||||
"""
|
||||
Abstract base class for context scorers.
|
||||
|
||||
Provides common functionality and interface for
|
||||
different scoring strategies.
|
||||
"""
|
||||
|
||||
def __init__(self, weight: float = 1.0) -> None:
|
||||
"""
|
||||
Initialize scorer.
|
||||
|
||||
Args:
|
||||
weight: Weight for this scorer in composite scoring
|
||||
"""
|
||||
self._weight = weight
|
||||
|
||||
@property
|
||||
def weight(self) -> float:
|
||||
"""Get scorer weight."""
|
||||
return self._weight
|
||||
|
||||
@weight.setter
|
||||
def weight(self, value: float) -> None:
|
||||
"""Set scorer weight."""
|
||||
if not 0.0 <= value <= 1.0:
|
||||
raise ValueError("Weight must be between 0.0 and 1.0")
|
||||
self._weight = value
|
||||
|
||||
@abstractmethod
|
||||
async def score(
|
||||
self,
|
||||
context: BaseContext,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> float:
|
||||
"""
|
||||
Score a context item.
|
||||
|
||||
Args:
|
||||
context: Context to score
|
||||
query: Query to score against
|
||||
**kwargs: Additional scoring parameters
|
||||
|
||||
Returns:
|
||||
Score between 0.0 and 1.0
|
||||
"""
|
||||
...
|
||||
|
||||
def normalize_score(self, score: float) -> float:
|
||||
"""
|
||||
Normalize score to [0.0, 1.0] range.
|
||||
|
||||
Args:
|
||||
score: Raw score
|
||||
|
||||
Returns:
|
||||
Normalized score
|
||||
"""
|
||||
return max(0.0, min(1.0, score))
|
||||
368
backend/app/services/context/scoring/composite.py
Normal file
368
backend/app/services/context/scoring/composite.py
Normal file
@@ -0,0 +1,368 @@
|
||||
"""
|
||||
Composite Scorer for Context Management.
|
||||
|
||||
Combines multiple scoring strategies with configurable weights.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from ..config import ContextSettings, get_context_settings
|
||||
from ..types import BaseContext
|
||||
from .priority import PriorityScorer
|
||||
from .recency import RecencyScorer
|
||||
from .relevance import RelevanceScorer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.services.mcp.client_manager import MCPClientManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScoredContext:
|
||||
"""Context with computed scores."""
|
||||
|
||||
context: BaseContext
|
||||
composite_score: float
|
||||
relevance_score: float = 0.0
|
||||
recency_score: float = 0.0
|
||||
priority_score: float = 0.0
|
||||
|
||||
def __lt__(self, other: "ScoredContext") -> bool:
|
||||
"""Enable sorting by composite score."""
|
||||
return self.composite_score < other.composite_score
|
||||
|
||||
def __gt__(self, other: "ScoredContext") -> bool:
|
||||
"""Enable sorting by composite score."""
|
||||
return self.composite_score > other.composite_score
|
||||
|
||||
|
||||
class CompositeScorer:
|
||||
"""
|
||||
Combines multiple scoring strategies.
|
||||
|
||||
Weights:
|
||||
- relevance: How well content matches the query
|
||||
- recency: How recent the content is
|
||||
- priority: Explicit priority assignments
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mcp_manager: "MCPClientManager | None" = None,
|
||||
settings: ContextSettings | None = None,
|
||||
relevance_weight: float | None = None,
|
||||
recency_weight: float | None = None,
|
||||
priority_weight: float | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize composite scorer.
|
||||
|
||||
Args:
|
||||
mcp_manager: MCP manager for semantic scoring
|
||||
settings: Context settings (uses default if None)
|
||||
relevance_weight: Override relevance weight
|
||||
recency_weight: Override recency weight
|
||||
priority_weight: Override priority weight
|
||||
"""
|
||||
self._settings = settings or get_context_settings()
|
||||
weights = self._settings.get_scoring_weights()
|
||||
|
||||
self._relevance_weight = (
|
||||
relevance_weight if relevance_weight is not None else weights["relevance"]
|
||||
)
|
||||
self._recency_weight = (
|
||||
recency_weight if recency_weight is not None else weights["recency"]
|
||||
)
|
||||
self._priority_weight = (
|
||||
priority_weight if priority_weight is not None else weights["priority"]
|
||||
)
|
||||
|
||||
# Initialize scorers
|
||||
self._relevance_scorer = RelevanceScorer(
|
||||
mcp_manager=mcp_manager,
|
||||
weight=self._relevance_weight,
|
||||
)
|
||||
self._recency_scorer = RecencyScorer(weight=self._recency_weight)
|
||||
self._priority_scorer = PriorityScorer(weight=self._priority_weight)
|
||||
|
||||
# Per-context locks to prevent race conditions during parallel scoring
|
||||
# Uses dict with (lock, last_used_time) tuples for cleanup
|
||||
self._context_locks: dict[str, tuple[asyncio.Lock, float]] = {}
|
||||
self._locks_lock = asyncio.Lock() # Lock to protect _context_locks access
|
||||
self._max_locks = 1000 # Maximum locks to keep (prevent memory growth)
|
||||
self._lock_ttl = 60.0 # Seconds before a lock can be cleaned up
|
||||
|
||||
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
|
||||
"""Set MCP manager for semantic scoring."""
|
||||
self._relevance_scorer.set_mcp_manager(mcp_manager)
|
||||
|
||||
@property
|
||||
def weights(self) -> dict[str, float]:
|
||||
"""Get current scoring weights."""
|
||||
return {
|
||||
"relevance": self._relevance_weight,
|
||||
"recency": self._recency_weight,
|
||||
"priority": self._priority_weight,
|
||||
}
|
||||
|
||||
def update_weights(
|
||||
self,
|
||||
relevance: float | None = None,
|
||||
recency: float | None = None,
|
||||
priority: float | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Update scoring weights.
|
||||
|
||||
Args:
|
||||
relevance: New relevance weight
|
||||
recency: New recency weight
|
||||
priority: New priority weight
|
||||
"""
|
||||
if relevance is not None:
|
||||
self._relevance_weight = max(0.0, min(1.0, relevance))
|
||||
self._relevance_scorer.weight = self._relevance_weight
|
||||
|
||||
if recency is not None:
|
||||
self._recency_weight = max(0.0, min(1.0, recency))
|
||||
self._recency_scorer.weight = self._recency_weight
|
||||
|
||||
if priority is not None:
|
||||
self._priority_weight = max(0.0, min(1.0, priority))
|
||||
self._priority_scorer.weight = self._priority_weight
|
||||
|
||||
async def _get_context_lock(self, context_id: str) -> asyncio.Lock:
|
||||
"""
|
||||
Get or create a lock for a specific context.
|
||||
|
||||
Thread-safe access to per-context locks prevents race conditions
|
||||
when the same context is scored concurrently. Includes automatic
|
||||
cleanup of old locks to prevent memory growth.
|
||||
|
||||
Args:
|
||||
context_id: The context ID to get a lock for
|
||||
|
||||
Returns:
|
||||
asyncio.Lock for the context
|
||||
"""
|
||||
now = time.time()
|
||||
|
||||
# Fast path: check if lock exists without acquiring main lock
|
||||
# NOTE: We only READ here - no writes to avoid race conditions
|
||||
# with cleanup. The timestamp will be updated in the slow path
|
||||
# if the lock is still valid.
|
||||
lock_entry = self._context_locks.get(context_id)
|
||||
if lock_entry is not None:
|
||||
lock, _ = lock_entry
|
||||
# Return the lock but defer timestamp update to avoid race
|
||||
# The lock is still valid; timestamp update is best-effort
|
||||
return lock
|
||||
|
||||
# Slow path: create lock or update timestamp while holding main lock
|
||||
async with self._locks_lock:
|
||||
# Double-check after acquiring lock - entry may have been
|
||||
# created by another coroutine or deleted by cleanup
|
||||
lock_entry = self._context_locks.get(context_id)
|
||||
if lock_entry is not None:
|
||||
lock, _ = lock_entry
|
||||
# Safe to update timestamp here since we hold the lock
|
||||
self._context_locks[context_id] = (lock, now)
|
||||
return lock
|
||||
|
||||
# Cleanup old locks if we have too many
|
||||
if len(self._context_locks) >= self._max_locks:
|
||||
self._cleanup_old_locks(now)
|
||||
|
||||
# Create new lock
|
||||
new_lock = asyncio.Lock()
|
||||
self._context_locks[context_id] = (new_lock, now)
|
||||
return new_lock
|
||||
|
||||
def _cleanup_old_locks(self, now: float) -> None:
|
||||
"""
|
||||
Remove old locks that haven't been used recently.
|
||||
|
||||
Called while holding _locks_lock. Removes locks older than _lock_ttl,
|
||||
but only if they're not currently held.
|
||||
|
||||
Args:
|
||||
now: Current timestamp for age calculation
|
||||
"""
|
||||
cutoff = now - self._lock_ttl
|
||||
to_remove = []
|
||||
|
||||
for context_id, (lock, last_used) in self._context_locks.items():
|
||||
# Only remove if old AND not currently held
|
||||
if last_used < cutoff and not lock.locked():
|
||||
to_remove.append(context_id)
|
||||
|
||||
# Remove oldest 50% if still over limit after TTL filtering
|
||||
if len(self._context_locks) - len(to_remove) >= self._max_locks:
|
||||
# Sort by last used time and mark oldest for removal
|
||||
sorted_entries = sorted(
|
||||
self._context_locks.items(),
|
||||
key=lambda x: x[1][1], # Sort by last_used time
|
||||
)
|
||||
# Remove oldest 50% that aren't locked
|
||||
target_remove = len(self._context_locks) // 2
|
||||
for context_id, (lock, _) in sorted_entries:
|
||||
if len(to_remove) >= target_remove:
|
||||
break
|
||||
if context_id not in to_remove and not lock.locked():
|
||||
to_remove.append(context_id)
|
||||
|
||||
for context_id in to_remove:
|
||||
del self._context_locks[context_id]
|
||||
|
||||
if to_remove:
|
||||
logger.debug(f"Cleaned up {len(to_remove)} context locks")
|
||||
|
||||
async def score(
|
||||
self,
|
||||
context: BaseContext,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> float:
|
||||
"""
|
||||
Compute composite score for a context.
|
||||
|
||||
Args:
|
||||
context: Context to score
|
||||
query: Query to score against
|
||||
**kwargs: Additional scoring parameters
|
||||
|
||||
Returns:
|
||||
Composite score between 0.0 and 1.0
|
||||
"""
|
||||
scored = await self.score_with_details(context, query, **kwargs)
|
||||
return scored.composite_score
|
||||
|
||||
async def score_with_details(
|
||||
self,
|
||||
context: BaseContext,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> ScoredContext:
|
||||
"""
|
||||
Compute composite score with individual scores.
|
||||
|
||||
Uses per-context locking to prevent race conditions when the same
|
||||
context is scored concurrently in parallel scoring operations.
|
||||
|
||||
Args:
|
||||
context: Context to score
|
||||
query: Query to score against
|
||||
**kwargs: Additional scoring parameters
|
||||
|
||||
Returns:
|
||||
ScoredContext with all scores
|
||||
"""
|
||||
# Get lock for this specific context to prevent race conditions
|
||||
# within concurrent scoring operations for the same query
|
||||
context_lock = await self._get_context_lock(context.id)
|
||||
|
||||
async with context_lock:
|
||||
# Compute individual scores in parallel
|
||||
# Note: We do NOT cache scores on the context because scores are
|
||||
# query-dependent. Caching without considering the query would
|
||||
# return incorrect scores for different queries.
|
||||
relevance_task = self._relevance_scorer.score(context, query, **kwargs)
|
||||
recency_task = self._recency_scorer.score(context, query, **kwargs)
|
||||
priority_task = self._priority_scorer.score(context, query, **kwargs)
|
||||
|
||||
relevance_score, recency_score, priority_score = await asyncio.gather(
|
||||
relevance_task, recency_task, priority_task
|
||||
)
|
||||
|
||||
# Compute weighted composite
|
||||
total_weight = (
|
||||
self._relevance_weight + self._recency_weight + self._priority_weight
|
||||
)
|
||||
|
||||
if total_weight > 0:
|
||||
composite = (
|
||||
relevance_score * self._relevance_weight
|
||||
+ recency_score * self._recency_weight
|
||||
+ priority_score * self._priority_weight
|
||||
) / total_weight
|
||||
else:
|
||||
composite = 0.0
|
||||
|
||||
return ScoredContext(
|
||||
context=context,
|
||||
composite_score=composite,
|
||||
relevance_score=relevance_score,
|
||||
recency_score=recency_score,
|
||||
priority_score=priority_score,
|
||||
)
|
||||
|
||||
async def score_batch(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
query: str,
|
||||
parallel: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> list[ScoredContext]:
|
||||
"""
|
||||
Score multiple contexts.
|
||||
|
||||
Args:
|
||||
contexts: Contexts to score
|
||||
query: Query to score against
|
||||
parallel: Whether to score in parallel
|
||||
**kwargs: Additional scoring parameters
|
||||
|
||||
Returns:
|
||||
List of ScoredContext (same order as input)
|
||||
"""
|
||||
if parallel:
|
||||
tasks = [self.score_with_details(ctx, query, **kwargs) for ctx in contexts]
|
||||
return await asyncio.gather(*tasks)
|
||||
else:
|
||||
results = []
|
||||
for ctx in contexts:
|
||||
scored = await self.score_with_details(ctx, query, **kwargs)
|
||||
results.append(scored)
|
||||
return results
|
||||
|
||||
async def rank(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
query: str,
|
||||
limit: int | None = None,
|
||||
min_score: float = 0.0,
|
||||
**kwargs: Any,
|
||||
) -> list[ScoredContext]:
|
||||
"""
|
||||
Score and rank contexts.
|
||||
|
||||
Args:
|
||||
contexts: Contexts to rank
|
||||
query: Query to rank against
|
||||
limit: Maximum number of results
|
||||
min_score: Minimum score threshold
|
||||
**kwargs: Additional scoring parameters
|
||||
|
||||
Returns:
|
||||
Sorted list of ScoredContext (highest first)
|
||||
"""
|
||||
# Score all contexts
|
||||
scored = await self.score_batch(contexts, query, **kwargs)
|
||||
|
||||
# Filter by minimum score
|
||||
if min_score > 0:
|
||||
scored = [s for s in scored if s.composite_score >= min_score]
|
||||
|
||||
# Sort by score (highest first)
|
||||
scored.sort(reverse=True)
|
||||
|
||||
# Apply limit
|
||||
if limit is not None:
|
||||
scored = scored[:limit]
|
||||
|
||||
return scored
|
||||
135
backend/app/services/context/scoring/priority.py
Normal file
135
backend/app/services/context/scoring/priority.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""
|
||||
Priority Scorer for Context Management.
|
||||
|
||||
Scores context based on assigned priority levels.
|
||||
"""
|
||||
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from ..types import BaseContext, ContextType
|
||||
from .base import BaseScorer
|
||||
|
||||
|
||||
class PriorityScorer(BaseScorer):
|
||||
"""
|
||||
Scores context based on priority levels.
|
||||
|
||||
Converts priority enum values to normalized scores.
|
||||
Also applies type-based priority bonuses.
|
||||
"""
|
||||
|
||||
# Default priority bonuses by context type
|
||||
DEFAULT_TYPE_BONUSES: ClassVar[dict[ContextType, float]] = {
|
||||
ContextType.SYSTEM: 0.2, # System prompts get a boost
|
||||
ContextType.TASK: 0.15, # Current task is important
|
||||
ContextType.TOOL: 0.1, # Recent tool results matter
|
||||
ContextType.KNOWLEDGE: 0.0, # Knowledge scored by relevance
|
||||
ContextType.CONVERSATION: 0.0, # Conversation scored by recency
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight: float = 1.0,
|
||||
type_bonuses: dict[ContextType, float] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize priority scorer.
|
||||
|
||||
Args:
|
||||
weight: Scorer weight for composite scoring
|
||||
type_bonuses: Optional context-type priority bonuses
|
||||
"""
|
||||
super().__init__(weight)
|
||||
self._type_bonuses = type_bonuses or self.DEFAULT_TYPE_BONUSES.copy()
|
||||
|
||||
async def score(
|
||||
self,
|
||||
context: BaseContext,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> float:
|
||||
"""
|
||||
Score context based on priority.
|
||||
|
||||
Args:
|
||||
context: Context to score
|
||||
query: Query (not used for priority, kept for interface)
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
Priority score between 0.0 and 1.0
|
||||
"""
|
||||
# Get base priority score
|
||||
priority_value = context.priority
|
||||
base_score = self._priority_to_score(priority_value)
|
||||
|
||||
# Apply type bonus
|
||||
context_type = context.get_type()
|
||||
bonus = self._type_bonuses.get(context_type, 0.0)
|
||||
|
||||
return self.normalize_score(base_score + bonus)
|
||||
|
||||
def _priority_to_score(self, priority: int) -> float:
|
||||
"""
|
||||
Convert priority value to normalized score.
|
||||
|
||||
Priority values (from ContextPriority):
|
||||
- CRITICAL (100) -> 1.0
|
||||
- HIGH (80) -> 0.8
|
||||
- NORMAL (50) -> 0.5
|
||||
- LOW (20) -> 0.2
|
||||
- MINIMAL (0) -> 0.0
|
||||
|
||||
Args:
|
||||
priority: Priority value (0-100)
|
||||
|
||||
Returns:
|
||||
Normalized score (0.0-1.0)
|
||||
"""
|
||||
# Clamp to valid range
|
||||
clamped = max(0, min(100, priority))
|
||||
return clamped / 100.0
|
||||
|
||||
def get_type_bonus(self, context_type: ContextType) -> float:
|
||||
"""
|
||||
Get priority bonus for a context type.
|
||||
|
||||
Args:
|
||||
context_type: Context type
|
||||
|
||||
Returns:
|
||||
Bonus value
|
||||
"""
|
||||
return self._type_bonuses.get(context_type, 0.0)
|
||||
|
||||
def set_type_bonus(self, context_type: ContextType, bonus: float) -> None:
|
||||
"""
|
||||
Set priority bonus for a context type.
|
||||
|
||||
Args:
|
||||
context_type: Context type
|
||||
bonus: Bonus value (0.0-1.0)
|
||||
"""
|
||||
if not 0.0 <= bonus <= 1.0:
|
||||
raise ValueError("Bonus must be between 0.0 and 1.0")
|
||||
self._type_bonuses[context_type] = bonus
|
||||
|
||||
async def score_batch(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> list[float]:
|
||||
"""
|
||||
Score multiple contexts.
|
||||
|
||||
Args:
|
||||
contexts: Contexts to score
|
||||
query: Query (not used)
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
List of scores (same order as input)
|
||||
"""
|
||||
# Priority scoring is fast, no async needed
|
||||
return [await self.score(ctx, query, **kwargs) for ctx in contexts]
|
||||
141
backend/app/services/context/scoring/recency.py
Normal file
141
backend/app/services/context/scoring/recency.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""
|
||||
Recency Scorer for Context Management.
|
||||
|
||||
Scores context based on how recent it is.
|
||||
More recent content gets higher scores.
|
||||
"""
|
||||
|
||||
import math
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from ..types import BaseContext, ContextType
|
||||
from .base import BaseScorer
|
||||
|
||||
|
||||
class RecencyScorer(BaseScorer):
|
||||
"""
|
||||
Scores context based on recency.
|
||||
|
||||
Uses exponential decay to score content based on age.
|
||||
More recent content scores higher.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight: float = 1.0,
|
||||
half_life_hours: float = 24.0,
|
||||
type_half_lives: dict[ContextType, float] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize recency scorer.
|
||||
|
||||
Args:
|
||||
weight: Scorer weight for composite scoring
|
||||
half_life_hours: Default hours until score decays to 0.5
|
||||
type_half_lives: Optional context-type-specific half lives
|
||||
"""
|
||||
super().__init__(weight)
|
||||
self._half_life_hours = half_life_hours
|
||||
self._type_half_lives = type_half_lives or {}
|
||||
|
||||
# Set sensible defaults for context types
|
||||
if ContextType.CONVERSATION not in self._type_half_lives:
|
||||
self._type_half_lives[ContextType.CONVERSATION] = 1.0 # 1 hour
|
||||
if ContextType.TOOL not in self._type_half_lives:
|
||||
self._type_half_lives[ContextType.TOOL] = 0.5 # 30 minutes
|
||||
if ContextType.KNOWLEDGE not in self._type_half_lives:
|
||||
self._type_half_lives[ContextType.KNOWLEDGE] = 168.0 # 1 week
|
||||
if ContextType.SYSTEM not in self._type_half_lives:
|
||||
self._type_half_lives[ContextType.SYSTEM] = 720.0 # 30 days
|
||||
if ContextType.TASK not in self._type_half_lives:
|
||||
self._type_half_lives[ContextType.TASK] = 24.0 # 1 day
|
||||
|
||||
async def score(
|
||||
self,
|
||||
context: BaseContext,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> float:
|
||||
"""
|
||||
Score context based on recency.
|
||||
|
||||
Args:
|
||||
context: Context to score
|
||||
query: Query (not used for recency, kept for interface)
|
||||
**kwargs: Additional parameters
|
||||
- reference_time: Time to measure recency from (default: now)
|
||||
|
||||
Returns:
|
||||
Recency score between 0.0 and 1.0
|
||||
"""
|
||||
reference_time = kwargs.get("reference_time")
|
||||
if reference_time is None:
|
||||
reference_time = datetime.now(UTC)
|
||||
elif reference_time.tzinfo is None:
|
||||
reference_time = reference_time.replace(tzinfo=UTC)
|
||||
|
||||
# Ensure context timestamp is timezone-aware
|
||||
context_time = context.timestamp
|
||||
if context_time.tzinfo is None:
|
||||
context_time = context_time.replace(tzinfo=UTC)
|
||||
|
||||
# Calculate age in hours
|
||||
age = reference_time - context_time
|
||||
age_hours = max(0, age.total_seconds() / 3600)
|
||||
|
||||
# Get half-life for this context type
|
||||
context_type = context.get_type()
|
||||
half_life = self._type_half_lives.get(context_type, self._half_life_hours)
|
||||
|
||||
# Exponential decay
|
||||
decay_factor = math.exp(-math.log(2) * age_hours / half_life)
|
||||
|
||||
return self.normalize_score(decay_factor)
|
||||
|
||||
def get_half_life(self, context_type: ContextType) -> float:
|
||||
"""
|
||||
Get half-life for a context type.
|
||||
|
||||
Args:
|
||||
context_type: Context type to get half-life for
|
||||
|
||||
Returns:
|
||||
Half-life in hours
|
||||
"""
|
||||
return self._type_half_lives.get(context_type, self._half_life_hours)
|
||||
|
||||
def set_half_life(self, context_type: ContextType, hours: float) -> None:
|
||||
"""
|
||||
Set half-life for a context type.
|
||||
|
||||
Args:
|
||||
context_type: Context type to set half-life for
|
||||
hours: Half-life in hours
|
||||
"""
|
||||
if hours <= 0:
|
||||
raise ValueError("Half-life must be positive")
|
||||
self._type_half_lives[context_type] = hours
|
||||
|
||||
async def score_batch(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> list[float]:
|
||||
"""
|
||||
Score multiple contexts.
|
||||
|
||||
Args:
|
||||
contexts: Contexts to score
|
||||
query: Query (not used)
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
List of scores (same order as input)
|
||||
"""
|
||||
scores = []
|
||||
for context in contexts:
|
||||
score = await self.score(context, query, **kwargs)
|
||||
scores.append(score)
|
||||
return scores
|
||||
220
backend/app/services/context/scoring/relevance.py
Normal file
220
backend/app/services/context/scoring/relevance.py
Normal file
@@ -0,0 +1,220 @@
|
||||
"""
|
||||
Relevance Scorer for Context Management.
|
||||
|
||||
Scores context based on semantic similarity to the query.
|
||||
Uses Knowledge Base embeddings when available.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from ..config import ContextSettings, get_context_settings
|
||||
from ..types import BaseContext, KnowledgeContext
|
||||
from .base import BaseScorer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.services.mcp.client_manager import MCPClientManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RelevanceScorer(BaseScorer):
|
||||
"""
|
||||
Scores context based on relevance to query.
|
||||
|
||||
Uses multiple strategies:
|
||||
1. Pre-computed scores (from RAG results)
|
||||
2. MCP-based semantic similarity (via Knowledge Base)
|
||||
3. Keyword matching fallback
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mcp_manager: "MCPClientManager | None" = None,
|
||||
weight: float = 1.0,
|
||||
keyword_fallback_weight: float | None = None,
|
||||
semantic_max_chars: int | None = None,
|
||||
settings: ContextSettings | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize relevance scorer.
|
||||
|
||||
Args:
|
||||
mcp_manager: MCP manager for Knowledge Base calls
|
||||
weight: Scorer weight for composite scoring
|
||||
keyword_fallback_weight: Max score for keyword-based fallback (overrides settings)
|
||||
semantic_max_chars: Max content length for semantic similarity (overrides settings)
|
||||
settings: Context settings (uses global if None)
|
||||
"""
|
||||
super().__init__(weight)
|
||||
self._settings = settings or get_context_settings()
|
||||
self._mcp = mcp_manager
|
||||
|
||||
# Use provided values or fall back to settings
|
||||
self._keyword_fallback_weight = (
|
||||
keyword_fallback_weight
|
||||
if keyword_fallback_weight is not None
|
||||
else self._settings.relevance_keyword_fallback_weight
|
||||
)
|
||||
self._semantic_max_chars = (
|
||||
semantic_max_chars
|
||||
if semantic_max_chars is not None
|
||||
else self._settings.relevance_semantic_max_chars
|
||||
)
|
||||
|
||||
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
|
||||
"""Set MCP manager for semantic scoring."""
|
||||
self._mcp = mcp_manager
|
||||
|
||||
async def score(
|
||||
self,
|
||||
context: BaseContext,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> float:
|
||||
"""
|
||||
Score context relevance to query.
|
||||
|
||||
Args:
|
||||
context: Context to score
|
||||
query: Query to score against
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
Relevance score between 0.0 and 1.0
|
||||
"""
|
||||
# 1. Check for pre-computed relevance score
|
||||
if (
|
||||
isinstance(context, KnowledgeContext)
|
||||
and context.relevance_score is not None
|
||||
):
|
||||
return self.normalize_score(context.relevance_score)
|
||||
|
||||
# 2. Check metadata for score
|
||||
if "relevance_score" in context.metadata:
|
||||
return self.normalize_score(context.metadata["relevance_score"])
|
||||
|
||||
if "score" in context.metadata:
|
||||
return self.normalize_score(context.metadata["score"])
|
||||
|
||||
# 3. Try MCP-based semantic similarity (if compute_similarity tool is available)
|
||||
# Note: This requires the knowledge-base MCP server to implement compute_similarity
|
||||
if self._mcp is not None:
|
||||
try:
|
||||
score = await self._compute_semantic_similarity(context, query)
|
||||
if score is not None:
|
||||
return score
|
||||
except Exception as e:
|
||||
# Log at debug level since this is expected if compute_similarity
|
||||
# tool is not implemented in the Knowledge Base server
|
||||
logger.debug(
|
||||
f"Semantic scoring unavailable, using keyword fallback: {e}"
|
||||
)
|
||||
|
||||
# 4. Fall back to keyword matching
|
||||
return self._compute_keyword_score(context, query)
|
||||
|
||||
async def _compute_semantic_similarity(
|
||||
self,
|
||||
context: BaseContext,
|
||||
query: str,
|
||||
) -> float | None:
|
||||
"""
|
||||
Compute semantic similarity using Knowledge Base embeddings.
|
||||
|
||||
Args:
|
||||
context: Context to score
|
||||
query: Query to compare
|
||||
|
||||
Returns:
|
||||
Similarity score or None if unavailable
|
||||
"""
|
||||
if self._mcp is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Use Knowledge Base's search capability to compute similarity
|
||||
result = await self._mcp.call_tool(
|
||||
server="knowledge-base",
|
||||
tool="compute_similarity",
|
||||
args={
|
||||
"text1": query,
|
||||
"text2": context.content[
|
||||
: self._semantic_max_chars
|
||||
], # Limit content length
|
||||
},
|
||||
)
|
||||
|
||||
if result.success and isinstance(result.data, dict):
|
||||
similarity = result.data.get("similarity")
|
||||
if similarity is not None:
|
||||
return self.normalize_score(float(similarity))
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Semantic similarity computation failed: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def _compute_keyword_score(
|
||||
self,
|
||||
context: BaseContext,
|
||||
query: str,
|
||||
) -> float:
|
||||
"""
|
||||
Compute relevance score based on keyword matching.
|
||||
|
||||
Simple but fast fallback when semantic search is unavailable.
|
||||
|
||||
Args:
|
||||
context: Context to score
|
||||
query: Query to match
|
||||
|
||||
Returns:
|
||||
Keyword-based relevance score
|
||||
"""
|
||||
if not query or not context.content:
|
||||
return 0.0
|
||||
|
||||
# Extract keywords from query
|
||||
query_lower = query.lower()
|
||||
content_lower = context.content.lower()
|
||||
|
||||
# Simple word tokenization
|
||||
query_words = set(re.findall(r"\b\w{3,}\b", query_lower))
|
||||
content_words = set(re.findall(r"\b\w{3,}\b", content_lower))
|
||||
|
||||
if not query_words:
|
||||
return 0.0
|
||||
|
||||
# Calculate overlap
|
||||
common_words = query_words & content_words
|
||||
overlap_ratio = len(common_words) / len(query_words)
|
||||
|
||||
# Apply fallback weight ceiling
|
||||
return self.normalize_score(overlap_ratio * self._keyword_fallback_weight)
|
||||
|
||||
async def score_batch(
|
||||
self,
|
||||
contexts: list[BaseContext],
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> list[float]:
|
||||
"""
|
||||
Score multiple contexts in parallel.
|
||||
|
||||
Args:
|
||||
contexts: Contexts to score
|
||||
query: Query to score against
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
List of scores (same order as input)
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
if not contexts:
|
||||
return []
|
||||
|
||||
tasks = [self.score(context, query, **kwargs) for context in contexts]
|
||||
return await asyncio.gather(*tasks)
|
||||
49
backend/app/services/context/types/__init__.py
Normal file
49
backend/app/services/context/types/__init__.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""
|
||||
Context Types Module.
|
||||
|
||||
Provides all context types used in the Context Management Engine.
|
||||
"""
|
||||
|
||||
from .base import (
|
||||
AssembledContext,
|
||||
BaseContext,
|
||||
ContextPriority,
|
||||
ContextType,
|
||||
)
|
||||
from .conversation import (
|
||||
ConversationContext,
|
||||
MessageRole,
|
||||
)
|
||||
from .knowledge import KnowledgeContext
|
||||
from .memory import (
|
||||
MemoryContext,
|
||||
MemorySubtype,
|
||||
)
|
||||
from .system import SystemContext
|
||||
from .task import (
|
||||
TaskComplexity,
|
||||
TaskContext,
|
||||
TaskStatus,
|
||||
)
|
||||
from .tool import (
|
||||
ToolContext,
|
||||
ToolResultStatus,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AssembledContext",
|
||||
"BaseContext",
|
||||
"ContextPriority",
|
||||
"ContextType",
|
||||
"ConversationContext",
|
||||
"KnowledgeContext",
|
||||
"MemoryContext",
|
||||
"MemorySubtype",
|
||||
"MessageRole",
|
||||
"SystemContext",
|
||||
"TaskComplexity",
|
||||
"TaskContext",
|
||||
"TaskStatus",
|
||||
"ToolContext",
|
||||
"ToolResultStatus",
|
||||
]
|
||||
348
backend/app/services/context/types/base.py
Normal file
348
backend/app/services/context/types/base.py
Normal file
@@ -0,0 +1,348 @@
|
||||
"""
|
||||
Base Context Types and Enums.
|
||||
|
||||
Provides the foundation for all context types used in
|
||||
the Context Management Engine.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
|
||||
class ContextType(str, Enum):
|
||||
"""
|
||||
Types of context that can be assembled.
|
||||
|
||||
Each type has specific handling, formatting, and
|
||||
budget allocation rules.
|
||||
"""
|
||||
|
||||
SYSTEM = "system"
|
||||
TASK = "task"
|
||||
KNOWLEDGE = "knowledge"
|
||||
CONVERSATION = "conversation"
|
||||
TOOL = "tool"
|
||||
MEMORY = "memory" # Agent memory (working, episodic, semantic, procedural)
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, value: str) -> "ContextType":
|
||||
"""
|
||||
Convert string to ContextType.
|
||||
|
||||
Args:
|
||||
value: String value
|
||||
|
||||
Returns:
|
||||
ContextType enum value
|
||||
|
||||
Raises:
|
||||
ValueError: If value is not a valid context type
|
||||
"""
|
||||
try:
|
||||
return cls(value.lower())
|
||||
except ValueError:
|
||||
valid = ", ".join(t.value for t in cls)
|
||||
raise ValueError(f"Invalid context type '{value}'. Valid types: {valid}")
|
||||
|
||||
|
||||
class ContextPriority(int, Enum):
|
||||
"""
|
||||
Priority levels for context ordering.
|
||||
|
||||
Higher values indicate higher priority.
|
||||
"""
|
||||
|
||||
LOWEST = 0
|
||||
LOW = 25
|
||||
NORMAL = 50
|
||||
HIGH = 75
|
||||
HIGHEST = 100
|
||||
CRITICAL = 150 # Never omit
|
||||
|
||||
@classmethod
|
||||
def from_int(cls, value: int) -> "ContextPriority":
|
||||
"""
|
||||
Get closest priority level for an integer.
|
||||
|
||||
Args:
|
||||
value: Integer priority value
|
||||
|
||||
Returns:
|
||||
Closest ContextPriority enum value
|
||||
"""
|
||||
priorities = sorted(cls, key=lambda p: p.value)
|
||||
for priority in reversed(priorities):
|
||||
if value >= priority.value:
|
||||
return priority
|
||||
return cls.LOWEST
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class BaseContext(ABC):
|
||||
"""
|
||||
Abstract base class for all context types.
|
||||
|
||||
Provides common fields and methods for context handling,
|
||||
scoring, and serialization.
|
||||
"""
|
||||
|
||||
# Required fields
|
||||
content: str
|
||||
source: str
|
||||
|
||||
# Optional fields with defaults
|
||||
id: str = field(default_factory=lambda: str(uuid4()))
|
||||
timestamp: datetime = field(default_factory=lambda: datetime.now(UTC))
|
||||
priority: int = field(default=ContextPriority.NORMAL.value)
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# Computed/cached fields
|
||||
_token_count: int | None = field(default=None, repr=False)
|
||||
_score: float | None = field(default=None, repr=False)
|
||||
|
||||
@property
|
||||
def token_count(self) -> int | None:
|
||||
"""Get cached token count (None if not counted yet)."""
|
||||
return self._token_count
|
||||
|
||||
@token_count.setter
|
||||
def token_count(self, value: int) -> None:
|
||||
"""Set token count."""
|
||||
self._token_count = value
|
||||
|
||||
@property
|
||||
def score(self) -> float | None:
|
||||
"""Get cached score (None if not scored yet)."""
|
||||
return self._score
|
||||
|
||||
@score.setter
|
||||
def score(self, value: float) -> None:
|
||||
"""Set score (clamped to 0.0-1.0)."""
|
||||
self._score = max(0.0, min(1.0, value))
|
||||
|
||||
@abstractmethod
|
||||
def get_type(self) -> ContextType:
|
||||
"""
|
||||
Get the type of this context.
|
||||
|
||||
Returns:
|
||||
ContextType enum value
|
||||
"""
|
||||
...
|
||||
|
||||
def get_age_seconds(self) -> float:
|
||||
"""
|
||||
Get age of context in seconds.
|
||||
|
||||
Returns:
|
||||
Age in seconds since creation
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
delta = now - self.timestamp
|
||||
return delta.total_seconds()
|
||||
|
||||
def get_age_hours(self) -> float:
|
||||
"""
|
||||
Get age of context in hours.
|
||||
|
||||
Returns:
|
||||
Age in hours since creation
|
||||
"""
|
||||
return self.get_age_seconds() / 3600
|
||||
|
||||
def is_stale(self, max_age_hours: float = 168.0) -> bool:
|
||||
"""
|
||||
Check if context is stale.
|
||||
|
||||
Args:
|
||||
max_age_hours: Maximum age before considered stale (default 7 days)
|
||||
|
||||
Returns:
|
||||
True if context is older than max_age_hours
|
||||
"""
|
||||
return self.get_age_hours() > max_age_hours
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""
|
||||
Convert context to dictionary for serialization.
|
||||
|
||||
Returns:
|
||||
Dictionary representation
|
||||
"""
|
||||
return {
|
||||
"id": self.id,
|
||||
"type": self.get_type().value,
|
||||
"content": self.content,
|
||||
"source": self.source,
|
||||
"timestamp": self.timestamp.isoformat(),
|
||||
"priority": self.priority,
|
||||
"metadata": self.metadata,
|
||||
"token_count": self._token_count,
|
||||
"score": self._score,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "BaseContext":
|
||||
"""
|
||||
Create context from dictionary.
|
||||
|
||||
Note: Subclasses should override this to return correct type.
|
||||
|
||||
Args:
|
||||
data: Dictionary with context data
|
||||
|
||||
Returns:
|
||||
Context instance
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement from_dict")
|
||||
|
||||
def truncate(self, max_tokens: int, suffix: str = "... [truncated]") -> str:
|
||||
"""
|
||||
Truncate content to fit within token limit.
|
||||
|
||||
This is a rough estimation based on characters.
|
||||
For accurate truncation, use the TokenCalculator.
|
||||
|
||||
Args:
|
||||
max_tokens: Maximum tokens allowed
|
||||
suffix: Suffix to append when truncated
|
||||
|
||||
Returns:
|
||||
Truncated content
|
||||
"""
|
||||
if self._token_count is None or self._token_count <= max_tokens:
|
||||
return self.content
|
||||
|
||||
# Rough estimation: 4 chars per token on average
|
||||
estimated_chars = max_tokens * 4
|
||||
suffix_chars = len(suffix)
|
||||
|
||||
if len(self.content) <= estimated_chars:
|
||||
return self.content
|
||||
|
||||
truncated = self.content[: estimated_chars - suffix_chars]
|
||||
# Try to break at word boundary
|
||||
last_space = truncated.rfind(" ")
|
||||
if last_space > estimated_chars * 0.8:
|
||||
truncated = truncated[:last_space]
|
||||
|
||||
return truncated + suffix
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""Hash based on ID for set/dict usage."""
|
||||
return hash(self.id)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
"""Equality based on ID."""
|
||||
if not isinstance(other, BaseContext):
|
||||
return False
|
||||
return self.id == other.id
|
||||
|
||||
|
||||
@dataclass
|
||||
class AssembledContext:
|
||||
"""
|
||||
Result of context assembly.
|
||||
|
||||
Contains the final formatted context ready for LLM consumption,
|
||||
along with metadata about the assembly process.
|
||||
"""
|
||||
|
||||
# Main content
|
||||
content: str
|
||||
total_tokens: int
|
||||
|
||||
# Assembly metadata
|
||||
context_count: int
|
||||
excluded_count: int = 0
|
||||
assembly_time_ms: float = 0.0
|
||||
model: str = ""
|
||||
|
||||
# Included contexts (optional - for inspection)
|
||||
contexts: list["BaseContext"] = field(default_factory=list)
|
||||
|
||||
# Additional metadata from assembly
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# Budget tracking
|
||||
budget_total: int = 0
|
||||
budget_used: int = 0
|
||||
|
||||
# Context breakdown
|
||||
by_type: dict[str, int] = field(default_factory=dict)
|
||||
|
||||
# Cache info
|
||||
cache_hit: bool = False
|
||||
cache_key: str | None = None
|
||||
|
||||
# Aliases for backward compatibility
|
||||
@property
|
||||
def token_count(self) -> int:
|
||||
"""Alias for total_tokens."""
|
||||
return self.total_tokens
|
||||
|
||||
@property
|
||||
def contexts_included(self) -> int:
|
||||
"""Alias for context_count."""
|
||||
return self.context_count
|
||||
|
||||
@property
|
||||
def contexts_excluded(self) -> int:
|
||||
"""Alias for excluded_count."""
|
||||
return self.excluded_count
|
||||
|
||||
@property
|
||||
def budget_utilization(self) -> float:
|
||||
"""Get budget utilization percentage."""
|
||||
if self.budget_total == 0:
|
||||
return 0.0
|
||||
return self.budget_used / self.budget_total
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"content": self.content,
|
||||
"total_tokens": self.total_tokens,
|
||||
"context_count": self.context_count,
|
||||
"excluded_count": self.excluded_count,
|
||||
"assembly_time_ms": round(self.assembly_time_ms, 2),
|
||||
"model": self.model,
|
||||
"metadata": self.metadata,
|
||||
"budget_total": self.budget_total,
|
||||
"budget_used": self.budget_used,
|
||||
"budget_utilization": round(self.budget_utilization, 3),
|
||||
"by_type": self.by_type,
|
||||
"cache_hit": self.cache_hit,
|
||||
"cache_key": self.cache_key,
|
||||
}
|
||||
|
||||
def to_json(self) -> str:
|
||||
"""Convert to JSON string."""
|
||||
import json
|
||||
|
||||
return json.dumps(self.to_dict())
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json_str: str) -> "AssembledContext":
|
||||
"""Create from JSON string."""
|
||||
import json
|
||||
|
||||
data = json.loads(json_str)
|
||||
return cls(
|
||||
content=data["content"],
|
||||
total_tokens=data["total_tokens"],
|
||||
context_count=data["context_count"],
|
||||
excluded_count=data.get("excluded_count", 0),
|
||||
assembly_time_ms=data.get("assembly_time_ms", 0.0),
|
||||
model=data.get("model", ""),
|
||||
metadata=data.get("metadata", {}),
|
||||
budget_total=data.get("budget_total", 0),
|
||||
budget_used=data.get("budget_used", 0),
|
||||
by_type=data.get("by_type", {}),
|
||||
cache_hit=data.get("cache_hit", False),
|
||||
cache_key=data.get("cache_key"),
|
||||
)
|
||||
182
backend/app/services/context/types/conversation.py
Normal file
182
backend/app/services/context/types/conversation.py
Normal file
@@ -0,0 +1,182 @@
|
||||
"""
|
||||
Conversation Context Type.
|
||||
|
||||
Represents conversation history for context continuity.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from .base import BaseContext, ContextPriority, ContextType
|
||||
|
||||
|
||||
class MessageRole(str, Enum):
|
||||
"""Roles for conversation messages."""
|
||||
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
SYSTEM = "system"
|
||||
TOOL = "tool"
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, value: str) -> "MessageRole":
|
||||
"""Convert string to MessageRole."""
|
||||
try:
|
||||
return cls(value.lower())
|
||||
except ValueError:
|
||||
# Default to user for unknown roles
|
||||
return cls.USER
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class ConversationContext(BaseContext):
|
||||
"""
|
||||
Context from conversation history.
|
||||
|
||||
Represents a single turn in the conversation,
|
||||
including user messages, assistant responses,
|
||||
and tool results.
|
||||
"""
|
||||
|
||||
# Conversation-specific fields
|
||||
role: MessageRole = field(default=MessageRole.USER)
|
||||
turn_index: int = field(default=0)
|
||||
session_id: str | None = field(default=None)
|
||||
parent_message_id: str | None = field(default=None)
|
||||
|
||||
def get_type(self) -> ContextType:
|
||||
"""Return CONVERSATION context type."""
|
||||
return ContextType.CONVERSATION
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary with conversation-specific fields."""
|
||||
base = super().to_dict()
|
||||
base.update(
|
||||
{
|
||||
"role": self.role.value,
|
||||
"turn_index": self.turn_index,
|
||||
"session_id": self.session_id,
|
||||
"parent_message_id": self.parent_message_id,
|
||||
}
|
||||
)
|
||||
return base
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "ConversationContext":
|
||||
"""Create ConversationContext from dictionary."""
|
||||
role = data.get("role", "user")
|
||||
if isinstance(role, str):
|
||||
role = MessageRole.from_string(role)
|
||||
|
||||
return cls(
|
||||
id=data.get("id", ""),
|
||||
content=data["content"],
|
||||
source=data.get("source", "conversation"),
|
||||
timestamp=datetime.fromisoformat(data["timestamp"])
|
||||
if isinstance(data.get("timestamp"), str)
|
||||
else data.get("timestamp", datetime.now(UTC)),
|
||||
priority=data.get("priority", ContextPriority.NORMAL.value),
|
||||
metadata=data.get("metadata", {}),
|
||||
role=role,
|
||||
turn_index=data.get("turn_index", 0),
|
||||
session_id=data.get("session_id"),
|
||||
parent_message_id=data.get("parent_message_id"),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_message(
|
||||
cls,
|
||||
content: str,
|
||||
role: str | MessageRole,
|
||||
turn_index: int = 0,
|
||||
session_id: str | None = None,
|
||||
timestamp: datetime | None = None,
|
||||
) -> "ConversationContext":
|
||||
"""
|
||||
Create ConversationContext from a message.
|
||||
|
||||
Args:
|
||||
content: Message content
|
||||
role: Message role (user, assistant, system, tool)
|
||||
turn_index: Position in conversation
|
||||
session_id: Session identifier
|
||||
timestamp: Message timestamp
|
||||
|
||||
Returns:
|
||||
ConversationContext instance
|
||||
"""
|
||||
if isinstance(role, str):
|
||||
role = MessageRole.from_string(role)
|
||||
|
||||
# Recent messages have higher priority
|
||||
priority = ContextPriority.NORMAL.value
|
||||
|
||||
return cls(
|
||||
content=content,
|
||||
source="conversation",
|
||||
role=role,
|
||||
turn_index=turn_index,
|
||||
session_id=session_id,
|
||||
timestamp=timestamp or datetime.now(UTC),
|
||||
priority=priority,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_history(
|
||||
cls,
|
||||
messages: list[dict[str, Any]],
|
||||
session_id: str | None = None,
|
||||
) -> list["ConversationContext"]:
|
||||
"""
|
||||
Create multiple ConversationContexts from message history.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content'
|
||||
session_id: Session identifier
|
||||
|
||||
Returns:
|
||||
List of ConversationContext instances
|
||||
"""
|
||||
contexts = []
|
||||
for i, msg in enumerate(messages):
|
||||
ctx = cls.from_message(
|
||||
content=msg.get("content", ""),
|
||||
role=msg.get("role", "user"),
|
||||
turn_index=i,
|
||||
session_id=session_id,
|
||||
timestamp=datetime.fromisoformat(msg["timestamp"])
|
||||
if "timestamp" in msg
|
||||
else None,
|
||||
)
|
||||
contexts.append(ctx)
|
||||
return contexts
|
||||
|
||||
def is_user_message(self) -> bool:
|
||||
"""Check if this is a user message."""
|
||||
return self.role == MessageRole.USER
|
||||
|
||||
def is_assistant_message(self) -> bool:
|
||||
"""Check if this is an assistant message."""
|
||||
return self.role == MessageRole.ASSISTANT
|
||||
|
||||
def is_tool_result(self) -> bool:
|
||||
"""Check if this is a tool result."""
|
||||
return self.role == MessageRole.TOOL
|
||||
|
||||
def format_for_prompt(self) -> str:
|
||||
"""
|
||||
Format message for inclusion in prompt.
|
||||
|
||||
Returns:
|
||||
Formatted message string
|
||||
"""
|
||||
role_labels = {
|
||||
MessageRole.USER: "User",
|
||||
MessageRole.ASSISTANT: "Assistant",
|
||||
MessageRole.SYSTEM: "System",
|
||||
MessageRole.TOOL: "Tool Result",
|
||||
}
|
||||
label = role_labels.get(self.role, "Unknown")
|
||||
return f"{label}: {self.content}"
|
||||
152
backend/app/services/context/types/knowledge.py
Normal file
152
backend/app/services/context/types/knowledge.py
Normal file
@@ -0,0 +1,152 @@
|
||||
"""
|
||||
Knowledge Context Type.
|
||||
|
||||
Represents RAG results from the Knowledge Base MCP server.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from .base import BaseContext, ContextPriority, ContextType
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class KnowledgeContext(BaseContext):
|
||||
"""
|
||||
Context from knowledge base / RAG retrieval.
|
||||
|
||||
Knowledge context represents chunks retrieved from the
|
||||
Knowledge Base MCP server, including:
|
||||
- Code snippets
|
||||
- Documentation
|
||||
- Previous conversations
|
||||
- External knowledge
|
||||
|
||||
Each chunk includes relevance scoring from the search.
|
||||
"""
|
||||
|
||||
# Knowledge-specific fields
|
||||
collection: str = field(default="default")
|
||||
file_type: str | None = field(default=None)
|
||||
chunk_index: int = field(default=0)
|
||||
relevance_score: float = field(default=0.0)
|
||||
search_query: str = field(default="")
|
||||
|
||||
def get_type(self) -> ContextType:
|
||||
"""Return KNOWLEDGE context type."""
|
||||
return ContextType.KNOWLEDGE
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary with knowledge-specific fields."""
|
||||
base = super().to_dict()
|
||||
base.update(
|
||||
{
|
||||
"collection": self.collection,
|
||||
"file_type": self.file_type,
|
||||
"chunk_index": self.chunk_index,
|
||||
"relevance_score": self.relevance_score,
|
||||
"search_query": self.search_query,
|
||||
}
|
||||
)
|
||||
return base
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "KnowledgeContext":
|
||||
"""Create KnowledgeContext from dictionary."""
|
||||
return cls(
|
||||
id=data.get("id", ""),
|
||||
content=data["content"],
|
||||
source=data["source"],
|
||||
timestamp=datetime.fromisoformat(data["timestamp"])
|
||||
if isinstance(data.get("timestamp"), str)
|
||||
else data.get("timestamp", datetime.now(UTC)),
|
||||
priority=data.get("priority", ContextPriority.NORMAL.value),
|
||||
metadata=data.get("metadata", {}),
|
||||
collection=data.get("collection", "default"),
|
||||
file_type=data.get("file_type"),
|
||||
chunk_index=data.get("chunk_index", 0),
|
||||
relevance_score=data.get("relevance_score", 0.0),
|
||||
search_query=data.get("search_query", ""),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_search_result(
|
||||
cls,
|
||||
result: dict[str, Any],
|
||||
query: str,
|
||||
) -> "KnowledgeContext":
|
||||
"""
|
||||
Create KnowledgeContext from a Knowledge Base search result.
|
||||
|
||||
Args:
|
||||
result: Search result from Knowledge Base MCP
|
||||
query: Search query used
|
||||
|
||||
Returns:
|
||||
KnowledgeContext instance
|
||||
"""
|
||||
return cls(
|
||||
content=result.get("content", ""),
|
||||
source=result.get("source_path", "unknown"),
|
||||
collection=result.get("collection", "default"),
|
||||
file_type=result.get("file_type"),
|
||||
chunk_index=result.get("chunk_index", 0),
|
||||
relevance_score=result.get("score", 0.0),
|
||||
search_query=query,
|
||||
metadata={
|
||||
"chunk_id": result.get("id"),
|
||||
"content_hash": result.get("content_hash"),
|
||||
},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_search_results(
|
||||
cls,
|
||||
results: list[dict[str, Any]],
|
||||
query: str,
|
||||
) -> list["KnowledgeContext"]:
|
||||
"""
|
||||
Create multiple KnowledgeContexts from search results.
|
||||
|
||||
Args:
|
||||
results: List of search results
|
||||
query: Search query used
|
||||
|
||||
Returns:
|
||||
List of KnowledgeContext instances
|
||||
"""
|
||||
return [cls.from_search_result(r, query) for r in results]
|
||||
|
||||
def is_code(self) -> bool:
|
||||
"""Check if this is code content."""
|
||||
code_types = {
|
||||
"python",
|
||||
"javascript",
|
||||
"typescript",
|
||||
"go",
|
||||
"rust",
|
||||
"java",
|
||||
"c",
|
||||
"cpp",
|
||||
}
|
||||
return self.file_type is not None and self.file_type.lower() in code_types
|
||||
|
||||
def is_documentation(self) -> bool:
|
||||
"""Check if this is documentation content."""
|
||||
doc_types = {"markdown", "rst", "txt", "md"}
|
||||
return self.file_type is not None and self.file_type.lower() in doc_types
|
||||
|
||||
def get_formatted_source(self) -> str:
|
||||
"""
|
||||
Get a formatted source string for display.
|
||||
|
||||
Returns:
|
||||
Formatted source string
|
||||
"""
|
||||
parts = [self.source]
|
||||
if self.file_type:
|
||||
parts.append(f"({self.file_type})")
|
||||
if self.collection != "default":
|
||||
parts.insert(0, f"[{self.collection}]")
|
||||
return " ".join(parts)
|
||||
282
backend/app/services/context/types/memory.py
Normal file
282
backend/app/services/context/types/memory.py
Normal file
@@ -0,0 +1,282 @@
|
||||
"""
|
||||
Memory Context Type.
|
||||
|
||||
Represents agent memory as context for LLM requests.
|
||||
Includes working, episodic, semantic, and procedural memories.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from .base import BaseContext, ContextPriority, ContextType
|
||||
|
||||
|
||||
class MemorySubtype(str, Enum):
|
||||
"""Types of agent memory."""
|
||||
|
||||
WORKING = "working" # Session-scoped temporary data
|
||||
EPISODIC = "episodic" # Task history and outcomes
|
||||
SEMANTIC = "semantic" # Facts and knowledge
|
||||
PROCEDURAL = "procedural" # Learned procedures
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class MemoryContext(BaseContext):
|
||||
"""
|
||||
Context from agent memory system.
|
||||
|
||||
Memory context represents data retrieved from the agent
|
||||
memory system, including:
|
||||
- Working memory: Current session state
|
||||
- Episodic memory: Past task experiences
|
||||
- Semantic memory: Learned facts and knowledge
|
||||
- Procedural memory: Known procedures and workflows
|
||||
|
||||
Each memory item includes relevance scoring from search.
|
||||
"""
|
||||
|
||||
# Memory-specific fields
|
||||
memory_subtype: MemorySubtype = field(default=MemorySubtype.EPISODIC)
|
||||
memory_id: str | None = field(default=None)
|
||||
relevance_score: float = field(default=0.0)
|
||||
importance: float = field(default=0.5)
|
||||
search_query: str = field(default="")
|
||||
|
||||
# Type-specific fields (populated based on memory_subtype)
|
||||
key: str | None = field(default=None) # For working memory
|
||||
task_type: str | None = field(default=None) # For episodic
|
||||
outcome: str | None = field(default=None) # For episodic
|
||||
subject: str | None = field(default=None) # For semantic
|
||||
predicate: str | None = field(default=None) # For semantic
|
||||
object_value: str | None = field(default=None) # For semantic
|
||||
trigger: str | None = field(default=None) # For procedural
|
||||
success_rate: float | None = field(default=None) # For procedural
|
||||
|
||||
def get_type(self) -> ContextType:
|
||||
"""Return MEMORY context type."""
|
||||
return ContextType.MEMORY
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary with memory-specific fields."""
|
||||
base = super().to_dict()
|
||||
base.update(
|
||||
{
|
||||
"memory_subtype": self.memory_subtype.value,
|
||||
"memory_id": self.memory_id,
|
||||
"relevance_score": self.relevance_score,
|
||||
"importance": self.importance,
|
||||
"search_query": self.search_query,
|
||||
"key": self.key,
|
||||
"task_type": self.task_type,
|
||||
"outcome": self.outcome,
|
||||
"subject": self.subject,
|
||||
"predicate": self.predicate,
|
||||
"object_value": self.object_value,
|
||||
"trigger": self.trigger,
|
||||
"success_rate": self.success_rate,
|
||||
}
|
||||
)
|
||||
return base
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "MemoryContext":
|
||||
"""Create MemoryContext from dictionary."""
|
||||
return cls(
|
||||
id=data.get("id", ""),
|
||||
content=data["content"],
|
||||
source=data["source"],
|
||||
timestamp=datetime.fromisoformat(data["timestamp"])
|
||||
if isinstance(data.get("timestamp"), str)
|
||||
else data.get("timestamp", datetime.now(UTC)),
|
||||
priority=data.get("priority", ContextPriority.NORMAL.value),
|
||||
metadata=data.get("metadata", {}),
|
||||
memory_subtype=MemorySubtype(data.get("memory_subtype", "episodic")),
|
||||
memory_id=data.get("memory_id"),
|
||||
relevance_score=data.get("relevance_score", 0.0),
|
||||
importance=data.get("importance", 0.5),
|
||||
search_query=data.get("search_query", ""),
|
||||
key=data.get("key"),
|
||||
task_type=data.get("task_type"),
|
||||
outcome=data.get("outcome"),
|
||||
subject=data.get("subject"),
|
||||
predicate=data.get("predicate"),
|
||||
object_value=data.get("object_value"),
|
||||
trigger=data.get("trigger"),
|
||||
success_rate=data.get("success_rate"),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_working_memory(
|
||||
cls,
|
||||
key: str,
|
||||
value: Any,
|
||||
source: str = "working_memory",
|
||||
query: str = "",
|
||||
) -> "MemoryContext":
|
||||
"""
|
||||
Create MemoryContext from working memory entry.
|
||||
|
||||
Args:
|
||||
key: Working memory key
|
||||
value: Value stored at key
|
||||
source: Source identifier
|
||||
query: Search query used
|
||||
|
||||
Returns:
|
||||
MemoryContext instance
|
||||
"""
|
||||
return cls(
|
||||
content=str(value),
|
||||
source=source,
|
||||
memory_subtype=MemorySubtype.WORKING,
|
||||
key=key,
|
||||
relevance_score=1.0, # Working memory is always relevant
|
||||
importance=0.8, # Higher importance for current session state
|
||||
search_query=query,
|
||||
priority=ContextPriority.HIGH.value,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_episodic_memory(
|
||||
cls,
|
||||
episode: Any,
|
||||
query: str = "",
|
||||
) -> "MemoryContext":
|
||||
"""
|
||||
Create MemoryContext from episodic memory episode.
|
||||
|
||||
Args:
|
||||
episode: Episode object from episodic memory
|
||||
query: Search query used
|
||||
|
||||
Returns:
|
||||
MemoryContext instance
|
||||
"""
|
||||
outcome_val = None
|
||||
if hasattr(episode, "outcome") and episode.outcome:
|
||||
outcome_val = (
|
||||
episode.outcome.value
|
||||
if hasattr(episode.outcome, "value")
|
||||
else str(episode.outcome)
|
||||
)
|
||||
|
||||
return cls(
|
||||
content=episode.task_description,
|
||||
source=f"episodic:{episode.id}",
|
||||
memory_subtype=MemorySubtype.EPISODIC,
|
||||
memory_id=str(episode.id),
|
||||
relevance_score=getattr(episode, "importance_score", 0.5),
|
||||
importance=getattr(episode, "importance_score", 0.5),
|
||||
search_query=query,
|
||||
task_type=getattr(episode, "task_type", None),
|
||||
outcome=outcome_val,
|
||||
metadata={
|
||||
"session_id": getattr(episode, "session_id", None),
|
||||
"occurred_at": episode.occurred_at.isoformat()
|
||||
if hasattr(episode, "occurred_at") and episode.occurred_at
|
||||
else None,
|
||||
"lessons_learned": getattr(episode, "lessons_learned", []),
|
||||
},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_semantic_memory(
|
||||
cls,
|
||||
fact: Any,
|
||||
query: str = "",
|
||||
) -> "MemoryContext":
|
||||
"""
|
||||
Create MemoryContext from semantic memory fact.
|
||||
|
||||
Args:
|
||||
fact: Fact object from semantic memory
|
||||
query: Search query used
|
||||
|
||||
Returns:
|
||||
MemoryContext instance
|
||||
"""
|
||||
triple = f"{fact.subject} {fact.predicate} {fact.object}"
|
||||
return cls(
|
||||
content=triple,
|
||||
source=f"semantic:{fact.id}",
|
||||
memory_subtype=MemorySubtype.SEMANTIC,
|
||||
memory_id=str(fact.id),
|
||||
relevance_score=getattr(fact, "confidence", 0.5),
|
||||
importance=getattr(fact, "confidence", 0.5),
|
||||
search_query=query,
|
||||
subject=fact.subject,
|
||||
predicate=fact.predicate,
|
||||
object_value=fact.object,
|
||||
priority=ContextPriority.NORMAL.value,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_procedural_memory(
|
||||
cls,
|
||||
procedure: Any,
|
||||
query: str = "",
|
||||
) -> "MemoryContext":
|
||||
"""
|
||||
Create MemoryContext from procedural memory procedure.
|
||||
|
||||
Args:
|
||||
procedure: Procedure object from procedural memory
|
||||
query: Search query used
|
||||
|
||||
Returns:
|
||||
MemoryContext instance
|
||||
"""
|
||||
# Format steps as content
|
||||
steps = getattr(procedure, "steps", [])
|
||||
steps_content = "\n".join(
|
||||
f" {i + 1}. {step.get('action', step) if isinstance(step, dict) else step}"
|
||||
for i, step in enumerate(steps)
|
||||
)
|
||||
content = f"Procedure: {procedure.name}\nTrigger: {procedure.trigger_pattern}\nSteps:\n{steps_content}"
|
||||
|
||||
return cls(
|
||||
content=content,
|
||||
source=f"procedural:{procedure.id}",
|
||||
memory_subtype=MemorySubtype.PROCEDURAL,
|
||||
memory_id=str(procedure.id),
|
||||
relevance_score=getattr(procedure, "success_rate", 0.5),
|
||||
importance=0.7, # Procedures are moderately important
|
||||
search_query=query,
|
||||
trigger=procedure.trigger_pattern,
|
||||
success_rate=getattr(procedure, "success_rate", None),
|
||||
metadata={
|
||||
"steps_count": len(steps),
|
||||
"execution_count": getattr(procedure, "success_count", 0)
|
||||
+ getattr(procedure, "failure_count", 0),
|
||||
},
|
||||
)
|
||||
|
||||
def is_working_memory(self) -> bool:
|
||||
"""Check if this is working memory."""
|
||||
return self.memory_subtype == MemorySubtype.WORKING
|
||||
|
||||
def is_episodic_memory(self) -> bool:
|
||||
"""Check if this is episodic memory."""
|
||||
return self.memory_subtype == MemorySubtype.EPISODIC
|
||||
|
||||
def is_semantic_memory(self) -> bool:
|
||||
"""Check if this is semantic memory."""
|
||||
return self.memory_subtype == MemorySubtype.SEMANTIC
|
||||
|
||||
def is_procedural_memory(self) -> bool:
|
||||
"""Check if this is procedural memory."""
|
||||
return self.memory_subtype == MemorySubtype.PROCEDURAL
|
||||
|
||||
def get_formatted_source(self) -> str:
|
||||
"""
|
||||
Get a formatted source string for display.
|
||||
|
||||
Returns:
|
||||
Formatted source string
|
||||
"""
|
||||
parts = [f"[{self.memory_subtype.value}]", self.source]
|
||||
if self.memory_id:
|
||||
parts.append(f"({self.memory_id[:8]}...)")
|
||||
return " ".join(parts)
|
||||
138
backend/app/services/context/types/system.py
Normal file
138
backend/app/services/context/types/system.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""
|
||||
System Context Type.
|
||||
|
||||
Represents system prompts, instructions, and agent personas.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from .base import BaseContext, ContextPriority, ContextType
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class SystemContext(BaseContext):
|
||||
"""
|
||||
Context for system prompts and instructions.
|
||||
|
||||
System context typically includes:
|
||||
- Agent persona and role definitions
|
||||
- Behavioral instructions
|
||||
- Safety guidelines
|
||||
- Output format requirements
|
||||
|
||||
System context is usually high priority and should
|
||||
rarely be truncated or omitted.
|
||||
"""
|
||||
|
||||
# System context specific fields
|
||||
role: str = field(default="assistant")
|
||||
instructions_type: str = field(default="general")
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Set high priority for system context."""
|
||||
# System context defaults to high priority
|
||||
if self.priority == ContextPriority.NORMAL.value:
|
||||
self.priority = ContextPriority.HIGH.value
|
||||
|
||||
def get_type(self) -> ContextType:
|
||||
"""Return SYSTEM context type."""
|
||||
return ContextType.SYSTEM
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary with system-specific fields."""
|
||||
base = super().to_dict()
|
||||
base.update(
|
||||
{
|
||||
"role": self.role,
|
||||
"instructions_type": self.instructions_type,
|
||||
}
|
||||
)
|
||||
return base
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "SystemContext":
|
||||
"""Create SystemContext from dictionary."""
|
||||
return cls(
|
||||
id=data.get("id", ""),
|
||||
content=data["content"],
|
||||
source=data["source"],
|
||||
timestamp=datetime.fromisoformat(data["timestamp"])
|
||||
if isinstance(data.get("timestamp"), str)
|
||||
else data.get("timestamp", datetime.now(UTC)),
|
||||
priority=data.get("priority", ContextPriority.HIGH.value),
|
||||
metadata=data.get("metadata", {}),
|
||||
role=data.get("role", "assistant"),
|
||||
instructions_type=data.get("instructions_type", "general"),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create_persona(
|
||||
cls,
|
||||
name: str,
|
||||
description: str,
|
||||
capabilities: list[str] | None = None,
|
||||
constraints: list[str] | None = None,
|
||||
) -> "SystemContext":
|
||||
"""
|
||||
Create a persona system context.
|
||||
|
||||
Args:
|
||||
name: Agent name/role
|
||||
description: Role description
|
||||
capabilities: List of things the agent can do
|
||||
constraints: List of limitations
|
||||
|
||||
Returns:
|
||||
SystemContext with formatted persona
|
||||
"""
|
||||
parts = [f"You are {name}.", "", description]
|
||||
|
||||
if capabilities:
|
||||
parts.append("")
|
||||
parts.append("You can:")
|
||||
for cap in capabilities:
|
||||
parts.append(f"- {cap}")
|
||||
|
||||
if constraints:
|
||||
parts.append("")
|
||||
parts.append("You must not:")
|
||||
for constraint in constraints:
|
||||
parts.append(f"- {constraint}")
|
||||
|
||||
return cls(
|
||||
content="\n".join(parts),
|
||||
source="persona_builder",
|
||||
role=name.lower().replace(" ", "_"),
|
||||
instructions_type="persona",
|
||||
priority=ContextPriority.HIGHEST.value,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create_instructions(
|
||||
cls,
|
||||
instructions: str | list[str],
|
||||
source: str = "instructions",
|
||||
) -> "SystemContext":
|
||||
"""
|
||||
Create an instructions system context.
|
||||
|
||||
Args:
|
||||
instructions: Instructions string or list of instruction strings
|
||||
source: Source identifier
|
||||
|
||||
Returns:
|
||||
SystemContext with instructions
|
||||
"""
|
||||
if isinstance(instructions, list):
|
||||
content = "\n".join(f"- {inst}" for inst in instructions)
|
||||
else:
|
||||
content = instructions
|
||||
|
||||
return cls(
|
||||
content=content,
|
||||
source=source,
|
||||
instructions_type="instructions",
|
||||
priority=ContextPriority.HIGH.value,
|
||||
)
|
||||
193
backend/app/services/context/types/task.py
Normal file
193
backend/app/services/context/types/task.py
Normal file
@@ -0,0 +1,193 @@
|
||||
"""
|
||||
Task Context Type.
|
||||
|
||||
Represents the current task or objective for the agent.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from .base import BaseContext, ContextPriority, ContextType
|
||||
|
||||
|
||||
class TaskStatus(str, Enum):
|
||||
"""Status of a task."""
|
||||
|
||||
PENDING = "pending"
|
||||
IN_PROGRESS = "in_progress"
|
||||
BLOCKED = "blocked"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class TaskComplexity(str, Enum):
|
||||
"""Complexity level of a task."""
|
||||
|
||||
TRIVIAL = "trivial"
|
||||
SIMPLE = "simple"
|
||||
MODERATE = "moderate"
|
||||
COMPLEX = "complex"
|
||||
VERY_COMPLEX = "very_complex"
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class TaskContext(BaseContext):
|
||||
"""
|
||||
Context for the current task or objective.
|
||||
|
||||
Task context provides information about what the agent
|
||||
should accomplish, including:
|
||||
- Task description and goals
|
||||
- Acceptance criteria
|
||||
- Constraints and requirements
|
||||
- Related issue/ticket information
|
||||
"""
|
||||
|
||||
# Task-specific fields
|
||||
title: str = field(default="")
|
||||
status: TaskStatus = field(default=TaskStatus.PENDING)
|
||||
complexity: TaskComplexity = field(default=TaskComplexity.MODERATE)
|
||||
issue_id: str | None = field(default=None)
|
||||
project_id: str | None = field(default=None)
|
||||
acceptance_criteria: list[str] = field(default_factory=list)
|
||||
constraints: list[str] = field(default_factory=list)
|
||||
parent_task_id: str | None = field(default=None)
|
||||
|
||||
# Note: TaskContext should typically have HIGH priority,
|
||||
# but we don't auto-promote to allow explicit priority setting.
|
||||
# Use TaskContext.create() for default HIGH priority behavior.
|
||||
|
||||
def get_type(self) -> ContextType:
|
||||
"""Return TASK context type."""
|
||||
return ContextType.TASK
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary with task-specific fields."""
|
||||
base = super().to_dict()
|
||||
base.update(
|
||||
{
|
||||
"title": self.title,
|
||||
"status": self.status.value,
|
||||
"complexity": self.complexity.value,
|
||||
"issue_id": self.issue_id,
|
||||
"project_id": self.project_id,
|
||||
"acceptance_criteria": self.acceptance_criteria,
|
||||
"constraints": self.constraints,
|
||||
"parent_task_id": self.parent_task_id,
|
||||
}
|
||||
)
|
||||
return base
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "TaskContext":
|
||||
"""Create TaskContext from dictionary."""
|
||||
status = data.get("status", "pending")
|
||||
if isinstance(status, str):
|
||||
status = TaskStatus(status)
|
||||
|
||||
complexity = data.get("complexity", "moderate")
|
||||
if isinstance(complexity, str):
|
||||
complexity = TaskComplexity(complexity)
|
||||
|
||||
return cls(
|
||||
id=data.get("id", ""),
|
||||
content=data["content"],
|
||||
source=data.get("source", "task"),
|
||||
timestamp=datetime.fromisoformat(data["timestamp"])
|
||||
if isinstance(data.get("timestamp"), str)
|
||||
else data.get("timestamp", datetime.now(UTC)),
|
||||
priority=data.get("priority", ContextPriority.HIGH.value),
|
||||
metadata=data.get("metadata", {}),
|
||||
title=data.get("title", ""),
|
||||
status=status,
|
||||
complexity=complexity,
|
||||
issue_id=data.get("issue_id"),
|
||||
project_id=data.get("project_id"),
|
||||
acceptance_criteria=data.get("acceptance_criteria", []),
|
||||
constraints=data.get("constraints", []),
|
||||
parent_task_id=data.get("parent_task_id"),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
title: str,
|
||||
description: str,
|
||||
acceptance_criteria: list[str] | None = None,
|
||||
constraints: list[str] | None = None,
|
||||
issue_id: str | None = None,
|
||||
project_id: str | None = None,
|
||||
complexity: TaskComplexity | str = TaskComplexity.MODERATE,
|
||||
) -> "TaskContext":
|
||||
"""
|
||||
Create a task context.
|
||||
|
||||
Args:
|
||||
title: Task title
|
||||
description: Task description
|
||||
acceptance_criteria: List of acceptance criteria
|
||||
constraints: List of constraints
|
||||
issue_id: Related issue ID
|
||||
project_id: Project ID
|
||||
complexity: Task complexity
|
||||
|
||||
Returns:
|
||||
TaskContext instance
|
||||
"""
|
||||
if isinstance(complexity, str):
|
||||
complexity = TaskComplexity(complexity)
|
||||
|
||||
return cls(
|
||||
content=description,
|
||||
source=f"task:{issue_id}" if issue_id else "task",
|
||||
title=title,
|
||||
status=TaskStatus.IN_PROGRESS,
|
||||
complexity=complexity,
|
||||
issue_id=issue_id,
|
||||
project_id=project_id,
|
||||
acceptance_criteria=acceptance_criteria or [],
|
||||
constraints=constraints or [],
|
||||
)
|
||||
|
||||
def format_for_prompt(self) -> str:
|
||||
"""
|
||||
Format task for inclusion in prompt.
|
||||
|
||||
Returns:
|
||||
Formatted task string
|
||||
"""
|
||||
parts = []
|
||||
|
||||
if self.title:
|
||||
parts.append(f"Task: {self.title}")
|
||||
parts.append("")
|
||||
|
||||
parts.append(self.content)
|
||||
|
||||
if self.acceptance_criteria:
|
||||
parts.append("")
|
||||
parts.append("Acceptance Criteria:")
|
||||
for criterion in self.acceptance_criteria:
|
||||
parts.append(f"- {criterion}")
|
||||
|
||||
if self.constraints:
|
||||
parts.append("")
|
||||
parts.append("Constraints:")
|
||||
for constraint in self.constraints:
|
||||
parts.append(f"- {constraint}")
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
def is_active(self) -> bool:
|
||||
"""Check if task is currently active."""
|
||||
return self.status in (TaskStatus.PENDING, TaskStatus.IN_PROGRESS)
|
||||
|
||||
def is_complete(self) -> bool:
|
||||
"""Check if task is complete."""
|
||||
return self.status == TaskStatus.COMPLETED
|
||||
|
||||
def is_blocked(self) -> bool:
|
||||
"""Check if task is blocked."""
|
||||
return self.status == TaskStatus.BLOCKED
|
||||
211
backend/app/services/context/types/tool.py
Normal file
211
backend/app/services/context/types/tool.py
Normal file
@@ -0,0 +1,211 @@
|
||||
"""
|
||||
Tool Context Type.
|
||||
|
||||
Represents available tools and recent tool execution results.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from .base import BaseContext, ContextPriority, ContextType
|
||||
|
||||
|
||||
class ToolResultStatus(str, Enum):
|
||||
"""Status of a tool execution result."""
|
||||
|
||||
SUCCESS = "success"
|
||||
ERROR = "error"
|
||||
TIMEOUT = "timeout"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class ToolContext(BaseContext):
|
||||
"""
|
||||
Context for tools and tool execution results.
|
||||
|
||||
Tool context includes:
|
||||
- Tool descriptions and parameters
|
||||
- Recent tool execution results
|
||||
- Tool availability information
|
||||
|
||||
This helps the LLM understand what tools are available
|
||||
and what results previous tool calls produced.
|
||||
"""
|
||||
|
||||
# Tool-specific fields
|
||||
tool_name: str = field(default="")
|
||||
tool_description: str = field(default="")
|
||||
is_result: bool = field(default=False)
|
||||
result_status: ToolResultStatus | None = field(default=None)
|
||||
execution_time_ms: float | None = field(default=None)
|
||||
parameters: dict[str, Any] = field(default_factory=dict)
|
||||
server_name: str | None = field(default=None)
|
||||
|
||||
def get_type(self) -> ContextType:
|
||||
"""Return TOOL context type."""
|
||||
return ContextType.TOOL
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary with tool-specific fields."""
|
||||
base = super().to_dict()
|
||||
base.update(
|
||||
{
|
||||
"tool_name": self.tool_name,
|
||||
"tool_description": self.tool_description,
|
||||
"is_result": self.is_result,
|
||||
"result_status": self.result_status.value
|
||||
if self.result_status
|
||||
else None,
|
||||
"execution_time_ms": self.execution_time_ms,
|
||||
"parameters": self.parameters,
|
||||
"server_name": self.server_name,
|
||||
}
|
||||
)
|
||||
return base
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "ToolContext":
|
||||
"""Create ToolContext from dictionary."""
|
||||
result_status = data.get("result_status")
|
||||
if isinstance(result_status, str):
|
||||
result_status = ToolResultStatus(result_status)
|
||||
|
||||
return cls(
|
||||
id=data.get("id", ""),
|
||||
content=data["content"],
|
||||
source=data.get("source", "tool"),
|
||||
timestamp=datetime.fromisoformat(data["timestamp"])
|
||||
if isinstance(data.get("timestamp"), str)
|
||||
else data.get("timestamp", datetime.now(UTC)),
|
||||
priority=data.get("priority", ContextPriority.NORMAL.value),
|
||||
metadata=data.get("metadata", {}),
|
||||
tool_name=data.get("tool_name", ""),
|
||||
tool_description=data.get("tool_description", ""),
|
||||
is_result=data.get("is_result", False),
|
||||
result_status=result_status,
|
||||
execution_time_ms=data.get("execution_time_ms"),
|
||||
parameters=data.get("parameters", {}),
|
||||
server_name=data.get("server_name"),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_tool_definition(
|
||||
cls,
|
||||
name: str,
|
||||
description: str,
|
||||
parameters: dict[str, Any] | None = None,
|
||||
server_name: str | None = None,
|
||||
) -> "ToolContext":
|
||||
"""
|
||||
Create a ToolContext from a tool definition.
|
||||
|
||||
Args:
|
||||
name: Tool name
|
||||
description: Tool description
|
||||
parameters: Tool parameter schema
|
||||
server_name: MCP server name
|
||||
|
||||
Returns:
|
||||
ToolContext instance
|
||||
"""
|
||||
# Format content as tool documentation
|
||||
content_parts = [f"Tool: {name}", "", description]
|
||||
|
||||
if parameters:
|
||||
content_parts.append("")
|
||||
content_parts.append("Parameters:")
|
||||
for param_name, param_info in parameters.items():
|
||||
param_type = param_info.get("type", "any")
|
||||
param_desc = param_info.get("description", "")
|
||||
required = param_info.get("required", False)
|
||||
req_marker = " (required)" if required else ""
|
||||
content_parts.append(f" - {param_name}: {param_type}{req_marker}")
|
||||
if param_desc:
|
||||
content_parts.append(f" {param_desc}")
|
||||
|
||||
return cls(
|
||||
content="\n".join(content_parts),
|
||||
source=f"tool:{server_name}:{name}" if server_name else f"tool:{name}",
|
||||
tool_name=name,
|
||||
tool_description=description,
|
||||
is_result=False,
|
||||
parameters=parameters or {},
|
||||
server_name=server_name,
|
||||
priority=ContextPriority.LOW.value,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_tool_result(
|
||||
cls,
|
||||
tool_name: str,
|
||||
result: Any,
|
||||
status: ToolResultStatus = ToolResultStatus.SUCCESS,
|
||||
execution_time_ms: float | None = None,
|
||||
parameters: dict[str, Any] | None = None,
|
||||
server_name: str | None = None,
|
||||
) -> "ToolContext":
|
||||
"""
|
||||
Create a ToolContext from a tool execution result.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool that was executed
|
||||
result: Result content (will be converted to string)
|
||||
status: Execution status
|
||||
execution_time_ms: Execution time in milliseconds
|
||||
parameters: Parameters that were passed to the tool
|
||||
server_name: MCP server name
|
||||
|
||||
Returns:
|
||||
ToolContext instance
|
||||
"""
|
||||
# Convert result to string content
|
||||
if isinstance(result, str):
|
||||
content = result
|
||||
elif isinstance(result, dict):
|
||||
import json
|
||||
|
||||
try:
|
||||
content = json.dumps(result, indent=2)
|
||||
except (TypeError, ValueError):
|
||||
content = str(result)
|
||||
else:
|
||||
content = str(result)
|
||||
|
||||
return cls(
|
||||
content=content,
|
||||
source=f"tool_result:{server_name}:{tool_name}"
|
||||
if server_name
|
||||
else f"tool_result:{tool_name}",
|
||||
tool_name=tool_name,
|
||||
is_result=True,
|
||||
result_status=status,
|
||||
execution_time_ms=execution_time_ms,
|
||||
parameters=parameters or {},
|
||||
server_name=server_name,
|
||||
priority=ContextPriority.HIGH.value, # Recent results are high priority
|
||||
)
|
||||
|
||||
def is_successful(self) -> bool:
|
||||
"""Check if this is a successful tool result."""
|
||||
return self.is_result and self.result_status == ToolResultStatus.SUCCESS
|
||||
|
||||
def is_error(self) -> bool:
|
||||
"""Check if this is an error result."""
|
||||
return self.is_result and self.result_status == ToolResultStatus.ERROR
|
||||
|
||||
def format_for_prompt(self) -> str:
|
||||
"""
|
||||
Format tool context for inclusion in prompt.
|
||||
|
||||
Returns:
|
||||
Formatted tool string
|
||||
"""
|
||||
if self.is_result:
|
||||
status_str = self.result_status.value if self.result_status else "unknown"
|
||||
header = f"Tool Result ({self.tool_name}, {status_str}):"
|
||||
return f"{header}\n{self.content}"
|
||||
else:
|
||||
return self.content
|
||||
85
backend/app/services/mcp/__init__.py
Normal file
85
backend/app/services/mcp/__init__.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""
|
||||
MCP Client Service Package
|
||||
|
||||
Provides infrastructure for communicating with MCP (Model Context Protocol)
|
||||
servers. This is the foundation for AI agent tool integration.
|
||||
|
||||
Usage:
|
||||
from app.services.mcp import get_mcp_client, MCPClientManager
|
||||
|
||||
# In FastAPI route
|
||||
async def my_route(mcp: MCPClientManager = Depends(get_mcp_client)):
|
||||
result = await mcp.call_tool("llm-gateway", "chat", {"prompt": "Hello"})
|
||||
|
||||
# Direct usage
|
||||
manager = MCPClientManager()
|
||||
await manager.initialize()
|
||||
result = await manager.call_tool("issues", "create_issue", {...})
|
||||
await manager.shutdown()
|
||||
"""
|
||||
|
||||
from .client_manager import (
|
||||
MCPClientManager,
|
||||
ServerHealth,
|
||||
get_mcp_client,
|
||||
reset_mcp_client,
|
||||
shutdown_mcp_client,
|
||||
)
|
||||
from .config import (
|
||||
MCPConfig,
|
||||
MCPServerConfig,
|
||||
TransportType,
|
||||
create_default_config,
|
||||
load_mcp_config,
|
||||
)
|
||||
from .connection import ConnectionPool, ConnectionState, MCPConnection
|
||||
from .exceptions import (
|
||||
MCPCircuitOpenError,
|
||||
MCPConnectionError,
|
||||
MCPError,
|
||||
MCPServerNotFoundError,
|
||||
MCPTimeoutError,
|
||||
MCPToolError,
|
||||
MCPToolNotFoundError,
|
||||
MCPValidationError,
|
||||
)
|
||||
from .registry import MCPServerRegistry, ServerCapabilities, get_registry
|
||||
from .routing import AsyncCircuitBreaker, CircuitState, ToolInfo, ToolResult, ToolRouter
|
||||
|
||||
__all__ = [
|
||||
# Main facade
|
||||
"MCPClientManager",
|
||||
"get_mcp_client",
|
||||
"shutdown_mcp_client",
|
||||
"reset_mcp_client",
|
||||
"ServerHealth",
|
||||
# Configuration
|
||||
"MCPConfig",
|
||||
"MCPServerConfig",
|
||||
"TransportType",
|
||||
"load_mcp_config",
|
||||
"create_default_config",
|
||||
# Registry
|
||||
"MCPServerRegistry",
|
||||
"ServerCapabilities",
|
||||
"get_registry",
|
||||
# Connection
|
||||
"ConnectionPool",
|
||||
"ConnectionState",
|
||||
"MCPConnection",
|
||||
# Routing
|
||||
"ToolRouter",
|
||||
"ToolInfo",
|
||||
"ToolResult",
|
||||
"AsyncCircuitBreaker",
|
||||
"CircuitState",
|
||||
# Exceptions
|
||||
"MCPError",
|
||||
"MCPConnectionError",
|
||||
"MCPTimeoutError",
|
||||
"MCPToolError",
|
||||
"MCPServerNotFoundError",
|
||||
"MCPToolNotFoundError",
|
||||
"MCPCircuitOpenError",
|
||||
"MCPValidationError",
|
||||
]
|
||||
438
backend/app/services/mcp/client_manager.py
Normal file
438
backend/app/services/mcp/client_manager.py
Normal file
@@ -0,0 +1,438 @@
|
||||
"""
|
||||
MCP Client Manager
|
||||
|
||||
Main facade for all MCP operations. Manages server connections,
|
||||
tool discovery, and provides a unified interface for tool calls.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from .config import MCPConfig, MCPServerConfig, load_mcp_config
|
||||
from .connection import ConnectionPool, ConnectionState
|
||||
from .exceptions import MCPServerNotFoundError
|
||||
from .registry import MCPServerRegistry, get_registry
|
||||
from .routing import ToolInfo, ToolResult, ToolRouter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServerHealth:
|
||||
"""Health status for an MCP server."""
|
||||
|
||||
name: str
|
||||
healthy: bool
|
||||
state: str
|
||||
url: str
|
||||
error: str | None = None
|
||||
tools_count: int = 0
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"name": self.name,
|
||||
"healthy": self.healthy,
|
||||
"state": self.state,
|
||||
"url": self.url,
|
||||
"error": self.error,
|
||||
"tools_count": self.tools_count,
|
||||
}
|
||||
|
||||
|
||||
class MCPClientManager:
|
||||
"""
|
||||
Central manager for all MCP client operations.
|
||||
|
||||
Provides a unified interface for:
|
||||
- Connecting to MCP servers
|
||||
- Discovering and calling tools
|
||||
- Health monitoring
|
||||
- Connection lifecycle management
|
||||
|
||||
This is the main entry point for MCP operations in the application.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: MCPConfig | None = None,
|
||||
registry: MCPServerRegistry | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the MCP client manager.
|
||||
|
||||
Args:
|
||||
config: Optional MCP configuration. If None, loads from default.
|
||||
registry: Optional registry instance. If None, uses singleton.
|
||||
"""
|
||||
self._registry = registry or get_registry()
|
||||
self._pool = ConnectionPool()
|
||||
self._router: ToolRouter | None = None
|
||||
self._initialized = False
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# Load configuration if provided
|
||||
if config is not None:
|
||||
self._registry.load_config(config)
|
||||
|
||||
@property
|
||||
def is_initialized(self) -> bool:
|
||||
"""Check if the manager is initialized."""
|
||||
return self._initialized
|
||||
|
||||
async def initialize(self, config: MCPConfig | None = None) -> None:
|
||||
"""
|
||||
Initialize the MCP client manager.
|
||||
|
||||
Loads configuration, creates connections, and discovers tools.
|
||||
|
||||
Args:
|
||||
config: Optional configuration to load
|
||||
"""
|
||||
async with self._lock:
|
||||
if self._initialized:
|
||||
logger.warning("MCPClientManager already initialized")
|
||||
return
|
||||
|
||||
logger.info("Initializing MCP Client Manager")
|
||||
|
||||
# Load configuration
|
||||
if config is not None:
|
||||
self._registry.load_config(config)
|
||||
elif len(self._registry.list_servers()) == 0:
|
||||
# Try to load from default location
|
||||
self._registry.load_config(load_mcp_config())
|
||||
|
||||
# Create router
|
||||
self._router = ToolRouter(self._registry, self._pool)
|
||||
|
||||
# Connect to all enabled servers
|
||||
await self._connect_all_servers()
|
||||
|
||||
# Discover tools from all servers
|
||||
if self._router:
|
||||
await self._router.discover_tools()
|
||||
|
||||
self._initialized = True
|
||||
logger.info(
|
||||
"MCP Client Manager initialized with %d servers",
|
||||
len(self._registry.list_enabled_servers()),
|
||||
)
|
||||
|
||||
async def _connect_all_servers(self) -> None:
|
||||
"""Connect to all enabled MCP servers concurrently."""
|
||||
import asyncio
|
||||
|
||||
enabled_servers = self._registry.get_enabled_configs()
|
||||
|
||||
async def connect_server(name: str, config: "MCPServerConfig") -> None:
|
||||
try:
|
||||
await self._pool.get_connection(name, config)
|
||||
logger.info("Connected to MCP server: %s", name)
|
||||
except Exception as e:
|
||||
logger.error("Failed to connect to MCP server %s: %s", name, e)
|
||||
|
||||
# Connect to all servers concurrently for faster startup
|
||||
await asyncio.gather(
|
||||
*(connect_server(name, config) for name, config in enabled_servers.items()),
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""
|
||||
Shutdown the MCP client manager.
|
||||
|
||||
Closes all connections and cleans up resources.
|
||||
"""
|
||||
async with self._lock:
|
||||
if not self._initialized:
|
||||
return
|
||||
|
||||
logger.info("Shutting down MCP Client Manager")
|
||||
|
||||
await self._pool.close_all()
|
||||
self._initialized = False
|
||||
|
||||
logger.info("MCP Client Manager shutdown complete")
|
||||
|
||||
async def connect(self, server_name: str) -> None:
|
||||
"""
|
||||
Connect to a specific MCP server.
|
||||
|
||||
Args:
|
||||
server_name: Name of the server to connect to
|
||||
|
||||
Raises:
|
||||
MCPServerNotFoundError: If server is not registered
|
||||
"""
|
||||
config = self._registry.get(server_name)
|
||||
await self._pool.get_connection(server_name, config)
|
||||
logger.info("Connected to MCP server: %s", server_name)
|
||||
|
||||
async def disconnect(self, server_name: str) -> None:
|
||||
"""
|
||||
Disconnect from a specific MCP server.
|
||||
|
||||
Args:
|
||||
server_name: Name of the server to disconnect from
|
||||
"""
|
||||
await self._pool.close_connection(server_name)
|
||||
logger.info("Disconnected from MCP server: %s", server_name)
|
||||
|
||||
async def disconnect_all(self) -> None:
|
||||
"""Disconnect from all MCP servers."""
|
||||
await self._pool.close_all()
|
||||
|
||||
async def call_tool(
|
||||
self,
|
||||
server: str,
|
||||
tool: str,
|
||||
args: dict[str, Any] | None = None,
|
||||
timeout: float | None = None,
|
||||
) -> ToolResult:
|
||||
"""
|
||||
Call a tool on a specific MCP server.
|
||||
|
||||
Args:
|
||||
server: Name of the MCP server
|
||||
tool: Name of the tool to call
|
||||
args: Tool arguments
|
||||
timeout: Optional timeout override
|
||||
|
||||
Returns:
|
||||
Tool execution result
|
||||
"""
|
||||
if not self._initialized or self._router is None:
|
||||
await self.initialize()
|
||||
|
||||
assert self._router is not None # Guaranteed after initialize()
|
||||
return await self._router.call_tool(
|
||||
server_name=server,
|
||||
tool_name=tool,
|
||||
arguments=args,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
async def route_tool(
|
||||
self,
|
||||
tool: str,
|
||||
args: dict[str, Any] | None = None,
|
||||
timeout: float | None = None,
|
||||
) -> ToolResult:
|
||||
"""
|
||||
Route a tool call to the appropriate server automatically.
|
||||
|
||||
Args:
|
||||
tool: Name of the tool to call
|
||||
args: Tool arguments
|
||||
timeout: Optional timeout override
|
||||
|
||||
Returns:
|
||||
Tool execution result
|
||||
"""
|
||||
if not self._initialized or self._router is None:
|
||||
await self.initialize()
|
||||
|
||||
assert self._router is not None # Guaranteed after initialize()
|
||||
return await self._router.route_tool(
|
||||
tool_name=tool,
|
||||
arguments=args,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
async def list_tools(self, server: str) -> list[ToolInfo]:
|
||||
"""
|
||||
List all tools available on a specific server.
|
||||
|
||||
Args:
|
||||
server: Name of the MCP server
|
||||
|
||||
Returns:
|
||||
List of tool information
|
||||
"""
|
||||
capabilities = await self._registry.get_capabilities(server)
|
||||
return [
|
||||
ToolInfo(
|
||||
name=t.get("name", ""),
|
||||
description=t.get("description"),
|
||||
server_name=server,
|
||||
input_schema=t.get("input_schema"),
|
||||
)
|
||||
for t in capabilities.tools
|
||||
]
|
||||
|
||||
async def list_all_tools(self) -> list[ToolInfo]:
|
||||
"""
|
||||
List all tools from all servers.
|
||||
|
||||
Returns:
|
||||
List of tool information
|
||||
"""
|
||||
if not self._initialized or self._router is None:
|
||||
await self.initialize()
|
||||
|
||||
assert self._router is not None # Guaranteed after initialize()
|
||||
return await self._router.list_all_tools()
|
||||
|
||||
async def health_check(self) -> dict[str, ServerHealth]:
|
||||
"""
|
||||
Perform health check on all MCP servers.
|
||||
|
||||
Returns:
|
||||
Dict mapping server names to health status
|
||||
"""
|
||||
results: dict[str, ServerHealth] = {}
|
||||
pool_status = self._pool.get_status()
|
||||
pool_health = await self._pool.health_check_all()
|
||||
|
||||
for server_name in self._registry.list_servers():
|
||||
try:
|
||||
config = self._registry.get(server_name)
|
||||
status = pool_status.get(server_name, {})
|
||||
healthy = pool_health.get(server_name, False)
|
||||
|
||||
capabilities = self._registry.get_cached_capabilities(server_name)
|
||||
|
||||
results[server_name] = ServerHealth(
|
||||
name=server_name,
|
||||
healthy=healthy,
|
||||
state=status.get("state", ConnectionState.DISCONNECTED.value),
|
||||
url=config.url,
|
||||
tools_count=len(capabilities.tools),
|
||||
)
|
||||
except MCPServerNotFoundError:
|
||||
pass
|
||||
except Exception as e:
|
||||
results[server_name] = ServerHealth(
|
||||
name=server_name,
|
||||
healthy=False,
|
||||
state=ConnectionState.ERROR.value,
|
||||
url="unknown",
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
def list_servers(self) -> list[str]:
|
||||
"""Get list of all registered server names."""
|
||||
return self._registry.list_servers()
|
||||
|
||||
def list_enabled_servers(self) -> list[str]:
|
||||
"""Get list of enabled server names."""
|
||||
return self._registry.list_enabled_servers()
|
||||
|
||||
def get_server_config(self, server_name: str) -> MCPServerConfig:
|
||||
"""
|
||||
Get configuration for a specific server.
|
||||
|
||||
Args:
|
||||
server_name: Name of the server
|
||||
|
||||
Returns:
|
||||
Server configuration
|
||||
|
||||
Raises:
|
||||
MCPServerNotFoundError: If server is not registered
|
||||
"""
|
||||
return self._registry.get(server_name)
|
||||
|
||||
def register_server(
|
||||
self,
|
||||
name: str,
|
||||
config: MCPServerConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Register a new MCP server at runtime.
|
||||
|
||||
Args:
|
||||
name: Unique server name
|
||||
config: Server configuration
|
||||
"""
|
||||
self._registry.register(name, config)
|
||||
|
||||
def unregister_server(self, name: str) -> bool:
|
||||
"""
|
||||
Unregister an MCP server.
|
||||
|
||||
Args:
|
||||
name: Server name to unregister
|
||||
|
||||
Returns:
|
||||
True if server was found and removed
|
||||
"""
|
||||
return self._registry.unregister(name)
|
||||
|
||||
def get_circuit_breaker_status(self) -> dict[str, dict[str, Any]]:
|
||||
"""Get status of all circuit breakers."""
|
||||
if self._router is None:
|
||||
return {}
|
||||
return self._router.get_circuit_breaker_status()
|
||||
|
||||
async def reset_circuit_breaker(self, server_name: str) -> bool:
|
||||
"""
|
||||
Reset a circuit breaker for a server.
|
||||
|
||||
Args:
|
||||
server_name: Name of the server
|
||||
|
||||
Returns:
|
||||
True if circuit breaker was reset
|
||||
"""
|
||||
if self._router is None:
|
||||
return False
|
||||
return await self._router.reset_circuit_breaker(server_name)
|
||||
|
||||
|
||||
# Singleton instance
|
||||
_manager_instance: MCPClientManager | None = None
|
||||
_manager_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def get_mcp_client() -> MCPClientManager:
|
||||
"""
|
||||
Get the global MCP client manager instance.
|
||||
|
||||
This is the main dependency injection point for FastAPI.
|
||||
Uses proper locking to avoid race conditions in async contexts.
|
||||
"""
|
||||
global _manager_instance
|
||||
|
||||
# Use lock for the entire check-and-create operation to avoid race conditions
|
||||
async with _manager_lock:
|
||||
if _manager_instance is None:
|
||||
_manager_instance = MCPClientManager()
|
||||
await _manager_instance.initialize()
|
||||
|
||||
return _manager_instance
|
||||
|
||||
|
||||
async def shutdown_mcp_client() -> None:
|
||||
"""Shutdown the global MCP client manager."""
|
||||
global _manager_instance
|
||||
|
||||
# Use lock to prevent race with get_mcp_client()
|
||||
async with _manager_lock:
|
||||
if _manager_instance is not None:
|
||||
await _manager_instance.shutdown()
|
||||
_manager_instance = None
|
||||
|
||||
|
||||
async def reset_mcp_client() -> None:
|
||||
"""
|
||||
Reset the global MCP client manager (for testing).
|
||||
|
||||
This is an async function to properly acquire the manager lock
|
||||
and avoid race conditions with get_mcp_client().
|
||||
"""
|
||||
global _manager_instance
|
||||
|
||||
async with _manager_lock:
|
||||
if _manager_instance is not None:
|
||||
# Shutdown gracefully before resetting
|
||||
try:
|
||||
await _manager_instance.shutdown()
|
||||
except Exception: # noqa: S110
|
||||
pass # Ignore errors during test cleanup
|
||||
_manager_instance = None
|
||||
245
backend/app/services/mcp/config.py
Normal file
245
backend/app/services/mcp/config.py
Normal file
@@ -0,0 +1,245 @@
|
||||
"""
|
||||
MCP Configuration System
|
||||
|
||||
Pydantic models for MCP server configuration with YAML file loading
|
||||
and environment variable overrides.
|
||||
"""
|
||||
|
||||
import os
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class TransportType(str, Enum):
|
||||
"""Supported MCP transport types."""
|
||||
|
||||
HTTP = "http"
|
||||
STDIO = "stdio"
|
||||
SSE = "sse"
|
||||
|
||||
|
||||
class MCPServerConfig(BaseModel):
|
||||
"""Configuration for a single MCP server."""
|
||||
|
||||
url: str = Field(..., description="Server URL (supports ${ENV_VAR} syntax)")
|
||||
transport: TransportType = Field(
|
||||
default=TransportType.HTTP,
|
||||
description="Transport protocol to use",
|
||||
)
|
||||
timeout: int = Field(
|
||||
default=30,
|
||||
ge=1,
|
||||
le=600,
|
||||
description="Request timeout in seconds",
|
||||
)
|
||||
retry_attempts: int = Field(
|
||||
default=3,
|
||||
ge=0,
|
||||
le=10,
|
||||
description="Number of retry attempts on failure",
|
||||
)
|
||||
retry_delay: float = Field(
|
||||
default=1.0,
|
||||
ge=0.1,
|
||||
le=60.0,
|
||||
description="Initial delay between retries in seconds",
|
||||
)
|
||||
retry_max_delay: float = Field(
|
||||
default=30.0,
|
||||
ge=1.0,
|
||||
le=300.0,
|
||||
description="Maximum delay between retries in seconds",
|
||||
)
|
||||
circuit_breaker_threshold: int = Field(
|
||||
default=5,
|
||||
ge=1,
|
||||
le=50,
|
||||
description="Number of failures before opening circuit",
|
||||
)
|
||||
circuit_breaker_timeout: float = Field(
|
||||
default=30.0,
|
||||
ge=5.0,
|
||||
le=300.0,
|
||||
description="Seconds to wait before attempting to close circuit",
|
||||
)
|
||||
enabled: bool = Field(
|
||||
default=True,
|
||||
description="Whether this server is enabled",
|
||||
)
|
||||
description: str | None = Field(
|
||||
default=None,
|
||||
description="Human-readable description of the server",
|
||||
)
|
||||
|
||||
@field_validator("url", mode="before")
|
||||
@classmethod
|
||||
def expand_env_vars(cls, v: str) -> str:
|
||||
"""Expand environment variables in URL using ${VAR:-default} syntax."""
|
||||
if not isinstance(v, str):
|
||||
return v
|
||||
|
||||
result = v
|
||||
# Find all ${VAR} or ${VAR:-default} patterns
|
||||
import re
|
||||
|
||||
pattern = r"\$\{([^}]+)\}"
|
||||
matches = re.findall(pattern, v)
|
||||
|
||||
for match in matches:
|
||||
if ":-" in match:
|
||||
var_name, default = match.split(":-", 1)
|
||||
else:
|
||||
var_name, default = match, ""
|
||||
|
||||
env_value = os.environ.get(var_name.strip(), default)
|
||||
result = result.replace(f"${{{match}}}", env_value)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class MCPConfig(BaseModel):
|
||||
"""Root configuration for all MCP servers."""
|
||||
|
||||
mcp_servers: dict[str, MCPServerConfig] = Field(
|
||||
default_factory=dict,
|
||||
description="Map of server names to their configurations",
|
||||
)
|
||||
|
||||
# Global defaults
|
||||
default_timeout: int = Field(
|
||||
default=30,
|
||||
description="Default timeout for all servers",
|
||||
)
|
||||
default_retry_attempts: int = Field(
|
||||
default=3,
|
||||
description="Default retry attempts for all servers",
|
||||
)
|
||||
connection_pool_size: int = Field(
|
||||
default=10,
|
||||
ge=1,
|
||||
le=100,
|
||||
description="Maximum connections per server",
|
||||
)
|
||||
health_check_interval: int = Field(
|
||||
default=30,
|
||||
ge=5,
|
||||
le=300,
|
||||
description="Seconds between health checks",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_yaml(cls, path: str | Path) -> "MCPConfig":
|
||||
"""Load configuration from a YAML file."""
|
||||
path = Path(path)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"MCP config file not found: {path}")
|
||||
|
||||
with path.open("r") as f:
|
||||
data = yaml.safe_load(f)
|
||||
|
||||
if data is None:
|
||||
data = {}
|
||||
|
||||
return cls.model_validate(data)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "MCPConfig":
|
||||
"""Load configuration from a dictionary."""
|
||||
return cls.model_validate(data)
|
||||
|
||||
def get_server(self, name: str) -> MCPServerConfig | None:
|
||||
"""Get a server configuration by name."""
|
||||
return self.mcp_servers.get(name)
|
||||
|
||||
def get_enabled_servers(self) -> dict[str, MCPServerConfig]:
|
||||
"""Get all enabled server configurations."""
|
||||
return {
|
||||
name: config for name, config in self.mcp_servers.items() if config.enabled
|
||||
}
|
||||
|
||||
def list_server_names(self) -> list[str]:
|
||||
"""Get list of all configured server names."""
|
||||
return list(self.mcp_servers.keys())
|
||||
|
||||
|
||||
# Default configuration path
|
||||
DEFAULT_CONFIG_PATH = Path(__file__).parent.parent.parent.parent / "mcp_servers.yaml"
|
||||
|
||||
|
||||
def load_mcp_config(path: str | Path | None = None) -> MCPConfig:
|
||||
"""
|
||||
Load MCP configuration from file or environment.
|
||||
|
||||
Priority:
|
||||
1. Explicit path parameter
|
||||
2. MCP_CONFIG_PATH environment variable
|
||||
3. Default path (backend/mcp_servers.yaml)
|
||||
4. Empty config if no file exists
|
||||
|
||||
In test mode (IS_TEST=True), retry settings are reduced for faster tests.
|
||||
"""
|
||||
if path is None:
|
||||
path = os.environ.get("MCP_CONFIG_PATH", str(DEFAULT_CONFIG_PATH))
|
||||
|
||||
path = Path(path)
|
||||
|
||||
if not path.exists():
|
||||
# Return empty config if no file exists (allows runtime registration)
|
||||
return MCPConfig()
|
||||
|
||||
config = MCPConfig.from_yaml(path)
|
||||
|
||||
# In test mode, reduce retry settings to speed up tests
|
||||
is_test = os.environ.get("IS_TEST", "").lower() in ("true", "1", "yes")
|
||||
if is_test:
|
||||
for server_config in config.mcp_servers.values():
|
||||
server_config.retry_attempts = 1 # Single attempt
|
||||
server_config.retry_delay = 0.1 # 100ms instead of 1s
|
||||
server_config.retry_max_delay = 0.5 # 500ms max
|
||||
server_config.timeout = 2 # 2s timeout instead of 30-120s
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def create_default_config() -> MCPConfig:
|
||||
"""
|
||||
Create a default MCP configuration with standard servers.
|
||||
|
||||
This is useful for development and as a template.
|
||||
"""
|
||||
return MCPConfig(
|
||||
mcp_servers={
|
||||
"llm-gateway": MCPServerConfig(
|
||||
url="${LLM_GATEWAY_URL:-http://localhost:8001}",
|
||||
transport=TransportType.HTTP,
|
||||
timeout=60,
|
||||
description="LLM Gateway for multi-provider AI interactions",
|
||||
),
|
||||
"knowledge-base": MCPServerConfig(
|
||||
url="${KNOWLEDGE_BASE_URL:-http://localhost:8002}",
|
||||
transport=TransportType.HTTP,
|
||||
timeout=30,
|
||||
description="Knowledge Base for RAG and document retrieval",
|
||||
),
|
||||
"git-ops": MCPServerConfig(
|
||||
url="${GIT_OPS_URL:-http://localhost:8003}",
|
||||
transport=TransportType.HTTP,
|
||||
timeout=120,
|
||||
description="Git Operations for repository management",
|
||||
),
|
||||
"issues": MCPServerConfig(
|
||||
url="${ISSUES_URL:-http://localhost:8004}",
|
||||
transport=TransportType.HTTP,
|
||||
timeout=30,
|
||||
description="Issue Tracker for Gitea/GitHub/GitLab",
|
||||
),
|
||||
},
|
||||
default_timeout=30,
|
||||
default_retry_attempts=3,
|
||||
connection_pool_size=10,
|
||||
health_check_interval=30,
|
||||
)
|
||||
473
backend/app/services/mcp/connection.py
Normal file
473
backend/app/services/mcp/connection.py
Normal file
@@ -0,0 +1,473 @@
|
||||
"""
|
||||
MCP Connection Management
|
||||
|
||||
Handles connection lifecycle, pooling, and automatic reconnection
|
||||
for MCP servers.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from .config import MCPServerConfig, TransportType
|
||||
from .exceptions import MCPConnectionError, MCPTimeoutError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConnectionState(str, Enum):
|
||||
"""Connection state enumeration."""
|
||||
|
||||
DISCONNECTED = "disconnected"
|
||||
CONNECTING = "connecting"
|
||||
CONNECTED = "connected"
|
||||
RECONNECTING = "reconnecting"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class MCPConnection:
|
||||
"""
|
||||
Manages a single connection to an MCP server.
|
||||
|
||||
Handles connection lifecycle, health checking, and automatic reconnection.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_name: str,
|
||||
config: MCPServerConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize connection.
|
||||
|
||||
Args:
|
||||
server_name: Name of the MCP server
|
||||
config: Server configuration
|
||||
"""
|
||||
self.server_name = server_name
|
||||
self.config = config
|
||||
self._state = ConnectionState.DISCONNECTED
|
||||
self._client: httpx.AsyncClient | None = None
|
||||
self._lock = asyncio.Lock()
|
||||
self._last_activity: float | None = None
|
||||
self._connection_attempts = 0
|
||||
self._last_error: Exception | None = None
|
||||
|
||||
# Reconnection settings
|
||||
self._base_delay = config.retry_delay
|
||||
self._max_delay = config.retry_max_delay
|
||||
self._max_attempts = config.retry_attempts
|
||||
|
||||
@property
|
||||
def state(self) -> ConnectionState:
|
||||
"""Get current connection state."""
|
||||
return self._state
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if connection is established."""
|
||||
return self._state == ConnectionState.CONNECTED
|
||||
|
||||
@property
|
||||
def last_error(self) -> Exception | None:
|
||||
"""Get the last error that occurred."""
|
||||
return self._last_error
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""
|
||||
Establish connection to the MCP server.
|
||||
|
||||
Raises:
|
||||
MCPConnectionError: If connection fails after all retries
|
||||
"""
|
||||
async with self._lock:
|
||||
if self._state == ConnectionState.CONNECTED:
|
||||
return
|
||||
|
||||
self._state = ConnectionState.CONNECTING
|
||||
self._connection_attempts = 0
|
||||
self._last_error = None
|
||||
|
||||
while self._connection_attempts < self._max_attempts:
|
||||
try:
|
||||
await self._do_connect()
|
||||
self._state = ConnectionState.CONNECTED
|
||||
self._last_activity = time.time()
|
||||
logger.info(
|
||||
"Connected to MCP server: %s at %s",
|
||||
self.server_name,
|
||||
self.config.url,
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
self._connection_attempts += 1
|
||||
self._last_error = e
|
||||
logger.warning(
|
||||
"Connection attempt %d/%d failed for %s: %s",
|
||||
self._connection_attempts,
|
||||
self._max_attempts,
|
||||
self.server_name,
|
||||
e,
|
||||
)
|
||||
|
||||
if self._connection_attempts < self._max_attempts:
|
||||
delay = self._calculate_backoff_delay()
|
||||
logger.debug(
|
||||
"Retrying connection to %s in %.1fs",
|
||||
self.server_name,
|
||||
delay,
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
# All attempts failed
|
||||
self._state = ConnectionState.ERROR
|
||||
raise MCPConnectionError(
|
||||
f"Failed to connect after {self._max_attempts} attempts",
|
||||
server_name=self.server_name,
|
||||
url=self.config.url,
|
||||
cause=self._last_error,
|
||||
)
|
||||
|
||||
async def _do_connect(self) -> None:
|
||||
"""Perform the actual connection (transport-specific)."""
|
||||
if self.config.transport == TransportType.HTTP:
|
||||
self._client = httpx.AsyncClient(
|
||||
base_url=self.config.url,
|
||||
timeout=httpx.Timeout(self.config.timeout),
|
||||
headers={
|
||||
"User-Agent": "Syndarix-MCP-Client/1.0",
|
||||
"Accept": "application/json",
|
||||
},
|
||||
)
|
||||
# Verify connectivity with a simple request
|
||||
try:
|
||||
# Try to hit the MCP capabilities endpoint
|
||||
response = await self._client.get("/mcp/capabilities")
|
||||
if response.status_code not in (200, 404):
|
||||
# 404 is acceptable - server might not have capabilities endpoint
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code != 404:
|
||||
raise
|
||||
except httpx.ConnectError as e:
|
||||
raise MCPConnectionError(
|
||||
"Failed to connect to server",
|
||||
server_name=self.server_name,
|
||||
url=self.config.url,
|
||||
cause=e,
|
||||
) from e
|
||||
else:
|
||||
# For STDIO and SSE transports, we'll implement later
|
||||
raise NotImplementedError(
|
||||
f"Transport {self.config.transport} not yet implemented"
|
||||
)
|
||||
|
||||
def _calculate_backoff_delay(self) -> float:
|
||||
"""Calculate exponential backoff delay with jitter."""
|
||||
import random
|
||||
|
||||
delay = self._base_delay * (2 ** (self._connection_attempts - 1))
|
||||
delay = min(delay, self._max_delay)
|
||||
# Add jitter (±25%)
|
||||
jitter = delay * 0.25 * (random.random() * 2 - 1)
|
||||
return delay + jitter
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Disconnect from the MCP server."""
|
||||
async with self._lock:
|
||||
if self._client is not None:
|
||||
try:
|
||||
await self._client.aclose()
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Error closing connection to %s: %s",
|
||||
self.server_name,
|
||||
e,
|
||||
)
|
||||
finally:
|
||||
self._client = None
|
||||
|
||||
self._state = ConnectionState.DISCONNECTED
|
||||
logger.info("Disconnected from MCP server: %s", self.server_name)
|
||||
|
||||
async def reconnect(self) -> None:
|
||||
"""Reconnect to the MCP server."""
|
||||
async with self._lock:
|
||||
self._state = ConnectionState.RECONNECTING
|
||||
await self.disconnect()
|
||||
await self.connect()
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""
|
||||
Perform a health check on the connection.
|
||||
|
||||
Returns:
|
||||
True if connection is healthy
|
||||
"""
|
||||
if not self.is_connected or self._client is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
if self.config.transport == TransportType.HTTP:
|
||||
response = await self._client.get(
|
||||
"/health",
|
||||
timeout=5.0,
|
||||
)
|
||||
return response.status_code == 200
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Health check failed for %s: %s",
|
||||
self.server_name,
|
||||
e,
|
||||
)
|
||||
return False
|
||||
|
||||
async def execute_request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
data: dict[str, Any] | None = None,
|
||||
timeout: float | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Execute an HTTP request to the MCP server.
|
||||
|
||||
Args:
|
||||
method: HTTP method (GET, POST, etc.)
|
||||
path: Request path
|
||||
data: Optional request body
|
||||
timeout: Optional timeout override
|
||||
|
||||
Returns:
|
||||
Response data
|
||||
|
||||
Raises:
|
||||
MCPConnectionError: If not connected
|
||||
MCPTimeoutError: If request times out
|
||||
"""
|
||||
if not self.is_connected or self._client is None:
|
||||
raise MCPConnectionError(
|
||||
"Not connected to server",
|
||||
server_name=self.server_name,
|
||||
)
|
||||
|
||||
effective_timeout = timeout or self.config.timeout
|
||||
|
||||
try:
|
||||
if method.upper() == "GET":
|
||||
response = await self._client.get(
|
||||
path,
|
||||
timeout=effective_timeout,
|
||||
)
|
||||
elif method.upper() == "POST":
|
||||
response = await self._client.post(
|
||||
path,
|
||||
json=data,
|
||||
timeout=effective_timeout,
|
||||
)
|
||||
else:
|
||||
response = await self._client.request(
|
||||
method.upper(),
|
||||
path,
|
||||
json=data,
|
||||
timeout=effective_timeout,
|
||||
)
|
||||
|
||||
self._last_activity = time.time()
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
except httpx.TimeoutException as e:
|
||||
raise MCPTimeoutError(
|
||||
"Request timed out",
|
||||
server_name=self.server_name,
|
||||
timeout_seconds=effective_timeout,
|
||||
operation=f"{method} {path}",
|
||||
) from e
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise MCPConnectionError(
|
||||
f"HTTP error: {e.response.status_code}",
|
||||
server_name=self.server_name,
|
||||
url=f"{self.config.url}{path}",
|
||||
cause=e,
|
||||
) from e
|
||||
except Exception as e:
|
||||
raise MCPConnectionError(
|
||||
f"Request failed: {e}",
|
||||
server_name=self.server_name,
|
||||
cause=e,
|
||||
) from e
|
||||
|
||||
|
||||
class ConnectionPool:
|
||||
"""
|
||||
Pool of connections to MCP servers.
|
||||
|
||||
Manages connection lifecycle and provides connection reuse.
|
||||
"""
|
||||
|
||||
def __init__(self, max_connections_per_server: int = 10) -> None:
|
||||
"""
|
||||
Initialize connection pool.
|
||||
|
||||
Args:
|
||||
max_connections_per_server: Maximum connections per server
|
||||
"""
|
||||
self._connections: dict[str, MCPConnection] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
self._per_server_locks: dict[str, asyncio.Lock] = {}
|
||||
self._max_per_server = max_connections_per_server
|
||||
|
||||
def _get_server_lock(self, server_name: str) -> asyncio.Lock:
|
||||
"""Get or create a lock for a specific server.
|
||||
|
||||
Uses setdefault for atomic dict access to prevent race conditions
|
||||
where two coroutines could create different locks for the same server.
|
||||
"""
|
||||
# setdefault is atomic - if key exists, returns existing value
|
||||
# if key doesn't exist, inserts new value and returns it
|
||||
return self._per_server_locks.setdefault(server_name, asyncio.Lock())
|
||||
|
||||
async def get_connection(
|
||||
self,
|
||||
server_name: str,
|
||||
config: MCPServerConfig,
|
||||
) -> MCPConnection:
|
||||
"""
|
||||
Get or create a connection to a server.
|
||||
|
||||
Uses per-server locking to avoid blocking all connections
|
||||
when establishing a new connection.
|
||||
|
||||
Args:
|
||||
server_name: Name of the server
|
||||
config: Server configuration
|
||||
|
||||
Returns:
|
||||
Active connection
|
||||
"""
|
||||
# Quick check without lock - if connection exists and is connected, return it
|
||||
if server_name in self._connections:
|
||||
connection = self._connections[server_name]
|
||||
if connection.is_connected:
|
||||
return connection
|
||||
|
||||
# Need to create or reconnect - use per-server lock to avoid blocking others
|
||||
async with self._lock:
|
||||
server_lock = self._get_server_lock(server_name)
|
||||
|
||||
async with server_lock:
|
||||
# Double-check after acquiring per-server lock
|
||||
if server_name in self._connections:
|
||||
connection = self._connections[server_name]
|
||||
if connection.is_connected:
|
||||
return connection
|
||||
# Connection exists but not connected - reconnect
|
||||
await connection.connect()
|
||||
return connection
|
||||
|
||||
# Create new connection (outside global lock, under per-server lock)
|
||||
connection = MCPConnection(server_name, config)
|
||||
await connection.connect()
|
||||
|
||||
# Store connection under global lock
|
||||
async with self._lock:
|
||||
self._connections[server_name] = connection
|
||||
|
||||
return connection
|
||||
|
||||
async def release_connection(self, server_name: str) -> None:
|
||||
"""
|
||||
Release a connection (currently just tracks usage).
|
||||
|
||||
Args:
|
||||
server_name: Name of the server
|
||||
"""
|
||||
# For now, we keep connections alive
|
||||
# Future: implement connection reaping for idle connections
|
||||
|
||||
async def close_connection(self, server_name: str) -> None:
|
||||
"""
|
||||
Close and remove a connection.
|
||||
|
||||
Args:
|
||||
server_name: Name of the server
|
||||
"""
|
||||
async with self._lock:
|
||||
if server_name in self._connections:
|
||||
await self._connections[server_name].disconnect()
|
||||
del self._connections[server_name]
|
||||
# Clean up per-server lock
|
||||
if server_name in self._per_server_locks:
|
||||
del self._per_server_locks[server_name]
|
||||
|
||||
async def close_all(self) -> None:
|
||||
"""Close all connections in the pool."""
|
||||
async with self._lock:
|
||||
for connection in self._connections.values():
|
||||
try:
|
||||
await connection.disconnect()
|
||||
except Exception as e:
|
||||
logger.warning("Error closing connection: %s", e)
|
||||
|
||||
self._connections.clear()
|
||||
self._per_server_locks.clear()
|
||||
logger.info("Closed all MCP connections")
|
||||
|
||||
async def health_check_all(self) -> dict[str, bool]:
|
||||
"""
|
||||
Perform health check on all connections.
|
||||
|
||||
Returns:
|
||||
Dict mapping server names to health status
|
||||
"""
|
||||
# Copy connections under lock to prevent modification during iteration
|
||||
async with self._lock:
|
||||
connections_snapshot = dict(self._connections)
|
||||
|
||||
results = {}
|
||||
for name, connection in connections_snapshot.items():
|
||||
results[name] = await connection.health_check()
|
||||
return results
|
||||
|
||||
def get_status(self) -> dict[str, dict[str, Any]]:
|
||||
"""
|
||||
Get status of all connections.
|
||||
|
||||
Returns:
|
||||
Dict mapping server names to status info
|
||||
"""
|
||||
return {
|
||||
name: {
|
||||
"state": conn.state.value,
|
||||
"is_connected": conn.is_connected,
|
||||
"url": conn.config.url,
|
||||
}
|
||||
for name, conn in self._connections.items()
|
||||
}
|
||||
|
||||
@asynccontextmanager
|
||||
async def connection(
|
||||
self,
|
||||
server_name: str,
|
||||
config: MCPServerConfig,
|
||||
) -> AsyncGenerator[MCPConnection, None]:
|
||||
"""
|
||||
Context manager for getting a connection.
|
||||
|
||||
Usage:
|
||||
async with pool.connection("server", config) as conn:
|
||||
result = await conn.execute_request("POST", "/tool", data)
|
||||
"""
|
||||
conn = await self.get_connection(server_name, config)
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
await self.release_connection(server_name)
|
||||
201
backend/app/services/mcp/exceptions.py
Normal file
201
backend/app/services/mcp/exceptions.py
Normal file
@@ -0,0 +1,201 @@
|
||||
"""
|
||||
MCP Exception Classes
|
||||
|
||||
Custom exceptions for MCP client operations with detailed error context.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
class MCPError(Exception):
|
||||
"""Base exception for all MCP-related errors."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
*,
|
||||
server_name: str | None = None,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.server_name = server_name
|
||||
self.details = details or {}
|
||||
|
||||
def __str__(self) -> str:
|
||||
parts = [self.message]
|
||||
if self.server_name:
|
||||
parts.append(f"server={self.server_name}")
|
||||
if self.details:
|
||||
parts.append(f"details={self.details}")
|
||||
return " | ".join(parts)
|
||||
|
||||
|
||||
class MCPConnectionError(MCPError):
|
||||
"""Raised when connection to an MCP server fails."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
*,
|
||||
server_name: str | None = None,
|
||||
url: str | None = None,
|
||||
cause: Exception | None = None,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
super().__init__(message, server_name=server_name, details=details)
|
||||
self.url = url
|
||||
self.cause = cause
|
||||
|
||||
def __str__(self) -> str:
|
||||
base = super().__str__()
|
||||
if self.url:
|
||||
base = f"{base} | url={self.url}"
|
||||
if self.cause:
|
||||
base = f"{base} | cause={type(self.cause).__name__}: {self.cause}"
|
||||
return base
|
||||
|
||||
|
||||
class MCPTimeoutError(MCPError):
|
||||
"""Raised when an MCP operation times out."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
*,
|
||||
server_name: str | None = None,
|
||||
timeout_seconds: float | None = None,
|
||||
operation: str | None = None,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
super().__init__(message, server_name=server_name, details=details)
|
||||
self.timeout_seconds = timeout_seconds
|
||||
self.operation = operation
|
||||
|
||||
def __str__(self) -> str:
|
||||
base = super().__str__()
|
||||
if self.timeout_seconds is not None:
|
||||
base = f"{base} | timeout={self.timeout_seconds}s"
|
||||
if self.operation:
|
||||
base = f"{base} | operation={self.operation}"
|
||||
return base
|
||||
|
||||
|
||||
class MCPToolError(MCPError):
|
||||
"""Raised when a tool execution fails."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
*,
|
||||
server_name: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
tool_args: dict[str, Any] | None = None,
|
||||
error_code: str | None = None,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
super().__init__(message, server_name=server_name, details=details)
|
||||
self.tool_name = tool_name
|
||||
self.tool_args = tool_args
|
||||
self.error_code = error_code
|
||||
|
||||
def __str__(self) -> str:
|
||||
base = super().__str__()
|
||||
if self.tool_name:
|
||||
base = f"{base} | tool={self.tool_name}"
|
||||
if self.error_code:
|
||||
base = f"{base} | error_code={self.error_code}"
|
||||
return base
|
||||
|
||||
|
||||
class MCPServerNotFoundError(MCPError):
|
||||
"""Raised when a requested MCP server is not registered."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_name: str,
|
||||
*,
|
||||
available_servers: list[str] | None = None,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
message = f"MCP server not found: {server_name}"
|
||||
super().__init__(message, server_name=server_name, details=details)
|
||||
self.available_servers = available_servers or []
|
||||
|
||||
def __str__(self) -> str:
|
||||
base = super().__str__()
|
||||
if self.available_servers:
|
||||
base = f"{base} | available={self.available_servers}"
|
||||
return base
|
||||
|
||||
|
||||
class MCPToolNotFoundError(MCPError):
|
||||
"""Raised when a requested tool is not found on any server."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tool_name: str,
|
||||
*,
|
||||
server_name: str | None = None,
|
||||
available_tools: list[str] | None = None,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
message = f"Tool not found: {tool_name}"
|
||||
super().__init__(message, server_name=server_name, details=details)
|
||||
self.tool_name = tool_name
|
||||
self.available_tools = available_tools or []
|
||||
|
||||
def __str__(self) -> str:
|
||||
base = super().__str__()
|
||||
if self.available_tools:
|
||||
base = f"{base} | available_tools={self.available_tools[:5]}..."
|
||||
return base
|
||||
|
||||
|
||||
class MCPCircuitOpenError(MCPError):
|
||||
"""Raised when a circuit breaker is open (server temporarily unavailable)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_name: str,
|
||||
*,
|
||||
failure_count: int | None = None,
|
||||
reset_timeout: float | None = None,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
message = f"Circuit breaker open for server: {server_name}"
|
||||
super().__init__(message, server_name=server_name, details=details)
|
||||
self.failure_count = failure_count
|
||||
self.reset_timeout = reset_timeout
|
||||
|
||||
def __str__(self) -> str:
|
||||
base = super().__str__()
|
||||
if self.failure_count is not None:
|
||||
base = f"{base} | failures={self.failure_count}"
|
||||
if self.reset_timeout is not None:
|
||||
base = f"{base} | reset_in={self.reset_timeout}s"
|
||||
return base
|
||||
|
||||
|
||||
class MCPValidationError(MCPError):
|
||||
"""Raised when tool arguments fail validation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
*,
|
||||
tool_name: str | None = None,
|
||||
field_errors: dict[str, str] | None = None,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
super().__init__(message, details=details)
|
||||
self.tool_name = tool_name
|
||||
self.field_errors = field_errors or {}
|
||||
|
||||
def __str__(self) -> str:
|
||||
base = super().__str__()
|
||||
if self.tool_name:
|
||||
base = f"{base} | tool={self.tool_name}"
|
||||
if self.field_errors:
|
||||
base = f"{base} | fields={list(self.field_errors.keys())}"
|
||||
return base
|
||||
305
backend/app/services/mcp/registry.py
Normal file
305
backend/app/services/mcp/registry.py
Normal file
@@ -0,0 +1,305 @@
|
||||
"""
|
||||
MCP Server Registry
|
||||
|
||||
Thread-safe singleton registry for managing MCP server configurations
|
||||
and their capabilities.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from threading import Lock
|
||||
from typing import Any
|
||||
|
||||
from .config import MCPConfig, MCPServerConfig, load_mcp_config
|
||||
from .exceptions import MCPServerNotFoundError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ServerCapabilities:
|
||||
"""Cached capabilities for an MCP server."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
resources: list[dict[str, Any]] | None = None,
|
||||
prompts: list[dict[str, Any]] | None = None,
|
||||
) -> None:
|
||||
self.tools = tools or []
|
||||
self.resources = resources or []
|
||||
self.prompts = prompts or []
|
||||
self._loaded = False
|
||||
self._load_time: float | None = None
|
||||
|
||||
@property
|
||||
def is_loaded(self) -> bool:
|
||||
"""Check if capabilities have been loaded."""
|
||||
return self._loaded
|
||||
|
||||
@property
|
||||
def tool_names(self) -> list[str]:
|
||||
"""Get list of tool names."""
|
||||
return [t.get("name", "") for t in self.tools if t.get("name")]
|
||||
|
||||
def mark_loaded(self) -> None:
|
||||
"""Mark capabilities as loaded."""
|
||||
import time
|
||||
|
||||
self._loaded = True
|
||||
self._load_time = time.time()
|
||||
|
||||
|
||||
class MCPServerRegistry:
|
||||
"""
|
||||
Thread-safe singleton registry for MCP servers.
|
||||
|
||||
Manages server configurations and caches their capabilities.
|
||||
"""
|
||||
|
||||
_instance: "MCPServerRegistry | None" = None
|
||||
_lock = Lock()
|
||||
|
||||
def __new__(cls) -> "MCPServerRegistry":
|
||||
"""Ensure singleton pattern."""
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize registry (only runs once due to singleton)."""
|
||||
if getattr(self, "_initialized", False):
|
||||
return
|
||||
|
||||
self._config: MCPConfig = MCPConfig()
|
||||
self._capabilities: dict[str, ServerCapabilities] = {}
|
||||
self._capabilities_lock = asyncio.Lock()
|
||||
self._initialized = True
|
||||
|
||||
logger.info("MCP Server Registry initialized")
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> "MCPServerRegistry":
|
||||
"""Get the singleton registry instance."""
|
||||
return cls()
|
||||
|
||||
@classmethod
|
||||
def reset_instance(cls) -> None:
|
||||
"""Reset the singleton (for testing)."""
|
||||
with cls._lock:
|
||||
cls._instance = None
|
||||
|
||||
def load_config(self, config: MCPConfig | None = None) -> None:
|
||||
"""
|
||||
Load configuration into the registry.
|
||||
|
||||
Args:
|
||||
config: Optional config to load. If None, loads from default path.
|
||||
"""
|
||||
if config is None:
|
||||
config = load_mcp_config()
|
||||
|
||||
self._config = config
|
||||
self._capabilities.clear()
|
||||
|
||||
logger.info(
|
||||
"Loaded MCP configuration with %d servers",
|
||||
len(config.mcp_servers),
|
||||
)
|
||||
for name in config.list_server_names():
|
||||
logger.debug("Registered MCP server: %s", name)
|
||||
|
||||
def register(self, name: str, config: MCPServerConfig) -> None:
|
||||
"""
|
||||
Register a new MCP server.
|
||||
|
||||
Args:
|
||||
name: Unique server name
|
||||
config: Server configuration
|
||||
"""
|
||||
self._config.mcp_servers[name] = config
|
||||
self._capabilities.pop(name, None) # Clear any cached capabilities
|
||||
|
||||
logger.info("Registered MCP server: %s at %s", name, config.url)
|
||||
|
||||
def unregister(self, name: str) -> bool:
|
||||
"""
|
||||
Unregister an MCP server.
|
||||
|
||||
Args:
|
||||
name: Server name to unregister
|
||||
|
||||
Returns:
|
||||
True if server was found and removed
|
||||
"""
|
||||
if name in self._config.mcp_servers:
|
||||
del self._config.mcp_servers[name]
|
||||
self._capabilities.pop(name, None)
|
||||
logger.info("Unregistered MCP server: %s", name)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get(self, name: str) -> MCPServerConfig:
|
||||
"""
|
||||
Get a server configuration by name.
|
||||
|
||||
Args:
|
||||
name: Server name
|
||||
|
||||
Returns:
|
||||
Server configuration
|
||||
|
||||
Raises:
|
||||
MCPServerNotFoundError: If server is not registered
|
||||
"""
|
||||
config = self._config.get_server(name)
|
||||
if config is None:
|
||||
raise MCPServerNotFoundError(
|
||||
server_name=name,
|
||||
available_servers=self.list_servers(),
|
||||
)
|
||||
return config
|
||||
|
||||
def get_or_none(self, name: str) -> MCPServerConfig | None:
|
||||
"""
|
||||
Get a server configuration by name, or None if not found.
|
||||
|
||||
Args:
|
||||
name: Server name
|
||||
|
||||
Returns:
|
||||
Server configuration or None
|
||||
"""
|
||||
return self._config.get_server(name)
|
||||
|
||||
def list_servers(self) -> list[str]:
|
||||
"""Get list of all registered server names."""
|
||||
return self._config.list_server_names()
|
||||
|
||||
def list_enabled_servers(self) -> list[str]:
|
||||
"""Get list of enabled server names."""
|
||||
return list(self._config.get_enabled_servers().keys())
|
||||
|
||||
def get_all_configs(self) -> dict[str, MCPServerConfig]:
|
||||
"""Get all server configurations."""
|
||||
return dict(self._config.mcp_servers)
|
||||
|
||||
def get_enabled_configs(self) -> dict[str, MCPServerConfig]:
|
||||
"""Get all enabled server configurations."""
|
||||
return self._config.get_enabled_servers()
|
||||
|
||||
async def get_capabilities(
|
||||
self,
|
||||
name: str,
|
||||
force_refresh: bool = False,
|
||||
) -> ServerCapabilities:
|
||||
"""
|
||||
Get capabilities for a server (lazy-loaded and cached).
|
||||
|
||||
Args:
|
||||
name: Server name
|
||||
force_refresh: If True, refresh cached capabilities
|
||||
|
||||
Returns:
|
||||
Server capabilities
|
||||
|
||||
Raises:
|
||||
MCPServerNotFoundError: If server is not registered
|
||||
"""
|
||||
# Verify server exists
|
||||
self.get(name)
|
||||
|
||||
async with self._capabilities_lock:
|
||||
if name not in self._capabilities or force_refresh:
|
||||
# Will be populated by connection manager when connecting
|
||||
self._capabilities[name] = ServerCapabilities()
|
||||
|
||||
return self._capabilities[name]
|
||||
|
||||
def set_capabilities(
|
||||
self,
|
||||
name: str,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
resources: list[dict[str, Any]] | None = None,
|
||||
prompts: list[dict[str, Any]] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Set capabilities for a server (called by connection manager).
|
||||
|
||||
Args:
|
||||
name: Server name
|
||||
tools: List of tool definitions
|
||||
resources: List of resource definitions
|
||||
prompts: List of prompt definitions
|
||||
"""
|
||||
capabilities = ServerCapabilities(
|
||||
tools=tools,
|
||||
resources=resources,
|
||||
prompts=prompts,
|
||||
)
|
||||
capabilities.mark_loaded()
|
||||
self._capabilities[name] = capabilities
|
||||
|
||||
logger.debug(
|
||||
"Updated capabilities for %s: %d tools, %d resources, %d prompts",
|
||||
name,
|
||||
len(capabilities.tools),
|
||||
len(capabilities.resources),
|
||||
len(capabilities.prompts),
|
||||
)
|
||||
|
||||
def get_cached_capabilities(self, name: str) -> ServerCapabilities:
|
||||
"""
|
||||
Get cached capabilities without async loading.
|
||||
|
||||
Use this for synchronous access when you only need
|
||||
cached values (e.g., for health check responses).
|
||||
|
||||
Args:
|
||||
name: Server name
|
||||
|
||||
Returns:
|
||||
Cached capabilities or empty ServerCapabilities
|
||||
"""
|
||||
return self._capabilities.get(name, ServerCapabilities())
|
||||
|
||||
def find_server_for_tool(self, tool_name: str) -> str | None:
|
||||
"""
|
||||
Find which server provides a specific tool.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool to find
|
||||
|
||||
Returns:
|
||||
Server name or None if not found
|
||||
"""
|
||||
for name, caps in self._capabilities.items():
|
||||
if tool_name in caps.tool_names:
|
||||
return name
|
||||
return None
|
||||
|
||||
def get_all_tools(self) -> dict[str, list[dict[str, Any]]]:
|
||||
"""
|
||||
Get all tools from all servers.
|
||||
|
||||
Returns:
|
||||
Dict mapping server name to list of tool definitions
|
||||
"""
|
||||
return {
|
||||
name: caps.tools
|
||||
for name, caps in self._capabilities.items()
|
||||
if caps.is_loaded
|
||||
}
|
||||
|
||||
@property
|
||||
def global_config(self) -> MCPConfig:
|
||||
"""Get the global MCP configuration."""
|
||||
return self._config
|
||||
|
||||
|
||||
# Module-level convenience function
|
||||
def get_registry() -> MCPServerRegistry:
|
||||
"""Get the global MCP server registry instance."""
|
||||
return MCPServerRegistry.get_instance()
|
||||
619
backend/app/services/mcp/routing.py
Normal file
619
backend/app/services/mcp/routing.py
Normal file
@@ -0,0 +1,619 @@
|
||||
"""
|
||||
MCP Tool Call Routing
|
||||
|
||||
Routes tool calls to appropriate servers with retry logic,
|
||||
circuit breakers, and request/response serialization.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from .config import MCPServerConfig
|
||||
from .connection import ConnectionPool, MCPConnection
|
||||
from .exceptions import (
|
||||
MCPCircuitOpenError,
|
||||
MCPError,
|
||||
MCPTimeoutError,
|
||||
MCPToolError,
|
||||
MCPToolNotFoundError,
|
||||
)
|
||||
from .registry import MCPServerRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CircuitState(Enum):
|
||||
"""Circuit breaker states."""
|
||||
|
||||
CLOSED = "closed"
|
||||
OPEN = "open"
|
||||
HALF_OPEN = "half-open"
|
||||
|
||||
|
||||
class AsyncCircuitBreaker:
|
||||
"""
|
||||
Async-compatible circuit breaker implementation.
|
||||
|
||||
Unlike pybreaker which wraps sync functions, this implementation
|
||||
provides explicit success/failure tracking for async code.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fail_max: int = 5,
|
||||
reset_timeout: float = 30.0,
|
||||
name: str = "",
|
||||
) -> None:
|
||||
"""
|
||||
Initialize circuit breaker.
|
||||
|
||||
Args:
|
||||
fail_max: Maximum failures before opening circuit
|
||||
reset_timeout: Seconds to wait before trying again
|
||||
name: Name for logging
|
||||
"""
|
||||
self.fail_max = fail_max
|
||||
self.reset_timeout = reset_timeout
|
||||
self.name = name
|
||||
self._state = CircuitState.CLOSED
|
||||
self._fail_counter = 0
|
||||
self._last_failure_time: float | None = None
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
@property
|
||||
def current_state(self) -> str:
|
||||
"""Get current state as string."""
|
||||
# Check if we should transition from OPEN to HALF_OPEN
|
||||
if self._state == CircuitState.OPEN:
|
||||
if self._should_try_reset():
|
||||
return CircuitState.HALF_OPEN.value
|
||||
return self._state.value
|
||||
|
||||
@property
|
||||
def fail_counter(self) -> int:
|
||||
"""Get current failure count."""
|
||||
return self._fail_counter
|
||||
|
||||
def _should_try_reset(self) -> bool:
|
||||
"""Check if enough time has passed to try resetting."""
|
||||
if self._last_failure_time is None:
|
||||
return True
|
||||
return (time.time() - self._last_failure_time) >= self.reset_timeout
|
||||
|
||||
async def success(self) -> None:
|
||||
"""Record a successful call."""
|
||||
async with self._lock:
|
||||
self._fail_counter = 0
|
||||
self._state = CircuitState.CLOSED
|
||||
self._last_failure_time = None
|
||||
|
||||
async def failure(self) -> None:
|
||||
"""Record a failed call."""
|
||||
async with self._lock:
|
||||
self._fail_counter += 1
|
||||
self._last_failure_time = time.time()
|
||||
|
||||
if self._fail_counter >= self.fail_max:
|
||||
self._state = CircuitState.OPEN
|
||||
logger.warning(
|
||||
"Circuit breaker %s opened after %d failures",
|
||||
self.name,
|
||||
self._fail_counter,
|
||||
)
|
||||
|
||||
def is_open(self) -> bool:
|
||||
"""Check if circuit is open (not allowing calls)."""
|
||||
if self._state == CircuitState.OPEN:
|
||||
return not self._should_try_reset()
|
||||
return False
|
||||
|
||||
async def reset(self) -> None:
|
||||
"""Manually reset the circuit breaker."""
|
||||
async with self._lock:
|
||||
self._state = CircuitState.CLOSED
|
||||
self._fail_counter = 0
|
||||
self._last_failure_time = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolInfo:
|
||||
"""Information about an available tool."""
|
||||
|
||||
name: str
|
||||
description: str | None = None
|
||||
server_name: str | None = None
|
||||
input_schema: dict[str, Any] | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"server_name": self.server_name,
|
||||
"input_schema": self.input_schema,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolResult:
|
||||
"""Result of a tool execution."""
|
||||
|
||||
success: bool
|
||||
data: Any = None
|
||||
error: str | None = None
|
||||
error_code: str | None = None
|
||||
tool_name: str | None = None
|
||||
server_name: str | None = None
|
||||
execution_time_ms: float = 0.0
|
||||
request_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"success": self.success,
|
||||
"data": self.data,
|
||||
"error": self.error,
|
||||
"error_code": self.error_code,
|
||||
"tool_name": self.tool_name,
|
||||
"server_name": self.server_name,
|
||||
"execution_time_ms": self.execution_time_ms,
|
||||
"request_id": self.request_id,
|
||||
}
|
||||
|
||||
|
||||
class ToolRouter:
|
||||
"""
|
||||
Routes tool calls to the appropriate MCP server.
|
||||
|
||||
Features:
|
||||
- Tool name to server mapping
|
||||
- Retry logic with exponential backoff
|
||||
- Circuit breaker pattern for fault tolerance
|
||||
- Request/response serialization
|
||||
- Execution timing and metrics
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
registry: MCPServerRegistry,
|
||||
connection_pool: ConnectionPool,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the tool router.
|
||||
|
||||
Args:
|
||||
registry: MCP server registry
|
||||
connection_pool: Connection pool for servers
|
||||
"""
|
||||
self._registry = registry
|
||||
self._pool = connection_pool
|
||||
self._circuit_breakers: dict[str, AsyncCircuitBreaker] = {}
|
||||
self._tool_to_server: dict[str, str] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
def _get_circuit_breaker(
|
||||
self,
|
||||
server_name: str,
|
||||
config: MCPServerConfig,
|
||||
) -> AsyncCircuitBreaker:
|
||||
"""Get or create a circuit breaker for a server."""
|
||||
if server_name not in self._circuit_breakers:
|
||||
self._circuit_breakers[server_name] = AsyncCircuitBreaker(
|
||||
fail_max=config.circuit_breaker_threshold,
|
||||
reset_timeout=config.circuit_breaker_timeout,
|
||||
name=f"mcp-{server_name}",
|
||||
)
|
||||
return self._circuit_breakers[server_name]
|
||||
|
||||
async def register_tool_mapping(
|
||||
self,
|
||||
tool_name: str,
|
||||
server_name: str,
|
||||
) -> None:
|
||||
"""
|
||||
Register a mapping from tool name to server.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool
|
||||
server_name: Name of the server providing the tool
|
||||
"""
|
||||
async with self._lock:
|
||||
self._tool_to_server[tool_name] = server_name
|
||||
logger.debug("Registered tool %s -> server %s", tool_name, server_name)
|
||||
|
||||
async def discover_tools(self) -> None:
|
||||
"""
|
||||
Discover all tools from registered servers and build mappings.
|
||||
"""
|
||||
for server_name in self._registry.list_enabled_servers():
|
||||
try:
|
||||
config = self._registry.get(server_name)
|
||||
connection = await self._pool.get_connection(server_name, config)
|
||||
|
||||
# Fetch tools from server
|
||||
tools = await self._fetch_tools_from_server(connection)
|
||||
|
||||
# Update registry with capabilities
|
||||
self._registry.set_capabilities(
|
||||
server_name,
|
||||
tools=[t.to_dict() for t in tools],
|
||||
)
|
||||
|
||||
# Update tool mappings
|
||||
for tool in tools:
|
||||
await self.register_tool_mapping(tool.name, server_name)
|
||||
|
||||
logger.info(
|
||||
"Discovered %d tools from server %s",
|
||||
len(tools),
|
||||
server_name,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to discover tools from %s: %s",
|
||||
server_name,
|
||||
e,
|
||||
)
|
||||
|
||||
async def _fetch_tools_from_server(
|
||||
self,
|
||||
connection: MCPConnection,
|
||||
) -> list[ToolInfo]:
|
||||
"""Fetch available tools from an MCP server."""
|
||||
try:
|
||||
response = await connection.execute_request(
|
||||
"GET",
|
||||
"/mcp/tools",
|
||||
)
|
||||
|
||||
tools = []
|
||||
for tool_data in response.get("tools", []):
|
||||
tools.append(
|
||||
ToolInfo(
|
||||
name=tool_data.get("name", ""),
|
||||
description=tool_data.get("description"),
|
||||
server_name=connection.server_name,
|
||||
input_schema=tool_data.get("inputSchema"),
|
||||
)
|
||||
)
|
||||
return tools
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Error fetching tools from %s: %s",
|
||||
connection.server_name,
|
||||
e,
|
||||
)
|
||||
return []
|
||||
|
||||
def find_server_for_tool(self, tool_name: str) -> str | None:
|
||||
"""
|
||||
Find which server provides a specific tool.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool
|
||||
|
||||
Returns:
|
||||
Server name or None if not found
|
||||
"""
|
||||
return self._tool_to_server.get(tool_name)
|
||||
|
||||
async def call_tool(
|
||||
self,
|
||||
server_name: str,
|
||||
tool_name: str,
|
||||
arguments: dict[str, Any] | None = None,
|
||||
timeout: float | None = None,
|
||||
) -> ToolResult:
|
||||
"""
|
||||
Call a tool on a specific server.
|
||||
|
||||
Args:
|
||||
server_name: Name of the MCP server
|
||||
tool_name: Name of the tool to call
|
||||
arguments: Tool arguments
|
||||
timeout: Optional timeout override
|
||||
|
||||
Returns:
|
||||
Tool execution result
|
||||
"""
|
||||
start_time = time.time()
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
logger.debug(
|
||||
"Tool call [%s]: %s.%s with args %s",
|
||||
request_id,
|
||||
server_name,
|
||||
tool_name,
|
||||
arguments,
|
||||
)
|
||||
|
||||
try:
|
||||
config = self._registry.get(server_name)
|
||||
circuit_breaker = self._get_circuit_breaker(server_name, config)
|
||||
|
||||
# Check circuit breaker state
|
||||
if circuit_breaker.is_open():
|
||||
raise MCPCircuitOpenError(
|
||||
server_name=server_name,
|
||||
failure_count=circuit_breaker.fail_counter,
|
||||
reset_timeout=config.circuit_breaker_timeout,
|
||||
)
|
||||
|
||||
# Execute with retry logic
|
||||
result = await self._execute_with_retry(
|
||||
server_name=server_name,
|
||||
config=config,
|
||||
tool_name=tool_name,
|
||||
arguments=arguments or {},
|
||||
timeout=timeout,
|
||||
circuit_breaker=circuit_breaker,
|
||||
)
|
||||
|
||||
execution_time = (time.time() - start_time) * 1000
|
||||
|
||||
return ToolResult(
|
||||
success=True,
|
||||
data=result,
|
||||
tool_name=tool_name,
|
||||
server_name=server_name,
|
||||
execution_time_ms=execution_time,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
except MCPCircuitOpenError:
|
||||
raise
|
||||
except MCPError as e:
|
||||
execution_time = (time.time() - start_time) * 1000
|
||||
logger.error(
|
||||
"Tool call failed [%s]: %s.%s - %s",
|
||||
request_id,
|
||||
server_name,
|
||||
tool_name,
|
||||
e,
|
||||
)
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
error_code=type(e).__name__,
|
||||
tool_name=tool_name,
|
||||
server_name=server_name,
|
||||
execution_time_ms=execution_time,
|
||||
request_id=request_id,
|
||||
)
|
||||
except Exception as e:
|
||||
execution_time = (time.time() - start_time) * 1000
|
||||
logger.exception(
|
||||
"Unexpected error in tool call [%s]: %s.%s",
|
||||
request_id,
|
||||
server_name,
|
||||
tool_name,
|
||||
)
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
error_code="UnexpectedError",
|
||||
tool_name=tool_name,
|
||||
server_name=server_name,
|
||||
execution_time_ms=execution_time,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
async def _execute_with_retry(
|
||||
self,
|
||||
server_name: str,
|
||||
config: MCPServerConfig,
|
||||
tool_name: str,
|
||||
arguments: dict[str, Any],
|
||||
timeout: float | None,
|
||||
circuit_breaker: AsyncCircuitBreaker,
|
||||
) -> Any:
|
||||
"""Execute tool call with retry logic."""
|
||||
last_error: Exception | None = None
|
||||
attempts = 0
|
||||
max_attempts = config.retry_attempts + 1 # +1 for initial attempt
|
||||
|
||||
while attempts < max_attempts:
|
||||
attempts += 1
|
||||
|
||||
try:
|
||||
# Use circuit breaker to track failures
|
||||
result = await self._execute_tool_call(
|
||||
server_name=server_name,
|
||||
config=config,
|
||||
tool_name=tool_name,
|
||||
arguments=arguments,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
# Success - record it
|
||||
await circuit_breaker.success()
|
||||
return result
|
||||
|
||||
except MCPCircuitOpenError:
|
||||
raise
|
||||
except MCPTimeoutError:
|
||||
# Timeout - don't retry
|
||||
await circuit_breaker.failure()
|
||||
raise
|
||||
except MCPToolError:
|
||||
# Tool-level error - don't retry (user error)
|
||||
raise
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
await circuit_breaker.failure()
|
||||
|
||||
if attempts < max_attempts:
|
||||
delay = self._calculate_retry_delay(attempts, config)
|
||||
logger.warning(
|
||||
"Tool call attempt %d/%d failed for %s.%s: %s. "
|
||||
"Retrying in %.1fs",
|
||||
attempts,
|
||||
max_attempts,
|
||||
server_name,
|
||||
tool_name,
|
||||
e,
|
||||
delay,
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
# All attempts failed
|
||||
raise MCPToolError(
|
||||
f"Tool call failed after {max_attempts} attempts",
|
||||
server_name=server_name,
|
||||
tool_name=tool_name,
|
||||
tool_args=arguments,
|
||||
details={"last_error": str(last_error)},
|
||||
)
|
||||
|
||||
def _calculate_retry_delay(
|
||||
self,
|
||||
attempt: int,
|
||||
config: MCPServerConfig,
|
||||
) -> float:
|
||||
"""Calculate exponential backoff delay with jitter."""
|
||||
import random
|
||||
|
||||
delay = config.retry_delay * (2 ** (attempt - 1))
|
||||
delay = min(delay, config.retry_max_delay)
|
||||
# Add jitter (±25%)
|
||||
jitter = delay * 0.25 * (random.random() * 2 - 1)
|
||||
return max(0.1, delay + jitter)
|
||||
|
||||
async def _execute_tool_call(
|
||||
self,
|
||||
server_name: str,
|
||||
config: MCPServerConfig,
|
||||
tool_name: str,
|
||||
arguments: dict[str, Any],
|
||||
timeout: float | None,
|
||||
) -> Any:
|
||||
"""Execute a single tool call."""
|
||||
connection = await self._pool.get_connection(server_name, config)
|
||||
|
||||
# Build MCP tool call request
|
||||
request_body = {
|
||||
"jsonrpc": "2.0",
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": tool_name,
|
||||
"arguments": arguments,
|
||||
},
|
||||
"id": str(uuid.uuid4()),
|
||||
}
|
||||
|
||||
response = await connection.execute_request(
|
||||
method="POST",
|
||||
path="/mcp",
|
||||
data=request_body,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
# Handle JSON-RPC response
|
||||
if "error" in response:
|
||||
error = response["error"]
|
||||
raise MCPToolError(
|
||||
error.get("message", "Tool execution failed"),
|
||||
server_name=server_name,
|
||||
tool_name=tool_name,
|
||||
tool_args=arguments,
|
||||
error_code=str(error.get("code", "UNKNOWN")),
|
||||
)
|
||||
|
||||
return response.get("result")
|
||||
|
||||
async def route_tool(
|
||||
self,
|
||||
tool_name: str,
|
||||
arguments: dict[str, Any] | None = None,
|
||||
timeout: float | None = None,
|
||||
) -> ToolResult:
|
||||
"""
|
||||
Route a tool call to the appropriate server.
|
||||
|
||||
Automatically discovers which server provides the tool.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool to call
|
||||
arguments: Tool arguments
|
||||
timeout: Optional timeout override
|
||||
|
||||
Returns:
|
||||
Tool execution result
|
||||
|
||||
Raises:
|
||||
MCPToolNotFoundError: If no server provides the tool
|
||||
"""
|
||||
server_name = self.find_server_for_tool(tool_name)
|
||||
|
||||
if server_name is None:
|
||||
# Try to find from registry
|
||||
server_name = self._registry.find_server_for_tool(tool_name)
|
||||
|
||||
if server_name is None:
|
||||
raise MCPToolNotFoundError(
|
||||
tool_name=tool_name,
|
||||
available_tools=list(self._tool_to_server.keys()),
|
||||
)
|
||||
|
||||
return await self.call_tool(
|
||||
server_name=server_name,
|
||||
tool_name=tool_name,
|
||||
arguments=arguments,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
async def list_all_tools(self) -> list[ToolInfo]:
|
||||
"""
|
||||
Get all available tools from all servers.
|
||||
|
||||
Returns:
|
||||
List of tool information
|
||||
"""
|
||||
tools = []
|
||||
all_server_tools = self._registry.get_all_tools()
|
||||
|
||||
for server_name, server_tools in all_server_tools.items():
|
||||
for tool_data in server_tools:
|
||||
tools.append(
|
||||
ToolInfo(
|
||||
name=tool_data.get("name", ""),
|
||||
description=tool_data.get("description"),
|
||||
server_name=server_name,
|
||||
input_schema=tool_data.get("input_schema"),
|
||||
)
|
||||
)
|
||||
|
||||
return tools
|
||||
|
||||
def get_circuit_breaker_status(self) -> dict[str, dict[str, Any]]:
|
||||
"""Get status of all circuit breakers."""
|
||||
return {
|
||||
name: {
|
||||
"state": cb.current_state,
|
||||
"failure_count": cb.fail_counter,
|
||||
}
|
||||
for name, cb in self._circuit_breakers.items()
|
||||
}
|
||||
|
||||
async def reset_circuit_breaker(self, server_name: str) -> bool:
|
||||
"""
|
||||
Manually reset a circuit breaker.
|
||||
|
||||
Args:
|
||||
server_name: Name of the server
|
||||
|
||||
Returns:
|
||||
True if circuit breaker was reset
|
||||
"""
|
||||
async with self._lock:
|
||||
if server_name in self._circuit_breakers:
|
||||
# Reset by removing (will be recreated on next call)
|
||||
del self._circuit_breakers[server_name]
|
||||
logger.info("Reset circuit breaker for %s", server_name)
|
||||
return True
|
||||
return False
|
||||
141
backend/app/services/memory/__init__.py
Normal file
141
backend/app/services/memory/__init__.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""
|
||||
Agent Memory System
|
||||
|
||||
Multi-tier cognitive memory for AI agents, providing:
|
||||
- Working Memory: Session-scoped ephemeral state (Redis/In-memory)
|
||||
- Episodic Memory: Experiential records of past tasks (PostgreSQL)
|
||||
- Semantic Memory: Learned facts and knowledge (PostgreSQL + pgvector)
|
||||
- Procedural Memory: Learned skills and procedures (PostgreSQL)
|
||||
|
||||
Usage:
|
||||
from app.services.memory import (
|
||||
MemoryManager,
|
||||
MemorySettings,
|
||||
get_memory_settings,
|
||||
MemoryType,
|
||||
ScopeLevel,
|
||||
)
|
||||
|
||||
# Create a manager for a session
|
||||
manager = MemoryManager.for_session(
|
||||
session_id="sess-123",
|
||||
project_id=uuid,
|
||||
)
|
||||
|
||||
async with manager:
|
||||
# Working memory
|
||||
await manager.set_working("key", {"data": "value"})
|
||||
value = await manager.get_working("key")
|
||||
|
||||
# Episodic memory
|
||||
episode = await manager.record_episode(episode_data)
|
||||
similar = await manager.search_episodes("query")
|
||||
|
||||
# Semantic memory
|
||||
fact = await manager.store_fact(fact_data)
|
||||
facts = await manager.search_facts("query")
|
||||
|
||||
# Procedural memory
|
||||
procedure = await manager.record_procedure(procedure_data)
|
||||
procedures = await manager.find_procedures("context")
|
||||
"""
|
||||
|
||||
# Configuration
|
||||
from .config import (
|
||||
MemorySettings,
|
||||
get_default_settings,
|
||||
get_memory_settings,
|
||||
reset_memory_settings,
|
||||
)
|
||||
|
||||
# Exceptions
|
||||
from .exceptions import (
|
||||
CheckpointError,
|
||||
EmbeddingError,
|
||||
MemoryCapacityError,
|
||||
MemoryConflictError,
|
||||
MemoryConsolidationError,
|
||||
MemoryError,
|
||||
MemoryExpiredError,
|
||||
MemoryNotFoundError,
|
||||
MemoryRetrievalError,
|
||||
MemoryScopeError,
|
||||
MemorySerializationError,
|
||||
MemoryStorageError,
|
||||
)
|
||||
|
||||
# Manager
|
||||
from .manager import MemoryManager
|
||||
|
||||
# Types
|
||||
from .types import (
|
||||
ConsolidationStatus,
|
||||
ConsolidationType,
|
||||
Episode,
|
||||
EpisodeCreate,
|
||||
Fact,
|
||||
FactCreate,
|
||||
MemoryItem,
|
||||
MemoryStats,
|
||||
MemoryStore,
|
||||
MemoryType,
|
||||
Outcome,
|
||||
Procedure,
|
||||
ProcedureCreate,
|
||||
RetrievalResult,
|
||||
ScopeContext,
|
||||
ScopeLevel,
|
||||
Step,
|
||||
TaskState,
|
||||
WorkingMemoryItem,
|
||||
)
|
||||
|
||||
# Reflection (lazy import available)
|
||||
# Import directly: from app.services.memory.reflection import MemoryReflection
|
||||
|
||||
__all__ = [
|
||||
"CheckpointError",
|
||||
"ConsolidationStatus",
|
||||
"ConsolidationType",
|
||||
"EmbeddingError",
|
||||
"Episode",
|
||||
"EpisodeCreate",
|
||||
"Fact",
|
||||
"FactCreate",
|
||||
"MemoryCapacityError",
|
||||
"MemoryConflictError",
|
||||
"MemoryConsolidationError",
|
||||
# Exceptions
|
||||
"MemoryError",
|
||||
"MemoryExpiredError",
|
||||
"MemoryItem",
|
||||
# Manager
|
||||
"MemoryManager",
|
||||
"MemoryNotFoundError",
|
||||
"MemoryRetrievalError",
|
||||
"MemoryScopeError",
|
||||
"MemorySerializationError",
|
||||
# Configuration
|
||||
"MemorySettings",
|
||||
"MemoryStats",
|
||||
"MemoryStorageError",
|
||||
# Types - Abstract
|
||||
"MemoryStore",
|
||||
# Types - Enums
|
||||
"MemoryType",
|
||||
"Outcome",
|
||||
"Procedure",
|
||||
"ProcedureCreate",
|
||||
"RetrievalResult",
|
||||
# Types - Data Classes
|
||||
"ScopeContext",
|
||||
"ScopeLevel",
|
||||
"Step",
|
||||
"TaskState",
|
||||
"WorkingMemoryItem",
|
||||
"get_default_settings",
|
||||
"get_memory_settings",
|
||||
"reset_memory_settings",
|
||||
# MCP Tools - lazy import to avoid circular dependencies
|
||||
# Import directly: from app.services.memory.mcp import MemoryToolService
|
||||
]
|
||||
21
backend/app/services/memory/cache/__init__.py
vendored
Normal file
21
backend/app/services/memory/cache/__init__.py
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
# app/services/memory/cache/__init__.py
|
||||
"""
|
||||
Memory Caching Layer.
|
||||
|
||||
Provides caching for memory operations:
|
||||
- Hot Memory Cache: LRU cache for frequently accessed memories
|
||||
- Embedding Cache: Cache embeddings by content hash
|
||||
- Cache Manager: Unified cache management with invalidation
|
||||
"""
|
||||
|
||||
from .cache_manager import CacheManager, CacheStats, get_cache_manager
|
||||
from .embedding_cache import EmbeddingCache
|
||||
from .hot_cache import HotMemoryCache
|
||||
|
||||
__all__ = [
|
||||
"CacheManager",
|
||||
"CacheStats",
|
||||
"EmbeddingCache",
|
||||
"HotMemoryCache",
|
||||
"get_cache_manager",
|
||||
]
|
||||
505
backend/app/services/memory/cache/cache_manager.py
vendored
Normal file
505
backend/app/services/memory/cache/cache_manager.py
vendored
Normal file
@@ -0,0 +1,505 @@
|
||||
# app/services/memory/cache/cache_manager.py
|
||||
"""
|
||||
Cache Manager.
|
||||
|
||||
Unified cache management for memory operations.
|
||||
Coordinates hot cache, embedding cache, and retrieval cache.
|
||||
Provides centralized invalidation and statistics.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from uuid import UUID
|
||||
|
||||
from app.services.memory.config import get_memory_settings
|
||||
|
||||
from .embedding_cache import EmbeddingCache, create_embedding_cache
|
||||
from .hot_cache import CacheKey, HotMemoryCache, create_hot_cache
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis.asyncio import Redis
|
||||
|
||||
from app.services.memory.indexing.retrieval import RetrievalCache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _utcnow() -> datetime:
|
||||
"""Get current UTC time as timezone-aware datetime."""
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheStats:
|
||||
"""Aggregated cache statistics."""
|
||||
|
||||
hot_cache: dict[str, Any] = field(default_factory=dict)
|
||||
embedding_cache: dict[str, Any] = field(default_factory=dict)
|
||||
retrieval_cache: dict[str, Any] = field(default_factory=dict)
|
||||
overall_hit_rate: float = 0.0
|
||||
last_cleanup: datetime | None = None
|
||||
cleanup_count: int = 0
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"hot_cache": self.hot_cache,
|
||||
"embedding_cache": self.embedding_cache,
|
||||
"retrieval_cache": self.retrieval_cache,
|
||||
"overall_hit_rate": self.overall_hit_rate,
|
||||
"last_cleanup": self.last_cleanup.isoformat()
|
||||
if self.last_cleanup
|
||||
else None,
|
||||
"cleanup_count": self.cleanup_count,
|
||||
}
|
||||
|
||||
|
||||
class CacheManager:
|
||||
"""
|
||||
Unified cache manager for memory operations.
|
||||
|
||||
Provides:
|
||||
- Centralized cache configuration
|
||||
- Coordinated invalidation across caches
|
||||
- Aggregated statistics
|
||||
- Automatic cleanup scheduling
|
||||
|
||||
Performance targets:
|
||||
- Overall cache hit rate > 80%
|
||||
- Cache operations < 1ms (memory), < 5ms (Redis)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hot_cache: HotMemoryCache[Any] | None = None,
|
||||
embedding_cache: EmbeddingCache | None = None,
|
||||
retrieval_cache: "RetrievalCache | None" = None,
|
||||
redis: "Redis | None" = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the cache manager.
|
||||
|
||||
Args:
|
||||
hot_cache: Optional pre-configured hot cache
|
||||
embedding_cache: Optional pre-configured embedding cache
|
||||
retrieval_cache: Optional pre-configured retrieval cache
|
||||
redis: Optional Redis connection for persistence
|
||||
"""
|
||||
self._settings = get_memory_settings()
|
||||
self._redis = redis
|
||||
self._enabled = self._settings.cache_enabled
|
||||
|
||||
# Initialize caches
|
||||
if hot_cache:
|
||||
self._hot_cache = hot_cache
|
||||
else:
|
||||
self._hot_cache = create_hot_cache(
|
||||
max_size=self._settings.cache_max_items,
|
||||
default_ttl_seconds=self._settings.cache_ttl_seconds,
|
||||
)
|
||||
|
||||
if embedding_cache:
|
||||
self._embedding_cache = embedding_cache
|
||||
else:
|
||||
self._embedding_cache = create_embedding_cache(
|
||||
max_size=self._settings.cache_max_items,
|
||||
default_ttl_seconds=self._settings.cache_ttl_seconds
|
||||
* 12, # 1hr for embeddings
|
||||
redis=redis,
|
||||
)
|
||||
|
||||
self._retrieval_cache = retrieval_cache
|
||||
|
||||
# Stats tracking
|
||||
self._last_cleanup: datetime | None = None
|
||||
self._cleanup_count = 0
|
||||
self._lock = threading.RLock()
|
||||
|
||||
logger.info(
|
||||
f"Initialized CacheManager: enabled={self._enabled}, "
|
||||
f"redis={'connected' if redis else 'disabled'}"
|
||||
)
|
||||
|
||||
def set_redis(self, redis: "Redis") -> None:
|
||||
"""Set Redis connection for all caches."""
|
||||
self._redis = redis
|
||||
self._embedding_cache.set_redis(redis)
|
||||
|
||||
def set_retrieval_cache(self, cache: "RetrievalCache") -> None:
|
||||
"""Set retrieval cache instance."""
|
||||
self._retrieval_cache = cache
|
||||
|
||||
@property
|
||||
def is_enabled(self) -> bool:
|
||||
"""Check if caching is enabled."""
|
||||
return self._enabled
|
||||
|
||||
@property
|
||||
def hot_cache(self) -> HotMemoryCache[Any]:
|
||||
"""Get the hot memory cache."""
|
||||
return self._hot_cache
|
||||
|
||||
@property
|
||||
def embedding_cache(self) -> EmbeddingCache:
|
||||
"""Get the embedding cache."""
|
||||
return self._embedding_cache
|
||||
|
||||
@property
|
||||
def retrieval_cache(self) -> "RetrievalCache | None":
|
||||
"""Get the retrieval cache."""
|
||||
return self._retrieval_cache
|
||||
|
||||
# =========================================================================
|
||||
# Hot Memory Cache Operations
|
||||
# =========================================================================
|
||||
|
||||
def get_memory(
|
||||
self,
|
||||
memory_type: str,
|
||||
memory_id: UUID | str,
|
||||
scope: str | None = None,
|
||||
) -> Any | None:
|
||||
"""
|
||||
Get a memory from hot cache.
|
||||
|
||||
Args:
|
||||
memory_type: Type of memory
|
||||
memory_id: Memory ID
|
||||
scope: Optional scope
|
||||
|
||||
Returns:
|
||||
Cached memory or None
|
||||
"""
|
||||
if not self._enabled:
|
||||
return None
|
||||
return self._hot_cache.get_by_id(memory_type, memory_id, scope)
|
||||
|
||||
def cache_memory(
|
||||
self,
|
||||
memory_type: str,
|
||||
memory_id: UUID | str,
|
||||
memory: Any,
|
||||
scope: str | None = None,
|
||||
ttl_seconds: float | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Cache a memory in hot cache.
|
||||
|
||||
Args:
|
||||
memory_type: Type of memory
|
||||
memory_id: Memory ID
|
||||
memory: Memory object
|
||||
scope: Optional scope
|
||||
ttl_seconds: Optional TTL override
|
||||
"""
|
||||
if not self._enabled:
|
||||
return
|
||||
self._hot_cache.put_by_id(memory_type, memory_id, memory, scope, ttl_seconds)
|
||||
|
||||
# =========================================================================
|
||||
# Embedding Cache Operations
|
||||
# =========================================================================
|
||||
|
||||
async def get_embedding(
|
||||
self,
|
||||
content: str,
|
||||
model: str = "default",
|
||||
) -> list[float] | None:
|
||||
"""
|
||||
Get a cached embedding.
|
||||
|
||||
Args:
|
||||
content: Content text
|
||||
model: Model name
|
||||
|
||||
Returns:
|
||||
Cached embedding or None
|
||||
"""
|
||||
if not self._enabled:
|
||||
return None
|
||||
return await self._embedding_cache.get(content, model)
|
||||
|
||||
async def cache_embedding(
|
||||
self,
|
||||
content: str,
|
||||
embedding: list[float],
|
||||
model: str = "default",
|
||||
ttl_seconds: float | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Cache an embedding.
|
||||
|
||||
Args:
|
||||
content: Content text
|
||||
embedding: Embedding vector
|
||||
model: Model name
|
||||
ttl_seconds: Optional TTL override
|
||||
|
||||
Returns:
|
||||
Content hash
|
||||
"""
|
||||
if not self._enabled:
|
||||
return EmbeddingCache.hash_content(content)
|
||||
return await self._embedding_cache.put(content, embedding, model, ttl_seconds)
|
||||
|
||||
# =========================================================================
|
||||
# Invalidation
|
||||
# =========================================================================
|
||||
|
||||
async def invalidate_memory(
|
||||
self,
|
||||
memory_type: str,
|
||||
memory_id: UUID | str,
|
||||
scope: str | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Invalidate a memory across all caches.
|
||||
|
||||
Args:
|
||||
memory_type: Type of memory
|
||||
memory_id: Memory ID
|
||||
scope: Optional scope
|
||||
|
||||
Returns:
|
||||
Number of entries invalidated
|
||||
"""
|
||||
count = 0
|
||||
|
||||
# Invalidate hot cache
|
||||
if self._hot_cache.invalidate_by_id(memory_type, memory_id, scope):
|
||||
count += 1
|
||||
|
||||
# Invalidate retrieval cache
|
||||
if self._retrieval_cache:
|
||||
uuid_id = (
|
||||
UUID(str(memory_id)) if not isinstance(memory_id, UUID) else memory_id
|
||||
)
|
||||
count += self._retrieval_cache.invalidate_by_memory(uuid_id)
|
||||
|
||||
logger.debug(f"Invalidated {count} cache entries for {memory_type}:{memory_id}")
|
||||
return count
|
||||
|
||||
async def invalidate_by_type(self, memory_type: str) -> int:
|
||||
"""
|
||||
Invalidate all entries of a memory type.
|
||||
|
||||
Args:
|
||||
memory_type: Type of memory
|
||||
|
||||
Returns:
|
||||
Number of entries invalidated
|
||||
"""
|
||||
count = self._hot_cache.invalidate_by_type(memory_type)
|
||||
|
||||
if self._retrieval_cache:
|
||||
count += self._retrieval_cache.clear()
|
||||
|
||||
logger.info(f"Invalidated {count} cache entries for type {memory_type}")
|
||||
return count
|
||||
|
||||
async def invalidate_by_scope(self, scope: str) -> int:
|
||||
"""
|
||||
Invalidate all entries in a scope.
|
||||
|
||||
Args:
|
||||
scope: Scope to invalidate (e.g., project_id)
|
||||
|
||||
Returns:
|
||||
Number of entries invalidated
|
||||
"""
|
||||
count = self._hot_cache.invalidate_by_scope(scope)
|
||||
|
||||
# Retrieval cache doesn't support scope-based invalidation
|
||||
# so we clear it entirely for safety
|
||||
if self._retrieval_cache:
|
||||
count += self._retrieval_cache.clear()
|
||||
|
||||
logger.info(f"Invalidated {count} cache entries for scope {scope}")
|
||||
return count
|
||||
|
||||
async def invalidate_embedding(
|
||||
self,
|
||||
content: str,
|
||||
model: str = "default",
|
||||
) -> bool:
|
||||
"""
|
||||
Invalidate a cached embedding.
|
||||
|
||||
Args:
|
||||
content: Content text
|
||||
model: Model name
|
||||
|
||||
Returns:
|
||||
True if entry was found and removed
|
||||
"""
|
||||
return await self._embedding_cache.invalidate(content, model)
|
||||
|
||||
async def clear_all(self) -> int:
|
||||
"""
|
||||
Clear all caches.
|
||||
|
||||
Returns:
|
||||
Total number of entries cleared
|
||||
"""
|
||||
count = 0
|
||||
|
||||
count += self._hot_cache.clear()
|
||||
count += await self._embedding_cache.clear()
|
||||
|
||||
if self._retrieval_cache:
|
||||
count += self._retrieval_cache.clear()
|
||||
|
||||
logger.info(f"Cleared {count} entries from all caches")
|
||||
return count
|
||||
|
||||
# =========================================================================
|
||||
# Cleanup
|
||||
# =========================================================================
|
||||
|
||||
async def cleanup_expired(self) -> int:
|
||||
"""
|
||||
Clean up expired entries from all caches.
|
||||
|
||||
Returns:
|
||||
Number of entries cleaned up
|
||||
"""
|
||||
with self._lock:
|
||||
count = 0
|
||||
|
||||
count += self._hot_cache.cleanup_expired()
|
||||
count += self._embedding_cache.cleanup_expired()
|
||||
|
||||
# Retrieval cache doesn't have a cleanup method,
|
||||
# but entries expire on access
|
||||
|
||||
self._last_cleanup = _utcnow()
|
||||
self._cleanup_count += 1
|
||||
|
||||
if count > 0:
|
||||
logger.info(f"Cleaned up {count} expired cache entries")
|
||||
|
||||
return count
|
||||
|
||||
# =========================================================================
|
||||
# Statistics
|
||||
# =========================================================================
|
||||
|
||||
def get_stats(self) -> CacheStats:
|
||||
"""
|
||||
Get aggregated cache statistics.
|
||||
|
||||
Returns:
|
||||
CacheStats with all cache metrics
|
||||
"""
|
||||
hot_stats = self._hot_cache.get_stats().to_dict()
|
||||
emb_stats = self._embedding_cache.get_stats().to_dict()
|
||||
|
||||
retrieval_stats: dict[str, Any] = {}
|
||||
if self._retrieval_cache:
|
||||
retrieval_stats = self._retrieval_cache.get_stats()
|
||||
|
||||
# Calculate overall hit rate
|
||||
total_hits = hot_stats.get("hits", 0) + emb_stats.get("hits", 0)
|
||||
total_misses = hot_stats.get("misses", 0) + emb_stats.get("misses", 0)
|
||||
|
||||
if retrieval_stats:
|
||||
# Retrieval cache doesn't track hits/misses the same way
|
||||
pass
|
||||
|
||||
total_requests = total_hits + total_misses
|
||||
overall_hit_rate = total_hits / total_requests if total_requests > 0 else 0.0
|
||||
|
||||
return CacheStats(
|
||||
hot_cache=hot_stats,
|
||||
embedding_cache=emb_stats,
|
||||
retrieval_cache=retrieval_stats,
|
||||
overall_hit_rate=overall_hit_rate,
|
||||
last_cleanup=self._last_cleanup,
|
||||
cleanup_count=self._cleanup_count,
|
||||
)
|
||||
|
||||
def get_hot_memories(self, limit: int = 10) -> list[tuple[CacheKey, int]]:
|
||||
"""
|
||||
Get the most frequently accessed memories.
|
||||
|
||||
Args:
|
||||
limit: Maximum number to return
|
||||
|
||||
Returns:
|
||||
List of (key, access_count) tuples
|
||||
"""
|
||||
return self._hot_cache.get_hot_memories(limit)
|
||||
|
||||
def reset_stats(self) -> None:
|
||||
"""Reset all cache statistics."""
|
||||
self._hot_cache.reset_stats()
|
||||
self._embedding_cache.reset_stats()
|
||||
|
||||
# =========================================================================
|
||||
# Warmup
|
||||
# =========================================================================
|
||||
|
||||
async def warmup(
|
||||
self,
|
||||
memories: list[tuple[str, UUID | str, Any]],
|
||||
scope: str | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Warm up the hot cache with memories.
|
||||
|
||||
Args:
|
||||
memories: List of (memory_type, memory_id, memory) tuples
|
||||
scope: Optional scope for all memories
|
||||
|
||||
Returns:
|
||||
Number of memories cached
|
||||
"""
|
||||
if not self._enabled:
|
||||
return 0
|
||||
|
||||
for memory_type, memory_id, memory in memories:
|
||||
self._hot_cache.put_by_id(memory_type, memory_id, memory, scope)
|
||||
|
||||
logger.info(f"Warmed up cache with {len(memories)} memories")
|
||||
return len(memories)
|
||||
|
||||
|
||||
# Singleton instance
|
||||
_cache_manager: CacheManager | None = None
|
||||
_cache_manager_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_cache_manager(
|
||||
redis: "Redis | None" = None,
|
||||
reset: bool = False,
|
||||
) -> CacheManager:
|
||||
"""
|
||||
Get the global CacheManager instance.
|
||||
|
||||
Thread-safe with double-checked locking pattern.
|
||||
|
||||
Args:
|
||||
redis: Optional Redis connection
|
||||
reset: Force create a new instance
|
||||
|
||||
Returns:
|
||||
CacheManager instance
|
||||
"""
|
||||
global _cache_manager
|
||||
|
||||
if reset or _cache_manager is None:
|
||||
with _cache_manager_lock:
|
||||
if reset or _cache_manager is None:
|
||||
_cache_manager = CacheManager(redis=redis)
|
||||
|
||||
return _cache_manager
|
||||
|
||||
|
||||
def reset_cache_manager() -> None:
|
||||
"""Reset the global cache manager instance."""
|
||||
global _cache_manager
|
||||
with _cache_manager_lock:
|
||||
_cache_manager = None
|
||||
623
backend/app/services/memory/cache/embedding_cache.py
vendored
Normal file
623
backend/app/services/memory/cache/embedding_cache.py
vendored
Normal file
@@ -0,0 +1,623 @@
|
||||
# app/services/memory/cache/embedding_cache.py
|
||||
"""
|
||||
Embedding Cache.
|
||||
|
||||
Caches embeddings by content hash to avoid recomputing.
|
||||
Provides significant performance improvement for repeated content.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import threading
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis.asyncio import Redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _utcnow() -> datetime:
|
||||
"""Get current UTC time as timezone-aware datetime."""
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingEntry:
|
||||
"""A cached embedding entry."""
|
||||
|
||||
embedding: list[float]
|
||||
content_hash: str
|
||||
model: str
|
||||
created_at: datetime
|
||||
ttl_seconds: float = 3600.0 # 1 hour default
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if this entry has expired."""
|
||||
age = (_utcnow() - self.created_at).total_seconds()
|
||||
return age > self.ttl_seconds
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingCacheStats:
|
||||
"""Statistics for the embedding cache."""
|
||||
|
||||
hits: int = 0
|
||||
misses: int = 0
|
||||
evictions: int = 0
|
||||
expirations: int = 0
|
||||
current_size: int = 0
|
||||
max_size: int = 0
|
||||
bytes_saved: int = 0 # Estimated bytes saved by caching
|
||||
|
||||
@property
|
||||
def hit_rate(self) -> float:
|
||||
"""Calculate cache hit rate."""
|
||||
total = self.hits + self.misses
|
||||
if total == 0:
|
||||
return 0.0
|
||||
return self.hits / total
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"hits": self.hits,
|
||||
"misses": self.misses,
|
||||
"evictions": self.evictions,
|
||||
"expirations": self.expirations,
|
||||
"current_size": self.current_size,
|
||||
"max_size": self.max_size,
|
||||
"hit_rate": self.hit_rate,
|
||||
"bytes_saved": self.bytes_saved,
|
||||
}
|
||||
|
||||
|
||||
class EmbeddingCache:
|
||||
"""
|
||||
Cache for embeddings by content hash.
|
||||
|
||||
Features:
|
||||
- Content-hash based deduplication
|
||||
- LRU eviction
|
||||
- TTL-based expiration
|
||||
- Optional Redis backing for persistence
|
||||
- Thread-safe operations
|
||||
|
||||
Performance targets:
|
||||
- Cache hit rate > 90% for repeated content
|
||||
- Get/put operations < 1ms (memory), < 5ms (Redis)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_size: int = 50000,
|
||||
default_ttl_seconds: float = 3600.0,
|
||||
redis: "Redis | None" = None,
|
||||
redis_prefix: str = "mem:emb",
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the embedding cache.
|
||||
|
||||
Args:
|
||||
max_size: Maximum number of entries in memory cache
|
||||
default_ttl_seconds: Default TTL for entries (1 hour)
|
||||
redis: Optional Redis connection for persistence
|
||||
redis_prefix: Prefix for Redis keys
|
||||
"""
|
||||
self._max_size = max_size
|
||||
self._default_ttl = default_ttl_seconds
|
||||
self._cache: OrderedDict[str, EmbeddingEntry] = OrderedDict()
|
||||
self._lock = threading.RLock()
|
||||
self._stats = EmbeddingCacheStats(max_size=max_size)
|
||||
self._redis = redis
|
||||
self._redis_prefix = redis_prefix
|
||||
|
||||
logger.info(
|
||||
f"Initialized EmbeddingCache with max_size={max_size}, "
|
||||
f"ttl={default_ttl_seconds}s, redis={'enabled' if redis else 'disabled'}"
|
||||
)
|
||||
|
||||
def set_redis(self, redis: "Redis") -> None:
|
||||
"""Set Redis connection for persistence."""
|
||||
self._redis = redis
|
||||
|
||||
@staticmethod
|
||||
def hash_content(content: str) -> str:
|
||||
"""
|
||||
Compute hash of content for cache key.
|
||||
|
||||
Args:
|
||||
content: Content to hash
|
||||
|
||||
Returns:
|
||||
32-character hex hash
|
||||
"""
|
||||
return hashlib.sha256(content.encode()).hexdigest()[:32]
|
||||
|
||||
def _cache_key(self, content_hash: str, model: str) -> str:
|
||||
"""Build cache key from content hash and model."""
|
||||
return f"{content_hash}:{model}"
|
||||
|
||||
def _redis_key(self, content_hash: str, model: str) -> str:
|
||||
"""Build Redis key from content hash and model."""
|
||||
return f"{self._redis_prefix}:{content_hash}:{model}"
|
||||
|
||||
async def get(
|
||||
self,
|
||||
content: str,
|
||||
model: str = "default",
|
||||
) -> list[float] | None:
|
||||
"""
|
||||
Get a cached embedding.
|
||||
|
||||
Args:
|
||||
content: Content text
|
||||
model: Model name
|
||||
|
||||
Returns:
|
||||
Cached embedding or None if not found/expired
|
||||
"""
|
||||
content_hash = self.hash_content(content)
|
||||
cache_key = self._cache_key(content_hash, model)
|
||||
|
||||
# Check memory cache first
|
||||
with self._lock:
|
||||
if cache_key in self._cache:
|
||||
entry = self._cache[cache_key]
|
||||
if entry.is_expired():
|
||||
del self._cache[cache_key]
|
||||
self._stats.expirations += 1
|
||||
self._stats.current_size = len(self._cache)
|
||||
else:
|
||||
# Move to end (most recently used)
|
||||
self._cache.move_to_end(cache_key)
|
||||
self._stats.hits += 1
|
||||
return entry.embedding
|
||||
|
||||
# Check Redis if available
|
||||
if self._redis:
|
||||
try:
|
||||
redis_key = self._redis_key(content_hash, model)
|
||||
data = await self._redis.get(redis_key)
|
||||
if data:
|
||||
import json
|
||||
|
||||
embedding = json.loads(data)
|
||||
# Store in memory cache for faster access
|
||||
self._put_memory(content_hash, model, embedding)
|
||||
self._stats.hits += 1
|
||||
return embedding
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis get error: {e}")
|
||||
|
||||
self._stats.misses += 1
|
||||
return None
|
||||
|
||||
async def get_by_hash(
|
||||
self,
|
||||
content_hash: str,
|
||||
model: str = "default",
|
||||
) -> list[float] | None:
|
||||
"""
|
||||
Get a cached embedding by hash.
|
||||
|
||||
Args:
|
||||
content_hash: Content hash
|
||||
model: Model name
|
||||
|
||||
Returns:
|
||||
Cached embedding or None if not found/expired
|
||||
"""
|
||||
cache_key = self._cache_key(content_hash, model)
|
||||
|
||||
with self._lock:
|
||||
if cache_key in self._cache:
|
||||
entry = self._cache[cache_key]
|
||||
if entry.is_expired():
|
||||
del self._cache[cache_key]
|
||||
self._stats.expirations += 1
|
||||
self._stats.current_size = len(self._cache)
|
||||
else:
|
||||
self._cache.move_to_end(cache_key)
|
||||
self._stats.hits += 1
|
||||
return entry.embedding
|
||||
|
||||
# Check Redis
|
||||
if self._redis:
|
||||
try:
|
||||
redis_key = self._redis_key(content_hash, model)
|
||||
data = await self._redis.get(redis_key)
|
||||
if data:
|
||||
import json
|
||||
|
||||
embedding = json.loads(data)
|
||||
self._put_memory(content_hash, model, embedding)
|
||||
self._stats.hits += 1
|
||||
return embedding
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis get error: {e}")
|
||||
|
||||
self._stats.misses += 1
|
||||
return None
|
||||
|
||||
async def put(
|
||||
self,
|
||||
content: str,
|
||||
embedding: list[float],
|
||||
model: str = "default",
|
||||
ttl_seconds: float | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Cache an embedding.
|
||||
|
||||
Args:
|
||||
content: Content text
|
||||
embedding: Embedding vector
|
||||
model: Model name
|
||||
ttl_seconds: Optional TTL override
|
||||
|
||||
Returns:
|
||||
Content hash
|
||||
"""
|
||||
content_hash = self.hash_content(content)
|
||||
ttl = ttl_seconds or self._default_ttl
|
||||
|
||||
# Store in memory
|
||||
self._put_memory(content_hash, model, embedding, ttl)
|
||||
|
||||
# Store in Redis if available
|
||||
if self._redis:
|
||||
try:
|
||||
import json
|
||||
|
||||
redis_key = self._redis_key(content_hash, model)
|
||||
await self._redis.setex(
|
||||
redis_key,
|
||||
int(ttl),
|
||||
json.dumps(embedding),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis put error: {e}")
|
||||
|
||||
return content_hash
|
||||
|
||||
def _put_memory(
|
||||
self,
|
||||
content_hash: str,
|
||||
model: str,
|
||||
embedding: list[float],
|
||||
ttl_seconds: float | None = None,
|
||||
) -> None:
|
||||
"""Store in memory cache."""
|
||||
with self._lock:
|
||||
# Evict if at capacity
|
||||
self._evict_if_needed()
|
||||
|
||||
cache_key = self._cache_key(content_hash, model)
|
||||
entry = EmbeddingEntry(
|
||||
embedding=embedding,
|
||||
content_hash=content_hash,
|
||||
model=model,
|
||||
created_at=_utcnow(),
|
||||
ttl_seconds=ttl_seconds or self._default_ttl,
|
||||
)
|
||||
|
||||
self._cache[cache_key] = entry
|
||||
self._cache.move_to_end(cache_key)
|
||||
self._stats.current_size = len(self._cache)
|
||||
|
||||
def _evict_if_needed(self) -> None:
|
||||
"""Evict entries if cache is at capacity."""
|
||||
while len(self._cache) >= self._max_size:
|
||||
if self._cache:
|
||||
self._cache.popitem(last=False)
|
||||
self._stats.evictions += 1
|
||||
|
||||
async def put_batch(
|
||||
self,
|
||||
items: list[tuple[str, list[float]]],
|
||||
model: str = "default",
|
||||
ttl_seconds: float | None = None,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Cache multiple embeddings.
|
||||
|
||||
Args:
|
||||
items: List of (content, embedding) tuples
|
||||
model: Model name
|
||||
ttl_seconds: Optional TTL override
|
||||
|
||||
Returns:
|
||||
List of content hashes
|
||||
"""
|
||||
hashes = []
|
||||
for content, embedding in items:
|
||||
content_hash = await self.put(content, embedding, model, ttl_seconds)
|
||||
hashes.append(content_hash)
|
||||
return hashes
|
||||
|
||||
async def invalidate(
|
||||
self,
|
||||
content: str,
|
||||
model: str = "default",
|
||||
) -> bool:
|
||||
"""
|
||||
Invalidate a cached embedding.
|
||||
|
||||
Args:
|
||||
content: Content text
|
||||
model: Model name
|
||||
|
||||
Returns:
|
||||
True if entry was found and removed
|
||||
"""
|
||||
content_hash = self.hash_content(content)
|
||||
return await self.invalidate_by_hash(content_hash, model)
|
||||
|
||||
async def invalidate_by_hash(
|
||||
self,
|
||||
content_hash: str,
|
||||
model: str = "default",
|
||||
) -> bool:
|
||||
"""
|
||||
Invalidate a cached embedding by hash.
|
||||
|
||||
Args:
|
||||
content_hash: Content hash
|
||||
model: Model name
|
||||
|
||||
Returns:
|
||||
True if entry was found and removed
|
||||
"""
|
||||
cache_key = self._cache_key(content_hash, model)
|
||||
removed = False
|
||||
|
||||
with self._lock:
|
||||
if cache_key in self._cache:
|
||||
del self._cache[cache_key]
|
||||
self._stats.current_size = len(self._cache)
|
||||
removed = True
|
||||
|
||||
# Remove from Redis
|
||||
if self._redis:
|
||||
try:
|
||||
redis_key = self._redis_key(content_hash, model)
|
||||
await self._redis.delete(redis_key)
|
||||
removed = True
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis delete error: {e}")
|
||||
|
||||
return removed
|
||||
|
||||
async def invalidate_by_model(self, model: str) -> int:
|
||||
"""
|
||||
Invalidate all embeddings for a model.
|
||||
|
||||
Args:
|
||||
model: Model name
|
||||
|
||||
Returns:
|
||||
Number of entries invalidated
|
||||
"""
|
||||
count = 0
|
||||
|
||||
with self._lock:
|
||||
keys_to_remove = [k for k, v in self._cache.items() if v.model == model]
|
||||
for key in keys_to_remove:
|
||||
del self._cache[key]
|
||||
count += 1
|
||||
|
||||
self._stats.current_size = len(self._cache)
|
||||
|
||||
# Note: Redis pattern deletion would require SCAN which is expensive
|
||||
# For now, we only clear memory cache for model-based invalidation
|
||||
|
||||
return count
|
||||
|
||||
async def clear(self) -> int:
|
||||
"""
|
||||
Clear all cache entries.
|
||||
|
||||
Returns:
|
||||
Number of entries cleared
|
||||
"""
|
||||
with self._lock:
|
||||
count = len(self._cache)
|
||||
self._cache.clear()
|
||||
self._stats.current_size = 0
|
||||
|
||||
# Clear Redis entries
|
||||
if self._redis:
|
||||
try:
|
||||
pattern = f"{self._redis_prefix}:*"
|
||||
deleted = 0
|
||||
async for key in self._redis.scan_iter(match=pattern):
|
||||
await self._redis.delete(key)
|
||||
deleted += 1
|
||||
count = max(count, deleted)
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis clear error: {e}")
|
||||
|
||||
logger.info(f"Cleared {count} entries from embedding cache")
|
||||
return count
|
||||
|
||||
def cleanup_expired(self) -> int:
|
||||
"""
|
||||
Remove all expired entries from memory cache.
|
||||
|
||||
Returns:
|
||||
Number of entries removed
|
||||
"""
|
||||
with self._lock:
|
||||
keys_to_remove = [k for k, v in self._cache.items() if v.is_expired()]
|
||||
for key in keys_to_remove:
|
||||
del self._cache[key]
|
||||
self._stats.expirations += 1
|
||||
|
||||
self._stats.current_size = len(self._cache)
|
||||
|
||||
if keys_to_remove:
|
||||
logger.debug(f"Cleaned up {len(keys_to_remove)} expired embeddings")
|
||||
|
||||
return len(keys_to_remove)
|
||||
|
||||
def get_stats(self) -> EmbeddingCacheStats:
|
||||
"""Get cache statistics."""
|
||||
with self._lock:
|
||||
self._stats.current_size = len(self._cache)
|
||||
return self._stats
|
||||
|
||||
def reset_stats(self) -> None:
|
||||
"""Reset cache statistics."""
|
||||
with self._lock:
|
||||
self._stats = EmbeddingCacheStats(
|
||||
max_size=self._max_size,
|
||||
current_size=len(self._cache),
|
||||
)
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
"""Get current cache size."""
|
||||
return len(self._cache)
|
||||
|
||||
@property
|
||||
def max_size(self) -> int:
|
||||
"""Get maximum cache size."""
|
||||
return self._max_size
|
||||
|
||||
|
||||
class CachedEmbeddingGenerator:
|
||||
"""
|
||||
Wrapper for embedding generators with caching.
|
||||
|
||||
Wraps an embedding generator to cache results.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
generator: Any,
|
||||
cache: EmbeddingCache,
|
||||
model: str = "default",
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the cached embedding generator.
|
||||
|
||||
Args:
|
||||
generator: Underlying embedding generator
|
||||
cache: Embedding cache
|
||||
model: Model name for cache keys
|
||||
"""
|
||||
self._generator = generator
|
||||
self._cache = cache
|
||||
self._model = model
|
||||
self._call_count = 0
|
||||
self._cache_hit_count = 0
|
||||
|
||||
async def generate(self, text: str) -> list[float]:
|
||||
"""
|
||||
Generate embedding with caching.
|
||||
|
||||
Args:
|
||||
text: Text to embed
|
||||
|
||||
Returns:
|
||||
Embedding vector
|
||||
"""
|
||||
self._call_count += 1
|
||||
|
||||
# Check cache first
|
||||
cached = await self._cache.get(text, self._model)
|
||||
if cached is not None:
|
||||
self._cache_hit_count += 1
|
||||
return cached
|
||||
|
||||
# Generate and cache
|
||||
embedding = await self._generator.generate(text)
|
||||
await self._cache.put(text, embedding, self._model)
|
||||
|
||||
return embedding
|
||||
|
||||
async def generate_batch(
|
||||
self,
|
||||
texts: list[str],
|
||||
) -> list[list[float]]:
|
||||
"""
|
||||
Generate embeddings for multiple texts with caching.
|
||||
|
||||
Args:
|
||||
texts: Texts to embed
|
||||
|
||||
Returns:
|
||||
List of embedding vectors
|
||||
"""
|
||||
results: list[list[float] | None] = [None] * len(texts)
|
||||
to_generate: list[tuple[int, str]] = []
|
||||
|
||||
# Check cache for each text
|
||||
for i, text in enumerate(texts):
|
||||
cached = await self._cache.get(text, self._model)
|
||||
if cached is not None:
|
||||
results[i] = cached
|
||||
self._cache_hit_count += 1
|
||||
else:
|
||||
to_generate.append((i, text))
|
||||
|
||||
self._call_count += len(texts)
|
||||
|
||||
# Generate missing embeddings
|
||||
if to_generate:
|
||||
if hasattr(self._generator, "generate_batch"):
|
||||
texts_to_gen = [t for _, t in to_generate]
|
||||
embeddings = await self._generator.generate_batch(texts_to_gen)
|
||||
|
||||
for (idx, text), embedding in zip(to_generate, embeddings, strict=True):
|
||||
results[idx] = embedding
|
||||
await self._cache.put(text, embedding, self._model)
|
||||
else:
|
||||
# Fallback to individual generation
|
||||
for idx, text in to_generate:
|
||||
embedding = await self._generator.generate(text)
|
||||
results[idx] = embedding
|
||||
await self._cache.put(text, embedding, self._model)
|
||||
|
||||
return results # type: ignore[return-value]
|
||||
|
||||
def get_stats(self) -> dict[str, Any]:
|
||||
"""Get generator statistics."""
|
||||
return {
|
||||
"call_count": self._call_count,
|
||||
"cache_hit_count": self._cache_hit_count,
|
||||
"cache_hit_rate": (
|
||||
self._cache_hit_count / self._call_count
|
||||
if self._call_count > 0
|
||||
else 0.0
|
||||
),
|
||||
"cache_stats": self._cache.get_stats().to_dict(),
|
||||
}
|
||||
|
||||
|
||||
# Factory function
|
||||
def create_embedding_cache(
|
||||
max_size: int = 50000,
|
||||
default_ttl_seconds: float = 3600.0,
|
||||
redis: "Redis | None" = None,
|
||||
) -> EmbeddingCache:
|
||||
"""
|
||||
Create an embedding cache.
|
||||
|
||||
Args:
|
||||
max_size: Maximum number of entries
|
||||
default_ttl_seconds: Default TTL for entries
|
||||
redis: Optional Redis connection
|
||||
|
||||
Returns:
|
||||
Configured EmbeddingCache instance
|
||||
"""
|
||||
return EmbeddingCache(
|
||||
max_size=max_size,
|
||||
default_ttl_seconds=default_ttl_seconds,
|
||||
redis=redis,
|
||||
)
|
||||
461
backend/app/services/memory/cache/hot_cache.py
vendored
Normal file
461
backend/app/services/memory/cache/hot_cache.py
vendored
Normal file
@@ -0,0 +1,461 @@
|
||||
# app/services/memory/cache/hot_cache.py
|
||||
"""
|
||||
Hot Memory Cache.
|
||||
|
||||
LRU cache for frequently accessed memories.
|
||||
Provides fast access to recently used memories without database queries.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _utcnow() -> datetime:
|
||||
"""Get current UTC time as timezone-aware datetime."""
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheEntry[T]:
|
||||
"""A cached memory entry with metadata."""
|
||||
|
||||
value: T
|
||||
created_at: datetime
|
||||
last_accessed_at: datetime
|
||||
access_count: int = 1
|
||||
ttl_seconds: float = 300.0
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if this entry has expired."""
|
||||
age = (_utcnow() - self.created_at).total_seconds()
|
||||
return age > self.ttl_seconds
|
||||
|
||||
def touch(self) -> None:
|
||||
"""Update access time and count."""
|
||||
self.last_accessed_at = _utcnow()
|
||||
self.access_count += 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheKey:
|
||||
"""A structured cache key with components."""
|
||||
|
||||
memory_type: str
|
||||
memory_id: str
|
||||
scope: str | None = None
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.memory_type, self.memory_id, self.scope))
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, CacheKey):
|
||||
return False
|
||||
return (
|
||||
self.memory_type == other.memory_type
|
||||
and self.memory_id == other.memory_id
|
||||
and self.scope == other.scope
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
if self.scope:
|
||||
return f"{self.memory_type}:{self.scope}:{self.memory_id}"
|
||||
return f"{self.memory_type}:{self.memory_id}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class HotCacheStats:
|
||||
"""Statistics for the hot memory cache."""
|
||||
|
||||
hits: int = 0
|
||||
misses: int = 0
|
||||
evictions: int = 0
|
||||
expirations: int = 0
|
||||
current_size: int = 0
|
||||
max_size: int = 0
|
||||
|
||||
@property
|
||||
def hit_rate(self) -> float:
|
||||
"""Calculate cache hit rate."""
|
||||
total = self.hits + self.misses
|
||||
if total == 0:
|
||||
return 0.0
|
||||
return self.hits / total
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"hits": self.hits,
|
||||
"misses": self.misses,
|
||||
"evictions": self.evictions,
|
||||
"expirations": self.expirations,
|
||||
"current_size": self.current_size,
|
||||
"max_size": self.max_size,
|
||||
"hit_rate": self.hit_rate,
|
||||
}
|
||||
|
||||
|
||||
class HotMemoryCache[T]:
|
||||
"""
|
||||
LRU cache for frequently accessed memories.
|
||||
|
||||
Features:
|
||||
- LRU eviction when capacity is reached
|
||||
- TTL-based expiration
|
||||
- Access count tracking for hot memory identification
|
||||
- Thread-safe operations
|
||||
- Scoped invalidation
|
||||
|
||||
Performance targets:
|
||||
- Cache hit rate > 80% for hot memories
|
||||
- Get/put operations < 1ms
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_size: int = 10000,
|
||||
default_ttl_seconds: float = 300.0,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the hot memory cache.
|
||||
|
||||
Args:
|
||||
max_size: Maximum number of entries
|
||||
default_ttl_seconds: Default TTL for entries (5 minutes)
|
||||
"""
|
||||
self._max_size = max_size
|
||||
self._default_ttl = default_ttl_seconds
|
||||
self._cache: OrderedDict[CacheKey, CacheEntry[T]] = OrderedDict()
|
||||
self._lock = threading.RLock()
|
||||
self._stats = HotCacheStats(max_size=max_size)
|
||||
logger.info(
|
||||
f"Initialized HotMemoryCache with max_size={max_size}, "
|
||||
f"ttl={default_ttl_seconds}s"
|
||||
)
|
||||
|
||||
def get(self, key: CacheKey) -> T | None:
|
||||
"""
|
||||
Get a memory from cache.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
|
||||
Returns:
|
||||
Cached value or None if not found/expired
|
||||
"""
|
||||
with self._lock:
|
||||
if key not in self._cache:
|
||||
self._stats.misses += 1
|
||||
return None
|
||||
|
||||
entry = self._cache[key]
|
||||
|
||||
# Check expiration
|
||||
if entry.is_expired():
|
||||
del self._cache[key]
|
||||
self._stats.expirations += 1
|
||||
self._stats.misses += 1
|
||||
self._stats.current_size = len(self._cache)
|
||||
return None
|
||||
|
||||
# Move to end (most recently used)
|
||||
self._cache.move_to_end(key)
|
||||
entry.touch()
|
||||
|
||||
self._stats.hits += 1
|
||||
return entry.value
|
||||
|
||||
def get_by_id(
|
||||
self,
|
||||
memory_type: str,
|
||||
memory_id: UUID | str,
|
||||
scope: str | None = None,
|
||||
) -> T | None:
|
||||
"""
|
||||
Get a memory by type and ID.
|
||||
|
||||
Args:
|
||||
memory_type: Type of memory (episodic, semantic, procedural)
|
||||
memory_id: Memory ID
|
||||
scope: Optional scope (project_id, agent_id)
|
||||
|
||||
Returns:
|
||||
Cached value or None if not found/expired
|
||||
"""
|
||||
key = CacheKey(
|
||||
memory_type=memory_type,
|
||||
memory_id=str(memory_id),
|
||||
scope=scope,
|
||||
)
|
||||
return self.get(key)
|
||||
|
||||
def put(
|
||||
self,
|
||||
key: CacheKey,
|
||||
value: T,
|
||||
ttl_seconds: float | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Put a memory into cache.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
value: Value to cache
|
||||
ttl_seconds: Optional TTL override
|
||||
"""
|
||||
with self._lock:
|
||||
# Evict if at capacity
|
||||
self._evict_if_needed()
|
||||
|
||||
now = _utcnow()
|
||||
entry = CacheEntry(
|
||||
value=value,
|
||||
created_at=now,
|
||||
last_accessed_at=now,
|
||||
access_count=1,
|
||||
ttl_seconds=ttl_seconds or self._default_ttl,
|
||||
)
|
||||
|
||||
self._cache[key] = entry
|
||||
self._cache.move_to_end(key)
|
||||
self._stats.current_size = len(self._cache)
|
||||
|
||||
def put_by_id(
|
||||
self,
|
||||
memory_type: str,
|
||||
memory_id: UUID | str,
|
||||
value: T,
|
||||
scope: str | None = None,
|
||||
ttl_seconds: float | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Put a memory by type and ID.
|
||||
|
||||
Args:
|
||||
memory_type: Type of memory
|
||||
memory_id: Memory ID
|
||||
value: Value to cache
|
||||
scope: Optional scope
|
||||
ttl_seconds: Optional TTL override
|
||||
"""
|
||||
key = CacheKey(
|
||||
memory_type=memory_type,
|
||||
memory_id=str(memory_id),
|
||||
scope=scope,
|
||||
)
|
||||
self.put(key, value, ttl_seconds)
|
||||
|
||||
def _evict_if_needed(self) -> None:
|
||||
"""Evict entries if cache is at capacity."""
|
||||
while len(self._cache) >= self._max_size:
|
||||
# Remove least recently used (first item)
|
||||
if self._cache:
|
||||
self._cache.popitem(last=False)
|
||||
self._stats.evictions += 1
|
||||
|
||||
def invalidate(self, key: CacheKey) -> bool:
|
||||
"""
|
||||
Invalidate a specific cache entry.
|
||||
|
||||
Args:
|
||||
key: Cache key to invalidate
|
||||
|
||||
Returns:
|
||||
True if entry was found and removed
|
||||
"""
|
||||
with self._lock:
|
||||
if key in self._cache:
|
||||
del self._cache[key]
|
||||
self._stats.current_size = len(self._cache)
|
||||
return True
|
||||
return False
|
||||
|
||||
def invalidate_by_id(
|
||||
self,
|
||||
memory_type: str,
|
||||
memory_id: UUID | str,
|
||||
scope: str | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Invalidate a memory by type and ID.
|
||||
|
||||
Args:
|
||||
memory_type: Type of memory
|
||||
memory_id: Memory ID
|
||||
scope: Optional scope
|
||||
|
||||
Returns:
|
||||
True if entry was found and removed
|
||||
"""
|
||||
key = CacheKey(
|
||||
memory_type=memory_type,
|
||||
memory_id=str(memory_id),
|
||||
scope=scope,
|
||||
)
|
||||
return self.invalidate(key)
|
||||
|
||||
def invalidate_by_type(self, memory_type: str) -> int:
|
||||
"""
|
||||
Invalidate all entries of a memory type.
|
||||
|
||||
Args:
|
||||
memory_type: Type of memory to invalidate
|
||||
|
||||
Returns:
|
||||
Number of entries invalidated
|
||||
"""
|
||||
with self._lock:
|
||||
keys_to_remove = [
|
||||
k for k in self._cache.keys() if k.memory_type == memory_type
|
||||
]
|
||||
for key in keys_to_remove:
|
||||
del self._cache[key]
|
||||
|
||||
self._stats.current_size = len(self._cache)
|
||||
return len(keys_to_remove)
|
||||
|
||||
def invalidate_by_scope(self, scope: str) -> int:
|
||||
"""
|
||||
Invalidate all entries in a scope.
|
||||
|
||||
Args:
|
||||
scope: Scope to invalidate (e.g., project_id)
|
||||
|
||||
Returns:
|
||||
Number of entries invalidated
|
||||
"""
|
||||
with self._lock:
|
||||
keys_to_remove = [k for k in self._cache.keys() if k.scope == scope]
|
||||
for key in keys_to_remove:
|
||||
del self._cache[key]
|
||||
|
||||
self._stats.current_size = len(self._cache)
|
||||
return len(keys_to_remove)
|
||||
|
||||
def invalidate_pattern(self, pattern: str) -> int:
|
||||
"""
|
||||
Invalidate entries matching a pattern.
|
||||
|
||||
Pattern can include * as wildcard.
|
||||
|
||||
Args:
|
||||
pattern: Pattern to match (e.g., "episodic:*")
|
||||
|
||||
Returns:
|
||||
Number of entries invalidated
|
||||
"""
|
||||
import fnmatch
|
||||
|
||||
with self._lock:
|
||||
keys_to_remove = [
|
||||
k for k in self._cache.keys() if fnmatch.fnmatch(str(k), pattern)
|
||||
]
|
||||
for key in keys_to_remove:
|
||||
del self._cache[key]
|
||||
|
||||
self._stats.current_size = len(self._cache)
|
||||
return len(keys_to_remove)
|
||||
|
||||
def clear(self) -> int:
|
||||
"""
|
||||
Clear all cache entries.
|
||||
|
||||
Returns:
|
||||
Number of entries cleared
|
||||
"""
|
||||
with self._lock:
|
||||
count = len(self._cache)
|
||||
self._cache.clear()
|
||||
self._stats.current_size = 0
|
||||
logger.info(f"Cleared {count} entries from hot cache")
|
||||
return count
|
||||
|
||||
def cleanup_expired(self) -> int:
|
||||
"""
|
||||
Remove all expired entries.
|
||||
|
||||
Returns:
|
||||
Number of entries removed
|
||||
"""
|
||||
with self._lock:
|
||||
keys_to_remove = [k for k, v in self._cache.items() if v.is_expired()]
|
||||
for key in keys_to_remove:
|
||||
del self._cache[key]
|
||||
self._stats.expirations += 1
|
||||
|
||||
self._stats.current_size = len(self._cache)
|
||||
|
||||
if keys_to_remove:
|
||||
logger.debug(f"Cleaned up {len(keys_to_remove)} expired entries")
|
||||
|
||||
return len(keys_to_remove)
|
||||
|
||||
def get_hot_memories(self, limit: int = 10) -> list[tuple[CacheKey, int]]:
|
||||
"""
|
||||
Get the most frequently accessed memories.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of memories to return
|
||||
|
||||
Returns:
|
||||
List of (key, access_count) tuples sorted by access count
|
||||
"""
|
||||
with self._lock:
|
||||
entries = [
|
||||
(k, v.access_count)
|
||||
for k, v in self._cache.items()
|
||||
if not v.is_expired()
|
||||
]
|
||||
entries.sort(key=lambda x: x[1], reverse=True)
|
||||
return entries[:limit]
|
||||
|
||||
def get_stats(self) -> HotCacheStats:
|
||||
"""Get cache statistics."""
|
||||
with self._lock:
|
||||
self._stats.current_size = len(self._cache)
|
||||
return self._stats
|
||||
|
||||
def reset_stats(self) -> None:
|
||||
"""Reset cache statistics."""
|
||||
with self._lock:
|
||||
self._stats = HotCacheStats(
|
||||
max_size=self._max_size,
|
||||
current_size=len(self._cache),
|
||||
)
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
"""Get current cache size."""
|
||||
return len(self._cache)
|
||||
|
||||
@property
|
||||
def max_size(self) -> int:
|
||||
"""Get maximum cache size."""
|
||||
return self._max_size
|
||||
|
||||
|
||||
# Factory function for typed caches
|
||||
def create_hot_cache(
|
||||
max_size: int = 10000,
|
||||
default_ttl_seconds: float = 300.0,
|
||||
) -> HotMemoryCache[Any]:
|
||||
"""
|
||||
Create a hot memory cache.
|
||||
|
||||
Args:
|
||||
max_size: Maximum number of entries
|
||||
default_ttl_seconds: Default TTL for entries
|
||||
|
||||
Returns:
|
||||
Configured HotMemoryCache instance
|
||||
"""
|
||||
return HotMemoryCache(
|
||||
max_size=max_size,
|
||||
default_ttl_seconds=default_ttl_seconds,
|
||||
)
|
||||
410
backend/app/services/memory/config.py
Normal file
410
backend/app/services/memory/config.py
Normal file
@@ -0,0 +1,410 @@
|
||||
"""
|
||||
Memory System Configuration.
|
||||
|
||||
Provides Pydantic settings for the Agent Memory System,
|
||||
including storage backends, capacity limits, and consolidation policies.
|
||||
"""
|
||||
|
||||
import threading
|
||||
from functools import lru_cache
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, field_validator, model_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class MemorySettings(BaseSettings):
|
||||
"""
|
||||
Configuration for the Agent Memory System.
|
||||
|
||||
All settings can be overridden via environment variables
|
||||
with the MEM_ prefix.
|
||||
"""
|
||||
|
||||
# Working Memory Settings
|
||||
working_memory_backend: str = Field(
|
||||
default="redis",
|
||||
description="Backend for working memory: 'redis' or 'memory'",
|
||||
)
|
||||
working_memory_default_ttl_seconds: int = Field(
|
||||
default=3600,
|
||||
ge=60,
|
||||
le=86400,
|
||||
description="Default TTL for working memory items (1 hour default)",
|
||||
)
|
||||
working_memory_max_items_per_session: int = Field(
|
||||
default=1000,
|
||||
ge=100,
|
||||
le=100000,
|
||||
description="Maximum items per session in working memory",
|
||||
)
|
||||
working_memory_max_value_size_bytes: int = Field(
|
||||
default=1048576, # 1MB
|
||||
ge=1024,
|
||||
le=104857600, # 100MB
|
||||
description="Maximum size of a single value in working memory",
|
||||
)
|
||||
working_memory_checkpoint_enabled: bool = Field(
|
||||
default=True,
|
||||
description="Enable checkpointing for working memory recovery",
|
||||
)
|
||||
|
||||
# Redis Settings (for working memory)
|
||||
redis_url: str = Field(
|
||||
default="redis://localhost:6379/0",
|
||||
description="Redis connection URL",
|
||||
)
|
||||
redis_prefix: str = Field(
|
||||
default="mem",
|
||||
description="Redis key prefix for memory items",
|
||||
)
|
||||
redis_connection_timeout_seconds: int = Field(
|
||||
default=5,
|
||||
ge=1,
|
||||
le=60,
|
||||
description="Redis connection timeout",
|
||||
)
|
||||
|
||||
# Episodic Memory Settings
|
||||
episodic_max_episodes_per_project: int = Field(
|
||||
default=10000,
|
||||
ge=100,
|
||||
le=1000000,
|
||||
description="Maximum episodes to retain per project",
|
||||
)
|
||||
episodic_default_importance: float = Field(
|
||||
default=0.5,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Default importance score for new episodes",
|
||||
)
|
||||
episodic_retention_days: int = Field(
|
||||
default=365,
|
||||
ge=7,
|
||||
le=3650,
|
||||
description="Days to retain episodes before archival",
|
||||
)
|
||||
|
||||
# Semantic Memory Settings
|
||||
semantic_max_facts_per_project: int = Field(
|
||||
default=50000,
|
||||
ge=1000,
|
||||
le=10000000,
|
||||
description="Maximum facts to retain per project",
|
||||
)
|
||||
semantic_confidence_decay_days: int = Field(
|
||||
default=90,
|
||||
ge=7,
|
||||
le=365,
|
||||
description="Days until confidence decays by 50%",
|
||||
)
|
||||
semantic_min_confidence: float = Field(
|
||||
default=0.1,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Minimum confidence before fact is pruned",
|
||||
)
|
||||
|
||||
# Procedural Memory Settings
|
||||
procedural_max_procedures_per_project: int = Field(
|
||||
default=1000,
|
||||
ge=10,
|
||||
le=100000,
|
||||
description="Maximum procedures per project",
|
||||
)
|
||||
procedural_min_success_rate: float = Field(
|
||||
default=0.3,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Minimum success rate before procedure is pruned",
|
||||
)
|
||||
procedural_min_uses_before_suggest: int = Field(
|
||||
default=3,
|
||||
ge=1,
|
||||
le=100,
|
||||
description="Minimum uses before procedure is suggested",
|
||||
)
|
||||
|
||||
# Embedding Settings
|
||||
embedding_model: str = Field(
|
||||
default="text-embedding-3-small",
|
||||
description="Model to use for embeddings",
|
||||
)
|
||||
embedding_dimensions: int = Field(
|
||||
default=1536,
|
||||
ge=256,
|
||||
le=4096,
|
||||
description="Embedding vector dimensions",
|
||||
)
|
||||
embedding_batch_size: int = Field(
|
||||
default=100,
|
||||
ge=1,
|
||||
le=1000,
|
||||
description="Batch size for embedding generation",
|
||||
)
|
||||
embedding_cache_enabled: bool = Field(
|
||||
default=True,
|
||||
description="Enable caching of embeddings",
|
||||
)
|
||||
|
||||
# Retrieval Settings
|
||||
retrieval_default_limit: int = Field(
|
||||
default=10,
|
||||
ge=1,
|
||||
le=100,
|
||||
description="Default limit for retrieval queries",
|
||||
)
|
||||
retrieval_max_limit: int = Field(
|
||||
default=100,
|
||||
ge=10,
|
||||
le=1000,
|
||||
description="Maximum limit for retrieval queries",
|
||||
)
|
||||
retrieval_min_similarity: float = Field(
|
||||
default=0.5,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Minimum similarity score for retrieval",
|
||||
)
|
||||
|
||||
# Consolidation Settings
|
||||
consolidation_enabled: bool = Field(
|
||||
default=True,
|
||||
description="Enable automatic memory consolidation",
|
||||
)
|
||||
consolidation_batch_size: int = Field(
|
||||
default=100,
|
||||
ge=10,
|
||||
le=1000,
|
||||
description="Batch size for consolidation jobs",
|
||||
)
|
||||
consolidation_schedule_cron: str = Field(
|
||||
default="0 3 * * *",
|
||||
description="Cron expression for nightly consolidation (3 AM)",
|
||||
)
|
||||
consolidation_working_to_episodic_delay_minutes: int = Field(
|
||||
default=30,
|
||||
ge=5,
|
||||
le=1440,
|
||||
description="Minutes after session end before consolidating to episodic",
|
||||
)
|
||||
|
||||
# Pruning Settings
|
||||
pruning_enabled: bool = Field(
|
||||
default=True,
|
||||
description="Enable automatic memory pruning",
|
||||
)
|
||||
pruning_min_age_days: int = Field(
|
||||
default=7,
|
||||
ge=1,
|
||||
le=365,
|
||||
description="Minimum age before memory can be pruned",
|
||||
)
|
||||
pruning_importance_threshold: float = Field(
|
||||
default=0.2,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Importance threshold below which memory can be pruned",
|
||||
)
|
||||
|
||||
# Caching Settings
|
||||
cache_enabled: bool = Field(
|
||||
default=True,
|
||||
description="Enable caching for memory retrieval",
|
||||
)
|
||||
cache_ttl_seconds: int = Field(
|
||||
default=300,
|
||||
ge=10,
|
||||
le=3600,
|
||||
description="Cache TTL for retrieval results",
|
||||
)
|
||||
cache_max_items: int = Field(
|
||||
default=10000,
|
||||
ge=100,
|
||||
le=1000000,
|
||||
description="Maximum items in memory cache",
|
||||
)
|
||||
|
||||
# Performance Settings
|
||||
max_retrieval_time_ms: int = Field(
|
||||
default=100,
|
||||
ge=10,
|
||||
le=5000,
|
||||
description="Target maximum retrieval time in milliseconds",
|
||||
)
|
||||
parallel_retrieval: bool = Field(
|
||||
default=True,
|
||||
description="Enable parallel retrieval from multiple memory types",
|
||||
)
|
||||
max_parallel_retrievals: int = Field(
|
||||
default=4,
|
||||
ge=1,
|
||||
le=10,
|
||||
description="Maximum concurrent retrieval operations",
|
||||
)
|
||||
|
||||
@field_validator("working_memory_backend")
|
||||
@classmethod
|
||||
def validate_backend(cls, v: str) -> str:
|
||||
"""Validate working memory backend."""
|
||||
valid_backends = {"redis", "memory"}
|
||||
if v not in valid_backends:
|
||||
raise ValueError(f"backend must be one of: {valid_backends}")
|
||||
return v
|
||||
|
||||
@field_validator("embedding_model")
|
||||
@classmethod
|
||||
def validate_embedding_model(cls, v: str) -> str:
|
||||
"""Validate embedding model name."""
|
||||
valid_models = {
|
||||
"text-embedding-3-small",
|
||||
"text-embedding-3-large",
|
||||
"text-embedding-ada-002",
|
||||
}
|
||||
if v not in valid_models:
|
||||
raise ValueError(f"embedding_model must be one of: {valid_models}")
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_limits(self) -> "MemorySettings":
|
||||
"""Validate that limits are consistent."""
|
||||
if self.retrieval_default_limit > self.retrieval_max_limit:
|
||||
raise ValueError(
|
||||
f"retrieval_default_limit ({self.retrieval_default_limit}) "
|
||||
f"cannot exceed retrieval_max_limit ({self.retrieval_max_limit})"
|
||||
)
|
||||
return self
|
||||
|
||||
def get_working_memory_config(self) -> dict[str, Any]:
|
||||
"""Get working memory configuration as a dictionary."""
|
||||
return {
|
||||
"backend": self.working_memory_backend,
|
||||
"default_ttl_seconds": self.working_memory_default_ttl_seconds,
|
||||
"max_items_per_session": self.working_memory_max_items_per_session,
|
||||
"max_value_size_bytes": self.working_memory_max_value_size_bytes,
|
||||
"checkpoint_enabled": self.working_memory_checkpoint_enabled,
|
||||
}
|
||||
|
||||
def get_redis_config(self) -> dict[str, Any]:
|
||||
"""Get Redis configuration as a dictionary."""
|
||||
return {
|
||||
"url": self.redis_url,
|
||||
"prefix": self.redis_prefix,
|
||||
"connection_timeout_seconds": self.redis_connection_timeout_seconds,
|
||||
}
|
||||
|
||||
def get_embedding_config(self) -> dict[str, Any]:
|
||||
"""Get embedding configuration as a dictionary."""
|
||||
return {
|
||||
"model": self.embedding_model,
|
||||
"dimensions": self.embedding_dimensions,
|
||||
"batch_size": self.embedding_batch_size,
|
||||
"cache_enabled": self.embedding_cache_enabled,
|
||||
}
|
||||
|
||||
def get_consolidation_config(self) -> dict[str, Any]:
|
||||
"""Get consolidation configuration as a dictionary."""
|
||||
return {
|
||||
"enabled": self.consolidation_enabled,
|
||||
"batch_size": self.consolidation_batch_size,
|
||||
"schedule_cron": self.consolidation_schedule_cron,
|
||||
"working_to_episodic_delay_minutes": (
|
||||
self.consolidation_working_to_episodic_delay_minutes
|
||||
),
|
||||
}
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert settings to dictionary for logging/debugging."""
|
||||
return {
|
||||
"working_memory": self.get_working_memory_config(),
|
||||
"redis": self.get_redis_config(),
|
||||
"episodic": {
|
||||
"max_episodes_per_project": self.episodic_max_episodes_per_project,
|
||||
"default_importance": self.episodic_default_importance,
|
||||
"retention_days": self.episodic_retention_days,
|
||||
},
|
||||
"semantic": {
|
||||
"max_facts_per_project": self.semantic_max_facts_per_project,
|
||||
"confidence_decay_days": self.semantic_confidence_decay_days,
|
||||
"min_confidence": self.semantic_min_confidence,
|
||||
},
|
||||
"procedural": {
|
||||
"max_procedures_per_project": self.procedural_max_procedures_per_project,
|
||||
"min_success_rate": self.procedural_min_success_rate,
|
||||
"min_uses_before_suggest": self.procedural_min_uses_before_suggest,
|
||||
},
|
||||
"embedding": self.get_embedding_config(),
|
||||
"retrieval": {
|
||||
"default_limit": self.retrieval_default_limit,
|
||||
"max_limit": self.retrieval_max_limit,
|
||||
"min_similarity": self.retrieval_min_similarity,
|
||||
},
|
||||
"consolidation": self.get_consolidation_config(),
|
||||
"pruning": {
|
||||
"enabled": self.pruning_enabled,
|
||||
"min_age_days": self.pruning_min_age_days,
|
||||
"importance_threshold": self.pruning_importance_threshold,
|
||||
},
|
||||
"cache": {
|
||||
"enabled": self.cache_enabled,
|
||||
"ttl_seconds": self.cache_ttl_seconds,
|
||||
"max_items": self.cache_max_items,
|
||||
},
|
||||
"performance": {
|
||||
"max_retrieval_time_ms": self.max_retrieval_time_ms,
|
||||
"parallel_retrieval": self.parallel_retrieval,
|
||||
"max_parallel_retrievals": self.max_parallel_retrievals,
|
||||
},
|
||||
}
|
||||
|
||||
model_config = {
|
||||
"env_prefix": "MEM_",
|
||||
"env_file": ".env",
|
||||
"env_file_encoding": "utf-8",
|
||||
"case_sensitive": False,
|
||||
"extra": "ignore",
|
||||
}
|
||||
|
||||
|
||||
# Thread-safe singleton pattern
|
||||
_settings: MemorySettings | None = None
|
||||
_settings_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_memory_settings() -> MemorySettings:
|
||||
"""
|
||||
Get the global MemorySettings instance.
|
||||
|
||||
Thread-safe with double-checked locking pattern.
|
||||
|
||||
Returns:
|
||||
MemorySettings instance
|
||||
"""
|
||||
global _settings
|
||||
if _settings is None:
|
||||
with _settings_lock:
|
||||
if _settings is None:
|
||||
_settings = MemorySettings()
|
||||
return _settings
|
||||
|
||||
|
||||
def reset_memory_settings() -> None:
|
||||
"""
|
||||
Reset the global settings instance.
|
||||
|
||||
Primarily used for testing.
|
||||
"""
|
||||
global _settings
|
||||
with _settings_lock:
|
||||
_settings = None
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_default_settings() -> MemorySettings:
|
||||
"""
|
||||
Get default settings (cached).
|
||||
|
||||
Use this for read-only access to defaults.
|
||||
For mutable access, use get_memory_settings().
|
||||
"""
|
||||
return MemorySettings()
|
||||
29
backend/app/services/memory/consolidation/__init__.py
Normal file
29
backend/app/services/memory/consolidation/__init__.py
Normal file
@@ -0,0 +1,29 @@
|
||||
# app/services/memory/consolidation/__init__.py
|
||||
"""
|
||||
Memory Consolidation.
|
||||
|
||||
Transfers and extracts knowledge between memory tiers:
|
||||
- Working -> Episodic (session end)
|
||||
- Episodic -> Semantic (learn facts)
|
||||
- Episodic -> Procedural (learn procedures)
|
||||
|
||||
Also handles memory pruning and importance-based retention.
|
||||
"""
|
||||
|
||||
from .service import (
|
||||
ConsolidationConfig,
|
||||
ConsolidationResult,
|
||||
MemoryConsolidationService,
|
||||
NightlyConsolidationResult,
|
||||
SessionConsolidationResult,
|
||||
get_consolidation_service,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ConsolidationConfig",
|
||||
"ConsolidationResult",
|
||||
"MemoryConsolidationService",
|
||||
"NightlyConsolidationResult",
|
||||
"SessionConsolidationResult",
|
||||
"get_consolidation_service",
|
||||
]
|
||||
913
backend/app/services/memory/consolidation/service.py
Normal file
913
backend/app/services/memory/consolidation/service.py
Normal file
@@ -0,0 +1,913 @@
|
||||
# app/services/memory/consolidation/service.py
|
||||
"""
|
||||
Memory Consolidation Service.
|
||||
|
||||
Transfers and extracts knowledge between memory tiers:
|
||||
- Working -> Episodic (session end)
|
||||
- Episodic -> Semantic (learn facts)
|
||||
- Episodic -> Procedural (learn procedures)
|
||||
|
||||
Also handles memory pruning and importance-based retention.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.memory.episodic.memory import EpisodicMemory
|
||||
from app.services.memory.procedural.memory import ProceduralMemory
|
||||
from app.services.memory.semantic.extraction import FactExtractor, get_fact_extractor
|
||||
from app.services.memory.semantic.memory import SemanticMemory
|
||||
from app.services.memory.types import (
|
||||
Episode,
|
||||
EpisodeCreate,
|
||||
Outcome,
|
||||
ProcedureCreate,
|
||||
TaskState,
|
||||
)
|
||||
from app.services.memory.working.memory import WorkingMemory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConsolidationConfig:
|
||||
"""Configuration for memory consolidation."""
|
||||
|
||||
# Working -> Episodic thresholds
|
||||
min_steps_for_episode: int = 2
|
||||
min_duration_seconds: float = 5.0
|
||||
|
||||
# Episodic -> Semantic thresholds
|
||||
min_confidence_for_fact: float = 0.6
|
||||
max_facts_per_episode: int = 10
|
||||
reinforce_existing_facts: bool = True
|
||||
|
||||
# Episodic -> Procedural thresholds
|
||||
min_episodes_for_procedure: int = 3
|
||||
min_success_rate_for_procedure: float = 0.7
|
||||
min_steps_for_procedure: int = 2
|
||||
|
||||
# Pruning thresholds
|
||||
max_episode_age_days: int = 90
|
||||
min_importance_to_keep: float = 0.2
|
||||
keep_all_failures: bool = True
|
||||
keep_all_with_lessons: bool = True
|
||||
|
||||
# Batch sizes
|
||||
batch_size: int = 100
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConsolidationResult:
|
||||
"""Result of a consolidation operation."""
|
||||
|
||||
source_type: str
|
||||
target_type: str
|
||||
items_processed: int = 0
|
||||
items_created: int = 0
|
||||
items_updated: int = 0
|
||||
items_skipped: int = 0
|
||||
items_pruned: int = 0
|
||||
errors: list[str] = field(default_factory=list)
|
||||
duration_seconds: float = 0.0
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"source_type": self.source_type,
|
||||
"target_type": self.target_type,
|
||||
"items_processed": self.items_processed,
|
||||
"items_created": self.items_created,
|
||||
"items_updated": self.items_updated,
|
||||
"items_skipped": self.items_skipped,
|
||||
"items_pruned": self.items_pruned,
|
||||
"errors": self.errors,
|
||||
"duration_seconds": self.duration_seconds,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionConsolidationResult:
|
||||
"""Result of consolidating a session's working memory to episodic."""
|
||||
|
||||
session_id: str
|
||||
episode_created: bool = False
|
||||
episode_id: UUID | None = None
|
||||
scratchpad_entries: int = 0
|
||||
variables_captured: int = 0
|
||||
error: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class NightlyConsolidationResult:
|
||||
"""Result of nightly consolidation run."""
|
||||
|
||||
started_at: datetime
|
||||
completed_at: datetime | None = None
|
||||
episodic_to_semantic: ConsolidationResult | None = None
|
||||
episodic_to_procedural: ConsolidationResult | None = None
|
||||
pruning: ConsolidationResult | None = None
|
||||
total_episodes_processed: int = 0
|
||||
total_facts_created: int = 0
|
||||
total_procedures_created: int = 0
|
||||
total_pruned: int = 0
|
||||
errors: list[str] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"started_at": self.started_at.isoformat(),
|
||||
"completed_at": self.completed_at.isoformat()
|
||||
if self.completed_at
|
||||
else None,
|
||||
"episodic_to_semantic": (
|
||||
self.episodic_to_semantic.to_dict()
|
||||
if self.episodic_to_semantic
|
||||
else None
|
||||
),
|
||||
"episodic_to_procedural": (
|
||||
self.episodic_to_procedural.to_dict()
|
||||
if self.episodic_to_procedural
|
||||
else None
|
||||
),
|
||||
"pruning": self.pruning.to_dict() if self.pruning else None,
|
||||
"total_episodes_processed": self.total_episodes_processed,
|
||||
"total_facts_created": self.total_facts_created,
|
||||
"total_procedures_created": self.total_procedures_created,
|
||||
"total_pruned": self.total_pruned,
|
||||
"errors": self.errors,
|
||||
}
|
||||
|
||||
|
||||
class MemoryConsolidationService:
|
||||
"""
|
||||
Service for consolidating memories between tiers.
|
||||
|
||||
Responsibilities:
|
||||
- Transfer working memory to episodic at session end
|
||||
- Extract facts from episodes to semantic memory
|
||||
- Learn procedures from successful episode patterns
|
||||
- Prune old/low-value memories
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
config: ConsolidationConfig | None = None,
|
||||
embedding_generator: Any | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize consolidation service.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
config: Consolidation configuration
|
||||
embedding_generator: Optional embedding generator
|
||||
"""
|
||||
self._session = session
|
||||
self._config = config or ConsolidationConfig()
|
||||
self._embedding_generator = embedding_generator
|
||||
self._fact_extractor: FactExtractor = get_fact_extractor()
|
||||
|
||||
# Memory services (lazy initialized)
|
||||
self._episodic: EpisodicMemory | None = None
|
||||
self._semantic: SemanticMemory | None = None
|
||||
self._procedural: ProceduralMemory | None = None
|
||||
|
||||
async def _get_episodic(self) -> EpisodicMemory:
|
||||
"""Get or create episodic memory service."""
|
||||
if self._episodic is None:
|
||||
self._episodic = await EpisodicMemory.create(
|
||||
self._session, self._embedding_generator
|
||||
)
|
||||
return self._episodic
|
||||
|
||||
async def _get_semantic(self) -> SemanticMemory:
|
||||
"""Get or create semantic memory service."""
|
||||
if self._semantic is None:
|
||||
self._semantic = await SemanticMemory.create(
|
||||
self._session, self._embedding_generator
|
||||
)
|
||||
return self._semantic
|
||||
|
||||
async def _get_procedural(self) -> ProceduralMemory:
|
||||
"""Get or create procedural memory service."""
|
||||
if self._procedural is None:
|
||||
self._procedural = await ProceduralMemory.create(
|
||||
self._session, self._embedding_generator
|
||||
)
|
||||
return self._procedural
|
||||
|
||||
# =========================================================================
|
||||
# Working -> Episodic Consolidation
|
||||
# =========================================================================
|
||||
|
||||
async def consolidate_session(
|
||||
self,
|
||||
working_memory: WorkingMemory,
|
||||
project_id: UUID,
|
||||
session_id: str,
|
||||
task_type: str = "session_task",
|
||||
agent_instance_id: UUID | None = None,
|
||||
agent_type_id: UUID | None = None,
|
||||
) -> SessionConsolidationResult:
|
||||
"""
|
||||
Consolidate a session's working memory to episodic memory.
|
||||
|
||||
Called at session end to transfer relevant session data
|
||||
into a persistent episode.
|
||||
|
||||
Args:
|
||||
working_memory: The session's working memory
|
||||
project_id: Project ID
|
||||
session_id: Session ID
|
||||
task_type: Type of task performed
|
||||
agent_instance_id: Optional agent instance
|
||||
agent_type_id: Optional agent type
|
||||
|
||||
Returns:
|
||||
SessionConsolidationResult with outcome details
|
||||
"""
|
||||
result = SessionConsolidationResult(session_id=session_id)
|
||||
|
||||
try:
|
||||
# Get task state
|
||||
task_state = await working_memory.get_task_state()
|
||||
|
||||
# Check if there's enough content to consolidate
|
||||
if not self._should_consolidate_session(task_state):
|
||||
logger.debug(
|
||||
f"Skipping consolidation for session {session_id}: insufficient content"
|
||||
)
|
||||
return result
|
||||
|
||||
# Gather scratchpad entries
|
||||
scratchpad = await working_memory.get_scratchpad()
|
||||
result.scratchpad_entries = len(scratchpad)
|
||||
|
||||
# Gather user variables
|
||||
all_data = await working_memory.get_all()
|
||||
result.variables_captured = len(all_data)
|
||||
|
||||
# Determine outcome
|
||||
outcome = self._determine_session_outcome(task_state)
|
||||
|
||||
# Build actions from scratchpad and variables
|
||||
actions = self._build_actions_from_session(scratchpad, all_data, task_state)
|
||||
|
||||
# Build context summary
|
||||
context_summary = self._build_context_summary(task_state, all_data)
|
||||
|
||||
# Extract lessons learned
|
||||
lessons = self._extract_session_lessons(task_state, outcome)
|
||||
|
||||
# Calculate importance
|
||||
importance = self._calculate_session_importance(
|
||||
task_state, outcome, actions
|
||||
)
|
||||
|
||||
# Create episode
|
||||
episode_data = EpisodeCreate(
|
||||
project_id=project_id,
|
||||
session_id=session_id,
|
||||
task_type=task_type,
|
||||
task_description=task_state.description
|
||||
if task_state
|
||||
else "Session task",
|
||||
actions=actions,
|
||||
context_summary=context_summary,
|
||||
outcome=outcome,
|
||||
outcome_details=task_state.status if task_state else "",
|
||||
duration_seconds=self._calculate_duration(task_state),
|
||||
tokens_used=0, # Would need to track this in working memory
|
||||
lessons_learned=lessons,
|
||||
importance_score=importance,
|
||||
agent_instance_id=agent_instance_id,
|
||||
agent_type_id=agent_type_id,
|
||||
)
|
||||
|
||||
episodic = await self._get_episodic()
|
||||
episode = await episodic.record_episode(episode_data)
|
||||
|
||||
result.episode_created = True
|
||||
result.episode_id = episode.id
|
||||
|
||||
logger.info(
|
||||
f"Consolidated session {session_id} to episode {episode.id} "
|
||||
f"({len(actions)} actions, outcome={outcome.value})"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
result.error = str(e)
|
||||
logger.exception(f"Failed to consolidate session {session_id}")
|
||||
|
||||
return result
|
||||
|
||||
def _should_consolidate_session(self, task_state: TaskState | None) -> bool:
|
||||
"""Check if session has enough content to consolidate."""
|
||||
if task_state is None:
|
||||
return False
|
||||
|
||||
# Check minimum steps
|
||||
if task_state.current_step < self._config.min_steps_for_episode:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _determine_session_outcome(self, task_state: TaskState | None) -> Outcome:
|
||||
"""Determine outcome from task state."""
|
||||
if task_state is None:
|
||||
return Outcome.PARTIAL
|
||||
|
||||
status = task_state.status.lower() if task_state.status else ""
|
||||
progress = task_state.progress_percent
|
||||
|
||||
if "success" in status or "complete" in status or progress >= 100:
|
||||
return Outcome.SUCCESS
|
||||
if "fail" in status or "error" in status:
|
||||
return Outcome.FAILURE
|
||||
if progress >= 50:
|
||||
return Outcome.PARTIAL
|
||||
return Outcome.FAILURE
|
||||
|
||||
def _build_actions_from_session(
|
||||
self,
|
||||
scratchpad: list[str],
|
||||
variables: dict[str, Any],
|
||||
task_state: TaskState | None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Build action list from session data."""
|
||||
actions: list[dict[str, Any]] = []
|
||||
|
||||
# Add scratchpad entries as actions
|
||||
for i, entry in enumerate(scratchpad):
|
||||
actions.append(
|
||||
{
|
||||
"step": i + 1,
|
||||
"type": "reasoning",
|
||||
"content": entry[:500], # Truncate long entries
|
||||
}
|
||||
)
|
||||
|
||||
# Add final state
|
||||
if task_state:
|
||||
actions.append(
|
||||
{
|
||||
"step": len(scratchpad) + 1,
|
||||
"type": "final_state",
|
||||
"current_step": task_state.current_step,
|
||||
"total_steps": task_state.total_steps,
|
||||
"progress": task_state.progress_percent,
|
||||
"status": task_state.status,
|
||||
}
|
||||
)
|
||||
|
||||
return actions
|
||||
|
||||
def _build_context_summary(
|
||||
self,
|
||||
task_state: TaskState | None,
|
||||
variables: dict[str, Any],
|
||||
) -> str:
|
||||
"""Build context summary from session data."""
|
||||
parts = []
|
||||
|
||||
if task_state:
|
||||
parts.append(f"Task: {task_state.description}")
|
||||
parts.append(f"Progress: {task_state.progress_percent:.1f}%")
|
||||
parts.append(f"Steps: {task_state.current_step}/{task_state.total_steps}")
|
||||
|
||||
# Include key variables
|
||||
key_vars = {k: v for k, v in variables.items() if len(str(v)) < 100}
|
||||
if key_vars:
|
||||
var_str = ", ".join(f"{k}={v}" for k, v in list(key_vars.items())[:5])
|
||||
parts.append(f"Variables: {var_str}")
|
||||
|
||||
return "; ".join(parts) if parts else "Session completed"
|
||||
|
||||
def _extract_session_lessons(
|
||||
self,
|
||||
task_state: TaskState | None,
|
||||
outcome: Outcome,
|
||||
) -> list[str]:
|
||||
"""Extract lessons from session."""
|
||||
lessons: list[str] = []
|
||||
|
||||
if task_state and task_state.status:
|
||||
if outcome == Outcome.FAILURE:
|
||||
lessons.append(
|
||||
f"Task failed at step {task_state.current_step}: {task_state.status}"
|
||||
)
|
||||
elif outcome == Outcome.SUCCESS:
|
||||
lessons.append(
|
||||
f"Successfully completed in {task_state.current_step} steps"
|
||||
)
|
||||
|
||||
return lessons
|
||||
|
||||
def _calculate_session_importance(
|
||||
self,
|
||||
task_state: TaskState | None,
|
||||
outcome: Outcome,
|
||||
actions: list[dict[str, Any]],
|
||||
) -> float:
|
||||
"""Calculate importance score for session."""
|
||||
score = 0.5 # Base score
|
||||
|
||||
# Failures are important to learn from
|
||||
if outcome == Outcome.FAILURE:
|
||||
score += 0.3
|
||||
|
||||
# Many steps means complex task
|
||||
if task_state and task_state.total_steps >= 5:
|
||||
score += 0.1
|
||||
|
||||
# Many actions means detailed reasoning
|
||||
if len(actions) >= 5:
|
||||
score += 0.1
|
||||
|
||||
return min(1.0, score)
|
||||
|
||||
def _calculate_duration(self, task_state: TaskState | None) -> float:
|
||||
"""Calculate session duration."""
|
||||
if task_state is None:
|
||||
return 0.0
|
||||
|
||||
if task_state.started_at and task_state.updated_at:
|
||||
delta = task_state.updated_at - task_state.started_at
|
||||
return delta.total_seconds()
|
||||
|
||||
return 0.0
|
||||
|
||||
# =========================================================================
|
||||
# Episodic -> Semantic Consolidation
|
||||
# =========================================================================
|
||||
|
||||
async def consolidate_episodes_to_facts(
|
||||
self,
|
||||
project_id: UUID,
|
||||
since: datetime | None = None,
|
||||
limit: int | None = None,
|
||||
) -> ConsolidationResult:
|
||||
"""
|
||||
Extract facts from episodic memories to semantic memory.
|
||||
|
||||
Args:
|
||||
project_id: Project to consolidate
|
||||
since: Only process episodes since this time
|
||||
limit: Maximum episodes to process
|
||||
|
||||
Returns:
|
||||
ConsolidationResult with extraction statistics
|
||||
"""
|
||||
start_time = datetime.now(UTC)
|
||||
result = ConsolidationResult(
|
||||
source_type="episodic",
|
||||
target_type="semantic",
|
||||
)
|
||||
|
||||
try:
|
||||
episodic = await self._get_episodic()
|
||||
semantic = await self._get_semantic()
|
||||
|
||||
# Get episodes to process
|
||||
since_time = since or datetime.now(UTC) - timedelta(days=1)
|
||||
episodes = await episodic.get_recent(
|
||||
project_id,
|
||||
limit=limit or self._config.batch_size,
|
||||
since=since_time,
|
||||
)
|
||||
|
||||
for episode in episodes:
|
||||
result.items_processed += 1
|
||||
|
||||
try:
|
||||
# Extract facts using the extractor
|
||||
extracted_facts = self._fact_extractor.extract_from_episode(episode)
|
||||
|
||||
for extracted_fact in extracted_facts:
|
||||
if (
|
||||
extracted_fact.confidence
|
||||
< self._config.min_confidence_for_fact
|
||||
):
|
||||
result.items_skipped += 1
|
||||
continue
|
||||
|
||||
# Create fact (store_fact handles deduplication/reinforcement)
|
||||
fact_create = extracted_fact.to_fact_create(
|
||||
project_id=project_id,
|
||||
source_episode_ids=[episode.id],
|
||||
)
|
||||
|
||||
# store_fact automatically reinforces if fact already exists
|
||||
fact = await semantic.store_fact(fact_create)
|
||||
|
||||
# Check if this was a new fact or reinforced existing
|
||||
if fact.reinforcement_count == 1:
|
||||
result.items_created += 1
|
||||
else:
|
||||
result.items_updated += 1
|
||||
|
||||
except Exception as e:
|
||||
result.errors.append(f"Episode {episode.id}: {e}")
|
||||
logger.warning(
|
||||
f"Failed to extract facts from episode {episode.id}: {e}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
result.errors.append(f"Consolidation failed: {e}")
|
||||
logger.exception("Failed episodic -> semantic consolidation")
|
||||
|
||||
result.duration_seconds = (datetime.now(UTC) - start_time).total_seconds()
|
||||
|
||||
logger.info(
|
||||
f"Episodic -> Semantic consolidation: "
|
||||
f"{result.items_processed} processed, "
|
||||
f"{result.items_created} created, "
|
||||
f"{result.items_updated} reinforced"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
# =========================================================================
|
||||
# Episodic -> Procedural Consolidation
|
||||
# =========================================================================
|
||||
|
||||
async def consolidate_episodes_to_procedures(
|
||||
self,
|
||||
project_id: UUID,
|
||||
agent_type_id: UUID | None = None,
|
||||
since: datetime | None = None,
|
||||
) -> ConsolidationResult:
|
||||
"""
|
||||
Learn procedures from patterns in episodic memories.
|
||||
|
||||
Identifies recurring successful patterns and creates/updates
|
||||
procedures to capture them.
|
||||
|
||||
Args:
|
||||
project_id: Project to consolidate
|
||||
agent_type_id: Optional filter by agent type
|
||||
since: Only process episodes since this time
|
||||
|
||||
Returns:
|
||||
ConsolidationResult with procedure statistics
|
||||
"""
|
||||
start_time = datetime.now(UTC)
|
||||
result = ConsolidationResult(
|
||||
source_type="episodic",
|
||||
target_type="procedural",
|
||||
)
|
||||
|
||||
try:
|
||||
episodic = await self._get_episodic()
|
||||
procedural = await self._get_procedural()
|
||||
|
||||
# Get successful episodes
|
||||
since_time = since or datetime.now(UTC) - timedelta(days=7)
|
||||
episodes = await episodic.get_by_outcome(
|
||||
project_id,
|
||||
outcome=Outcome.SUCCESS,
|
||||
limit=self._config.batch_size,
|
||||
agent_instance_id=None, # Get all agent instances
|
||||
)
|
||||
|
||||
# Group by task type
|
||||
task_groups: dict[str, list[Episode]] = {}
|
||||
for episode in episodes:
|
||||
if episode.occurred_at >= since_time:
|
||||
if episode.task_type not in task_groups:
|
||||
task_groups[episode.task_type] = []
|
||||
task_groups[episode.task_type].append(episode)
|
||||
|
||||
result.items_processed = len(episodes)
|
||||
|
||||
# Process each task type group
|
||||
for task_type, group in task_groups.items():
|
||||
if len(group) < self._config.min_episodes_for_procedure:
|
||||
result.items_skipped += len(group)
|
||||
continue
|
||||
|
||||
try:
|
||||
procedure_result = await self._learn_procedure_from_episodes(
|
||||
procedural,
|
||||
project_id,
|
||||
agent_type_id,
|
||||
task_type,
|
||||
group,
|
||||
)
|
||||
|
||||
if procedure_result == "created":
|
||||
result.items_created += 1
|
||||
elif procedure_result == "updated":
|
||||
result.items_updated += 1
|
||||
else:
|
||||
result.items_skipped += 1
|
||||
|
||||
except Exception as e:
|
||||
result.errors.append(f"Task type '{task_type}': {e}")
|
||||
logger.warning(f"Failed to learn procedure for '{task_type}': {e}")
|
||||
|
||||
except Exception as e:
|
||||
result.errors.append(f"Consolidation failed: {e}")
|
||||
logger.exception("Failed episodic -> procedural consolidation")
|
||||
|
||||
result.duration_seconds = (datetime.now(UTC) - start_time).total_seconds()
|
||||
|
||||
logger.info(
|
||||
f"Episodic -> Procedural consolidation: "
|
||||
f"{result.items_processed} processed, "
|
||||
f"{result.items_created} created, "
|
||||
f"{result.items_updated} updated"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def _learn_procedure_from_episodes(
|
||||
self,
|
||||
procedural: ProceduralMemory,
|
||||
project_id: UUID,
|
||||
agent_type_id: UUID | None,
|
||||
task_type: str,
|
||||
episodes: list[Episode],
|
||||
) -> str:
|
||||
"""Learn or update a procedure from a set of episodes."""
|
||||
# Calculate success rate for this pattern
|
||||
success_count = sum(1 for e in episodes if e.outcome == Outcome.SUCCESS)
|
||||
total_count = len(episodes)
|
||||
success_rate = success_count / total_count if total_count > 0 else 0
|
||||
|
||||
if success_rate < self._config.min_success_rate_for_procedure:
|
||||
return "skipped"
|
||||
|
||||
# Extract common steps from episodes
|
||||
steps = self._extract_common_steps(episodes)
|
||||
|
||||
if len(steps) < self._config.min_steps_for_procedure:
|
||||
return "skipped"
|
||||
|
||||
# Check for existing procedure
|
||||
matching = await procedural.find_matching(
|
||||
context=task_type,
|
||||
project_id=project_id,
|
||||
agent_type_id=agent_type_id,
|
||||
limit=1,
|
||||
)
|
||||
|
||||
if matching:
|
||||
# Update existing procedure with new success
|
||||
await procedural.record_outcome(
|
||||
matching[0].id,
|
||||
success=True,
|
||||
)
|
||||
return "updated"
|
||||
else:
|
||||
# Create new procedure
|
||||
# Note: success_count starts at 1 in record_procedure
|
||||
procedure_data = ProcedureCreate(
|
||||
project_id=project_id,
|
||||
agent_type_id=agent_type_id,
|
||||
name=f"Procedure for {task_type}",
|
||||
trigger_pattern=task_type,
|
||||
steps=steps,
|
||||
)
|
||||
await procedural.record_procedure(procedure_data)
|
||||
return "created"
|
||||
|
||||
def _extract_common_steps(self, episodes: list[Episode]) -> list[dict[str, Any]]:
|
||||
"""Extract common action steps from multiple episodes."""
|
||||
# Simple heuristic: take the steps from the most successful episode
|
||||
# with the most detailed actions
|
||||
|
||||
best_episode = max(
|
||||
episodes,
|
||||
key=lambda e: (
|
||||
e.outcome == Outcome.SUCCESS,
|
||||
len(e.actions),
|
||||
e.importance_score,
|
||||
),
|
||||
)
|
||||
|
||||
steps: list[dict[str, Any]] = []
|
||||
for i, action in enumerate(best_episode.actions):
|
||||
step = {
|
||||
"order": i + 1,
|
||||
"action": action.get("type", "action"),
|
||||
"description": action.get("content", str(action))[:500],
|
||||
"parameters": action,
|
||||
}
|
||||
steps.append(step)
|
||||
|
||||
return steps
|
||||
|
||||
# =========================================================================
|
||||
# Memory Pruning
|
||||
# =========================================================================
|
||||
|
||||
async def prune_old_episodes(
|
||||
self,
|
||||
project_id: UUID,
|
||||
max_age_days: int | None = None,
|
||||
min_importance: float | None = None,
|
||||
) -> ConsolidationResult:
|
||||
"""
|
||||
Prune old, low-value episodes.
|
||||
|
||||
Args:
|
||||
project_id: Project to prune
|
||||
max_age_days: Maximum age in days (default from config)
|
||||
min_importance: Minimum importance to keep (default from config)
|
||||
|
||||
Returns:
|
||||
ConsolidationResult with pruning statistics
|
||||
"""
|
||||
start_time = datetime.now(UTC)
|
||||
result = ConsolidationResult(
|
||||
source_type="episodic",
|
||||
target_type="pruned",
|
||||
)
|
||||
|
||||
max_age = max_age_days or self._config.max_episode_age_days
|
||||
min_imp = min_importance or self._config.min_importance_to_keep
|
||||
cutoff_date = datetime.now(UTC) - timedelta(days=max_age)
|
||||
|
||||
try:
|
||||
episodic = await self._get_episodic()
|
||||
|
||||
# Get old episodes
|
||||
# Note: In production, this would use a more efficient query
|
||||
all_episodes = await episodic.get_recent(
|
||||
project_id,
|
||||
limit=self._config.batch_size * 10,
|
||||
since=cutoff_date - timedelta(days=365), # Search past year
|
||||
)
|
||||
|
||||
for episode in all_episodes:
|
||||
result.items_processed += 1
|
||||
|
||||
# Check if should be pruned
|
||||
if not self._should_prune_episode(episode, cutoff_date, min_imp):
|
||||
result.items_skipped += 1
|
||||
continue
|
||||
|
||||
try:
|
||||
deleted = await episodic.delete(episode.id)
|
||||
if deleted:
|
||||
result.items_pruned += 1
|
||||
else:
|
||||
result.items_skipped += 1
|
||||
except Exception as e:
|
||||
result.errors.append(f"Episode {episode.id}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
result.errors.append(f"Pruning failed: {e}")
|
||||
logger.exception("Failed episode pruning")
|
||||
|
||||
result.duration_seconds = (datetime.now(UTC) - start_time).total_seconds()
|
||||
|
||||
logger.info(
|
||||
f"Episode pruning: {result.items_processed} processed, "
|
||||
f"{result.items_pruned} pruned"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _should_prune_episode(
|
||||
self,
|
||||
episode: Episode,
|
||||
cutoff_date: datetime,
|
||||
min_importance: float,
|
||||
) -> bool:
|
||||
"""Determine if an episode should be pruned."""
|
||||
# Keep recent episodes
|
||||
if episode.occurred_at >= cutoff_date:
|
||||
return False
|
||||
|
||||
# Keep failures if configured
|
||||
if self._config.keep_all_failures and episode.outcome == Outcome.FAILURE:
|
||||
return False
|
||||
|
||||
# Keep episodes with lessons if configured
|
||||
if self._config.keep_all_with_lessons and episode.lessons_learned:
|
||||
return False
|
||||
|
||||
# Keep high-importance episodes
|
||||
if episode.importance_score >= min_importance:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
# =========================================================================
|
||||
# Nightly Consolidation
|
||||
# =========================================================================
|
||||
|
||||
async def run_nightly_consolidation(
|
||||
self,
|
||||
project_id: UUID,
|
||||
agent_type_id: UUID | None = None,
|
||||
) -> NightlyConsolidationResult:
|
||||
"""
|
||||
Run full nightly consolidation workflow.
|
||||
|
||||
This includes:
|
||||
1. Extract facts from recent episodes
|
||||
2. Learn procedures from successful patterns
|
||||
3. Prune old, low-value memories
|
||||
|
||||
Args:
|
||||
project_id: Project to consolidate
|
||||
agent_type_id: Optional agent type filter
|
||||
|
||||
Returns:
|
||||
NightlyConsolidationResult with all outcomes
|
||||
"""
|
||||
result = NightlyConsolidationResult(started_at=datetime.now(UTC))
|
||||
|
||||
logger.info(f"Starting nightly consolidation for project {project_id}")
|
||||
|
||||
try:
|
||||
# Step 1: Episodic -> Semantic (last 24 hours)
|
||||
since_yesterday = datetime.now(UTC) - timedelta(days=1)
|
||||
result.episodic_to_semantic = await self.consolidate_episodes_to_facts(
|
||||
project_id=project_id,
|
||||
since=since_yesterday,
|
||||
)
|
||||
result.total_facts_created = result.episodic_to_semantic.items_created
|
||||
|
||||
# Step 2: Episodic -> Procedural (last 7 days)
|
||||
since_week = datetime.now(UTC) - timedelta(days=7)
|
||||
result.episodic_to_procedural = (
|
||||
await self.consolidate_episodes_to_procedures(
|
||||
project_id=project_id,
|
||||
agent_type_id=agent_type_id,
|
||||
since=since_week,
|
||||
)
|
||||
)
|
||||
result.total_procedures_created = (
|
||||
result.episodic_to_procedural.items_created
|
||||
)
|
||||
|
||||
# Step 3: Prune old memories
|
||||
result.pruning = await self.prune_old_episodes(project_id=project_id)
|
||||
result.total_pruned = result.pruning.items_pruned
|
||||
|
||||
# Calculate totals
|
||||
result.total_episodes_processed = (
|
||||
result.episodic_to_semantic.items_processed
|
||||
if result.episodic_to_semantic
|
||||
else 0
|
||||
) + (
|
||||
result.episodic_to_procedural.items_processed
|
||||
if result.episodic_to_procedural
|
||||
else 0
|
||||
)
|
||||
|
||||
# Collect all errors
|
||||
if result.episodic_to_semantic and result.episodic_to_semantic.errors:
|
||||
result.errors.extend(result.episodic_to_semantic.errors)
|
||||
if result.episodic_to_procedural and result.episodic_to_procedural.errors:
|
||||
result.errors.extend(result.episodic_to_procedural.errors)
|
||||
if result.pruning and result.pruning.errors:
|
||||
result.errors.extend(result.pruning.errors)
|
||||
|
||||
except Exception as e:
|
||||
result.errors.append(f"Nightly consolidation failed: {e}")
|
||||
logger.exception("Nightly consolidation failed")
|
||||
|
||||
result.completed_at = datetime.now(UTC)
|
||||
|
||||
duration = (result.completed_at - result.started_at).total_seconds()
|
||||
logger.info(
|
||||
f"Nightly consolidation completed in {duration:.1f}s: "
|
||||
f"{result.total_facts_created} facts, "
|
||||
f"{result.total_procedures_created} procedures, "
|
||||
f"{result.total_pruned} pruned"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# Factory function - no singleton to avoid stale session issues
|
||||
async def get_consolidation_service(
|
||||
session: AsyncSession,
|
||||
config: ConsolidationConfig | None = None,
|
||||
) -> MemoryConsolidationService:
|
||||
"""
|
||||
Create a memory consolidation service for the given session.
|
||||
|
||||
Note: This creates a new instance each time to avoid stale session issues.
|
||||
The service is lightweight and safe to recreate per-request.
|
||||
|
||||
Args:
|
||||
session: Database session (must be active)
|
||||
config: Optional configuration
|
||||
|
||||
Returns:
|
||||
MemoryConsolidationService instance
|
||||
"""
|
||||
return MemoryConsolidationService(session=session, config=config)
|
||||
17
backend/app/services/memory/episodic/__init__.py
Normal file
17
backend/app/services/memory/episodic/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
# app/services/memory/episodic/__init__.py
|
||||
"""
|
||||
Episodic Memory Package.
|
||||
|
||||
Provides experiential memory storage and retrieval for agent learning.
|
||||
"""
|
||||
|
||||
from .memory import EpisodicMemory
|
||||
from .recorder import EpisodeRecorder
|
||||
from .retrieval import EpisodeRetriever, RetrievalStrategy
|
||||
|
||||
__all__ = [
|
||||
"EpisodeRecorder",
|
||||
"EpisodeRetriever",
|
||||
"EpisodicMemory",
|
||||
"RetrievalStrategy",
|
||||
]
|
||||
490
backend/app/services/memory/episodic/memory.py
Normal file
490
backend/app/services/memory/episodic/memory.py
Normal file
@@ -0,0 +1,490 @@
|
||||
# app/services/memory/episodic/memory.py
|
||||
"""
|
||||
Episodic Memory Implementation.
|
||||
|
||||
Provides experiential memory storage and retrieval for agent learning.
|
||||
Combines episode recording and retrieval into a unified interface.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.memory.types import Episode, EpisodeCreate, Outcome, RetrievalResult
|
||||
|
||||
from .recorder import EpisodeRecorder
|
||||
from .retrieval import EpisodeRetriever, RetrievalStrategy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EpisodicMemory:
|
||||
"""
|
||||
Episodic Memory Service.
|
||||
|
||||
Provides experiential memory for agent learning:
|
||||
- Record task completions with context
|
||||
- Store failures with error context
|
||||
- Retrieve by semantic similarity
|
||||
- Retrieve by recency, outcome, task type
|
||||
- Track importance scores
|
||||
- Extract lessons learned
|
||||
|
||||
Performance target: <100ms P95 for retrieval
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
embedding_generator: Any | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize episodic memory.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
embedding_generator: Optional embedding generator for semantic search
|
||||
"""
|
||||
self._session = session
|
||||
self._embedding_generator = embedding_generator
|
||||
self._recorder = EpisodeRecorder(session, embedding_generator)
|
||||
self._retriever = EpisodeRetriever(session, embedding_generator)
|
||||
|
||||
@classmethod
|
||||
async def create(
|
||||
cls,
|
||||
session: AsyncSession,
|
||||
embedding_generator: Any | None = None,
|
||||
) -> "EpisodicMemory":
|
||||
"""
|
||||
Factory method to create EpisodicMemory.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
embedding_generator: Optional embedding generator
|
||||
|
||||
Returns:
|
||||
Configured EpisodicMemory instance
|
||||
"""
|
||||
return cls(session=session, embedding_generator=embedding_generator)
|
||||
|
||||
# =========================================================================
|
||||
# Recording Operations
|
||||
# =========================================================================
|
||||
|
||||
async def record_episode(self, episode: EpisodeCreate) -> Episode:
|
||||
"""
|
||||
Record a new episode.
|
||||
|
||||
Args:
|
||||
episode: Episode data to record
|
||||
|
||||
Returns:
|
||||
The created episode with assigned ID
|
||||
"""
|
||||
return await self._recorder.record(episode)
|
||||
|
||||
async def record_success(
|
||||
self,
|
||||
project_id: UUID,
|
||||
session_id: str,
|
||||
task_type: str,
|
||||
task_description: str,
|
||||
actions: list[dict[str, Any]],
|
||||
context_summary: str,
|
||||
outcome_details: str = "",
|
||||
duration_seconds: float = 0.0,
|
||||
tokens_used: int = 0,
|
||||
lessons_learned: list[str] | None = None,
|
||||
agent_instance_id: UUID | None = None,
|
||||
agent_type_id: UUID | None = None,
|
||||
) -> Episode:
|
||||
"""
|
||||
Convenience method to record a successful episode.
|
||||
|
||||
Args:
|
||||
project_id: Project ID
|
||||
session_id: Session ID
|
||||
task_type: Type of task
|
||||
task_description: Task description
|
||||
actions: Actions taken
|
||||
context_summary: Context summary
|
||||
outcome_details: Optional outcome details
|
||||
duration_seconds: Task duration
|
||||
tokens_used: Tokens consumed
|
||||
lessons_learned: Optional lessons
|
||||
agent_instance_id: Optional agent instance
|
||||
agent_type_id: Optional agent type
|
||||
|
||||
Returns:
|
||||
The created episode
|
||||
"""
|
||||
episode_data = EpisodeCreate(
|
||||
project_id=project_id,
|
||||
session_id=session_id,
|
||||
task_type=task_type,
|
||||
task_description=task_description,
|
||||
actions=actions,
|
||||
context_summary=context_summary,
|
||||
outcome=Outcome.SUCCESS,
|
||||
outcome_details=outcome_details,
|
||||
duration_seconds=duration_seconds,
|
||||
tokens_used=tokens_used,
|
||||
lessons_learned=lessons_learned or [],
|
||||
agent_instance_id=agent_instance_id,
|
||||
agent_type_id=agent_type_id,
|
||||
)
|
||||
return await self.record_episode(episode_data)
|
||||
|
||||
async def record_failure(
|
||||
self,
|
||||
project_id: UUID,
|
||||
session_id: str,
|
||||
task_type: str,
|
||||
task_description: str,
|
||||
actions: list[dict[str, Any]],
|
||||
context_summary: str,
|
||||
error_details: str,
|
||||
duration_seconds: float = 0.0,
|
||||
tokens_used: int = 0,
|
||||
lessons_learned: list[str] | None = None,
|
||||
agent_instance_id: UUID | None = None,
|
||||
agent_type_id: UUID | None = None,
|
||||
) -> Episode:
|
||||
"""
|
||||
Convenience method to record a failed episode.
|
||||
|
||||
Args:
|
||||
project_id: Project ID
|
||||
session_id: Session ID
|
||||
task_type: Type of task
|
||||
task_description: Task description
|
||||
actions: Actions taken before failure
|
||||
context_summary: Context summary
|
||||
error_details: Error details
|
||||
duration_seconds: Task duration
|
||||
tokens_used: Tokens consumed
|
||||
lessons_learned: Optional lessons from failure
|
||||
agent_instance_id: Optional agent instance
|
||||
agent_type_id: Optional agent type
|
||||
|
||||
Returns:
|
||||
The created episode
|
||||
"""
|
||||
episode_data = EpisodeCreate(
|
||||
project_id=project_id,
|
||||
session_id=session_id,
|
||||
task_type=task_type,
|
||||
task_description=task_description,
|
||||
actions=actions,
|
||||
context_summary=context_summary,
|
||||
outcome=Outcome.FAILURE,
|
||||
outcome_details=error_details,
|
||||
duration_seconds=duration_seconds,
|
||||
tokens_used=tokens_used,
|
||||
lessons_learned=lessons_learned or [],
|
||||
agent_instance_id=agent_instance_id,
|
||||
agent_type_id=agent_type_id,
|
||||
)
|
||||
return await self.record_episode(episode_data)
|
||||
|
||||
# =========================================================================
|
||||
# Retrieval Operations
|
||||
# =========================================================================
|
||||
|
||||
async def search_similar(
|
||||
self,
|
||||
project_id: UUID,
|
||||
query: str,
|
||||
limit: int = 10,
|
||||
agent_instance_id: UUID | None = None,
|
||||
) -> list[Episode]:
|
||||
"""
|
||||
Search for semantically similar episodes.
|
||||
|
||||
Args:
|
||||
project_id: Project to search within
|
||||
query: Search query
|
||||
limit: Maximum results
|
||||
agent_instance_id: Optional filter by agent instance
|
||||
|
||||
Returns:
|
||||
List of similar episodes
|
||||
"""
|
||||
result = await self._retriever.search_similar(
|
||||
project_id, query, limit, agent_instance_id
|
||||
)
|
||||
return result.items
|
||||
|
||||
async def get_recent(
|
||||
self,
|
||||
project_id: UUID,
|
||||
limit: int = 10,
|
||||
since: datetime | None = None,
|
||||
agent_instance_id: UUID | None = None,
|
||||
) -> list[Episode]:
|
||||
"""
|
||||
Get recent episodes.
|
||||
|
||||
Args:
|
||||
project_id: Project to search within
|
||||
limit: Maximum results
|
||||
since: Optional time filter
|
||||
agent_instance_id: Optional filter by agent instance
|
||||
|
||||
Returns:
|
||||
List of recent episodes
|
||||
"""
|
||||
result = await self._retriever.get_recent(
|
||||
project_id, limit, since, agent_instance_id
|
||||
)
|
||||
return result.items
|
||||
|
||||
async def get_by_outcome(
|
||||
self,
|
||||
project_id: UUID,
|
||||
outcome: Outcome,
|
||||
limit: int = 10,
|
||||
agent_instance_id: UUID | None = None,
|
||||
) -> list[Episode]:
|
||||
"""
|
||||
Get episodes by outcome.
|
||||
|
||||
Args:
|
||||
project_id: Project to search within
|
||||
outcome: Outcome to filter by
|
||||
limit: Maximum results
|
||||
agent_instance_id: Optional filter by agent instance
|
||||
|
||||
Returns:
|
||||
List of episodes with specified outcome
|
||||
"""
|
||||
result = await self._retriever.get_by_outcome(
|
||||
project_id, outcome, limit, agent_instance_id
|
||||
)
|
||||
return result.items
|
||||
|
||||
async def get_by_task_type(
|
||||
self,
|
||||
project_id: UUID,
|
||||
task_type: str,
|
||||
limit: int = 10,
|
||||
agent_instance_id: UUID | None = None,
|
||||
) -> list[Episode]:
|
||||
"""
|
||||
Get episodes by task type.
|
||||
|
||||
Args:
|
||||
project_id: Project to search within
|
||||
task_type: Task type to filter by
|
||||
limit: Maximum results
|
||||
agent_instance_id: Optional filter by agent instance
|
||||
|
||||
Returns:
|
||||
List of episodes with specified task type
|
||||
"""
|
||||
result = await self._retriever.get_by_task_type(
|
||||
project_id, task_type, limit, agent_instance_id
|
||||
)
|
||||
return result.items
|
||||
|
||||
async def get_important(
|
||||
self,
|
||||
project_id: UUID,
|
||||
limit: int = 10,
|
||||
min_importance: float = 0.7,
|
||||
agent_instance_id: UUID | None = None,
|
||||
) -> list[Episode]:
|
||||
"""
|
||||
Get high-importance episodes.
|
||||
|
||||
Args:
|
||||
project_id: Project to search within
|
||||
limit: Maximum results
|
||||
min_importance: Minimum importance score
|
||||
agent_instance_id: Optional filter by agent instance
|
||||
|
||||
Returns:
|
||||
List of important episodes
|
||||
"""
|
||||
result = await self._retriever.get_important(
|
||||
project_id, limit, min_importance, agent_instance_id
|
||||
)
|
||||
return result.items
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
project_id: UUID,
|
||||
strategy: RetrievalStrategy = RetrievalStrategy.RECENCY,
|
||||
limit: int = 10,
|
||||
**kwargs: Any,
|
||||
) -> RetrievalResult[Episode]:
|
||||
"""
|
||||
Retrieve episodes with full result metadata.
|
||||
|
||||
Args:
|
||||
project_id: Project to search within
|
||||
strategy: Retrieval strategy
|
||||
limit: Maximum results
|
||||
**kwargs: Strategy-specific parameters
|
||||
|
||||
Returns:
|
||||
RetrievalResult with episodes and metadata
|
||||
"""
|
||||
return await self._retriever.retrieve(project_id, strategy, limit, **kwargs)
|
||||
|
||||
# =========================================================================
|
||||
# Modification Operations
|
||||
# =========================================================================
|
||||
|
||||
async def get_by_id(self, episode_id: UUID) -> Episode | None:
|
||||
"""Get an episode by ID."""
|
||||
return await self._recorder.get_by_id(episode_id)
|
||||
|
||||
async def update_importance(
|
||||
self,
|
||||
episode_id: UUID,
|
||||
importance_score: float,
|
||||
) -> Episode | None:
|
||||
"""
|
||||
Update an episode's importance score.
|
||||
|
||||
Args:
|
||||
episode_id: Episode to update
|
||||
importance_score: New importance score (0.0 to 1.0)
|
||||
|
||||
Returns:
|
||||
Updated episode or None if not found
|
||||
"""
|
||||
return await self._recorder.update_importance(episode_id, importance_score)
|
||||
|
||||
async def add_lessons(
|
||||
self,
|
||||
episode_id: UUID,
|
||||
lessons: list[str],
|
||||
) -> Episode | None:
|
||||
"""
|
||||
Add lessons learned to an episode.
|
||||
|
||||
Args:
|
||||
episode_id: Episode to update
|
||||
lessons: Lessons to add
|
||||
|
||||
Returns:
|
||||
Updated episode or None if not found
|
||||
"""
|
||||
return await self._recorder.add_lessons(episode_id, lessons)
|
||||
|
||||
async def delete(self, episode_id: UUID) -> bool:
|
||||
"""
|
||||
Delete an episode.
|
||||
|
||||
Args:
|
||||
episode_id: Episode to delete
|
||||
|
||||
Returns:
|
||||
True if deleted
|
||||
"""
|
||||
return await self._recorder.delete(episode_id)
|
||||
|
||||
# =========================================================================
|
||||
# Summarization
|
||||
# =========================================================================
|
||||
|
||||
async def summarize_episodes(
|
||||
self,
|
||||
episode_ids: list[UUID],
|
||||
) -> str:
|
||||
"""
|
||||
Summarize multiple episodes into a consolidated view.
|
||||
|
||||
Args:
|
||||
episode_ids: Episodes to summarize
|
||||
|
||||
Returns:
|
||||
Summary text
|
||||
"""
|
||||
if not episode_ids:
|
||||
return "No episodes to summarize."
|
||||
|
||||
episodes: list[Episode] = []
|
||||
for episode_id in episode_ids:
|
||||
episode = await self.get_by_id(episode_id)
|
||||
if episode:
|
||||
episodes.append(episode)
|
||||
|
||||
if not episodes:
|
||||
return "No episodes found."
|
||||
|
||||
# Build summary
|
||||
lines = [f"Summary of {len(episodes)} episodes:", ""]
|
||||
|
||||
# Outcome breakdown
|
||||
success = sum(1 for e in episodes if e.outcome == Outcome.SUCCESS)
|
||||
failure = sum(1 for e in episodes if e.outcome == Outcome.FAILURE)
|
||||
partial = sum(1 for e in episodes if e.outcome == Outcome.PARTIAL)
|
||||
lines.append(
|
||||
f"Outcomes: {success} success, {failure} failure, {partial} partial"
|
||||
)
|
||||
|
||||
# Task types
|
||||
task_types = {e.task_type for e in episodes}
|
||||
lines.append(f"Task types: {', '.join(sorted(task_types))}")
|
||||
|
||||
# Aggregate lessons
|
||||
all_lessons: list[str] = []
|
||||
for e in episodes:
|
||||
all_lessons.extend(e.lessons_learned)
|
||||
|
||||
if all_lessons:
|
||||
lines.append("")
|
||||
lines.append("Key lessons learned:")
|
||||
# Deduplicate lessons
|
||||
unique_lessons = list(dict.fromkeys(all_lessons))
|
||||
for lesson in unique_lessons[:10]: # Top 10
|
||||
lines.append(f" - {lesson}")
|
||||
|
||||
# Duration and tokens
|
||||
total_duration = sum(e.duration_seconds for e in episodes)
|
||||
total_tokens = sum(e.tokens_used for e in episodes)
|
||||
lines.append("")
|
||||
lines.append(f"Total duration: {total_duration:.1f}s")
|
||||
lines.append(f"Total tokens: {total_tokens:,}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
# =========================================================================
|
||||
# Statistics
|
||||
# =========================================================================
|
||||
|
||||
async def get_stats(self, project_id: UUID) -> dict[str, Any]:
|
||||
"""
|
||||
Get episode statistics for a project.
|
||||
|
||||
Args:
|
||||
project_id: Project to get stats for
|
||||
|
||||
Returns:
|
||||
Dictionary with episode statistics
|
||||
"""
|
||||
return await self._recorder.get_stats(project_id)
|
||||
|
||||
async def count(
|
||||
self,
|
||||
project_id: UUID,
|
||||
since: datetime | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Count episodes for a project.
|
||||
|
||||
Args:
|
||||
project_id: Project to count for
|
||||
since: Optional time filter
|
||||
|
||||
Returns:
|
||||
Number of episodes
|
||||
"""
|
||||
return await self._recorder.count_by_project(project_id, since)
|
||||
357
backend/app/services/memory/episodic/recorder.py
Normal file
357
backend/app/services/memory/episodic/recorder.py
Normal file
@@ -0,0 +1,357 @@
|
||||
# app/services/memory/episodic/recorder.py
|
||||
"""
|
||||
Episode Recording.
|
||||
|
||||
Handles the creation and storage of episodic memories
|
||||
during agent task execution.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.memory.enums import EpisodeOutcome
|
||||
from app.models.memory.episode import Episode as EpisodeModel
|
||||
from app.services.memory.config import get_memory_settings
|
||||
from app.services.memory.types import Episode, EpisodeCreate, Outcome
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _outcome_to_db(outcome: Outcome) -> EpisodeOutcome:
|
||||
"""Convert service Outcome to database EpisodeOutcome."""
|
||||
return EpisodeOutcome(outcome.value)
|
||||
|
||||
|
||||
def _db_to_outcome(db_outcome: EpisodeOutcome) -> Outcome:
|
||||
"""Convert database EpisodeOutcome to service Outcome."""
|
||||
return Outcome(db_outcome.value)
|
||||
|
||||
|
||||
def _model_to_episode(model: EpisodeModel) -> Episode:
|
||||
"""Convert SQLAlchemy model to Episode dataclass."""
|
||||
# SQLAlchemy Column types are inferred as Column[T] by mypy, but at runtime
|
||||
# they return actual values. We use type: ignore to handle this mismatch.
|
||||
return Episode(
|
||||
id=model.id, # type: ignore[arg-type]
|
||||
project_id=model.project_id, # type: ignore[arg-type]
|
||||
agent_instance_id=model.agent_instance_id, # type: ignore[arg-type]
|
||||
agent_type_id=model.agent_type_id, # type: ignore[arg-type]
|
||||
session_id=model.session_id, # type: ignore[arg-type]
|
||||
task_type=model.task_type, # type: ignore[arg-type]
|
||||
task_description=model.task_description, # type: ignore[arg-type]
|
||||
actions=model.actions or [], # type: ignore[arg-type]
|
||||
context_summary=model.context_summary, # type: ignore[arg-type]
|
||||
outcome=_db_to_outcome(model.outcome), # type: ignore[arg-type]
|
||||
outcome_details=model.outcome_details or "", # type: ignore[arg-type]
|
||||
duration_seconds=model.duration_seconds, # type: ignore[arg-type]
|
||||
tokens_used=model.tokens_used, # type: ignore[arg-type]
|
||||
lessons_learned=model.lessons_learned or [], # type: ignore[arg-type]
|
||||
importance_score=model.importance_score, # type: ignore[arg-type]
|
||||
embedding=None, # Don't expose raw embedding
|
||||
occurred_at=model.occurred_at, # type: ignore[arg-type]
|
||||
created_at=model.created_at, # type: ignore[arg-type]
|
||||
updated_at=model.updated_at, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
|
||||
class EpisodeRecorder:
|
||||
"""
|
||||
Records episodes to the database.
|
||||
|
||||
Handles episode creation, importance scoring,
|
||||
and lesson extraction.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
embedding_generator: Any | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize recorder.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
embedding_generator: Optional embedding generator for semantic indexing
|
||||
"""
|
||||
self._session = session
|
||||
self._embedding_generator = embedding_generator
|
||||
self._settings = get_memory_settings()
|
||||
|
||||
async def record(self, episode_data: EpisodeCreate) -> Episode:
|
||||
"""
|
||||
Record a new episode.
|
||||
|
||||
Args:
|
||||
episode_data: Episode data to record
|
||||
|
||||
Returns:
|
||||
The created episode
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
|
||||
# Calculate importance score if not provided
|
||||
importance = episode_data.importance_score
|
||||
if importance == 0.5: # Default value, calculate
|
||||
importance = self._calculate_importance(episode_data)
|
||||
|
||||
# Create the model
|
||||
model = EpisodeModel(
|
||||
id=uuid4(),
|
||||
project_id=episode_data.project_id,
|
||||
agent_instance_id=episode_data.agent_instance_id,
|
||||
agent_type_id=episode_data.agent_type_id,
|
||||
session_id=episode_data.session_id,
|
||||
task_type=episode_data.task_type,
|
||||
task_description=episode_data.task_description,
|
||||
actions=episode_data.actions,
|
||||
context_summary=episode_data.context_summary,
|
||||
outcome=_outcome_to_db(episode_data.outcome),
|
||||
outcome_details=episode_data.outcome_details,
|
||||
duration_seconds=episode_data.duration_seconds,
|
||||
tokens_used=episode_data.tokens_used,
|
||||
lessons_learned=episode_data.lessons_learned,
|
||||
importance_score=importance,
|
||||
occurred_at=now,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
|
||||
# Generate embedding if generator available
|
||||
if self._embedding_generator is not None:
|
||||
try:
|
||||
text_for_embedding = self._create_embedding_text(episode_data)
|
||||
embedding = await self._embedding_generator.generate(text_for_embedding)
|
||||
model.embedding = embedding
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to generate embedding: {e}")
|
||||
|
||||
self._session.add(model)
|
||||
await self._session.flush()
|
||||
await self._session.refresh(model)
|
||||
|
||||
logger.debug(f"Recorded episode {model.id} for task {model.task_type}")
|
||||
return _model_to_episode(model)
|
||||
|
||||
def _calculate_importance(self, episode_data: EpisodeCreate) -> float:
|
||||
"""
|
||||
Calculate importance score for an episode.
|
||||
|
||||
Factors:
|
||||
- Outcome: Failures are more important to learn from
|
||||
- Duration: Longer tasks may be more significant
|
||||
- Token usage: Higher usage may indicate complexity
|
||||
- Lessons learned: Episodes with lessons are more valuable
|
||||
"""
|
||||
score = 0.5 # Base score
|
||||
|
||||
# Outcome factor
|
||||
if episode_data.outcome == Outcome.FAILURE:
|
||||
score += 0.2 # Failures are important for learning
|
||||
elif episode_data.outcome == Outcome.PARTIAL:
|
||||
score += 0.1
|
||||
# Success is default, no adjustment
|
||||
|
||||
# Lessons learned factor
|
||||
if episode_data.lessons_learned:
|
||||
score += min(0.15, len(episode_data.lessons_learned) * 0.05)
|
||||
|
||||
# Duration factor (longer tasks may be more significant)
|
||||
if episode_data.duration_seconds > 60:
|
||||
score += 0.05
|
||||
if episode_data.duration_seconds > 300:
|
||||
score += 0.05
|
||||
|
||||
# Token usage factor (complex tasks)
|
||||
if episode_data.tokens_used > 1000:
|
||||
score += 0.05
|
||||
|
||||
# Clamp to valid range
|
||||
return min(1.0, max(0.0, score))
|
||||
|
||||
def _create_embedding_text(self, episode_data: EpisodeCreate) -> str:
|
||||
"""Create text representation for embedding generation."""
|
||||
parts = [
|
||||
f"Task: {episode_data.task_type}",
|
||||
f"Description: {episode_data.task_description}",
|
||||
f"Context: {episode_data.context_summary}",
|
||||
f"Outcome: {episode_data.outcome.value}",
|
||||
]
|
||||
|
||||
if episode_data.outcome_details:
|
||||
parts.append(f"Details: {episode_data.outcome_details}")
|
||||
|
||||
if episode_data.lessons_learned:
|
||||
parts.append(f"Lessons: {', '.join(episode_data.lessons_learned)}")
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
async def get_by_id(self, episode_id: UUID) -> Episode | None:
|
||||
"""Get an episode by ID."""
|
||||
query = select(EpisodeModel).where(EpisodeModel.id == episode_id)
|
||||
result = await self._session.execute(query)
|
||||
model = result.scalar_one_or_none()
|
||||
|
||||
if model is None:
|
||||
return None
|
||||
|
||||
return _model_to_episode(model)
|
||||
|
||||
async def update_importance(
|
||||
self,
|
||||
episode_id: UUID,
|
||||
importance_score: float,
|
||||
) -> Episode | None:
|
||||
"""
|
||||
Update the importance score of an episode.
|
||||
|
||||
Args:
|
||||
episode_id: Episode to update
|
||||
importance_score: New importance score (0.0 to 1.0)
|
||||
|
||||
Returns:
|
||||
Updated episode or None if not found
|
||||
"""
|
||||
# Validate score
|
||||
importance_score = min(1.0, max(0.0, importance_score))
|
||||
|
||||
stmt = (
|
||||
update(EpisodeModel)
|
||||
.where(EpisodeModel.id == episode_id)
|
||||
.values(
|
||||
importance_score=importance_score,
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
.returning(EpisodeModel)
|
||||
)
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
model = result.scalar_one_or_none()
|
||||
|
||||
if model is None:
|
||||
return None
|
||||
|
||||
await self._session.flush()
|
||||
return _model_to_episode(model)
|
||||
|
||||
async def add_lessons(
|
||||
self,
|
||||
episode_id: UUID,
|
||||
lessons: list[str],
|
||||
) -> Episode | None:
|
||||
"""
|
||||
Add lessons learned to an episode.
|
||||
|
||||
Args:
|
||||
episode_id: Episode to update
|
||||
lessons: New lessons to add
|
||||
|
||||
Returns:
|
||||
Updated episode or None if not found
|
||||
"""
|
||||
# Get current episode
|
||||
query = select(EpisodeModel).where(EpisodeModel.id == episode_id)
|
||||
result = await self._session.execute(query)
|
||||
model = result.scalar_one_or_none()
|
||||
|
||||
if model is None:
|
||||
return None
|
||||
|
||||
# Append lessons
|
||||
current_lessons: list[str] = model.lessons_learned or [] # type: ignore[assignment]
|
||||
updated_lessons = current_lessons + lessons
|
||||
|
||||
stmt = (
|
||||
update(EpisodeModel)
|
||||
.where(EpisodeModel.id == episode_id)
|
||||
.values(
|
||||
lessons_learned=updated_lessons,
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
.returning(EpisodeModel)
|
||||
)
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
model = result.scalar_one_or_none()
|
||||
await self._session.flush()
|
||||
|
||||
return _model_to_episode(model) if model else None
|
||||
|
||||
async def delete(self, episode_id: UUID) -> bool:
|
||||
"""
|
||||
Delete an episode.
|
||||
|
||||
Args:
|
||||
episode_id: Episode to delete
|
||||
|
||||
Returns:
|
||||
True if deleted
|
||||
"""
|
||||
query = select(EpisodeModel).where(EpisodeModel.id == episode_id)
|
||||
result = await self._session.execute(query)
|
||||
model = result.scalar_one_or_none()
|
||||
|
||||
if model is None:
|
||||
return False
|
||||
|
||||
await self._session.delete(model)
|
||||
await self._session.flush()
|
||||
return True
|
||||
|
||||
async def count_by_project(
|
||||
self,
|
||||
project_id: UUID,
|
||||
since: datetime | None = None,
|
||||
) -> int:
|
||||
"""Count episodes for a project."""
|
||||
query = select(EpisodeModel).where(EpisodeModel.project_id == project_id)
|
||||
if since is not None:
|
||||
query = query.where(EpisodeModel.occurred_at >= since)
|
||||
|
||||
result = await self._session.execute(query)
|
||||
return len(list(result.scalars().all()))
|
||||
|
||||
async def get_stats(self, project_id: UUID) -> dict[str, Any]:
|
||||
"""
|
||||
Get statistics for a project's episodes.
|
||||
|
||||
Returns:
|
||||
Dictionary with episode statistics
|
||||
"""
|
||||
query = select(EpisodeModel).where(EpisodeModel.project_id == project_id)
|
||||
result = await self._session.execute(query)
|
||||
episodes = list(result.scalars().all())
|
||||
|
||||
if not episodes:
|
||||
return {
|
||||
"total_count": 0,
|
||||
"success_count": 0,
|
||||
"failure_count": 0,
|
||||
"partial_count": 0,
|
||||
"avg_importance": 0.0,
|
||||
"avg_duration": 0.0,
|
||||
"total_tokens": 0,
|
||||
}
|
||||
|
||||
success_count = sum(1 for e in episodes if e.outcome == EpisodeOutcome.SUCCESS)
|
||||
failure_count = sum(1 for e in episodes if e.outcome == EpisodeOutcome.FAILURE)
|
||||
partial_count = sum(1 for e in episodes if e.outcome == EpisodeOutcome.PARTIAL)
|
||||
|
||||
avg_importance = sum(e.importance_score for e in episodes) / len(episodes)
|
||||
avg_duration = sum(e.duration_seconds for e in episodes) / len(episodes)
|
||||
total_tokens = sum(e.tokens_used for e in episodes)
|
||||
|
||||
return {
|
||||
"total_count": len(episodes),
|
||||
"success_count": success_count,
|
||||
"failure_count": failure_count,
|
||||
"partial_count": partial_count,
|
||||
"avg_importance": avg_importance,
|
||||
"avg_duration": avg_duration,
|
||||
"total_tokens": total_tokens,
|
||||
}
|
||||
503
backend/app/services/memory/episodic/retrieval.py
Normal file
503
backend/app/services/memory/episodic/retrieval.py
Normal file
@@ -0,0 +1,503 @@
|
||||
# app/services/memory/episodic/retrieval.py
|
||||
"""
|
||||
Episode Retrieval Strategies.
|
||||
|
||||
Provides different retrieval strategies for finding relevant episodes:
|
||||
- Semantic similarity (vector search)
|
||||
- Recency-based
|
||||
- Outcome-based filtering
|
||||
- Importance-based ranking
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_, desc, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.memory.enums import EpisodeOutcome
|
||||
from app.models.memory.episode import Episode as EpisodeModel
|
||||
from app.services.memory.types import Episode, Outcome, RetrievalResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RetrievalStrategy(str, Enum):
|
||||
"""Retrieval strategy types."""
|
||||
|
||||
SEMANTIC = "semantic"
|
||||
RECENCY = "recency"
|
||||
OUTCOME = "outcome"
|
||||
IMPORTANCE = "importance"
|
||||
HYBRID = "hybrid"
|
||||
|
||||
|
||||
def _model_to_episode(model: EpisodeModel) -> Episode:
|
||||
"""Convert SQLAlchemy model to Episode dataclass."""
|
||||
# SQLAlchemy Column types are inferred as Column[T] by mypy, but at runtime
|
||||
# they return actual values. We use type: ignore to handle this mismatch.
|
||||
return Episode(
|
||||
id=model.id, # type: ignore[arg-type]
|
||||
project_id=model.project_id, # type: ignore[arg-type]
|
||||
agent_instance_id=model.agent_instance_id, # type: ignore[arg-type]
|
||||
agent_type_id=model.agent_type_id, # type: ignore[arg-type]
|
||||
session_id=model.session_id, # type: ignore[arg-type]
|
||||
task_type=model.task_type, # type: ignore[arg-type]
|
||||
task_description=model.task_description, # type: ignore[arg-type]
|
||||
actions=model.actions or [], # type: ignore[arg-type]
|
||||
context_summary=model.context_summary, # type: ignore[arg-type]
|
||||
outcome=Outcome(model.outcome.value),
|
||||
outcome_details=model.outcome_details or "", # type: ignore[arg-type]
|
||||
duration_seconds=model.duration_seconds, # type: ignore[arg-type]
|
||||
tokens_used=model.tokens_used, # type: ignore[arg-type]
|
||||
lessons_learned=model.lessons_learned or [], # type: ignore[arg-type]
|
||||
importance_score=model.importance_score, # type: ignore[arg-type]
|
||||
embedding=None, # Don't expose raw embedding
|
||||
occurred_at=model.occurred_at, # type: ignore[arg-type]
|
||||
created_at=model.created_at, # type: ignore[arg-type]
|
||||
updated_at=model.updated_at, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
|
||||
class BaseRetriever(ABC):
|
||||
"""Abstract base class for episode retrieval strategies."""
|
||||
|
||||
@abstractmethod
|
||||
async def retrieve(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
project_id: UUID,
|
||||
limit: int = 10,
|
||||
**kwargs: Any,
|
||||
) -> RetrievalResult[Episode]:
|
||||
"""Retrieve episodes based on the strategy."""
|
||||
...
|
||||
|
||||
|
||||
class RecencyRetriever(BaseRetriever):
|
||||
"""Retrieves episodes by recency (most recent first)."""
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
project_id: UUID,
|
||||
limit: int = 10,
|
||||
*,
|
||||
since: datetime | None = None,
|
||||
agent_instance_id: UUID | None = None,
|
||||
**kwargs: Any,
|
||||
) -> RetrievalResult[Episode]:
|
||||
"""Retrieve most recent episodes."""
|
||||
start_time = time.perf_counter()
|
||||
|
||||
query = (
|
||||
select(EpisodeModel)
|
||||
.where(EpisodeModel.project_id == project_id)
|
||||
.order_by(desc(EpisodeModel.occurred_at))
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
if since is not None:
|
||||
query = query.where(EpisodeModel.occurred_at >= since)
|
||||
|
||||
if agent_instance_id is not None:
|
||||
query = query.where(EpisodeModel.agent_instance_id == agent_instance_id)
|
||||
|
||||
result = await session.execute(query)
|
||||
models = list(result.scalars().all())
|
||||
|
||||
# Get total count
|
||||
count_query = select(EpisodeModel).where(EpisodeModel.project_id == project_id)
|
||||
if since is not None:
|
||||
count_query = count_query.where(EpisodeModel.occurred_at >= since)
|
||||
count_result = await session.execute(count_query)
|
||||
total_count = len(list(count_result.scalars().all()))
|
||||
|
||||
latency_ms = (time.perf_counter() - start_time) * 1000
|
||||
|
||||
return RetrievalResult(
|
||||
items=[_model_to_episode(m) for m in models],
|
||||
total_count=total_count,
|
||||
query="recency",
|
||||
retrieval_type=RetrievalStrategy.RECENCY.value,
|
||||
latency_ms=latency_ms,
|
||||
metadata={"since": since.isoformat() if since else None},
|
||||
)
|
||||
|
||||
|
||||
class OutcomeRetriever(BaseRetriever):
|
||||
"""Retrieves episodes filtered by outcome."""
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
project_id: UUID,
|
||||
limit: int = 10,
|
||||
*,
|
||||
outcome: Outcome | None = None,
|
||||
agent_instance_id: UUID | None = None,
|
||||
**kwargs: Any,
|
||||
) -> RetrievalResult[Episode]:
|
||||
"""Retrieve episodes by outcome."""
|
||||
start_time = time.perf_counter()
|
||||
|
||||
query = (
|
||||
select(EpisodeModel)
|
||||
.where(EpisodeModel.project_id == project_id)
|
||||
.order_by(desc(EpisodeModel.occurred_at))
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
if outcome is not None:
|
||||
db_outcome = EpisodeOutcome(outcome.value)
|
||||
query = query.where(EpisodeModel.outcome == db_outcome)
|
||||
|
||||
if agent_instance_id is not None:
|
||||
query = query.where(EpisodeModel.agent_instance_id == agent_instance_id)
|
||||
|
||||
result = await session.execute(query)
|
||||
models = list(result.scalars().all())
|
||||
|
||||
# Get total count
|
||||
count_query = select(EpisodeModel).where(EpisodeModel.project_id == project_id)
|
||||
if outcome is not None:
|
||||
count_query = count_query.where(
|
||||
EpisodeModel.outcome == EpisodeOutcome(outcome.value)
|
||||
)
|
||||
count_result = await session.execute(count_query)
|
||||
total_count = len(list(count_result.scalars().all()))
|
||||
|
||||
latency_ms = (time.perf_counter() - start_time) * 1000
|
||||
|
||||
return RetrievalResult(
|
||||
items=[_model_to_episode(m) for m in models],
|
||||
total_count=total_count,
|
||||
query=f"outcome:{outcome.value if outcome else 'all'}",
|
||||
retrieval_type=RetrievalStrategy.OUTCOME.value,
|
||||
latency_ms=latency_ms,
|
||||
metadata={"outcome": outcome.value if outcome else None},
|
||||
)
|
||||
|
||||
|
||||
class TaskTypeRetriever(BaseRetriever):
|
||||
"""Retrieves episodes filtered by task type."""
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
project_id: UUID,
|
||||
limit: int = 10,
|
||||
*,
|
||||
task_type: str | None = None,
|
||||
agent_instance_id: UUID | None = None,
|
||||
**kwargs: Any,
|
||||
) -> RetrievalResult[Episode]:
|
||||
"""Retrieve episodes by task type."""
|
||||
start_time = time.perf_counter()
|
||||
|
||||
query = (
|
||||
select(EpisodeModel)
|
||||
.where(EpisodeModel.project_id == project_id)
|
||||
.order_by(desc(EpisodeModel.occurred_at))
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
if task_type is not None:
|
||||
query = query.where(EpisodeModel.task_type == task_type)
|
||||
|
||||
if agent_instance_id is not None:
|
||||
query = query.where(EpisodeModel.agent_instance_id == agent_instance_id)
|
||||
|
||||
result = await session.execute(query)
|
||||
models = list(result.scalars().all())
|
||||
|
||||
# Get total count
|
||||
count_query = select(EpisodeModel).where(EpisodeModel.project_id == project_id)
|
||||
if task_type is not None:
|
||||
count_query = count_query.where(EpisodeModel.task_type == task_type)
|
||||
count_result = await session.execute(count_query)
|
||||
total_count = len(list(count_result.scalars().all()))
|
||||
|
||||
latency_ms = (time.perf_counter() - start_time) * 1000
|
||||
|
||||
return RetrievalResult(
|
||||
items=[_model_to_episode(m) for m in models],
|
||||
total_count=total_count,
|
||||
query=f"task_type:{task_type or 'all'}",
|
||||
retrieval_type="task_type",
|
||||
latency_ms=latency_ms,
|
||||
metadata={"task_type": task_type},
|
||||
)
|
||||
|
||||
|
||||
class ImportanceRetriever(BaseRetriever):
|
||||
"""Retrieves episodes ranked by importance score."""
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
project_id: UUID,
|
||||
limit: int = 10,
|
||||
*,
|
||||
min_importance: float = 0.0,
|
||||
agent_instance_id: UUID | None = None,
|
||||
**kwargs: Any,
|
||||
) -> RetrievalResult[Episode]:
|
||||
"""Retrieve episodes by importance."""
|
||||
start_time = time.perf_counter()
|
||||
|
||||
query = (
|
||||
select(EpisodeModel)
|
||||
.where(
|
||||
and_(
|
||||
EpisodeModel.project_id == project_id,
|
||||
EpisodeModel.importance_score >= min_importance,
|
||||
)
|
||||
)
|
||||
.order_by(desc(EpisodeModel.importance_score))
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
if agent_instance_id is not None:
|
||||
query = query.where(EpisodeModel.agent_instance_id == agent_instance_id)
|
||||
|
||||
result = await session.execute(query)
|
||||
models = list(result.scalars().all())
|
||||
|
||||
# Get total count
|
||||
count_query = select(EpisodeModel).where(
|
||||
and_(
|
||||
EpisodeModel.project_id == project_id,
|
||||
EpisodeModel.importance_score >= min_importance,
|
||||
)
|
||||
)
|
||||
count_result = await session.execute(count_query)
|
||||
total_count = len(list(count_result.scalars().all()))
|
||||
|
||||
latency_ms = (time.perf_counter() - start_time) * 1000
|
||||
|
||||
return RetrievalResult(
|
||||
items=[_model_to_episode(m) for m in models],
|
||||
total_count=total_count,
|
||||
query=f"importance>={min_importance}",
|
||||
retrieval_type=RetrievalStrategy.IMPORTANCE.value,
|
||||
latency_ms=latency_ms,
|
||||
metadata={"min_importance": min_importance},
|
||||
)
|
||||
|
||||
|
||||
class SemanticRetriever(BaseRetriever):
|
||||
"""Retrieves episodes by semantic similarity using vector search."""
|
||||
|
||||
def __init__(self, embedding_generator: Any | None = None) -> None:
|
||||
"""Initialize with optional embedding generator."""
|
||||
self._embedding_generator = embedding_generator
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
project_id: UUID,
|
||||
limit: int = 10,
|
||||
*,
|
||||
query_text: str | None = None,
|
||||
query_embedding: list[float] | None = None,
|
||||
agent_instance_id: UUID | None = None,
|
||||
**kwargs: Any,
|
||||
) -> RetrievalResult[Episode]:
|
||||
"""Retrieve episodes by semantic similarity."""
|
||||
start_time = time.perf_counter()
|
||||
|
||||
# If no embedding provided, fall back to recency
|
||||
if query_embedding is None and query_text is None:
|
||||
logger.warning(
|
||||
"No query provided for semantic search, falling back to recency"
|
||||
)
|
||||
recency = RecencyRetriever()
|
||||
fallback_result = await recency.retrieve(
|
||||
session, project_id, limit, agent_instance_id=agent_instance_id
|
||||
)
|
||||
latency_ms = (time.perf_counter() - start_time) * 1000
|
||||
return RetrievalResult(
|
||||
items=fallback_result.items,
|
||||
total_count=fallback_result.total_count,
|
||||
query="no_query",
|
||||
retrieval_type=RetrievalStrategy.SEMANTIC.value,
|
||||
latency_ms=latency_ms,
|
||||
metadata={"fallback": "recency", "reason": "no_query"},
|
||||
)
|
||||
|
||||
# Generate embedding if needed
|
||||
embedding = query_embedding
|
||||
if embedding is None and query_text is not None:
|
||||
if self._embedding_generator is not None:
|
||||
embedding = await self._embedding_generator.generate(query_text)
|
||||
else:
|
||||
logger.warning("No embedding generator, falling back to recency")
|
||||
recency = RecencyRetriever()
|
||||
fallback_result = await recency.retrieve(
|
||||
session, project_id, limit, agent_instance_id=agent_instance_id
|
||||
)
|
||||
latency_ms = (time.perf_counter() - start_time) * 1000
|
||||
return RetrievalResult(
|
||||
items=fallback_result.items,
|
||||
total_count=fallback_result.total_count,
|
||||
query=query_text,
|
||||
retrieval_type=RetrievalStrategy.SEMANTIC.value,
|
||||
latency_ms=latency_ms,
|
||||
metadata={
|
||||
"fallback": "recency",
|
||||
"reason": "no_embedding_generator",
|
||||
},
|
||||
)
|
||||
|
||||
# For now, use recency if vector search not available
|
||||
# TODO: Implement proper pgvector similarity search when integrated
|
||||
logger.debug("Vector search not yet implemented, using recency fallback")
|
||||
recency = RecencyRetriever()
|
||||
result = await recency.retrieve(
|
||||
session, project_id, limit, agent_instance_id=agent_instance_id
|
||||
)
|
||||
|
||||
latency_ms = (time.perf_counter() - start_time) * 1000
|
||||
|
||||
return RetrievalResult(
|
||||
items=result.items,
|
||||
total_count=result.total_count,
|
||||
query=query_text or "embedding",
|
||||
retrieval_type=RetrievalStrategy.SEMANTIC.value,
|
||||
latency_ms=latency_ms,
|
||||
metadata={"fallback": "recency"},
|
||||
)
|
||||
|
||||
|
||||
class EpisodeRetriever:
|
||||
"""
|
||||
Unified episode retrieval service.
|
||||
|
||||
Provides a single interface for all retrieval strategies.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
embedding_generator: Any | None = None,
|
||||
) -> None:
|
||||
"""Initialize retriever with database session."""
|
||||
self._session = session
|
||||
self._retrievers: dict[RetrievalStrategy, BaseRetriever] = {
|
||||
RetrievalStrategy.RECENCY: RecencyRetriever(),
|
||||
RetrievalStrategy.OUTCOME: OutcomeRetriever(),
|
||||
RetrievalStrategy.IMPORTANCE: ImportanceRetriever(),
|
||||
RetrievalStrategy.SEMANTIC: SemanticRetriever(embedding_generator),
|
||||
}
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
project_id: UUID,
|
||||
strategy: RetrievalStrategy = RetrievalStrategy.RECENCY,
|
||||
limit: int = 10,
|
||||
**kwargs: Any,
|
||||
) -> RetrievalResult[Episode]:
|
||||
"""
|
||||
Retrieve episodes using the specified strategy.
|
||||
|
||||
Args:
|
||||
project_id: Project to search within
|
||||
strategy: Retrieval strategy to use
|
||||
limit: Maximum number of episodes to return
|
||||
**kwargs: Strategy-specific parameters
|
||||
|
||||
Returns:
|
||||
RetrievalResult containing matching episodes
|
||||
"""
|
||||
retriever = self._retrievers.get(strategy)
|
||||
if retriever is None:
|
||||
raise ValueError(f"Unknown retrieval strategy: {strategy}")
|
||||
|
||||
return await retriever.retrieve(self._session, project_id, limit, **kwargs)
|
||||
|
||||
async def get_recent(
|
||||
self,
|
||||
project_id: UUID,
|
||||
limit: int = 10,
|
||||
since: datetime | None = None,
|
||||
agent_instance_id: UUID | None = None,
|
||||
) -> RetrievalResult[Episode]:
|
||||
"""Get recent episodes."""
|
||||
return await self.retrieve(
|
||||
project_id,
|
||||
RetrievalStrategy.RECENCY,
|
||||
limit,
|
||||
since=since,
|
||||
agent_instance_id=agent_instance_id,
|
||||
)
|
||||
|
||||
async def get_by_outcome(
|
||||
self,
|
||||
project_id: UUID,
|
||||
outcome: Outcome,
|
||||
limit: int = 10,
|
||||
agent_instance_id: UUID | None = None,
|
||||
) -> RetrievalResult[Episode]:
|
||||
"""Get episodes by outcome."""
|
||||
return await self.retrieve(
|
||||
project_id,
|
||||
RetrievalStrategy.OUTCOME,
|
||||
limit,
|
||||
outcome=outcome,
|
||||
agent_instance_id=agent_instance_id,
|
||||
)
|
||||
|
||||
async def get_by_task_type(
|
||||
self,
|
||||
project_id: UUID,
|
||||
task_type: str,
|
||||
limit: int = 10,
|
||||
agent_instance_id: UUID | None = None,
|
||||
) -> RetrievalResult[Episode]:
|
||||
"""Get episodes by task type."""
|
||||
retriever = TaskTypeRetriever()
|
||||
return await retriever.retrieve(
|
||||
self._session,
|
||||
project_id,
|
||||
limit,
|
||||
task_type=task_type,
|
||||
agent_instance_id=agent_instance_id,
|
||||
)
|
||||
|
||||
async def get_important(
|
||||
self,
|
||||
project_id: UUID,
|
||||
limit: int = 10,
|
||||
min_importance: float = 0.7,
|
||||
agent_instance_id: UUID | None = None,
|
||||
) -> RetrievalResult[Episode]:
|
||||
"""Get high-importance episodes."""
|
||||
return await self.retrieve(
|
||||
project_id,
|
||||
RetrievalStrategy.IMPORTANCE,
|
||||
limit,
|
||||
min_importance=min_importance,
|
||||
agent_instance_id=agent_instance_id,
|
||||
)
|
||||
|
||||
async def search_similar(
|
||||
self,
|
||||
project_id: UUID,
|
||||
query: str,
|
||||
limit: int = 10,
|
||||
agent_instance_id: UUID | None = None,
|
||||
) -> RetrievalResult[Episode]:
|
||||
"""Search for semantically similar episodes."""
|
||||
return await self.retrieve(
|
||||
project_id,
|
||||
RetrievalStrategy.SEMANTIC,
|
||||
limit,
|
||||
query_text=query,
|
||||
agent_instance_id=agent_instance_id,
|
||||
)
|
||||
222
backend/app/services/memory/exceptions.py
Normal file
222
backend/app/services/memory/exceptions.py
Normal file
@@ -0,0 +1,222 @@
|
||||
"""
|
||||
Memory System Exceptions
|
||||
|
||||
Custom exception classes for the Agent Memory System.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
|
||||
class MemoryError(Exception):
|
||||
"""Base exception for all memory-related errors."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
*,
|
||||
memory_type: str | None = None,
|
||||
scope_type: str | None = None,
|
||||
scope_id: str | None = None,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.memory_type = memory_type
|
||||
self.scope_type = scope_type
|
||||
self.scope_id = scope_id
|
||||
self.details = details or {}
|
||||
|
||||
|
||||
class MemoryNotFoundError(MemoryError):
|
||||
"""Raised when a memory item is not found."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Memory not found",
|
||||
*,
|
||||
memory_id: UUID | str | None = None,
|
||||
key: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.memory_id = memory_id
|
||||
self.key = key
|
||||
|
||||
|
||||
class MemoryCapacityError(MemoryError):
|
||||
"""Raised when memory capacity limits are exceeded."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Memory capacity exceeded",
|
||||
*,
|
||||
current_size: int = 0,
|
||||
max_size: int = 0,
|
||||
item_count: int = 0,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.current_size = current_size
|
||||
self.max_size = max_size
|
||||
self.item_count = item_count
|
||||
|
||||
|
||||
class MemoryExpiredError(MemoryError):
|
||||
"""Raised when attempting to access expired memory."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Memory has expired",
|
||||
*,
|
||||
key: str | None = None,
|
||||
expired_at: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.key = key
|
||||
self.expired_at = expired_at
|
||||
|
||||
|
||||
class MemoryStorageError(MemoryError):
|
||||
"""Raised when memory storage operations fail."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Memory storage operation failed",
|
||||
*,
|
||||
operation: str | None = None,
|
||||
backend: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.operation = operation
|
||||
self.backend = backend
|
||||
|
||||
|
||||
class MemoryConnectionError(MemoryError):
|
||||
"""Raised when memory storage connection fails."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Memory connection failed",
|
||||
*,
|
||||
backend: str | None = None,
|
||||
host: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.backend = backend
|
||||
self.host = host
|
||||
|
||||
|
||||
class MemorySerializationError(MemoryError):
|
||||
"""Raised when memory serialization/deserialization fails."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Memory serialization failed",
|
||||
*,
|
||||
content_type: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.content_type = content_type
|
||||
|
||||
|
||||
class MemoryScopeError(MemoryError):
|
||||
"""Raised when memory scope operations fail."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Memory scope error",
|
||||
*,
|
||||
requested_scope: str | None = None,
|
||||
allowed_scopes: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.requested_scope = requested_scope
|
||||
self.allowed_scopes = allowed_scopes or []
|
||||
|
||||
|
||||
class MemoryConsolidationError(MemoryError):
|
||||
"""Raised when memory consolidation fails."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Memory consolidation failed",
|
||||
*,
|
||||
source_type: str | None = None,
|
||||
target_type: str | None = None,
|
||||
items_processed: int = 0,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.source_type = source_type
|
||||
self.target_type = target_type
|
||||
self.items_processed = items_processed
|
||||
|
||||
|
||||
class MemoryRetrievalError(MemoryError):
|
||||
"""Raised when memory retrieval fails."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Memory retrieval failed",
|
||||
*,
|
||||
query: str | None = None,
|
||||
retrieval_type: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.query = query
|
||||
self.retrieval_type = retrieval_type
|
||||
|
||||
|
||||
class EmbeddingError(MemoryError):
|
||||
"""Raised when embedding generation fails."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Embedding generation failed",
|
||||
*,
|
||||
content_length: int = 0,
|
||||
model: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.content_length = content_length
|
||||
self.model = model
|
||||
|
||||
|
||||
class CheckpointError(MemoryError):
|
||||
"""Raised when checkpoint operations fail."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Checkpoint operation failed",
|
||||
*,
|
||||
checkpoint_id: str | None = None,
|
||||
operation: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.checkpoint_id = checkpoint_id
|
||||
self.operation = operation
|
||||
|
||||
|
||||
class MemoryConflictError(MemoryError):
|
||||
"""Raised when there's a conflict in memory (e.g., contradictory facts)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Memory conflict detected",
|
||||
*,
|
||||
conflicting_ids: list[str | UUID] | None = None,
|
||||
conflict_type: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.conflicting_ids = conflicting_ids or []
|
||||
self.conflict_type = conflict_type
|
||||
56
backend/app/services/memory/indexing/__init__.py
Normal file
56
backend/app/services/memory/indexing/__init__.py
Normal file
@@ -0,0 +1,56 @@
|
||||
# app/services/memory/indexing/__init__.py
|
||||
"""
|
||||
Memory Indexing & Retrieval.
|
||||
|
||||
Provides vector embeddings and multiple index types for efficient memory search:
|
||||
- Vector index for semantic similarity
|
||||
- Temporal index for time-based queries
|
||||
- Entity index for entity lookups
|
||||
- Outcome index for success/failure filtering
|
||||
"""
|
||||
|
||||
from .index import (
|
||||
EntityIndex,
|
||||
EntityIndexEntry,
|
||||
IndexEntry,
|
||||
MemoryIndex,
|
||||
MemoryIndexer,
|
||||
OutcomeIndex,
|
||||
OutcomeIndexEntry,
|
||||
TemporalIndex,
|
||||
TemporalIndexEntry,
|
||||
VectorIndex,
|
||||
VectorIndexEntry,
|
||||
get_memory_indexer,
|
||||
)
|
||||
from .retrieval import (
|
||||
CacheEntry,
|
||||
RelevanceScorer,
|
||||
RetrievalCache,
|
||||
RetrievalEngine,
|
||||
RetrievalQuery,
|
||||
ScoredResult,
|
||||
get_retrieval_engine,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"CacheEntry",
|
||||
"EntityIndex",
|
||||
"EntityIndexEntry",
|
||||
"IndexEntry",
|
||||
"MemoryIndex",
|
||||
"MemoryIndexer",
|
||||
"OutcomeIndex",
|
||||
"OutcomeIndexEntry",
|
||||
"RelevanceScorer",
|
||||
"RetrievalCache",
|
||||
"RetrievalEngine",
|
||||
"RetrievalQuery",
|
||||
"ScoredResult",
|
||||
"TemporalIndex",
|
||||
"TemporalIndexEntry",
|
||||
"VectorIndex",
|
||||
"VectorIndexEntry",
|
||||
"get_memory_indexer",
|
||||
"get_retrieval_engine",
|
||||
]
|
||||
858
backend/app/services/memory/indexing/index.py
Normal file
858
backend/app/services/memory/indexing/index.py
Normal file
@@ -0,0 +1,858 @@
|
||||
# app/services/memory/indexing/index.py
|
||||
"""
|
||||
Memory Indexing.
|
||||
|
||||
Provides multiple indexing strategies for efficient memory retrieval:
|
||||
- Vector embeddings for semantic search
|
||||
- Temporal index for time-based queries
|
||||
- Entity index for entity-based lookups
|
||||
- Outcome index for success/failure filtering
|
||||
"""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any, TypeVar
|
||||
from uuid import UUID
|
||||
|
||||
from app.services.memory.types import Episode, Fact, MemoryType, Outcome, Procedure
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T", Episode, Fact, Procedure)
|
||||
|
||||
|
||||
def _utcnow() -> datetime:
|
||||
"""Get current UTC time as timezone-aware datetime."""
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexEntry:
|
||||
"""A single entry in an index."""
|
||||
|
||||
memory_id: UUID
|
||||
memory_type: MemoryType
|
||||
indexed_at: datetime = field(default_factory=_utcnow)
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VectorIndexEntry(IndexEntry):
|
||||
"""An entry with vector embedding."""
|
||||
|
||||
embedding: list[float] = field(default_factory=list)
|
||||
dimension: int = 0
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Set dimension from embedding."""
|
||||
if self.embedding:
|
||||
self.dimension = len(self.embedding)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TemporalIndexEntry(IndexEntry):
|
||||
"""An entry indexed by time."""
|
||||
|
||||
timestamp: datetime = field(default_factory=_utcnow)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EntityIndexEntry(IndexEntry):
|
||||
"""An entry indexed by entity."""
|
||||
|
||||
entity_type: str = ""
|
||||
entity_value: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutcomeIndexEntry(IndexEntry):
|
||||
"""An entry indexed by outcome."""
|
||||
|
||||
outcome: Outcome = Outcome.SUCCESS
|
||||
|
||||
|
||||
class MemoryIndex[T](ABC):
|
||||
"""Abstract base class for memory indices."""
|
||||
|
||||
@abstractmethod
|
||||
async def add(self, item: T) -> IndexEntry:
|
||||
"""Add an item to the index."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def remove(self, memory_id: UUID) -> bool:
|
||||
"""Remove an item from the index."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def search(
|
||||
self,
|
||||
query: Any,
|
||||
limit: int = 10,
|
||||
**kwargs: Any,
|
||||
) -> list[IndexEntry]:
|
||||
"""Search the index."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def clear(self) -> int:
|
||||
"""Clear all entries from the index."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def count(self) -> int:
|
||||
"""Get the number of entries in the index."""
|
||||
...
|
||||
|
||||
|
||||
class VectorIndex(MemoryIndex[T]):
|
||||
"""
|
||||
Vector-based index using embeddings for semantic similarity search.
|
||||
|
||||
Uses cosine similarity for matching.
|
||||
"""
|
||||
|
||||
def __init__(self, dimension: int = 1536) -> None:
|
||||
"""
|
||||
Initialize the vector index.
|
||||
|
||||
Args:
|
||||
dimension: Embedding dimension (default 1536 for OpenAI)
|
||||
"""
|
||||
self._dimension = dimension
|
||||
self._entries: dict[UUID, VectorIndexEntry] = {}
|
||||
logger.info(f"Initialized VectorIndex with dimension={dimension}")
|
||||
|
||||
async def add(self, item: T) -> VectorIndexEntry:
|
||||
"""
|
||||
Add an item to the vector index.
|
||||
|
||||
Args:
|
||||
item: Memory item with embedding
|
||||
|
||||
Returns:
|
||||
The created index entry
|
||||
"""
|
||||
embedding = getattr(item, "embedding", None) or []
|
||||
|
||||
entry = VectorIndexEntry(
|
||||
memory_id=item.id,
|
||||
memory_type=self._get_memory_type(item),
|
||||
embedding=embedding,
|
||||
dimension=len(embedding),
|
||||
)
|
||||
|
||||
self._entries[item.id] = entry
|
||||
logger.debug(f"Added {item.id} to vector index")
|
||||
return entry
|
||||
|
||||
async def remove(self, memory_id: UUID) -> bool:
|
||||
"""Remove an item from the vector index."""
|
||||
if memory_id in self._entries:
|
||||
del self._entries[memory_id]
|
||||
logger.debug(f"Removed {memory_id} from vector index")
|
||||
return True
|
||||
return False
|
||||
|
||||
async def search( # type: ignore[override]
|
||||
self,
|
||||
query: Any,
|
||||
limit: int = 10,
|
||||
min_similarity: float = 0.0,
|
||||
**kwargs: Any,
|
||||
) -> list[VectorIndexEntry]:
|
||||
"""
|
||||
Search for similar items using vector similarity.
|
||||
|
||||
Args:
|
||||
query: Query embedding vector
|
||||
limit: Maximum results to return
|
||||
min_similarity: Minimum similarity threshold (0-1)
|
||||
**kwargs: Additional filter parameters
|
||||
|
||||
Returns:
|
||||
List of matching entries sorted by similarity
|
||||
"""
|
||||
if not isinstance(query, list) or not query:
|
||||
return []
|
||||
|
||||
results: list[tuple[float, VectorIndexEntry]] = []
|
||||
|
||||
for entry in self._entries.values():
|
||||
if not entry.embedding:
|
||||
continue
|
||||
|
||||
similarity = self._cosine_similarity(query, entry.embedding)
|
||||
if similarity >= min_similarity:
|
||||
results.append((similarity, entry))
|
||||
|
||||
# Sort by similarity descending
|
||||
results.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
# Apply memory type filter if provided
|
||||
memory_type = kwargs.get("memory_type")
|
||||
if memory_type:
|
||||
results = [(s, e) for s, e in results if e.memory_type == memory_type]
|
||||
|
||||
# Store similarity in metadata for the returned entries
|
||||
# Use a copy of metadata to avoid mutating cached entries
|
||||
output = []
|
||||
for similarity, entry in results[:limit]:
|
||||
# Create a shallow copy of the entry with updated metadata
|
||||
entry_with_score = VectorIndexEntry(
|
||||
memory_id=entry.memory_id,
|
||||
memory_type=entry.memory_type,
|
||||
embedding=entry.embedding,
|
||||
metadata={**entry.metadata, "similarity": similarity},
|
||||
)
|
||||
output.append(entry_with_score)
|
||||
|
||||
logger.debug(f"Vector search returned {len(output)} results")
|
||||
return output
|
||||
|
||||
async def clear(self) -> int:
|
||||
"""Clear all entries from the index."""
|
||||
count = len(self._entries)
|
||||
self._entries.clear()
|
||||
logger.info(f"Cleared {count} entries from vector index")
|
||||
return count
|
||||
|
||||
async def count(self) -> int:
|
||||
"""Get the number of entries in the index."""
|
||||
return len(self._entries)
|
||||
|
||||
def _cosine_similarity(self, a: list[float], b: list[float]) -> float:
|
||||
"""Calculate cosine similarity between two vectors."""
|
||||
if len(a) != len(b) or len(a) == 0:
|
||||
return 0.0
|
||||
|
||||
dot_product = sum(x * y for x, y in zip(a, b, strict=True))
|
||||
norm_a = sum(x * x for x in a) ** 0.5
|
||||
norm_b = sum(x * x for x in b) ** 0.5
|
||||
|
||||
if norm_a == 0 or norm_b == 0:
|
||||
return 0.0
|
||||
|
||||
return dot_product / (norm_a * norm_b)
|
||||
|
||||
def _get_memory_type(self, item: T) -> MemoryType:
|
||||
"""Get the memory type for an item."""
|
||||
if isinstance(item, Episode):
|
||||
return MemoryType.EPISODIC
|
||||
elif isinstance(item, Fact):
|
||||
return MemoryType.SEMANTIC
|
||||
elif isinstance(item, Procedure):
|
||||
return MemoryType.PROCEDURAL
|
||||
return MemoryType.WORKING
|
||||
|
||||
|
||||
class TemporalIndex(MemoryIndex[T]):
|
||||
"""
|
||||
Time-based index for efficient temporal queries.
|
||||
|
||||
Supports:
|
||||
- Range queries (between timestamps)
|
||||
- Recent items (within last N seconds/hours/days)
|
||||
- Oldest/newest sorting
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the temporal index."""
|
||||
self._entries: dict[UUID, TemporalIndexEntry] = {}
|
||||
# Sorted list for efficient range queries
|
||||
self._sorted_entries: list[tuple[datetime, UUID]] = []
|
||||
logger.info("Initialized TemporalIndex")
|
||||
|
||||
async def add(self, item: T) -> TemporalIndexEntry:
|
||||
"""
|
||||
Add an item to the temporal index.
|
||||
|
||||
Args:
|
||||
item: Memory item with timestamp
|
||||
|
||||
Returns:
|
||||
The created index entry
|
||||
"""
|
||||
# Get timestamp from various possible fields
|
||||
timestamp = self._get_timestamp(item)
|
||||
|
||||
entry = TemporalIndexEntry(
|
||||
memory_id=item.id,
|
||||
memory_type=self._get_memory_type(item),
|
||||
timestamp=timestamp,
|
||||
)
|
||||
|
||||
self._entries[item.id] = entry
|
||||
self._insert_sorted(timestamp, item.id)
|
||||
|
||||
logger.debug(f"Added {item.id} to temporal index at {timestamp}")
|
||||
return entry
|
||||
|
||||
async def remove(self, memory_id: UUID) -> bool:
|
||||
"""Remove an item from the temporal index."""
|
||||
if memory_id not in self._entries:
|
||||
return False
|
||||
|
||||
self._entries.pop(memory_id)
|
||||
self._sorted_entries = [
|
||||
(ts, mid) for ts, mid in self._sorted_entries if mid != memory_id
|
||||
]
|
||||
|
||||
logger.debug(f"Removed {memory_id} from temporal index")
|
||||
return True
|
||||
|
||||
async def search( # type: ignore[override]
|
||||
self,
|
||||
query: Any,
|
||||
limit: int = 10,
|
||||
start_time: datetime | None = None,
|
||||
end_time: datetime | None = None,
|
||||
recent_seconds: float | None = None,
|
||||
order: str = "desc",
|
||||
**kwargs: Any,
|
||||
) -> list[TemporalIndexEntry]:
|
||||
"""
|
||||
Search for items by time.
|
||||
|
||||
Args:
|
||||
query: Ignored for temporal search
|
||||
limit: Maximum results to return
|
||||
start_time: Start of time range
|
||||
end_time: End of time range
|
||||
recent_seconds: Get items from last N seconds
|
||||
order: Sort order ("asc" or "desc")
|
||||
**kwargs: Additional filter parameters
|
||||
|
||||
Returns:
|
||||
List of matching entries sorted by time
|
||||
"""
|
||||
if recent_seconds is not None:
|
||||
start_time = _utcnow() - timedelta(seconds=recent_seconds)
|
||||
end_time = _utcnow()
|
||||
|
||||
# Filter by time range
|
||||
results: list[TemporalIndexEntry] = []
|
||||
for entry in self._entries.values():
|
||||
if start_time and entry.timestamp < start_time:
|
||||
continue
|
||||
if end_time and entry.timestamp > end_time:
|
||||
continue
|
||||
results.append(entry)
|
||||
|
||||
# Apply memory type filter if provided
|
||||
memory_type = kwargs.get("memory_type")
|
||||
if memory_type:
|
||||
results = [e for e in results if e.memory_type == memory_type]
|
||||
|
||||
# Sort by timestamp
|
||||
results.sort(key=lambda e: e.timestamp, reverse=(order == "desc"))
|
||||
|
||||
logger.debug(f"Temporal search returned {min(len(results), limit)} results")
|
||||
return results[:limit]
|
||||
|
||||
async def clear(self) -> int:
|
||||
"""Clear all entries from the index."""
|
||||
count = len(self._entries)
|
||||
self._entries.clear()
|
||||
self._sorted_entries.clear()
|
||||
logger.info(f"Cleared {count} entries from temporal index")
|
||||
return count
|
||||
|
||||
async def count(self) -> int:
|
||||
"""Get the number of entries in the index."""
|
||||
return len(self._entries)
|
||||
|
||||
def _insert_sorted(self, timestamp: datetime, memory_id: UUID) -> None:
|
||||
"""Insert entry maintaining sorted order."""
|
||||
# Binary search insert for efficiency
|
||||
low, high = 0, len(self._sorted_entries)
|
||||
while low < high:
|
||||
mid = (low + high) // 2
|
||||
if self._sorted_entries[mid][0] < timestamp:
|
||||
low = mid + 1
|
||||
else:
|
||||
high = mid
|
||||
self._sorted_entries.insert(low, (timestamp, memory_id))
|
||||
|
||||
def _get_timestamp(self, item: T) -> datetime:
|
||||
"""Get the relevant timestamp for an item."""
|
||||
if hasattr(item, "occurred_at"):
|
||||
return item.occurred_at
|
||||
if hasattr(item, "first_learned"):
|
||||
return item.first_learned
|
||||
if hasattr(item, "last_used") and item.last_used:
|
||||
return item.last_used
|
||||
if hasattr(item, "created_at"):
|
||||
return item.created_at
|
||||
return _utcnow()
|
||||
|
||||
def _get_memory_type(self, item: T) -> MemoryType:
|
||||
"""Get the memory type for an item."""
|
||||
if isinstance(item, Episode):
|
||||
return MemoryType.EPISODIC
|
||||
elif isinstance(item, Fact):
|
||||
return MemoryType.SEMANTIC
|
||||
elif isinstance(item, Procedure):
|
||||
return MemoryType.PROCEDURAL
|
||||
return MemoryType.WORKING
|
||||
|
||||
|
||||
class EntityIndex(MemoryIndex[T]):
|
||||
"""
|
||||
Entity-based index for lookups by entities mentioned in memories.
|
||||
|
||||
Supports:
|
||||
- Single entity lookup
|
||||
- Multi-entity intersection
|
||||
- Entity type filtering
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the entity index."""
|
||||
# Main storage
|
||||
self._entries: dict[UUID, EntityIndexEntry] = {}
|
||||
# Inverted index: entity -> set of memory IDs
|
||||
self._entity_to_memories: dict[str, set[UUID]] = {}
|
||||
# Memory to entities mapping
|
||||
self._memory_to_entities: dict[UUID, set[str]] = {}
|
||||
logger.info("Initialized EntityIndex")
|
||||
|
||||
async def add(self, item: T) -> EntityIndexEntry:
|
||||
"""
|
||||
Add an item to the entity index.
|
||||
|
||||
Args:
|
||||
item: Memory item with entity information
|
||||
|
||||
Returns:
|
||||
The created index entry
|
||||
"""
|
||||
entities = self._extract_entities(item)
|
||||
|
||||
# Create entry for the primary entity (or first one)
|
||||
primary_entity = entities[0] if entities else ("unknown", "unknown")
|
||||
|
||||
entry = EntityIndexEntry(
|
||||
memory_id=item.id,
|
||||
memory_type=self._get_memory_type(item),
|
||||
entity_type=primary_entity[0],
|
||||
entity_value=primary_entity[1],
|
||||
)
|
||||
|
||||
self._entries[item.id] = entry
|
||||
|
||||
# Update inverted indices
|
||||
entity_keys = {f"{etype}:{evalue}" for etype, evalue in entities}
|
||||
self._memory_to_entities[item.id] = entity_keys
|
||||
|
||||
for entity_key in entity_keys:
|
||||
if entity_key not in self._entity_to_memories:
|
||||
self._entity_to_memories[entity_key] = set()
|
||||
self._entity_to_memories[entity_key].add(item.id)
|
||||
|
||||
logger.debug(f"Added {item.id} to entity index with {len(entities)} entities")
|
||||
return entry
|
||||
|
||||
async def remove(self, memory_id: UUID) -> bool:
|
||||
"""Remove an item from the entity index."""
|
||||
if memory_id not in self._entries:
|
||||
return False
|
||||
|
||||
# Remove from inverted index
|
||||
if memory_id in self._memory_to_entities:
|
||||
for entity_key in self._memory_to_entities[memory_id]:
|
||||
if entity_key in self._entity_to_memories:
|
||||
self._entity_to_memories[entity_key].discard(memory_id)
|
||||
if not self._entity_to_memories[entity_key]:
|
||||
del self._entity_to_memories[entity_key]
|
||||
del self._memory_to_entities[memory_id]
|
||||
|
||||
del self._entries[memory_id]
|
||||
logger.debug(f"Removed {memory_id} from entity index")
|
||||
return True
|
||||
|
||||
async def search( # type: ignore[override]
|
||||
self,
|
||||
query: Any,
|
||||
limit: int = 10,
|
||||
entity_type: str | None = None,
|
||||
entity_value: str | None = None,
|
||||
entities: list[tuple[str, str]] | None = None,
|
||||
match_all: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> list[EntityIndexEntry]:
|
||||
"""
|
||||
Search for items by entity.
|
||||
|
||||
Args:
|
||||
query: Entity value to search (if entity_type not specified)
|
||||
limit: Maximum results to return
|
||||
entity_type: Type of entity to filter
|
||||
entity_value: Specific entity value
|
||||
entities: List of (type, value) tuples to match
|
||||
match_all: If True, require all entities to match
|
||||
**kwargs: Additional filter parameters
|
||||
|
||||
Returns:
|
||||
List of matching entries
|
||||
"""
|
||||
matching_ids: set[UUID] | None = None
|
||||
|
||||
# Handle single entity query
|
||||
if entity_type and entity_value:
|
||||
entities = [(entity_type, entity_value)]
|
||||
elif entity_value is None and isinstance(query, str):
|
||||
# Search across all entity types
|
||||
entity_value = query
|
||||
|
||||
if entities:
|
||||
for etype, evalue in entities:
|
||||
entity_key = f"{etype}:{evalue}"
|
||||
if entity_key in self._entity_to_memories:
|
||||
ids = self._entity_to_memories[entity_key]
|
||||
if matching_ids is None:
|
||||
matching_ids = ids.copy()
|
||||
elif match_all:
|
||||
matching_ids &= ids
|
||||
else:
|
||||
matching_ids |= ids
|
||||
elif match_all:
|
||||
# Required entity not found
|
||||
matching_ids = set()
|
||||
break
|
||||
elif entity_value:
|
||||
# Search for value across all types
|
||||
matching_ids = set()
|
||||
for entity_key, ids in self._entity_to_memories.items():
|
||||
if entity_value.lower() in entity_key.lower():
|
||||
matching_ids |= ids
|
||||
|
||||
if matching_ids is None:
|
||||
matching_ids = set(self._entries.keys())
|
||||
|
||||
# Apply memory type filter if provided
|
||||
memory_type = kwargs.get("memory_type")
|
||||
results = []
|
||||
for mid in matching_ids:
|
||||
if mid in self._entries:
|
||||
entry = self._entries[mid]
|
||||
if memory_type and entry.memory_type != memory_type:
|
||||
continue
|
||||
results.append(entry)
|
||||
|
||||
logger.debug(f"Entity search returned {min(len(results), limit)} results")
|
||||
return results[:limit]
|
||||
|
||||
async def clear(self) -> int:
|
||||
"""Clear all entries from the index."""
|
||||
count = len(self._entries)
|
||||
self._entries.clear()
|
||||
self._entity_to_memories.clear()
|
||||
self._memory_to_entities.clear()
|
||||
logger.info(f"Cleared {count} entries from entity index")
|
||||
return count
|
||||
|
||||
async def count(self) -> int:
|
||||
"""Get the number of entries in the index."""
|
||||
return len(self._entries)
|
||||
|
||||
async def get_entities(self, memory_id: UUID) -> list[tuple[str, str]]:
|
||||
"""Get all entities for a memory item."""
|
||||
if memory_id not in self._memory_to_entities:
|
||||
return []
|
||||
|
||||
entities = []
|
||||
for entity_key in self._memory_to_entities[memory_id]:
|
||||
if ":" in entity_key:
|
||||
etype, evalue = entity_key.split(":", 1)
|
||||
entities.append((etype, evalue))
|
||||
return entities
|
||||
|
||||
def _extract_entities(self, item: T) -> list[tuple[str, str]]:
|
||||
"""Extract entities from a memory item."""
|
||||
entities: list[tuple[str, str]] = []
|
||||
|
||||
if isinstance(item, Episode):
|
||||
# Extract from task type and context
|
||||
entities.append(("task_type", item.task_type))
|
||||
if item.project_id:
|
||||
entities.append(("project", str(item.project_id)))
|
||||
if item.agent_instance_id:
|
||||
entities.append(("agent_instance", str(item.agent_instance_id)))
|
||||
if item.agent_type_id:
|
||||
entities.append(("agent_type", str(item.agent_type_id)))
|
||||
|
||||
elif isinstance(item, Fact):
|
||||
# Subject and object are entities
|
||||
entities.append(("subject", item.subject))
|
||||
entities.append(("object", item.object))
|
||||
if item.project_id:
|
||||
entities.append(("project", str(item.project_id)))
|
||||
|
||||
elif isinstance(item, Procedure):
|
||||
entities.append(("procedure", item.name))
|
||||
if item.project_id:
|
||||
entities.append(("project", str(item.project_id)))
|
||||
if item.agent_type_id:
|
||||
entities.append(("agent_type", str(item.agent_type_id)))
|
||||
|
||||
return entities
|
||||
|
||||
def _get_memory_type(self, item: T) -> MemoryType:
|
||||
"""Get the memory type for an item."""
|
||||
if isinstance(item, Episode):
|
||||
return MemoryType.EPISODIC
|
||||
elif isinstance(item, Fact):
|
||||
return MemoryType.SEMANTIC
|
||||
elif isinstance(item, Procedure):
|
||||
return MemoryType.PROCEDURAL
|
||||
return MemoryType.WORKING
|
||||
|
||||
|
||||
class OutcomeIndex(MemoryIndex[T]):
|
||||
"""
|
||||
Outcome-based index for filtering by success/failure.
|
||||
|
||||
Primarily used for episodes and procedures.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the outcome index."""
|
||||
self._entries: dict[UUID, OutcomeIndexEntry] = {}
|
||||
# Inverted index by outcome
|
||||
self._outcome_to_memories: dict[Outcome, set[UUID]] = {
|
||||
Outcome.SUCCESS: set(),
|
||||
Outcome.FAILURE: set(),
|
||||
Outcome.PARTIAL: set(),
|
||||
}
|
||||
logger.info("Initialized OutcomeIndex")
|
||||
|
||||
async def add(self, item: T) -> OutcomeIndexEntry:
|
||||
"""
|
||||
Add an item to the outcome index.
|
||||
|
||||
Args:
|
||||
item: Memory item with outcome information
|
||||
|
||||
Returns:
|
||||
The created index entry
|
||||
"""
|
||||
outcome = self._get_outcome(item)
|
||||
|
||||
entry = OutcomeIndexEntry(
|
||||
memory_id=item.id,
|
||||
memory_type=self._get_memory_type(item),
|
||||
outcome=outcome,
|
||||
)
|
||||
|
||||
self._entries[item.id] = entry
|
||||
self._outcome_to_memories[outcome].add(item.id)
|
||||
|
||||
logger.debug(f"Added {item.id} to outcome index with {outcome.value}")
|
||||
return entry
|
||||
|
||||
async def remove(self, memory_id: UUID) -> bool:
|
||||
"""Remove an item from the outcome index."""
|
||||
if memory_id not in self._entries:
|
||||
return False
|
||||
|
||||
entry = self._entries.pop(memory_id)
|
||||
self._outcome_to_memories[entry.outcome].discard(memory_id)
|
||||
|
||||
logger.debug(f"Removed {memory_id} from outcome index")
|
||||
return True
|
||||
|
||||
async def search( # type: ignore[override]
|
||||
self,
|
||||
query: Any,
|
||||
limit: int = 10,
|
||||
outcome: Outcome | None = None,
|
||||
outcomes: list[Outcome] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> list[OutcomeIndexEntry]:
|
||||
"""
|
||||
Search for items by outcome.
|
||||
|
||||
Args:
|
||||
query: Ignored for outcome search
|
||||
limit: Maximum results to return
|
||||
outcome: Single outcome to filter
|
||||
outcomes: Multiple outcomes to filter (OR)
|
||||
**kwargs: Additional filter parameters
|
||||
|
||||
Returns:
|
||||
List of matching entries
|
||||
"""
|
||||
if outcome:
|
||||
outcomes = [outcome]
|
||||
|
||||
if outcomes:
|
||||
matching_ids: set[UUID] = set()
|
||||
for o in outcomes:
|
||||
matching_ids |= self._outcome_to_memories.get(o, set())
|
||||
else:
|
||||
matching_ids = set(self._entries.keys())
|
||||
|
||||
# Apply memory type filter if provided
|
||||
memory_type = kwargs.get("memory_type")
|
||||
results = []
|
||||
for mid in matching_ids:
|
||||
if mid in self._entries:
|
||||
entry = self._entries[mid]
|
||||
if memory_type and entry.memory_type != memory_type:
|
||||
continue
|
||||
results.append(entry)
|
||||
|
||||
logger.debug(f"Outcome search returned {min(len(results), limit)} results")
|
||||
return results[:limit]
|
||||
|
||||
async def clear(self) -> int:
|
||||
"""Clear all entries from the index."""
|
||||
count = len(self._entries)
|
||||
self._entries.clear()
|
||||
for outcome in self._outcome_to_memories:
|
||||
self._outcome_to_memories[outcome].clear()
|
||||
logger.info(f"Cleared {count} entries from outcome index")
|
||||
return count
|
||||
|
||||
async def count(self) -> int:
|
||||
"""Get the number of entries in the index."""
|
||||
return len(self._entries)
|
||||
|
||||
async def get_outcome_stats(self) -> dict[Outcome, int]:
|
||||
"""Get statistics on outcomes."""
|
||||
return {outcome: len(ids) for outcome, ids in self._outcome_to_memories.items()}
|
||||
|
||||
def _get_outcome(self, item: T) -> Outcome:
|
||||
"""Get the outcome for an item."""
|
||||
if isinstance(item, Episode):
|
||||
return item.outcome
|
||||
elif isinstance(item, Procedure):
|
||||
# Derive from success rate
|
||||
if item.success_rate >= 0.8:
|
||||
return Outcome.SUCCESS
|
||||
elif item.success_rate <= 0.2:
|
||||
return Outcome.FAILURE
|
||||
return Outcome.PARTIAL
|
||||
return Outcome.SUCCESS
|
||||
|
||||
def _get_memory_type(self, item: T) -> MemoryType:
|
||||
"""Get the memory type for an item."""
|
||||
if isinstance(item, Episode):
|
||||
return MemoryType.EPISODIC
|
||||
elif isinstance(item, Fact):
|
||||
return MemoryType.SEMANTIC
|
||||
elif isinstance(item, Procedure):
|
||||
return MemoryType.PROCEDURAL
|
||||
return MemoryType.WORKING
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryIndexer:
|
||||
"""
|
||||
Unified indexer that manages all index types.
|
||||
|
||||
Provides a single interface for indexing and searching across
|
||||
multiple index types.
|
||||
"""
|
||||
|
||||
vector_index: VectorIndex[Any] = field(default_factory=VectorIndex)
|
||||
temporal_index: TemporalIndex[Any] = field(default_factory=TemporalIndex)
|
||||
entity_index: EntityIndex[Any] = field(default_factory=EntityIndex)
|
||||
outcome_index: OutcomeIndex[Any] = field(default_factory=OutcomeIndex)
|
||||
|
||||
async def index(self, item: Episode | Fact | Procedure) -> dict[str, IndexEntry]:
|
||||
"""
|
||||
Index an item across all applicable indices.
|
||||
|
||||
Args:
|
||||
item: Memory item to index
|
||||
|
||||
Returns:
|
||||
Dictionary of index type to entry
|
||||
"""
|
||||
results: dict[str, IndexEntry] = {}
|
||||
|
||||
# Vector index (if embedding present)
|
||||
if getattr(item, "embedding", None):
|
||||
results["vector"] = await self.vector_index.add(item)
|
||||
|
||||
# Temporal index
|
||||
results["temporal"] = await self.temporal_index.add(item)
|
||||
|
||||
# Entity index
|
||||
results["entity"] = await self.entity_index.add(item)
|
||||
|
||||
# Outcome index (for episodes and procedures)
|
||||
if isinstance(item, (Episode, Procedure)):
|
||||
results["outcome"] = await self.outcome_index.add(item)
|
||||
|
||||
logger.info(
|
||||
f"Indexed {item.id} across {len(results)} indices: {list(results.keys())}"
|
||||
)
|
||||
return results
|
||||
|
||||
async def remove(self, memory_id: UUID) -> dict[str, bool]:
|
||||
"""
|
||||
Remove an item from all indices.
|
||||
|
||||
Args:
|
||||
memory_id: ID of the memory to remove
|
||||
|
||||
Returns:
|
||||
Dictionary of index type to removal success
|
||||
"""
|
||||
results = {
|
||||
"vector": await self.vector_index.remove(memory_id),
|
||||
"temporal": await self.temporal_index.remove(memory_id),
|
||||
"entity": await self.entity_index.remove(memory_id),
|
||||
"outcome": await self.outcome_index.remove(memory_id),
|
||||
}
|
||||
|
||||
removed_from = [k for k, v in results.items() if v]
|
||||
if removed_from:
|
||||
logger.info(f"Removed {memory_id} from indices: {removed_from}")
|
||||
|
||||
return results
|
||||
|
||||
async def clear_all(self) -> dict[str, int]:
|
||||
"""
|
||||
Clear all indices.
|
||||
|
||||
Returns:
|
||||
Dictionary of index type to count cleared
|
||||
"""
|
||||
return {
|
||||
"vector": await self.vector_index.clear(),
|
||||
"temporal": await self.temporal_index.clear(),
|
||||
"entity": await self.entity_index.clear(),
|
||||
"outcome": await self.outcome_index.clear(),
|
||||
}
|
||||
|
||||
async def get_stats(self) -> dict[str, int]:
|
||||
"""
|
||||
Get statistics for all indices.
|
||||
|
||||
Returns:
|
||||
Dictionary of index type to entry count
|
||||
"""
|
||||
return {
|
||||
"vector": await self.vector_index.count(),
|
||||
"temporal": await self.temporal_index.count(),
|
||||
"entity": await self.entity_index.count(),
|
||||
"outcome": await self.outcome_index.count(),
|
||||
}
|
||||
|
||||
|
||||
# Singleton indexer instance
|
||||
_indexer: MemoryIndexer | None = None
|
||||
|
||||
|
||||
def get_memory_indexer() -> MemoryIndexer:
|
||||
"""Get the singleton memory indexer instance."""
|
||||
global _indexer
|
||||
if _indexer is None:
|
||||
_indexer = MemoryIndexer()
|
||||
return _indexer
|
||||
742
backend/app/services/memory/indexing/retrieval.py
Normal file
742
backend/app/services/memory/indexing/retrieval.py
Normal file
@@ -0,0 +1,742 @@
|
||||
# app/services/memory/indexing/retrieval.py
|
||||
"""
|
||||
Memory Retrieval Engine.
|
||||
|
||||
Provides hybrid retrieval capabilities combining:
|
||||
- Vector similarity search
|
||||
- Temporal filtering
|
||||
- Entity filtering
|
||||
- Outcome filtering
|
||||
- Relevance scoring
|
||||
- Result caching
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, TypeVar
|
||||
from uuid import UUID
|
||||
|
||||
from app.services.memory.types import (
|
||||
Episode,
|
||||
Fact,
|
||||
MemoryType,
|
||||
Outcome,
|
||||
Procedure,
|
||||
RetrievalResult,
|
||||
)
|
||||
|
||||
from .index import (
|
||||
MemoryIndexer,
|
||||
get_memory_indexer,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T", Episode, Fact, Procedure)
|
||||
|
||||
|
||||
def _utcnow() -> datetime:
|
||||
"""Get current UTC time as timezone-aware datetime."""
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrievalQuery:
|
||||
"""Query parameters for memory retrieval."""
|
||||
|
||||
# Text/semantic query
|
||||
query_text: str | None = None
|
||||
query_embedding: list[float] | None = None
|
||||
|
||||
# Temporal filters
|
||||
start_time: datetime | None = None
|
||||
end_time: datetime | None = None
|
||||
recent_seconds: float | None = None
|
||||
|
||||
# Entity filters
|
||||
entities: list[tuple[str, str]] | None = None
|
||||
entity_match_all: bool = False
|
||||
|
||||
# Outcome filters
|
||||
outcomes: list[Outcome] | None = None
|
||||
|
||||
# Memory type filter
|
||||
memory_types: list[MemoryType] | None = None
|
||||
|
||||
# Result options
|
||||
limit: int = 10
|
||||
min_relevance: float = 0.0
|
||||
|
||||
# Retrieval mode
|
||||
use_vector: bool = True
|
||||
use_temporal: bool = True
|
||||
use_entity: bool = True
|
||||
use_outcome: bool = True
|
||||
|
||||
def to_cache_key(self) -> str:
|
||||
"""Generate a cache key for this query."""
|
||||
key_parts = [
|
||||
self.query_text or "",
|
||||
str(self.start_time),
|
||||
str(self.end_time),
|
||||
str(self.recent_seconds),
|
||||
str(self.entities),
|
||||
str(self.outcomes),
|
||||
str(self.memory_types),
|
||||
str(self.limit),
|
||||
str(self.min_relevance),
|
||||
]
|
||||
key_string = "|".join(key_parts)
|
||||
return hashlib.sha256(key_string.encode()).hexdigest()[:32]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScoredResult:
|
||||
"""A retrieval result with relevance score."""
|
||||
|
||||
memory_id: UUID
|
||||
memory_type: MemoryType
|
||||
relevance_score: float
|
||||
score_breakdown: dict[str, float] = field(default_factory=dict)
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheEntry:
|
||||
"""A cached retrieval result."""
|
||||
|
||||
results: list[ScoredResult]
|
||||
created_at: datetime
|
||||
ttl_seconds: float
|
||||
query_key: str
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if this cache entry has expired."""
|
||||
age = (_utcnow() - self.created_at).total_seconds()
|
||||
return age > self.ttl_seconds
|
||||
|
||||
|
||||
class RelevanceScorer:
|
||||
"""
|
||||
Calculates relevance scores for retrieved memories.
|
||||
|
||||
Combines multiple signals:
|
||||
- Vector similarity (if available)
|
||||
- Temporal recency
|
||||
- Entity match count
|
||||
- Outcome preference
|
||||
- Importance/confidence
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vector_weight: float = 0.4,
|
||||
recency_weight: float = 0.2,
|
||||
entity_weight: float = 0.2,
|
||||
outcome_weight: float = 0.1,
|
||||
importance_weight: float = 0.1,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the relevance scorer.
|
||||
|
||||
Args:
|
||||
vector_weight: Weight for vector similarity (0-1)
|
||||
recency_weight: Weight for temporal recency (0-1)
|
||||
entity_weight: Weight for entity matches (0-1)
|
||||
outcome_weight: Weight for outcome preference (0-1)
|
||||
importance_weight: Weight for importance score (0-1)
|
||||
"""
|
||||
total = (
|
||||
vector_weight
|
||||
+ recency_weight
|
||||
+ entity_weight
|
||||
+ outcome_weight
|
||||
+ importance_weight
|
||||
)
|
||||
# Normalize weights
|
||||
self.vector_weight = vector_weight / total
|
||||
self.recency_weight = recency_weight / total
|
||||
self.entity_weight = entity_weight / total
|
||||
self.outcome_weight = outcome_weight / total
|
||||
self.importance_weight = importance_weight / total
|
||||
|
||||
def score(
|
||||
self,
|
||||
memory_id: UUID,
|
||||
memory_type: MemoryType,
|
||||
vector_similarity: float | None = None,
|
||||
timestamp: datetime | None = None,
|
||||
entity_match_count: int = 0,
|
||||
entity_total: int = 1,
|
||||
outcome: Outcome | None = None,
|
||||
importance: float = 0.5,
|
||||
preferred_outcomes: list[Outcome] | None = None,
|
||||
) -> ScoredResult:
|
||||
"""
|
||||
Calculate a relevance score for a memory.
|
||||
|
||||
Args:
|
||||
memory_id: ID of the memory
|
||||
memory_type: Type of memory
|
||||
vector_similarity: Similarity score from vector search (0-1)
|
||||
timestamp: Timestamp of the memory
|
||||
entity_match_count: Number of matching entities
|
||||
entity_total: Total entities in query
|
||||
outcome: Outcome of the memory
|
||||
importance: Importance score of the memory (0-1)
|
||||
preferred_outcomes: Outcomes to prefer
|
||||
|
||||
Returns:
|
||||
Scored result with breakdown
|
||||
"""
|
||||
breakdown: dict[str, float] = {}
|
||||
|
||||
# Vector similarity score
|
||||
if vector_similarity is not None:
|
||||
breakdown["vector"] = vector_similarity
|
||||
else:
|
||||
breakdown["vector"] = 0.5 # Neutral if no vector
|
||||
|
||||
# Recency score (exponential decay)
|
||||
if timestamp:
|
||||
age_hours = (_utcnow() - timestamp).total_seconds() / 3600
|
||||
# Decay with half-life of 24 hours
|
||||
breakdown["recency"] = 2 ** (-age_hours / 24)
|
||||
else:
|
||||
breakdown["recency"] = 0.5
|
||||
|
||||
# Entity match score
|
||||
if entity_total > 0:
|
||||
breakdown["entity"] = entity_match_count / entity_total
|
||||
else:
|
||||
breakdown["entity"] = 1.0 # No entity filter = full score
|
||||
|
||||
# Outcome score
|
||||
if preferred_outcomes and outcome:
|
||||
breakdown["outcome"] = 1.0 if outcome in preferred_outcomes else 0.0
|
||||
else:
|
||||
breakdown["outcome"] = 0.5 # Neutral if no preference
|
||||
|
||||
# Importance score
|
||||
breakdown["importance"] = importance
|
||||
|
||||
# Calculate weighted sum
|
||||
total_score = (
|
||||
breakdown["vector"] * self.vector_weight
|
||||
+ breakdown["recency"] * self.recency_weight
|
||||
+ breakdown["entity"] * self.entity_weight
|
||||
+ breakdown["outcome"] * self.outcome_weight
|
||||
+ breakdown["importance"] * self.importance_weight
|
||||
)
|
||||
|
||||
return ScoredResult(
|
||||
memory_id=memory_id,
|
||||
memory_type=memory_type,
|
||||
relevance_score=total_score,
|
||||
score_breakdown=breakdown,
|
||||
)
|
||||
|
||||
|
||||
class RetrievalCache:
|
||||
"""
|
||||
In-memory cache for retrieval results.
|
||||
|
||||
Supports TTL-based expiration and LRU eviction with O(1) operations.
|
||||
Uses OrderedDict for efficient LRU tracking.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_entries: int = 1000,
|
||||
default_ttl_seconds: float = 300,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the cache.
|
||||
|
||||
Args:
|
||||
max_entries: Maximum cache entries
|
||||
default_ttl_seconds: Default TTL for entries
|
||||
"""
|
||||
# OrderedDict maintains insertion order; we use move_to_end for O(1) LRU
|
||||
self._cache: OrderedDict[str, CacheEntry] = OrderedDict()
|
||||
self._max_entries = max_entries
|
||||
self._default_ttl = default_ttl_seconds
|
||||
logger.info(
|
||||
f"Initialized RetrievalCache with max_entries={max_entries}, "
|
||||
f"ttl={default_ttl_seconds}s"
|
||||
)
|
||||
|
||||
def get(self, query_key: str) -> list[ScoredResult] | None:
|
||||
"""
|
||||
Get cached results for a query.
|
||||
|
||||
Args:
|
||||
query_key: Cache key for the query
|
||||
|
||||
Returns:
|
||||
Cached results or None if not found/expired
|
||||
"""
|
||||
if query_key not in self._cache:
|
||||
return None
|
||||
|
||||
entry = self._cache[query_key]
|
||||
if entry.is_expired():
|
||||
del self._cache[query_key]
|
||||
return None
|
||||
|
||||
# Update access order (LRU) - O(1) with OrderedDict
|
||||
self._cache.move_to_end(query_key)
|
||||
|
||||
logger.debug(f"Cache hit for {query_key}")
|
||||
return entry.results
|
||||
|
||||
def put(
|
||||
self,
|
||||
query_key: str,
|
||||
results: list[ScoredResult],
|
||||
ttl_seconds: float | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Cache results for a query.
|
||||
|
||||
Args:
|
||||
query_key: Cache key for the query
|
||||
results: Results to cache
|
||||
ttl_seconds: TTL for this entry (or default)
|
||||
"""
|
||||
# Evict oldest entries if at capacity - O(1) with popitem(last=False)
|
||||
while len(self._cache) >= self._max_entries:
|
||||
self._cache.popitem(last=False)
|
||||
|
||||
entry = CacheEntry(
|
||||
results=results,
|
||||
created_at=_utcnow(),
|
||||
ttl_seconds=ttl_seconds or self._default_ttl,
|
||||
query_key=query_key,
|
||||
)
|
||||
|
||||
self._cache[query_key] = entry
|
||||
logger.debug(f"Cached {len(results)} results for {query_key}")
|
||||
|
||||
def invalidate(self, query_key: str) -> bool:
|
||||
"""
|
||||
Invalidate a specific cache entry.
|
||||
|
||||
Args:
|
||||
query_key: Cache key to invalidate
|
||||
|
||||
Returns:
|
||||
True if entry was found and removed
|
||||
"""
|
||||
if query_key in self._cache:
|
||||
del self._cache[query_key]
|
||||
return True
|
||||
return False
|
||||
|
||||
def invalidate_by_memory(self, memory_id: UUID) -> int:
|
||||
"""
|
||||
Invalidate all cache entries containing a specific memory.
|
||||
|
||||
Args:
|
||||
memory_id: Memory ID to invalidate
|
||||
|
||||
Returns:
|
||||
Number of entries invalidated
|
||||
"""
|
||||
keys_to_remove = []
|
||||
for key, entry in self._cache.items():
|
||||
if any(r.memory_id == memory_id for r in entry.results):
|
||||
keys_to_remove.append(key)
|
||||
|
||||
for key in keys_to_remove:
|
||||
self.invalidate(key)
|
||||
|
||||
if keys_to_remove:
|
||||
logger.debug(
|
||||
f"Invalidated {len(keys_to_remove)} cache entries for {memory_id}"
|
||||
)
|
||||
return len(keys_to_remove)
|
||||
|
||||
def clear(self) -> int:
|
||||
"""
|
||||
Clear all cache entries.
|
||||
|
||||
Returns:
|
||||
Number of entries cleared
|
||||
"""
|
||||
count = len(self._cache)
|
||||
self._cache.clear()
|
||||
logger.info(f"Cleared {count} cache entries")
|
||||
return count
|
||||
|
||||
def get_stats(self) -> dict[str, Any]:
|
||||
"""Get cache statistics."""
|
||||
expired_count = sum(1 for e in self._cache.values() if e.is_expired())
|
||||
return {
|
||||
"total_entries": len(self._cache),
|
||||
"expired_entries": expired_count,
|
||||
"max_entries": self._max_entries,
|
||||
"default_ttl_seconds": self._default_ttl,
|
||||
}
|
||||
|
||||
|
||||
class RetrievalEngine:
|
||||
"""
|
||||
Hybrid retrieval engine for memory search.
|
||||
|
||||
Combines multiple index types for comprehensive retrieval:
|
||||
- Vector search for semantic similarity
|
||||
- Temporal index for time-based filtering
|
||||
- Entity index for entity-based lookups
|
||||
- Outcome index for success/failure filtering
|
||||
|
||||
Results are scored and ranked using relevance scoring.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
indexer: MemoryIndexer | None = None,
|
||||
scorer: RelevanceScorer | None = None,
|
||||
cache: RetrievalCache | None = None,
|
||||
enable_cache: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the retrieval engine.
|
||||
|
||||
Args:
|
||||
indexer: Memory indexer (defaults to singleton)
|
||||
scorer: Relevance scorer (defaults to new instance)
|
||||
cache: Retrieval cache (defaults to new instance)
|
||||
enable_cache: Whether to enable result caching
|
||||
"""
|
||||
self._indexer = indexer or get_memory_indexer()
|
||||
self._scorer = scorer or RelevanceScorer()
|
||||
self._cache = cache or RetrievalCache() if enable_cache else None
|
||||
self._enable_cache = enable_cache
|
||||
logger.info(f"Initialized RetrievalEngine with cache={enable_cache}")
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
query: RetrievalQuery,
|
||||
use_cache: bool = True,
|
||||
) -> RetrievalResult[ScoredResult]:
|
||||
"""
|
||||
Retrieve relevant memories using hybrid search.
|
||||
|
||||
Args:
|
||||
query: Retrieval query parameters
|
||||
use_cache: Whether to use cached results
|
||||
|
||||
Returns:
|
||||
Retrieval result with scored items
|
||||
"""
|
||||
start_time = _utcnow()
|
||||
|
||||
# Check cache
|
||||
cache_key = query.to_cache_key()
|
||||
if use_cache and self._cache:
|
||||
cached = self._cache.get(cache_key)
|
||||
if cached:
|
||||
latency = (_utcnow() - start_time).total_seconds() * 1000
|
||||
return RetrievalResult(
|
||||
items=cached,
|
||||
total_count=len(cached),
|
||||
query=query.query_text or "",
|
||||
retrieval_type="cached",
|
||||
latency_ms=latency,
|
||||
metadata={"cache_hit": True},
|
||||
)
|
||||
|
||||
# Collect candidates from each index
|
||||
candidates: dict[UUID, dict[str, Any]] = {}
|
||||
|
||||
# Vector search
|
||||
if query.use_vector and query.query_embedding:
|
||||
vector_results = await self._indexer.vector_index.search(
|
||||
query=query.query_embedding,
|
||||
limit=query.limit * 3, # Get more for filtering
|
||||
min_similarity=query.min_relevance,
|
||||
memory_type=query.memory_types[0] if query.memory_types else None,
|
||||
)
|
||||
for entry in vector_results:
|
||||
if entry.memory_id not in candidates:
|
||||
candidates[entry.memory_id] = {
|
||||
"memory_type": entry.memory_type,
|
||||
"sources": [],
|
||||
}
|
||||
candidates[entry.memory_id]["vector_similarity"] = entry.metadata.get(
|
||||
"similarity", 0.5
|
||||
)
|
||||
candidates[entry.memory_id]["sources"].append("vector")
|
||||
|
||||
# Temporal search
|
||||
if query.use_temporal and (
|
||||
query.start_time or query.end_time or query.recent_seconds
|
||||
):
|
||||
temporal_results = await self._indexer.temporal_index.search(
|
||||
query=None,
|
||||
limit=query.limit * 3,
|
||||
start_time=query.start_time,
|
||||
end_time=query.end_time,
|
||||
recent_seconds=query.recent_seconds,
|
||||
memory_type=query.memory_types[0] if query.memory_types else None,
|
||||
)
|
||||
for temporal_entry in temporal_results:
|
||||
if temporal_entry.memory_id not in candidates:
|
||||
candidates[temporal_entry.memory_id] = {
|
||||
"memory_type": temporal_entry.memory_type,
|
||||
"sources": [],
|
||||
}
|
||||
candidates[temporal_entry.memory_id]["timestamp"] = (
|
||||
temporal_entry.timestamp
|
||||
)
|
||||
candidates[temporal_entry.memory_id]["sources"].append("temporal")
|
||||
|
||||
# Entity search
|
||||
if query.use_entity and query.entities:
|
||||
entity_results = await self._indexer.entity_index.search(
|
||||
query=None,
|
||||
limit=query.limit * 3,
|
||||
entities=query.entities,
|
||||
match_all=query.entity_match_all,
|
||||
memory_type=query.memory_types[0] if query.memory_types else None,
|
||||
)
|
||||
for entity_entry in entity_results:
|
||||
if entity_entry.memory_id not in candidates:
|
||||
candidates[entity_entry.memory_id] = {
|
||||
"memory_type": entity_entry.memory_type,
|
||||
"sources": [],
|
||||
}
|
||||
# Count entity matches
|
||||
entity_count = candidates[entity_entry.memory_id].get(
|
||||
"entity_match_count", 0
|
||||
)
|
||||
candidates[entity_entry.memory_id]["entity_match_count"] = (
|
||||
entity_count + 1
|
||||
)
|
||||
candidates[entity_entry.memory_id]["sources"].append("entity")
|
||||
|
||||
# Outcome search
|
||||
if query.use_outcome and query.outcomes:
|
||||
outcome_results = await self._indexer.outcome_index.search(
|
||||
query=None,
|
||||
limit=query.limit * 3,
|
||||
outcomes=query.outcomes,
|
||||
memory_type=query.memory_types[0] if query.memory_types else None,
|
||||
)
|
||||
for outcome_entry in outcome_results:
|
||||
if outcome_entry.memory_id not in candidates:
|
||||
candidates[outcome_entry.memory_id] = {
|
||||
"memory_type": outcome_entry.memory_type,
|
||||
"sources": [],
|
||||
}
|
||||
candidates[outcome_entry.memory_id]["outcome"] = outcome_entry.outcome
|
||||
candidates[outcome_entry.memory_id]["sources"].append("outcome")
|
||||
|
||||
# Score and rank candidates
|
||||
scored_results: list[ScoredResult] = []
|
||||
entity_total = len(query.entities) if query.entities else 1
|
||||
|
||||
for memory_id, data in candidates.items():
|
||||
scored = self._scorer.score(
|
||||
memory_id=memory_id,
|
||||
memory_type=data["memory_type"],
|
||||
vector_similarity=data.get("vector_similarity"),
|
||||
timestamp=data.get("timestamp"),
|
||||
entity_match_count=data.get("entity_match_count", 0),
|
||||
entity_total=entity_total,
|
||||
outcome=data.get("outcome"),
|
||||
preferred_outcomes=query.outcomes,
|
||||
)
|
||||
scored.metadata["sources"] = data.get("sources", [])
|
||||
|
||||
# Filter by minimum relevance
|
||||
if scored.relevance_score >= query.min_relevance:
|
||||
scored_results.append(scored)
|
||||
|
||||
# Sort by relevance score
|
||||
scored_results.sort(key=lambda x: x.relevance_score, reverse=True)
|
||||
|
||||
# Apply limit
|
||||
final_results = scored_results[: query.limit]
|
||||
|
||||
# Cache results
|
||||
if use_cache and self._cache and final_results:
|
||||
self._cache.put(cache_key, final_results)
|
||||
|
||||
latency = (_utcnow() - start_time).total_seconds() * 1000
|
||||
|
||||
logger.info(
|
||||
f"Retrieved {len(final_results)} results from {len(candidates)} candidates "
|
||||
f"in {latency:.2f}ms"
|
||||
)
|
||||
|
||||
return RetrievalResult(
|
||||
items=final_results,
|
||||
total_count=len(candidates),
|
||||
query=query.query_text or "",
|
||||
retrieval_type="hybrid",
|
||||
latency_ms=latency,
|
||||
metadata={
|
||||
"cache_hit": False,
|
||||
"candidates_count": len(candidates),
|
||||
"filtered_count": len(scored_results),
|
||||
},
|
||||
)
|
||||
|
||||
async def retrieve_similar(
|
||||
self,
|
||||
embedding: list[float],
|
||||
limit: int = 10,
|
||||
min_similarity: float = 0.5,
|
||||
memory_types: list[MemoryType] | None = None,
|
||||
) -> RetrievalResult[ScoredResult]:
|
||||
"""
|
||||
Retrieve memories similar to a given embedding.
|
||||
|
||||
Args:
|
||||
embedding: Query embedding
|
||||
limit: Maximum results
|
||||
min_similarity: Minimum similarity threshold
|
||||
memory_types: Filter by memory types
|
||||
|
||||
Returns:
|
||||
Retrieval result with scored items
|
||||
"""
|
||||
query = RetrievalQuery(
|
||||
query_embedding=embedding,
|
||||
limit=limit,
|
||||
min_relevance=min_similarity,
|
||||
memory_types=memory_types,
|
||||
use_temporal=False,
|
||||
use_entity=False,
|
||||
use_outcome=False,
|
||||
)
|
||||
return await self.retrieve(query)
|
||||
|
||||
async def retrieve_recent(
|
||||
self,
|
||||
hours: float = 24,
|
||||
limit: int = 10,
|
||||
memory_types: list[MemoryType] | None = None,
|
||||
) -> RetrievalResult[ScoredResult]:
|
||||
"""
|
||||
Retrieve recent memories.
|
||||
|
||||
Args:
|
||||
hours: Number of hours to look back
|
||||
limit: Maximum results
|
||||
memory_types: Filter by memory types
|
||||
|
||||
Returns:
|
||||
Retrieval result with scored items
|
||||
"""
|
||||
query = RetrievalQuery(
|
||||
recent_seconds=hours * 3600,
|
||||
limit=limit,
|
||||
memory_types=memory_types,
|
||||
use_vector=False,
|
||||
use_entity=False,
|
||||
use_outcome=False,
|
||||
)
|
||||
return await self.retrieve(query)
|
||||
|
||||
async def retrieve_by_entity(
|
||||
self,
|
||||
entity_type: str,
|
||||
entity_value: str,
|
||||
limit: int = 10,
|
||||
memory_types: list[MemoryType] | None = None,
|
||||
) -> RetrievalResult[ScoredResult]:
|
||||
"""
|
||||
Retrieve memories by entity.
|
||||
|
||||
Args:
|
||||
entity_type: Type of entity
|
||||
entity_value: Entity value
|
||||
limit: Maximum results
|
||||
memory_types: Filter by memory types
|
||||
|
||||
Returns:
|
||||
Retrieval result with scored items
|
||||
"""
|
||||
query = RetrievalQuery(
|
||||
entities=[(entity_type, entity_value)],
|
||||
limit=limit,
|
||||
memory_types=memory_types,
|
||||
use_vector=False,
|
||||
use_temporal=False,
|
||||
use_outcome=False,
|
||||
)
|
||||
return await self.retrieve(query)
|
||||
|
||||
async def retrieve_successful(
|
||||
self,
|
||||
limit: int = 10,
|
||||
memory_types: list[MemoryType] | None = None,
|
||||
) -> RetrievalResult[ScoredResult]:
|
||||
"""
|
||||
Retrieve successful memories.
|
||||
|
||||
Args:
|
||||
limit: Maximum results
|
||||
memory_types: Filter by memory types
|
||||
|
||||
Returns:
|
||||
Retrieval result with scored items
|
||||
"""
|
||||
query = RetrievalQuery(
|
||||
outcomes=[Outcome.SUCCESS],
|
||||
limit=limit,
|
||||
memory_types=memory_types,
|
||||
use_vector=False,
|
||||
use_temporal=False,
|
||||
use_entity=False,
|
||||
)
|
||||
return await self.retrieve(query)
|
||||
|
||||
def invalidate_cache(self) -> int:
|
||||
"""
|
||||
Invalidate all cached results.
|
||||
|
||||
Returns:
|
||||
Number of entries invalidated
|
||||
"""
|
||||
if self._cache:
|
||||
return self._cache.clear()
|
||||
return 0
|
||||
|
||||
def invalidate_cache_for_memory(self, memory_id: UUID) -> int:
|
||||
"""
|
||||
Invalidate cache entries containing a specific memory.
|
||||
|
||||
Args:
|
||||
memory_id: Memory ID to invalidate
|
||||
|
||||
Returns:
|
||||
Number of entries invalidated
|
||||
"""
|
||||
if self._cache:
|
||||
return self._cache.invalidate_by_memory(memory_id)
|
||||
return 0
|
||||
|
||||
def get_cache_stats(self) -> dict[str, Any]:
|
||||
"""Get cache statistics."""
|
||||
if self._cache:
|
||||
return self._cache.get_stats()
|
||||
return {"enabled": False}
|
||||
|
||||
|
||||
# Singleton retrieval engine instance
|
||||
_engine: RetrievalEngine | None = None
|
||||
|
||||
|
||||
def get_retrieval_engine() -> RetrievalEngine:
|
||||
"""Get the singleton retrieval engine instance."""
|
||||
global _engine
|
||||
if _engine is None:
|
||||
_engine = RetrievalEngine()
|
||||
return _engine
|
||||
19
backend/app/services/memory/integration/__init__.py
Normal file
19
backend/app/services/memory/integration/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
# app/services/memory/integration/__init__.py
|
||||
"""
|
||||
Memory Integration Module.
|
||||
|
||||
Provides integration between the agent memory system and other Syndarix components:
|
||||
- Context Engine: Memory as context source
|
||||
- Agent Lifecycle: Spawn, pause, resume, terminate hooks
|
||||
"""
|
||||
|
||||
from .context_source import MemoryContextSource, get_memory_context_source
|
||||
from .lifecycle import AgentLifecycleManager, LifecycleHooks, get_lifecycle_manager
|
||||
|
||||
__all__ = [
|
||||
"AgentLifecycleManager",
|
||||
"LifecycleHooks",
|
||||
"MemoryContextSource",
|
||||
"get_lifecycle_manager",
|
||||
"get_memory_context_source",
|
||||
]
|
||||
399
backend/app/services/memory/integration/context_source.py
Normal file
399
backend/app/services/memory/integration/context_source.py
Normal file
@@ -0,0 +1,399 @@
|
||||
# app/services/memory/integration/context_source.py
|
||||
"""
|
||||
Memory Context Source.
|
||||
|
||||
Provides agent memory as a context source for the Context Engine.
|
||||
Retrieves relevant memories based on query and converts them to MemoryContext objects.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.context.types.memory import MemoryContext
|
||||
from app.services.memory.episodic import EpisodicMemory
|
||||
from app.services.memory.procedural import ProceduralMemory
|
||||
from app.services.memory.semantic import SemanticMemory
|
||||
from app.services.memory.working import WorkingMemory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryFetchConfig:
|
||||
"""Configuration for memory fetching."""
|
||||
|
||||
# Limits per memory type
|
||||
working_limit: int = 10
|
||||
episodic_limit: int = 10
|
||||
semantic_limit: int = 15
|
||||
procedural_limit: int = 5
|
||||
|
||||
# Time ranges
|
||||
episodic_days_back: int = 30
|
||||
min_relevance: float = 0.3
|
||||
|
||||
# Which memory types to include
|
||||
include_working: bool = True
|
||||
include_episodic: bool = True
|
||||
include_semantic: bool = True
|
||||
include_procedural: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryFetchResult:
|
||||
"""Result of memory fetch operation."""
|
||||
|
||||
contexts: list[MemoryContext]
|
||||
by_type: dict[str, int]
|
||||
fetch_time_ms: float
|
||||
query: str
|
||||
|
||||
|
||||
class MemoryContextSource:
|
||||
"""
|
||||
Source for memory context in the Context Engine.
|
||||
|
||||
This service retrieves relevant memories based on a query and
|
||||
converts them to MemoryContext objects for context assembly.
|
||||
It coordinates between all memory types (working, episodic,
|
||||
semantic, procedural) to provide a comprehensive memory context.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
embedding_generator: Any | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the memory context source.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
embedding_generator: Optional embedding generator for semantic search
|
||||
"""
|
||||
self._session = session
|
||||
self._embedding_generator = embedding_generator
|
||||
|
||||
# Lazy-initialized memory services
|
||||
self._episodic: EpisodicMemory | None = None
|
||||
self._semantic: SemanticMemory | None = None
|
||||
self._procedural: ProceduralMemory | None = None
|
||||
|
||||
async def _get_episodic(self) -> EpisodicMemory:
|
||||
"""Get or create episodic memory service."""
|
||||
if self._episodic is None:
|
||||
self._episodic = await EpisodicMemory.create(
|
||||
self._session,
|
||||
self._embedding_generator,
|
||||
)
|
||||
return self._episodic
|
||||
|
||||
async def _get_semantic(self) -> SemanticMemory:
|
||||
"""Get or create semantic memory service."""
|
||||
if self._semantic is None:
|
||||
self._semantic = await SemanticMemory.create(
|
||||
self._session,
|
||||
self._embedding_generator,
|
||||
)
|
||||
return self._semantic
|
||||
|
||||
async def _get_procedural(self) -> ProceduralMemory:
|
||||
"""Get or create procedural memory service."""
|
||||
if self._procedural is None:
|
||||
self._procedural = await ProceduralMemory.create(
|
||||
self._session,
|
||||
self._embedding_generator,
|
||||
)
|
||||
return self._procedural
|
||||
|
||||
async def fetch_context(
|
||||
self,
|
||||
query: str,
|
||||
project_id: UUID,
|
||||
agent_instance_id: UUID | None = None,
|
||||
agent_type_id: UUID | None = None,
|
||||
session_id: str | None = None,
|
||||
config: MemoryFetchConfig | None = None,
|
||||
) -> MemoryFetchResult:
|
||||
"""
|
||||
Fetch relevant memories as context.
|
||||
|
||||
This is the main entry point for the Context Engine integration.
|
||||
It searches across all memory types and returns relevant memories
|
||||
as MemoryContext objects.
|
||||
|
||||
Args:
|
||||
query: Search query for finding relevant memories
|
||||
project_id: Project scope
|
||||
agent_instance_id: Optional agent instance scope
|
||||
agent_type_id: Optional agent type scope (for procedural)
|
||||
session_id: Optional session ID (for working memory)
|
||||
config: Optional fetch configuration
|
||||
|
||||
Returns:
|
||||
MemoryFetchResult with contexts and metadata
|
||||
"""
|
||||
config = config or MemoryFetchConfig()
|
||||
start_time = datetime.now(UTC)
|
||||
|
||||
contexts: list[MemoryContext] = []
|
||||
by_type: dict[str, int] = {
|
||||
"working": 0,
|
||||
"episodic": 0,
|
||||
"semantic": 0,
|
||||
"procedural": 0,
|
||||
}
|
||||
|
||||
# Fetch from working memory (session-scoped)
|
||||
if config.include_working and session_id:
|
||||
try:
|
||||
working_contexts = await self._fetch_working(
|
||||
query=query,
|
||||
session_id=session_id,
|
||||
project_id=project_id,
|
||||
agent_instance_id=agent_instance_id,
|
||||
limit=config.working_limit,
|
||||
)
|
||||
contexts.extend(working_contexts)
|
||||
by_type["working"] = len(working_contexts)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch working memory: {e}")
|
||||
|
||||
# Fetch from episodic memory
|
||||
if config.include_episodic:
|
||||
try:
|
||||
episodic_contexts = await self._fetch_episodic(
|
||||
query=query,
|
||||
project_id=project_id,
|
||||
agent_instance_id=agent_instance_id,
|
||||
limit=config.episodic_limit,
|
||||
days_back=config.episodic_days_back,
|
||||
)
|
||||
contexts.extend(episodic_contexts)
|
||||
by_type["episodic"] = len(episodic_contexts)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch episodic memory: {e}")
|
||||
|
||||
# Fetch from semantic memory
|
||||
if config.include_semantic:
|
||||
try:
|
||||
semantic_contexts = await self._fetch_semantic(
|
||||
query=query,
|
||||
project_id=project_id,
|
||||
limit=config.semantic_limit,
|
||||
min_relevance=config.min_relevance,
|
||||
)
|
||||
contexts.extend(semantic_contexts)
|
||||
by_type["semantic"] = len(semantic_contexts)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch semantic memory: {e}")
|
||||
|
||||
# Fetch from procedural memory
|
||||
if config.include_procedural:
|
||||
try:
|
||||
procedural_contexts = await self._fetch_procedural(
|
||||
query=query,
|
||||
project_id=project_id,
|
||||
agent_type_id=agent_type_id,
|
||||
limit=config.procedural_limit,
|
||||
)
|
||||
contexts.extend(procedural_contexts)
|
||||
by_type["procedural"] = len(procedural_contexts)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch procedural memory: {e}")
|
||||
|
||||
# Sort by relevance
|
||||
contexts.sort(key=lambda c: c.relevance_score, reverse=True)
|
||||
|
||||
fetch_time = (datetime.now(UTC) - start_time).total_seconds() * 1000
|
||||
|
||||
logger.debug(
|
||||
f"Fetched {len(contexts)} memory contexts for query '{query[:50]}...' "
|
||||
f"in {fetch_time:.1f}ms"
|
||||
)
|
||||
|
||||
return MemoryFetchResult(
|
||||
contexts=contexts,
|
||||
by_type=by_type,
|
||||
fetch_time_ms=fetch_time,
|
||||
query=query,
|
||||
)
|
||||
|
||||
async def _fetch_working(
|
||||
self,
|
||||
query: str,
|
||||
session_id: str,
|
||||
project_id: UUID,
|
||||
agent_instance_id: UUID | None,
|
||||
limit: int,
|
||||
) -> list[MemoryContext]:
|
||||
"""Fetch from working memory."""
|
||||
working = await WorkingMemory.for_session(
|
||||
session_id=session_id,
|
||||
project_id=str(project_id),
|
||||
agent_instance_id=str(agent_instance_id) if agent_instance_id else None,
|
||||
)
|
||||
|
||||
contexts: list[MemoryContext] = []
|
||||
all_keys = await working.list_keys()
|
||||
|
||||
# Filter keys by query (simple substring match)
|
||||
query_lower = query.lower()
|
||||
matched_keys = [k for k in all_keys if query_lower in k.lower()]
|
||||
|
||||
# If no query match, include all keys (working memory is always relevant)
|
||||
if not matched_keys and query:
|
||||
matched_keys = all_keys
|
||||
|
||||
for key in matched_keys[:limit]:
|
||||
value = await working.get(key)
|
||||
if value is not None:
|
||||
contexts.append(
|
||||
MemoryContext.from_working_memory(
|
||||
key=key,
|
||||
value=value,
|
||||
source=f"working:{session_id}",
|
||||
query=query,
|
||||
)
|
||||
)
|
||||
|
||||
return contexts
|
||||
|
||||
async def _fetch_episodic(
|
||||
self,
|
||||
query: str,
|
||||
project_id: UUID,
|
||||
agent_instance_id: UUID | None,
|
||||
limit: int,
|
||||
days_back: int,
|
||||
) -> list[MemoryContext]:
|
||||
"""Fetch from episodic memory."""
|
||||
episodic = await self._get_episodic()
|
||||
|
||||
# Search for similar episodes
|
||||
episodes = await episodic.search_similar(
|
||||
project_id=project_id,
|
||||
query=query,
|
||||
limit=limit,
|
||||
agent_instance_id=agent_instance_id,
|
||||
)
|
||||
|
||||
# Also get recent episodes if we didn't find enough
|
||||
if len(episodes) < limit // 2:
|
||||
since = datetime.now(UTC) - timedelta(days=days_back)
|
||||
recent = await episodic.get_recent(
|
||||
project_id=project_id,
|
||||
limit=limit,
|
||||
since=since,
|
||||
)
|
||||
# Deduplicate by ID
|
||||
existing_ids = {e.id for e in episodes}
|
||||
for ep in recent:
|
||||
if ep.id not in existing_ids:
|
||||
episodes.append(ep)
|
||||
if len(episodes) >= limit:
|
||||
break
|
||||
|
||||
return [
|
||||
MemoryContext.from_episodic_memory(ep, query=query)
|
||||
for ep in episodes[:limit]
|
||||
]
|
||||
|
||||
async def _fetch_semantic(
|
||||
self,
|
||||
query: str,
|
||||
project_id: UUID,
|
||||
limit: int,
|
||||
min_relevance: float,
|
||||
) -> list[MemoryContext]:
|
||||
"""Fetch from semantic memory."""
|
||||
semantic = await self._get_semantic()
|
||||
|
||||
facts = await semantic.search_facts(
|
||||
query=query,
|
||||
project_id=project_id,
|
||||
limit=limit,
|
||||
min_confidence=min_relevance,
|
||||
)
|
||||
|
||||
return [MemoryContext.from_semantic_memory(fact, query=query) for fact in facts]
|
||||
|
||||
async def _fetch_procedural(
|
||||
self,
|
||||
query: str,
|
||||
project_id: UUID,
|
||||
agent_type_id: UUID | None,
|
||||
limit: int,
|
||||
) -> list[MemoryContext]:
|
||||
"""Fetch from procedural memory."""
|
||||
procedural = await self._get_procedural()
|
||||
|
||||
procedures = await procedural.find_matching(
|
||||
context=query,
|
||||
project_id=project_id,
|
||||
agent_type_id=agent_type_id,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
return [
|
||||
MemoryContext.from_procedural_memory(proc, query=query)
|
||||
for proc in procedures
|
||||
]
|
||||
|
||||
async def fetch_all_working(
|
||||
self,
|
||||
session_id: str,
|
||||
project_id: UUID,
|
||||
agent_instance_id: UUID | None = None,
|
||||
) -> list[MemoryContext]:
|
||||
"""
|
||||
Fetch all working memory for a session.
|
||||
|
||||
Useful for including entire session state in context.
|
||||
|
||||
Args:
|
||||
session_id: Session ID
|
||||
project_id: Project scope
|
||||
agent_instance_id: Optional agent instance scope
|
||||
|
||||
Returns:
|
||||
List of MemoryContext for all working memory items
|
||||
"""
|
||||
working = await WorkingMemory.for_session(
|
||||
session_id=session_id,
|
||||
project_id=str(project_id),
|
||||
agent_instance_id=str(agent_instance_id) if agent_instance_id else None,
|
||||
)
|
||||
|
||||
contexts: list[MemoryContext] = []
|
||||
all_keys = await working.list_keys()
|
||||
|
||||
for key in all_keys:
|
||||
value = await working.get(key)
|
||||
if value is not None:
|
||||
contexts.append(
|
||||
MemoryContext.from_working_memory(
|
||||
key=key,
|
||||
value=value,
|
||||
source=f"working:{session_id}",
|
||||
)
|
||||
)
|
||||
|
||||
return contexts
|
||||
|
||||
|
||||
# Factory function
|
||||
async def get_memory_context_source(
|
||||
session: AsyncSession,
|
||||
embedding_generator: Any | None = None,
|
||||
) -> MemoryContextSource:
|
||||
"""Create a memory context source instance."""
|
||||
return MemoryContextSource(
|
||||
session=session,
|
||||
embedding_generator=embedding_generator,
|
||||
)
|
||||
635
backend/app/services/memory/integration/lifecycle.py
Normal file
635
backend/app/services/memory/integration/lifecycle.py
Normal file
@@ -0,0 +1,635 @@
|
||||
# app/services/memory/integration/lifecycle.py
|
||||
"""
|
||||
Agent Lifecycle Hooks for Memory System.
|
||||
|
||||
Provides memory management hooks for agent lifecycle events:
|
||||
- spawn: Initialize working memory for new agent instance
|
||||
- pause: Checkpoint working memory state
|
||||
- resume: Restore working memory from checkpoint
|
||||
- terminate: Consolidate session to episodic memory
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable, Coroutine
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.memory.episodic import EpisodicMemory
|
||||
from app.services.memory.types import EpisodeCreate, Outcome
|
||||
from app.services.memory.working import WorkingMemory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LifecycleEvent:
|
||||
"""Event data for lifecycle hooks."""
|
||||
|
||||
event_type: str # spawn, pause, resume, terminate
|
||||
project_id: UUID
|
||||
agent_instance_id: UUID
|
||||
agent_type_id: UUID | None = None
|
||||
session_id: str | None = None
|
||||
timestamp: datetime = field(default_factory=lambda: datetime.now(UTC))
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LifecycleResult:
|
||||
"""Result of a lifecycle operation."""
|
||||
|
||||
success: bool
|
||||
event_type: str
|
||||
message: str | None = None
|
||||
data: dict[str, Any] = field(default_factory=dict)
|
||||
duration_ms: float = 0.0
|
||||
|
||||
|
||||
# Type alias for lifecycle hooks
|
||||
LifecycleHook = Callable[[LifecycleEvent], Coroutine[Any, Any, None]]
|
||||
|
||||
|
||||
class LifecycleHooks:
|
||||
"""
|
||||
Collection of lifecycle hooks.
|
||||
|
||||
Allows registration of custom hooks for lifecycle events.
|
||||
Hooks are called after the core memory operations.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize lifecycle hooks."""
|
||||
self._spawn_hooks: list[LifecycleHook] = []
|
||||
self._pause_hooks: list[LifecycleHook] = []
|
||||
self._resume_hooks: list[LifecycleHook] = []
|
||||
self._terminate_hooks: list[LifecycleHook] = []
|
||||
|
||||
def on_spawn(self, hook: LifecycleHook) -> LifecycleHook:
|
||||
"""Register a spawn hook."""
|
||||
self._spawn_hooks.append(hook)
|
||||
return hook
|
||||
|
||||
def on_pause(self, hook: LifecycleHook) -> LifecycleHook:
|
||||
"""Register a pause hook."""
|
||||
self._pause_hooks.append(hook)
|
||||
return hook
|
||||
|
||||
def on_resume(self, hook: LifecycleHook) -> LifecycleHook:
|
||||
"""Register a resume hook."""
|
||||
self._resume_hooks.append(hook)
|
||||
return hook
|
||||
|
||||
def on_terminate(self, hook: LifecycleHook) -> LifecycleHook:
|
||||
"""Register a terminate hook."""
|
||||
self._terminate_hooks.append(hook)
|
||||
return hook
|
||||
|
||||
async def run_spawn_hooks(self, event: LifecycleEvent) -> None:
|
||||
"""Run all spawn hooks."""
|
||||
for hook in self._spawn_hooks:
|
||||
try:
|
||||
await hook(event)
|
||||
except Exception as e:
|
||||
logger.warning(f"Spawn hook failed: {e}")
|
||||
|
||||
async def run_pause_hooks(self, event: LifecycleEvent) -> None:
|
||||
"""Run all pause hooks."""
|
||||
for hook in self._pause_hooks:
|
||||
try:
|
||||
await hook(event)
|
||||
except Exception as e:
|
||||
logger.warning(f"Pause hook failed: {e}")
|
||||
|
||||
async def run_resume_hooks(self, event: LifecycleEvent) -> None:
|
||||
"""Run all resume hooks."""
|
||||
for hook in self._resume_hooks:
|
||||
try:
|
||||
await hook(event)
|
||||
except Exception as e:
|
||||
logger.warning(f"Resume hook failed: {e}")
|
||||
|
||||
async def run_terminate_hooks(self, event: LifecycleEvent) -> None:
|
||||
"""Run all terminate hooks."""
|
||||
for hook in self._terminate_hooks:
|
||||
try:
|
||||
await hook(event)
|
||||
except Exception as e:
|
||||
logger.warning(f"Terminate hook failed: {e}")
|
||||
|
||||
|
||||
class AgentLifecycleManager:
|
||||
"""
|
||||
Manager for agent lifecycle and memory integration.
|
||||
|
||||
Handles memory operations during agent lifecycle events:
|
||||
- spawn: Creates new working memory for the session
|
||||
- pause: Saves working memory state to checkpoint
|
||||
- resume: Restores working memory from checkpoint
|
||||
- terminate: Consolidates working memory to episodic memory
|
||||
"""
|
||||
|
||||
# Key prefix for checkpoint storage
|
||||
CHECKPOINT_PREFIX = "__checkpoint__"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
embedding_generator: Any | None = None,
|
||||
hooks: LifecycleHooks | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the lifecycle manager.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
embedding_generator: Optional embedding generator
|
||||
hooks: Optional lifecycle hooks
|
||||
"""
|
||||
self._session = session
|
||||
self._embedding_generator = embedding_generator
|
||||
self._hooks = hooks or LifecycleHooks()
|
||||
|
||||
# Lazy-initialized services
|
||||
self._episodic: EpisodicMemory | None = None
|
||||
|
||||
async def _get_episodic(self) -> EpisodicMemory:
|
||||
"""Get or create episodic memory service."""
|
||||
if self._episodic is None:
|
||||
self._episodic = await EpisodicMemory.create(
|
||||
self._session,
|
||||
self._embedding_generator,
|
||||
)
|
||||
return self._episodic
|
||||
|
||||
@property
|
||||
def hooks(self) -> LifecycleHooks:
|
||||
"""Get the lifecycle hooks."""
|
||||
return self._hooks
|
||||
|
||||
async def spawn(
|
||||
self,
|
||||
project_id: UUID,
|
||||
agent_instance_id: UUID,
|
||||
session_id: str,
|
||||
agent_type_id: UUID | None = None,
|
||||
initial_state: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> LifecycleResult:
|
||||
"""
|
||||
Handle agent spawn - initialize working memory.
|
||||
|
||||
Creates a new working memory instance for the agent session
|
||||
and optionally populates it with initial state.
|
||||
|
||||
Args:
|
||||
project_id: Project scope
|
||||
agent_instance_id: Agent instance ID
|
||||
session_id: Session ID for working memory
|
||||
agent_type_id: Optional agent type ID
|
||||
initial_state: Optional initial state to populate
|
||||
metadata: Optional metadata for the event
|
||||
|
||||
Returns:
|
||||
LifecycleResult with spawn outcome
|
||||
"""
|
||||
start_time = datetime.now(UTC)
|
||||
|
||||
try:
|
||||
# Create working memory for the session
|
||||
working = await WorkingMemory.for_session(
|
||||
session_id=session_id,
|
||||
project_id=str(project_id),
|
||||
agent_instance_id=str(agent_instance_id),
|
||||
)
|
||||
|
||||
# Populate initial state if provided
|
||||
items_set = 0
|
||||
if initial_state:
|
||||
for key, value in initial_state.items():
|
||||
await working.set(key, value)
|
||||
items_set += 1
|
||||
|
||||
# Create and run event hooks
|
||||
event = LifecycleEvent(
|
||||
event_type="spawn",
|
||||
project_id=project_id,
|
||||
agent_instance_id=agent_instance_id,
|
||||
agent_type_id=agent_type_id,
|
||||
session_id=session_id,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
await self._hooks.run_spawn_hooks(event)
|
||||
|
||||
duration_ms = (datetime.now(UTC) - start_time).total_seconds() * 1000
|
||||
|
||||
logger.info(
|
||||
f"Agent {agent_instance_id} spawned with session {session_id}, "
|
||||
f"initial state: {items_set} items"
|
||||
)
|
||||
|
||||
return LifecycleResult(
|
||||
success=True,
|
||||
event_type="spawn",
|
||||
message="Agent spawned successfully",
|
||||
data={
|
||||
"session_id": session_id,
|
||||
"initial_items": items_set,
|
||||
},
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Spawn failed for agent {agent_instance_id}: {e}")
|
||||
return LifecycleResult(
|
||||
success=False,
|
||||
event_type="spawn",
|
||||
message=f"Spawn failed: {e}",
|
||||
)
|
||||
|
||||
async def pause(
|
||||
self,
|
||||
project_id: UUID,
|
||||
agent_instance_id: UUID,
|
||||
session_id: str,
|
||||
checkpoint_id: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> LifecycleResult:
|
||||
"""
|
||||
Handle agent pause - checkpoint working memory.
|
||||
|
||||
Saves the current working memory state to a checkpoint
|
||||
that can be restored later with resume().
|
||||
|
||||
Args:
|
||||
project_id: Project scope
|
||||
agent_instance_id: Agent instance ID
|
||||
session_id: Session ID
|
||||
checkpoint_id: Optional checkpoint identifier
|
||||
metadata: Optional metadata for the event
|
||||
|
||||
Returns:
|
||||
LifecycleResult with checkpoint data
|
||||
"""
|
||||
start_time = datetime.now(UTC)
|
||||
checkpoint_id = checkpoint_id or f"checkpoint_{int(start_time.timestamp())}"
|
||||
|
||||
try:
|
||||
working = await WorkingMemory.for_session(
|
||||
session_id=session_id,
|
||||
project_id=str(project_id),
|
||||
agent_instance_id=str(agent_instance_id),
|
||||
)
|
||||
|
||||
# Get all current state
|
||||
all_keys = await working.list_keys()
|
||||
# Filter out checkpoint keys
|
||||
state_keys = [
|
||||
k for k in all_keys if not k.startswith(self.CHECKPOINT_PREFIX)
|
||||
]
|
||||
|
||||
state: dict[str, Any] = {}
|
||||
for key in state_keys:
|
||||
value = await working.get(key)
|
||||
if value is not None:
|
||||
state[key] = value
|
||||
|
||||
# Store checkpoint
|
||||
checkpoint_key = f"{self.CHECKPOINT_PREFIX}{checkpoint_id}"
|
||||
await working.set(
|
||||
checkpoint_key,
|
||||
{
|
||||
"state": state,
|
||||
"timestamp": start_time.isoformat(),
|
||||
"keys_count": len(state),
|
||||
},
|
||||
ttl_seconds=86400 * 7, # Keep checkpoint for 7 days
|
||||
)
|
||||
|
||||
# Run hooks
|
||||
event = LifecycleEvent(
|
||||
event_type="pause",
|
||||
project_id=project_id,
|
||||
agent_instance_id=agent_instance_id,
|
||||
session_id=session_id,
|
||||
metadata={**(metadata or {}), "checkpoint_id": checkpoint_id},
|
||||
)
|
||||
await self._hooks.run_pause_hooks(event)
|
||||
|
||||
duration_ms = (datetime.now(UTC) - start_time).total_seconds() * 1000
|
||||
|
||||
logger.info(
|
||||
f"Agent {agent_instance_id} paused, checkpoint {checkpoint_id} "
|
||||
f"saved with {len(state)} items"
|
||||
)
|
||||
|
||||
return LifecycleResult(
|
||||
success=True,
|
||||
event_type="pause",
|
||||
message="Agent paused successfully",
|
||||
data={
|
||||
"checkpoint_id": checkpoint_id,
|
||||
"items_saved": len(state),
|
||||
"timestamp": start_time.isoformat(),
|
||||
},
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Pause failed for agent {agent_instance_id}: {e}")
|
||||
return LifecycleResult(
|
||||
success=False,
|
||||
event_type="pause",
|
||||
message=f"Pause failed: {e}",
|
||||
)
|
||||
|
||||
async def resume(
|
||||
self,
|
||||
project_id: UUID,
|
||||
agent_instance_id: UUID,
|
||||
session_id: str,
|
||||
checkpoint_id: str,
|
||||
clear_current: bool = True,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> LifecycleResult:
|
||||
"""
|
||||
Handle agent resume - restore from checkpoint.
|
||||
|
||||
Restores working memory state from a previously saved checkpoint.
|
||||
|
||||
Args:
|
||||
project_id: Project scope
|
||||
agent_instance_id: Agent instance ID
|
||||
session_id: Session ID
|
||||
checkpoint_id: Checkpoint to restore from
|
||||
clear_current: Whether to clear current state before restoring
|
||||
metadata: Optional metadata for the event
|
||||
|
||||
Returns:
|
||||
LifecycleResult with restore outcome
|
||||
"""
|
||||
start_time = datetime.now(UTC)
|
||||
|
||||
try:
|
||||
working = await WorkingMemory.for_session(
|
||||
session_id=session_id,
|
||||
project_id=str(project_id),
|
||||
agent_instance_id=str(agent_instance_id),
|
||||
)
|
||||
|
||||
# Get checkpoint
|
||||
checkpoint_key = f"{self.CHECKPOINT_PREFIX}{checkpoint_id}"
|
||||
checkpoint = await working.get(checkpoint_key)
|
||||
|
||||
if checkpoint is None:
|
||||
return LifecycleResult(
|
||||
success=False,
|
||||
event_type="resume",
|
||||
message=f"Checkpoint '{checkpoint_id}' not found",
|
||||
)
|
||||
|
||||
# Clear current state if requested
|
||||
if clear_current:
|
||||
all_keys = await working.list_keys()
|
||||
for key in all_keys:
|
||||
if not key.startswith(self.CHECKPOINT_PREFIX):
|
||||
await working.delete(key)
|
||||
|
||||
# Restore state from checkpoint
|
||||
state = checkpoint.get("state", {})
|
||||
items_restored = 0
|
||||
for key, value in state.items():
|
||||
await working.set(key, value)
|
||||
items_restored += 1
|
||||
|
||||
# Run hooks
|
||||
event = LifecycleEvent(
|
||||
event_type="resume",
|
||||
project_id=project_id,
|
||||
agent_instance_id=agent_instance_id,
|
||||
session_id=session_id,
|
||||
metadata={**(metadata or {}), "checkpoint_id": checkpoint_id},
|
||||
)
|
||||
await self._hooks.run_resume_hooks(event)
|
||||
|
||||
duration_ms = (datetime.now(UTC) - start_time).total_seconds() * 1000
|
||||
|
||||
logger.info(
|
||||
f"Agent {agent_instance_id} resumed from checkpoint {checkpoint_id}, "
|
||||
f"restored {items_restored} items"
|
||||
)
|
||||
|
||||
return LifecycleResult(
|
||||
success=True,
|
||||
event_type="resume",
|
||||
message="Agent resumed successfully",
|
||||
data={
|
||||
"checkpoint_id": checkpoint_id,
|
||||
"items_restored": items_restored,
|
||||
"checkpoint_timestamp": checkpoint.get("timestamp"),
|
||||
},
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Resume failed for agent {agent_instance_id}: {e}")
|
||||
return LifecycleResult(
|
||||
success=False,
|
||||
event_type="resume",
|
||||
message=f"Resume failed: {e}",
|
||||
)
|
||||
|
||||
async def terminate(
|
||||
self,
|
||||
project_id: UUID,
|
||||
agent_instance_id: UUID,
|
||||
session_id: str,
|
||||
task_description: str | None = None,
|
||||
outcome: Outcome = Outcome.SUCCESS,
|
||||
lessons_learned: list[str] | None = None,
|
||||
consolidate_to_episodic: bool = True,
|
||||
cleanup_working: bool = True,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> LifecycleResult:
|
||||
"""
|
||||
Handle agent termination - consolidate to episodic memory.
|
||||
|
||||
Consolidates the session's working memory into an episodic memory
|
||||
entry, then optionally cleans up the working memory.
|
||||
|
||||
Args:
|
||||
project_id: Project scope
|
||||
agent_instance_id: Agent instance ID
|
||||
session_id: Session ID
|
||||
task_description: Description of what was accomplished
|
||||
outcome: Task outcome (SUCCESS, FAILURE, PARTIAL)
|
||||
lessons_learned: Optional list of lessons learned
|
||||
consolidate_to_episodic: Whether to create episodic entry
|
||||
cleanup_working: Whether to clear working memory
|
||||
metadata: Optional metadata for the event
|
||||
|
||||
Returns:
|
||||
LifecycleResult with termination outcome
|
||||
"""
|
||||
start_time = datetime.now(UTC)
|
||||
|
||||
try:
|
||||
working = await WorkingMemory.for_session(
|
||||
session_id=session_id,
|
||||
project_id=str(project_id),
|
||||
agent_instance_id=str(agent_instance_id),
|
||||
)
|
||||
|
||||
# Gather session state for consolidation
|
||||
all_keys = await working.list_keys()
|
||||
state_keys = [
|
||||
k for k in all_keys if not k.startswith(self.CHECKPOINT_PREFIX)
|
||||
]
|
||||
|
||||
session_state: dict[str, Any] = {}
|
||||
for key in state_keys:
|
||||
value = await working.get(key)
|
||||
if value is not None:
|
||||
session_state[key] = value
|
||||
|
||||
episode_id: str | None = None
|
||||
|
||||
# Consolidate to episodic memory
|
||||
if consolidate_to_episodic:
|
||||
episodic = await self._get_episodic()
|
||||
|
||||
description = task_description or f"Session {session_id} completed"
|
||||
|
||||
episode_data = EpisodeCreate(
|
||||
project_id=project_id,
|
||||
agent_instance_id=agent_instance_id,
|
||||
session_id=session_id,
|
||||
task_type="session_completion",
|
||||
task_description=description[:500],
|
||||
outcome=outcome,
|
||||
outcome_details=f"Session terminated with {len(session_state)} state items",
|
||||
actions=[
|
||||
{
|
||||
"type": "session_terminate",
|
||||
"state_keys": list(session_state.keys()),
|
||||
"outcome": outcome.value,
|
||||
}
|
||||
],
|
||||
context_summary=str(session_state)[:1000] if session_state else "",
|
||||
lessons_learned=lessons_learned or [],
|
||||
duration_seconds=0.0, # Unknown at this point
|
||||
tokens_used=0,
|
||||
importance_score=0.6, # Moderate importance for session ends
|
||||
)
|
||||
|
||||
episode = await episodic.record_episode(episode_data)
|
||||
episode_id = str(episode.id)
|
||||
|
||||
# Clean up working memory
|
||||
items_cleared = 0
|
||||
if cleanup_working:
|
||||
for key in all_keys:
|
||||
await working.delete(key)
|
||||
items_cleared += 1
|
||||
|
||||
# Run hooks
|
||||
event = LifecycleEvent(
|
||||
event_type="terminate",
|
||||
project_id=project_id,
|
||||
agent_instance_id=agent_instance_id,
|
||||
session_id=session_id,
|
||||
metadata={**(metadata or {}), "episode_id": episode_id},
|
||||
)
|
||||
await self._hooks.run_terminate_hooks(event)
|
||||
|
||||
duration_ms = (datetime.now(UTC) - start_time).total_seconds() * 1000
|
||||
|
||||
logger.info(
|
||||
f"Agent {agent_instance_id} terminated, session {session_id} "
|
||||
f"consolidated to episode {episode_id}"
|
||||
)
|
||||
|
||||
return LifecycleResult(
|
||||
success=True,
|
||||
event_type="terminate",
|
||||
message="Agent terminated successfully",
|
||||
data={
|
||||
"episode_id": episode_id,
|
||||
"state_items_consolidated": len(session_state),
|
||||
"items_cleared": items_cleared,
|
||||
"outcome": outcome.value,
|
||||
},
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Terminate failed for agent {agent_instance_id}: {e}")
|
||||
return LifecycleResult(
|
||||
success=False,
|
||||
event_type="terminate",
|
||||
message=f"Terminate failed: {e}",
|
||||
)
|
||||
|
||||
async def list_checkpoints(
|
||||
self,
|
||||
project_id: UUID,
|
||||
agent_instance_id: UUID,
|
||||
session_id: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
List available checkpoints for a session.
|
||||
|
||||
Args:
|
||||
project_id: Project scope
|
||||
agent_instance_id: Agent instance ID
|
||||
session_id: Session ID
|
||||
|
||||
Returns:
|
||||
List of checkpoint metadata dicts
|
||||
"""
|
||||
working = await WorkingMemory.for_session(
|
||||
session_id=session_id,
|
||||
project_id=str(project_id),
|
||||
agent_instance_id=str(agent_instance_id),
|
||||
)
|
||||
|
||||
all_keys = await working.list_keys()
|
||||
checkpoints: list[dict[str, Any]] = []
|
||||
|
||||
for key in all_keys:
|
||||
if key.startswith(self.CHECKPOINT_PREFIX):
|
||||
checkpoint_id = key[len(self.CHECKPOINT_PREFIX) :]
|
||||
checkpoint = await working.get(key)
|
||||
if checkpoint:
|
||||
checkpoints.append(
|
||||
{
|
||||
"checkpoint_id": checkpoint_id,
|
||||
"timestamp": checkpoint.get("timestamp"),
|
||||
"keys_count": checkpoint.get("keys_count", 0),
|
||||
}
|
||||
)
|
||||
|
||||
# Sort by timestamp (newest first)
|
||||
checkpoints.sort(
|
||||
key=lambda c: c.get("timestamp", ""),
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
return checkpoints
|
||||
|
||||
|
||||
# Factory function
|
||||
async def get_lifecycle_manager(
|
||||
session: AsyncSession,
|
||||
embedding_generator: Any | None = None,
|
||||
hooks: LifecycleHooks | None = None,
|
||||
) -> AgentLifecycleManager:
|
||||
"""Create a lifecycle manager instance."""
|
||||
return AgentLifecycleManager(
|
||||
session=session,
|
||||
embedding_generator=embedding_generator,
|
||||
hooks=hooks,
|
||||
)
|
||||
606
backend/app/services/memory/manager.py
Normal file
606
backend/app/services/memory/manager.py
Normal file
@@ -0,0 +1,606 @@
|
||||
"""
|
||||
Memory Manager
|
||||
|
||||
Facade for the Agent Memory System providing unified access
|
||||
to all memory types and operations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from .config import MemorySettings, get_memory_settings
|
||||
from .types import (
|
||||
Episode,
|
||||
EpisodeCreate,
|
||||
Fact,
|
||||
FactCreate,
|
||||
MemoryStats,
|
||||
MemoryType,
|
||||
Outcome,
|
||||
Procedure,
|
||||
ProcedureCreate,
|
||||
RetrievalResult,
|
||||
ScopeContext,
|
||||
ScopeLevel,
|
||||
TaskState,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MemoryManager:
|
||||
"""
|
||||
Unified facade for the Agent Memory System.
|
||||
|
||||
Provides a single entry point for all memory operations across
|
||||
working, episodic, semantic, and procedural memory types.
|
||||
|
||||
Usage:
|
||||
manager = MemoryManager.create()
|
||||
|
||||
# Working memory
|
||||
await manager.set_working("key", {"data": "value"})
|
||||
value = await manager.get_working("key")
|
||||
|
||||
# Episodic memory
|
||||
episode = await manager.record_episode(episode_data)
|
||||
similar = await manager.search_episodes("query")
|
||||
|
||||
# Semantic memory
|
||||
fact = await manager.store_fact(fact_data)
|
||||
facts = await manager.search_facts("query")
|
||||
|
||||
# Procedural memory
|
||||
procedure = await manager.record_procedure(procedure_data)
|
||||
procedures = await manager.find_procedures("context")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings: MemorySettings,
|
||||
scope: ScopeContext,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the MemoryManager.
|
||||
|
||||
Args:
|
||||
settings: Memory configuration settings
|
||||
scope: The scope context for this manager instance
|
||||
"""
|
||||
self._settings = settings
|
||||
self._scope = scope
|
||||
self._initialized = False
|
||||
|
||||
# These will be initialized when the respective sub-modules are implemented
|
||||
self._working_memory: Any | None = None
|
||||
self._episodic_memory: Any | None = None
|
||||
self._semantic_memory: Any | None = None
|
||||
self._procedural_memory: Any | None = None
|
||||
|
||||
logger.debug(
|
||||
"MemoryManager created for scope %s:%s",
|
||||
scope.scope_type.value,
|
||||
scope.scope_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
scope_type: ScopeLevel = ScopeLevel.SESSION,
|
||||
scope_id: str = "default",
|
||||
parent_scope: ScopeContext | None = None,
|
||||
settings: MemorySettings | None = None,
|
||||
) -> "MemoryManager":
|
||||
"""
|
||||
Create a new MemoryManager instance.
|
||||
|
||||
Args:
|
||||
scope_type: The scope level for this manager
|
||||
scope_id: The scope identifier
|
||||
parent_scope: Optional parent scope for inheritance
|
||||
settings: Optional custom settings (uses global if not provided)
|
||||
|
||||
Returns:
|
||||
A new MemoryManager instance
|
||||
"""
|
||||
if settings is None:
|
||||
settings = get_memory_settings()
|
||||
|
||||
scope = ScopeContext(
|
||||
scope_type=scope_type,
|
||||
scope_id=scope_id,
|
||||
parent=parent_scope,
|
||||
)
|
||||
|
||||
return cls(settings=settings, scope=scope)
|
||||
|
||||
@classmethod
|
||||
def for_session(
|
||||
cls,
|
||||
session_id: str,
|
||||
agent_instance_id: UUID | None = None,
|
||||
project_id: UUID | None = None,
|
||||
) -> "MemoryManager":
|
||||
"""
|
||||
Create a MemoryManager for a specific session.
|
||||
|
||||
Builds the appropriate scope hierarchy based on provided IDs.
|
||||
|
||||
Args:
|
||||
session_id: The session identifier
|
||||
agent_instance_id: Optional agent instance ID
|
||||
project_id: Optional project ID
|
||||
|
||||
Returns:
|
||||
A MemoryManager configured for the session scope
|
||||
"""
|
||||
settings = get_memory_settings()
|
||||
|
||||
# Build scope hierarchy
|
||||
parent: ScopeContext | None = None
|
||||
|
||||
if project_id:
|
||||
parent = ScopeContext(
|
||||
scope_type=ScopeLevel.PROJECT,
|
||||
scope_id=str(project_id),
|
||||
parent=ScopeContext(
|
||||
scope_type=ScopeLevel.GLOBAL,
|
||||
scope_id="global",
|
||||
),
|
||||
)
|
||||
|
||||
if agent_instance_id:
|
||||
parent = ScopeContext(
|
||||
scope_type=ScopeLevel.AGENT_INSTANCE,
|
||||
scope_id=str(agent_instance_id),
|
||||
parent=parent,
|
||||
)
|
||||
|
||||
scope = ScopeContext(
|
||||
scope_type=ScopeLevel.SESSION,
|
||||
scope_id=session_id,
|
||||
parent=parent,
|
||||
)
|
||||
|
||||
return cls(settings=settings, scope=scope)
|
||||
|
||||
@property
|
||||
def scope(self) -> ScopeContext:
|
||||
"""Get the current scope context."""
|
||||
return self._scope
|
||||
|
||||
@property
|
||||
def settings(self) -> MemorySettings:
|
||||
"""Get the memory settings."""
|
||||
return self._settings
|
||||
|
||||
# =========================================================================
|
||||
# Working Memory Operations
|
||||
# =========================================================================
|
||||
|
||||
async def set_working(
|
||||
self,
|
||||
key: str,
|
||||
value: Any,
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Set a value in working memory.
|
||||
|
||||
Args:
|
||||
key: The key to store the value under
|
||||
value: The value to store (must be JSON serializable)
|
||||
ttl_seconds: Optional TTL (uses default if not provided)
|
||||
"""
|
||||
# Placeholder - will be implemented in #89
|
||||
logger.debug("set_working called for key=%s (not yet implemented)", key)
|
||||
raise NotImplementedError("Working memory not yet implemented")
|
||||
|
||||
async def get_working(
|
||||
self,
|
||||
key: str,
|
||||
default: Any = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Get a value from working memory.
|
||||
|
||||
Args:
|
||||
key: The key to retrieve
|
||||
default: Default value if key not found
|
||||
|
||||
Returns:
|
||||
The stored value or default
|
||||
"""
|
||||
# Placeholder - will be implemented in #89
|
||||
logger.debug("get_working called for key=%s (not yet implemented)", key)
|
||||
raise NotImplementedError("Working memory not yet implemented")
|
||||
|
||||
async def delete_working(self, key: str) -> bool:
|
||||
"""
|
||||
Delete a value from working memory.
|
||||
|
||||
Args:
|
||||
key: The key to delete
|
||||
|
||||
Returns:
|
||||
True if the key was deleted, False if not found
|
||||
"""
|
||||
# Placeholder - will be implemented in #89
|
||||
logger.debug("delete_working called for key=%s (not yet implemented)", key)
|
||||
raise NotImplementedError("Working memory not yet implemented")
|
||||
|
||||
async def set_task_state(self, state: TaskState) -> None:
|
||||
"""
|
||||
Set the current task state in working memory.
|
||||
|
||||
Args:
|
||||
state: The task state to store
|
||||
"""
|
||||
# Placeholder - will be implemented in #89
|
||||
logger.debug(
|
||||
"set_task_state called for task=%s (not yet implemented)",
|
||||
state.task_id,
|
||||
)
|
||||
raise NotImplementedError("Working memory not yet implemented")
|
||||
|
||||
async def get_task_state(self) -> TaskState | None:
|
||||
"""
|
||||
Get the current task state from working memory.
|
||||
|
||||
Returns:
|
||||
The current task state or None
|
||||
"""
|
||||
# Placeholder - will be implemented in #89
|
||||
logger.debug("get_task_state called (not yet implemented)")
|
||||
raise NotImplementedError("Working memory not yet implemented")
|
||||
|
||||
async def create_checkpoint(self) -> str:
|
||||
"""
|
||||
Create a checkpoint of the current working memory state.
|
||||
|
||||
Returns:
|
||||
The checkpoint ID
|
||||
"""
|
||||
# Placeholder - will be implemented in #89
|
||||
logger.debug("create_checkpoint called (not yet implemented)")
|
||||
raise NotImplementedError("Working memory not yet implemented")
|
||||
|
||||
async def restore_checkpoint(self, checkpoint_id: str) -> None:
|
||||
"""
|
||||
Restore working memory from a checkpoint.
|
||||
|
||||
Args:
|
||||
checkpoint_id: The checkpoint to restore from
|
||||
"""
|
||||
# Placeholder - will be implemented in #89
|
||||
logger.debug(
|
||||
"restore_checkpoint called for id=%s (not yet implemented)",
|
||||
checkpoint_id,
|
||||
)
|
||||
raise NotImplementedError("Working memory not yet implemented")
|
||||
|
||||
# =========================================================================
|
||||
# Episodic Memory Operations
|
||||
# =========================================================================
|
||||
|
||||
async def record_episode(self, episode: EpisodeCreate) -> Episode:
|
||||
"""
|
||||
Record a new episode in episodic memory.
|
||||
|
||||
Args:
|
||||
episode: The episode data to record
|
||||
|
||||
Returns:
|
||||
The created episode with ID
|
||||
"""
|
||||
# Placeholder - will be implemented in #90
|
||||
logger.debug(
|
||||
"record_episode called for task=%s (not yet implemented)",
|
||||
episode.task_type,
|
||||
)
|
||||
raise NotImplementedError("Episodic memory not yet implemented")
|
||||
|
||||
async def search_episodes(
|
||||
self,
|
||||
query: str,
|
||||
limit: int | None = None,
|
||||
) -> RetrievalResult[Episode]:
|
||||
"""
|
||||
Search for similar episodes.
|
||||
|
||||
Args:
|
||||
query: The search query
|
||||
limit: Maximum results to return
|
||||
|
||||
Returns:
|
||||
Retrieval result with matching episodes
|
||||
"""
|
||||
# Placeholder - will be implemented in #90
|
||||
logger.debug(
|
||||
"search_episodes called for query=%s (not yet implemented)",
|
||||
query[:50],
|
||||
)
|
||||
raise NotImplementedError("Episodic memory not yet implemented")
|
||||
|
||||
async def get_recent_episodes(
|
||||
self,
|
||||
limit: int = 10,
|
||||
) -> list[Episode]:
|
||||
"""
|
||||
Get the most recent episodes.
|
||||
|
||||
Args:
|
||||
limit: Maximum episodes to return
|
||||
|
||||
Returns:
|
||||
List of recent episodes
|
||||
"""
|
||||
# Placeholder - will be implemented in #90
|
||||
logger.debug("get_recent_episodes called (not yet implemented)")
|
||||
raise NotImplementedError("Episodic memory not yet implemented")
|
||||
|
||||
async def get_episodes_by_outcome(
|
||||
self,
|
||||
outcome: Outcome,
|
||||
limit: int = 10,
|
||||
) -> list[Episode]:
|
||||
"""
|
||||
Get episodes by outcome.
|
||||
|
||||
Args:
|
||||
outcome: The outcome to filter by
|
||||
limit: Maximum episodes to return
|
||||
|
||||
Returns:
|
||||
List of episodes with the specified outcome
|
||||
"""
|
||||
# Placeholder - will be implemented in #90
|
||||
logger.debug(
|
||||
"get_episodes_by_outcome called for outcome=%s (not yet implemented)",
|
||||
outcome.value,
|
||||
)
|
||||
raise NotImplementedError("Episodic memory not yet implemented")
|
||||
|
||||
# =========================================================================
|
||||
# Semantic Memory Operations
|
||||
# =========================================================================
|
||||
|
||||
async def store_fact(self, fact: FactCreate) -> Fact:
|
||||
"""
|
||||
Store a new fact in semantic memory.
|
||||
|
||||
Args:
|
||||
fact: The fact data to store
|
||||
|
||||
Returns:
|
||||
The created fact with ID
|
||||
"""
|
||||
# Placeholder - will be implemented in #91
|
||||
logger.debug(
|
||||
"store_fact called for %s %s %s (not yet implemented)",
|
||||
fact.subject,
|
||||
fact.predicate,
|
||||
fact.object,
|
||||
)
|
||||
raise NotImplementedError("Semantic memory not yet implemented")
|
||||
|
||||
async def search_facts(
|
||||
self,
|
||||
query: str,
|
||||
limit: int | None = None,
|
||||
) -> RetrievalResult[Fact]:
|
||||
"""
|
||||
Search for facts matching a query.
|
||||
|
||||
Args:
|
||||
query: The search query
|
||||
limit: Maximum results to return
|
||||
|
||||
Returns:
|
||||
Retrieval result with matching facts
|
||||
"""
|
||||
# Placeholder - will be implemented in #91
|
||||
logger.debug(
|
||||
"search_facts called for query=%s (not yet implemented)",
|
||||
query[:50],
|
||||
)
|
||||
raise NotImplementedError("Semantic memory not yet implemented")
|
||||
|
||||
async def get_facts_by_entity(
|
||||
self,
|
||||
entity: str,
|
||||
limit: int = 20,
|
||||
) -> list[Fact]:
|
||||
"""
|
||||
Get facts related to an entity.
|
||||
|
||||
Args:
|
||||
entity: The entity to search for
|
||||
limit: Maximum facts to return
|
||||
|
||||
Returns:
|
||||
List of facts mentioning the entity
|
||||
"""
|
||||
# Placeholder - will be implemented in #91
|
||||
logger.debug(
|
||||
"get_facts_by_entity called for entity=%s (not yet implemented)",
|
||||
entity,
|
||||
)
|
||||
raise NotImplementedError("Semantic memory not yet implemented")
|
||||
|
||||
async def reinforce_fact(self, fact_id: UUID) -> Fact:
|
||||
"""
|
||||
Reinforce a fact (increase confidence from repeated learning).
|
||||
|
||||
Args:
|
||||
fact_id: The fact to reinforce
|
||||
|
||||
Returns:
|
||||
The updated fact
|
||||
"""
|
||||
# Placeholder - will be implemented in #91
|
||||
logger.debug(
|
||||
"reinforce_fact called for id=%s (not yet implemented)",
|
||||
fact_id,
|
||||
)
|
||||
raise NotImplementedError("Semantic memory not yet implemented")
|
||||
|
||||
# =========================================================================
|
||||
# Procedural Memory Operations
|
||||
# =========================================================================
|
||||
|
||||
async def record_procedure(self, procedure: ProcedureCreate) -> Procedure:
|
||||
"""
|
||||
Record a new procedure.
|
||||
|
||||
Args:
|
||||
procedure: The procedure data to record
|
||||
|
||||
Returns:
|
||||
The created procedure with ID
|
||||
"""
|
||||
# Placeholder - will be implemented in #92
|
||||
logger.debug(
|
||||
"record_procedure called for name=%s (not yet implemented)",
|
||||
procedure.name,
|
||||
)
|
||||
raise NotImplementedError("Procedural memory not yet implemented")
|
||||
|
||||
async def find_procedures(
|
||||
self,
|
||||
context: str,
|
||||
limit: int = 5,
|
||||
) -> list[Procedure]:
|
||||
"""
|
||||
Find procedures matching the current context.
|
||||
|
||||
Args:
|
||||
context: The context to match against
|
||||
limit: Maximum procedures to return
|
||||
|
||||
Returns:
|
||||
List of matching procedures sorted by success rate
|
||||
"""
|
||||
# Placeholder - will be implemented in #92
|
||||
logger.debug(
|
||||
"find_procedures called for context=%s (not yet implemented)",
|
||||
context[:50],
|
||||
)
|
||||
raise NotImplementedError("Procedural memory not yet implemented")
|
||||
|
||||
async def record_procedure_outcome(
|
||||
self,
|
||||
procedure_id: UUID,
|
||||
success: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Record the outcome of using a procedure.
|
||||
|
||||
Args:
|
||||
procedure_id: The procedure that was used
|
||||
success: Whether the procedure succeeded
|
||||
"""
|
||||
# Placeholder - will be implemented in #92
|
||||
logger.debug(
|
||||
"record_procedure_outcome called for id=%s success=%s (not yet implemented)",
|
||||
procedure_id,
|
||||
success,
|
||||
)
|
||||
raise NotImplementedError("Procedural memory not yet implemented")
|
||||
|
||||
# =========================================================================
|
||||
# Cross-Memory Operations
|
||||
# =========================================================================
|
||||
|
||||
async def recall(
|
||||
self,
|
||||
query: str,
|
||||
memory_types: list[MemoryType] | None = None,
|
||||
limit: int = 10,
|
||||
) -> dict[MemoryType, list[Any]]:
|
||||
"""
|
||||
Recall memories across multiple memory types.
|
||||
|
||||
Args:
|
||||
query: The search query
|
||||
memory_types: Memory types to search (all if not specified)
|
||||
limit: Maximum results per type
|
||||
|
||||
Returns:
|
||||
Dictionary mapping memory types to results
|
||||
"""
|
||||
# Placeholder - will be implemented in #97 (Component Integration)
|
||||
logger.debug("recall called for query=%s (not yet implemented)", query[:50])
|
||||
raise NotImplementedError("Cross-memory recall not yet implemented")
|
||||
|
||||
async def get_stats(
|
||||
self,
|
||||
memory_type: MemoryType | None = None,
|
||||
) -> list[MemoryStats]:
|
||||
"""
|
||||
Get memory statistics.
|
||||
|
||||
Args:
|
||||
memory_type: Specific type or all if not specified
|
||||
|
||||
Returns:
|
||||
List of statistics for requested memory types
|
||||
"""
|
||||
# Placeholder - will be implemented in #100 (Metrics & Observability)
|
||||
logger.debug("get_stats called (not yet implemented)")
|
||||
raise NotImplementedError("Memory stats not yet implemented")
|
||||
|
||||
# =========================================================================
|
||||
# Lifecycle Operations
|
||||
# =========================================================================
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""
|
||||
Initialize the memory manager and its backends.
|
||||
|
||||
Should be called before using the manager.
|
||||
"""
|
||||
if self._initialized:
|
||||
logger.debug("MemoryManager already initialized")
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"Initializing MemoryManager for scope %s:%s",
|
||||
self._scope.scope_type.value,
|
||||
self._scope.scope_id,
|
||||
)
|
||||
|
||||
# TODO: Initialize backends when implemented
|
||||
|
||||
self._initialized = True
|
||||
logger.info("MemoryManager initialized successfully")
|
||||
|
||||
async def close(self) -> None:
|
||||
"""
|
||||
Close the memory manager and release resources.
|
||||
|
||||
Should be called when done using the manager.
|
||||
"""
|
||||
if not self._initialized:
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"Closing MemoryManager for scope %s:%s",
|
||||
self._scope.scope_type.value,
|
||||
self._scope.scope_id,
|
||||
)
|
||||
|
||||
# TODO: Close backends when implemented
|
||||
|
||||
self._initialized = False
|
||||
logger.info("MemoryManager closed successfully")
|
||||
|
||||
async def __aenter__(self) -> "MemoryManager":
|
||||
"""Async context manager entry."""
|
||||
await self.initialize()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
"""Async context manager exit."""
|
||||
await self.close()
|
||||
40
backend/app/services/memory/mcp/__init__.py
Normal file
40
backend/app/services/memory/mcp/__init__.py
Normal file
@@ -0,0 +1,40 @@
|
||||
# app/services/memory/mcp/__init__.py
|
||||
"""
|
||||
MCP Tools for Agent Memory System.
|
||||
|
||||
Exposes memory operations as MCP-compatible tools that agents can invoke:
|
||||
- remember: Store data in memory
|
||||
- recall: Retrieve from memory
|
||||
- forget: Remove from memory
|
||||
- reflect: Analyze patterns
|
||||
- get_memory_stats: Usage statistics
|
||||
- search_procedures: Find relevant procedures
|
||||
- record_outcome: Record task success/failure
|
||||
"""
|
||||
|
||||
from .service import MemoryToolService, get_memory_tool_service
|
||||
from .tools import (
|
||||
MEMORY_TOOL_DEFINITIONS,
|
||||
ForgetArgs,
|
||||
GetMemoryStatsArgs,
|
||||
MemoryToolDefinition,
|
||||
RecallArgs,
|
||||
RecordOutcomeArgs,
|
||||
ReflectArgs,
|
||||
RememberArgs,
|
||||
SearchProceduresArgs,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"MEMORY_TOOL_DEFINITIONS",
|
||||
"ForgetArgs",
|
||||
"GetMemoryStatsArgs",
|
||||
"MemoryToolDefinition",
|
||||
"MemoryToolService",
|
||||
"RecallArgs",
|
||||
"RecordOutcomeArgs",
|
||||
"ReflectArgs",
|
||||
"RememberArgs",
|
||||
"SearchProceduresArgs",
|
||||
"get_memory_tool_service",
|
||||
]
|
||||
1086
backend/app/services/memory/mcp/service.py
Normal file
1086
backend/app/services/memory/mcp/service.py
Normal file
File diff suppressed because it is too large
Load Diff
485
backend/app/services/memory/mcp/tools.py
Normal file
485
backend/app/services/memory/mcp/tools.py
Normal file
@@ -0,0 +1,485 @@
|
||||
# app/services/memory/mcp/tools.py
|
||||
"""
|
||||
MCP Tool Definitions for Agent Memory System.
|
||||
|
||||
Defines the schema and metadata for memory-related MCP tools.
|
||||
These tools are invoked by AI agents to interact with the memory system.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# OutcomeType alias - uses core Outcome enum from types module for consistency
|
||||
from app.services.memory.types import Outcome as OutcomeType
|
||||
|
||||
|
||||
class MemoryType(str, Enum):
|
||||
"""Types of memory for storage operations."""
|
||||
|
||||
WORKING = "working"
|
||||
EPISODIC = "episodic"
|
||||
SEMANTIC = "semantic"
|
||||
PROCEDURAL = "procedural"
|
||||
|
||||
|
||||
class AnalysisType(str, Enum):
|
||||
"""Types of pattern analysis for the reflect tool."""
|
||||
|
||||
RECENT_PATTERNS = "recent_patterns"
|
||||
SUCCESS_FACTORS = "success_factors"
|
||||
FAILURE_PATTERNS = "failure_patterns"
|
||||
COMMON_PROCEDURES = "common_procedures"
|
||||
LEARNING_PROGRESS = "learning_progress"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tool Argument Schemas (Pydantic models for validation)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class RememberArgs(BaseModel):
|
||||
"""Arguments for the 'remember' tool."""
|
||||
|
||||
memory_type: MemoryType = Field(
|
||||
...,
|
||||
description="Type of memory to store in: working, episodic, semantic, or procedural",
|
||||
)
|
||||
content: str = Field(
|
||||
...,
|
||||
description="The content to remember. Can be text, facts, or procedure steps.",
|
||||
min_length=1,
|
||||
max_length=10000,
|
||||
)
|
||||
key: str | None = Field(
|
||||
None,
|
||||
description="Optional key for working memory entries. Required for working memory type.",
|
||||
max_length=256,
|
||||
)
|
||||
importance: float = Field(
|
||||
0.5,
|
||||
description="Importance score from 0.0 (low) to 1.0 (critical)",
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
)
|
||||
ttl_seconds: int | None = Field(
|
||||
None,
|
||||
description="Time-to-live in seconds for working memory. None for permanent storage.",
|
||||
ge=1,
|
||||
le=86400 * 30, # Max 30 days
|
||||
)
|
||||
metadata: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Additional metadata to store with the memory",
|
||||
)
|
||||
# For semantic memory (facts)
|
||||
subject: str | None = Field(
|
||||
None,
|
||||
description="Subject of the fact (for semantic memory)",
|
||||
max_length=256,
|
||||
)
|
||||
predicate: str | None = Field(
|
||||
None,
|
||||
description="Predicate/relationship (for semantic memory)",
|
||||
max_length=256,
|
||||
)
|
||||
object_value: str | None = Field(
|
||||
None,
|
||||
description="Object of the fact (for semantic memory)",
|
||||
max_length=1000,
|
||||
)
|
||||
# For procedural memory
|
||||
trigger: str | None = Field(
|
||||
None,
|
||||
description="Trigger condition for the procedure (for procedural memory)",
|
||||
max_length=500,
|
||||
)
|
||||
steps: list[dict[str, Any]] | None = Field(
|
||||
None,
|
||||
description="Procedure steps as a list of action dictionaries",
|
||||
)
|
||||
|
||||
|
||||
class RecallArgs(BaseModel):
|
||||
"""Arguments for the 'recall' tool."""
|
||||
|
||||
query: str = Field(
|
||||
...,
|
||||
description="Search query to find relevant memories",
|
||||
min_length=1,
|
||||
max_length=1000,
|
||||
)
|
||||
memory_types: list[MemoryType] = Field(
|
||||
default_factory=lambda: [MemoryType.EPISODIC, MemoryType.SEMANTIC],
|
||||
description="Types of memory to search in",
|
||||
)
|
||||
limit: int = Field(
|
||||
10,
|
||||
description="Maximum number of results to return",
|
||||
ge=1,
|
||||
le=100,
|
||||
)
|
||||
min_relevance: float = Field(
|
||||
0.0,
|
||||
description="Minimum relevance score (0.0-1.0) for results",
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
)
|
||||
filters: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Additional filters (e.g., outcome, task_type, date range)",
|
||||
)
|
||||
include_context: bool = Field(
|
||||
True,
|
||||
description="Whether to include surrounding context in results",
|
||||
)
|
||||
|
||||
|
||||
class ForgetArgs(BaseModel):
|
||||
"""Arguments for the 'forget' tool."""
|
||||
|
||||
memory_type: MemoryType = Field(
|
||||
...,
|
||||
description="Type of memory to remove from",
|
||||
)
|
||||
key: str | None = Field(
|
||||
None,
|
||||
description="Key to remove (for working memory)",
|
||||
max_length=256,
|
||||
)
|
||||
memory_id: str | None = Field(
|
||||
None,
|
||||
description="Specific memory ID to remove (for episodic/semantic/procedural)",
|
||||
)
|
||||
pattern: str | None = Field(
|
||||
None,
|
||||
description="Pattern to match for bulk removal (use with caution)",
|
||||
max_length=500,
|
||||
)
|
||||
confirm_bulk: bool = Field(
|
||||
False,
|
||||
description="Must be True to confirm bulk deletion when using pattern",
|
||||
)
|
||||
|
||||
|
||||
class ReflectArgs(BaseModel):
|
||||
"""Arguments for the 'reflect' tool."""
|
||||
|
||||
analysis_type: AnalysisType = Field(
|
||||
...,
|
||||
description="Type of pattern analysis to perform",
|
||||
)
|
||||
scope: str | None = Field(
|
||||
None,
|
||||
description="Optional scope to limit analysis (e.g., task_type, time range)",
|
||||
max_length=500,
|
||||
)
|
||||
depth: int = Field(
|
||||
3,
|
||||
description="Depth of analysis (1=surface, 5=deep)",
|
||||
ge=1,
|
||||
le=5,
|
||||
)
|
||||
include_examples: bool = Field(
|
||||
True,
|
||||
description="Whether to include example memories in the analysis",
|
||||
)
|
||||
max_items: int = Field(
|
||||
10,
|
||||
description="Maximum number of patterns/examples to analyze",
|
||||
ge=1,
|
||||
le=50,
|
||||
)
|
||||
|
||||
|
||||
class GetMemoryStatsArgs(BaseModel):
|
||||
"""Arguments for the 'get_memory_stats' tool."""
|
||||
|
||||
include_breakdown: bool = Field(
|
||||
True,
|
||||
description="Include breakdown by memory type",
|
||||
)
|
||||
include_recent_activity: bool = Field(
|
||||
True,
|
||||
description="Include recent memory activity summary",
|
||||
)
|
||||
time_range_days: int = Field(
|
||||
7,
|
||||
description="Time range for activity analysis in days",
|
||||
ge=1,
|
||||
le=90,
|
||||
)
|
||||
|
||||
|
||||
class SearchProceduresArgs(BaseModel):
|
||||
"""Arguments for the 'search_procedures' tool."""
|
||||
|
||||
trigger: str = Field(
|
||||
...,
|
||||
description="Trigger or situation to find procedures for",
|
||||
min_length=1,
|
||||
max_length=500,
|
||||
)
|
||||
task_type: str | None = Field(
|
||||
None,
|
||||
description="Optional task type to filter procedures",
|
||||
max_length=100,
|
||||
)
|
||||
min_success_rate: float = Field(
|
||||
0.5,
|
||||
description="Minimum success rate (0.0-1.0) for returned procedures",
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
)
|
||||
limit: int = Field(
|
||||
5,
|
||||
description="Maximum number of procedures to return",
|
||||
ge=1,
|
||||
le=20,
|
||||
)
|
||||
include_steps: bool = Field(
|
||||
True,
|
||||
description="Whether to include detailed steps in the response",
|
||||
)
|
||||
|
||||
|
||||
class RecordOutcomeArgs(BaseModel):
|
||||
"""Arguments for the 'record_outcome' tool."""
|
||||
|
||||
task_type: str = Field(
|
||||
...,
|
||||
description="Type of task that was executed",
|
||||
min_length=1,
|
||||
max_length=100,
|
||||
)
|
||||
outcome: OutcomeType = Field(
|
||||
...,
|
||||
description="Outcome of the task execution",
|
||||
)
|
||||
procedure_id: str | None = Field(
|
||||
None,
|
||||
description="ID of the procedure that was followed (if any)",
|
||||
)
|
||||
context: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Context in which the task was executed",
|
||||
)
|
||||
lessons_learned: str | None = Field(
|
||||
None,
|
||||
description="What was learned from this execution",
|
||||
max_length=2000,
|
||||
)
|
||||
duration_seconds: float | None = Field(
|
||||
None,
|
||||
description="How long the task took to execute",
|
||||
ge=0.0,
|
||||
)
|
||||
error_details: str | None = Field(
|
||||
None,
|
||||
description="Details about any errors encountered (for failures)",
|
||||
max_length=2000,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tool Definition Structure
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryToolDefinition:
|
||||
"""Definition of an MCP tool for the memory system."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
args_schema: type[BaseModel]
|
||||
input_schema: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Generate input schema from Pydantic model."""
|
||||
if not self.input_schema:
|
||||
self.input_schema = self.args_schema.model_json_schema()
|
||||
|
||||
def to_mcp_format(self) -> dict[str, Any]:
|
||||
"""Convert to MCP tool format."""
|
||||
return {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"inputSchema": self.input_schema,
|
||||
}
|
||||
|
||||
def validate_args(self, args: dict[str, Any]) -> BaseModel:
|
||||
"""Validate and parse arguments."""
|
||||
return self.args_schema.model_validate(args)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tool Definitions
|
||||
# ============================================================================
|
||||
|
||||
|
||||
REMEMBER_TOOL = MemoryToolDefinition(
|
||||
name="remember",
|
||||
description="""Store information in the agent's memory system.
|
||||
|
||||
Use this tool to:
|
||||
- Store temporary data in working memory (key-value with optional TTL)
|
||||
- Record important events in episodic memory (automatically done on session end)
|
||||
- Store facts/knowledge in semantic memory (subject-predicate-object triples)
|
||||
- Save procedures in procedural memory (trigger conditions and steps)
|
||||
|
||||
Examples:
|
||||
- Working memory: {"memory_type": "working", "key": "current_task", "content": "Implementing auth", "ttl_seconds": 3600}
|
||||
- Semantic fact: {"memory_type": "semantic", "subject": "User", "predicate": "prefers", "object_value": "dark mode", "content": "User preference noted"}
|
||||
- Procedure: {"memory_type": "procedural", "trigger": "When creating a new file", "steps": [{"action": "check_exists"}, {"action": "create"}], "content": "File creation procedure"}
|
||||
""",
|
||||
args_schema=RememberArgs,
|
||||
)
|
||||
|
||||
|
||||
RECALL_TOOL = MemoryToolDefinition(
|
||||
name="recall",
|
||||
description="""Retrieve information from the agent's memory system.
|
||||
|
||||
Use this tool to:
|
||||
- Search for relevant past experiences (episodic)
|
||||
- Look up known facts and knowledge (semantic)
|
||||
- Find applicable procedures for current task (procedural)
|
||||
- Get current session state (working)
|
||||
|
||||
The query supports semantic search - describe what you're looking for in natural language.
|
||||
|
||||
Examples:
|
||||
- {"query": "How did I handle authentication errors before?", "memory_types": ["episodic"]}
|
||||
- {"query": "What are the user's preferences?", "memory_types": ["semantic"], "limit": 5}
|
||||
- {"query": "database connection", "memory_types": ["episodic", "semantic", "procedural"], "filters": {"outcome": "success"}}
|
||||
""",
|
||||
args_schema=RecallArgs,
|
||||
)
|
||||
|
||||
|
||||
FORGET_TOOL = MemoryToolDefinition(
|
||||
name="forget",
|
||||
description="""Remove information from the agent's memory system.
|
||||
|
||||
Use this tool to:
|
||||
- Clear temporary working memory entries
|
||||
- Remove specific memories by ID
|
||||
- Bulk remove memories matching a pattern (requires confirmation)
|
||||
|
||||
WARNING: Deletion is permanent. Use with caution.
|
||||
|
||||
Examples:
|
||||
- Working memory: {"memory_type": "working", "key": "temp_calculation"}
|
||||
- Specific memory: {"memory_type": "episodic", "memory_id": "ep-123"}
|
||||
- Bulk (requires confirm): {"memory_type": "working", "pattern": "cache_*", "confirm_bulk": true}
|
||||
""",
|
||||
args_schema=ForgetArgs,
|
||||
)
|
||||
|
||||
|
||||
REFLECT_TOOL = MemoryToolDefinition(
|
||||
name="reflect",
|
||||
description="""Analyze patterns in the agent's memory to gain insights.
|
||||
|
||||
Use this tool to:
|
||||
- Identify patterns in recent work
|
||||
- Understand what leads to success/failure
|
||||
- Learn from past experiences
|
||||
- Track learning progress over time
|
||||
|
||||
Analysis types:
|
||||
- recent_patterns: What patterns appear in recent work
|
||||
- success_factors: What conditions lead to success
|
||||
- failure_patterns: What causes failures and how to avoid them
|
||||
- common_procedures: Most frequently used procedures
|
||||
- learning_progress: How knowledge has grown over time
|
||||
|
||||
Examples:
|
||||
- {"analysis_type": "success_factors", "scope": "code_review", "depth": 3}
|
||||
- {"analysis_type": "failure_patterns", "include_examples": true, "max_items": 5}
|
||||
""",
|
||||
args_schema=ReflectArgs,
|
||||
)
|
||||
|
||||
|
||||
GET_MEMORY_STATS_TOOL = MemoryToolDefinition(
|
||||
name="get_memory_stats",
|
||||
description="""Get statistics about the agent's memory usage.
|
||||
|
||||
Returns information about:
|
||||
- Total memories stored by type
|
||||
- Storage utilization
|
||||
- Recent activity summary
|
||||
- Memory health indicators
|
||||
|
||||
Use this to understand memory capacity and usage patterns.
|
||||
|
||||
Examples:
|
||||
- {"include_breakdown": true, "include_recent_activity": true}
|
||||
- {"time_range_days": 30, "include_breakdown": true}
|
||||
""",
|
||||
args_schema=GetMemoryStatsArgs,
|
||||
)
|
||||
|
||||
|
||||
SEARCH_PROCEDURES_TOOL = MemoryToolDefinition(
|
||||
name="search_procedures",
|
||||
description="""Find relevant procedures for a given situation.
|
||||
|
||||
Use this tool when you need to:
|
||||
- Find the best way to handle a situation
|
||||
- Look up proven approaches to problems
|
||||
- Get step-by-step guidance for tasks
|
||||
|
||||
Returns procedures ranked by relevance and success rate.
|
||||
|
||||
Examples:
|
||||
- {"trigger": "Deploying to production", "min_success_rate": 0.8}
|
||||
- {"trigger": "Handling merge conflicts", "task_type": "git_operations", "limit": 3}
|
||||
""",
|
||||
args_schema=SearchProceduresArgs,
|
||||
)
|
||||
|
||||
|
||||
RECORD_OUTCOME_TOOL = MemoryToolDefinition(
|
||||
name="record_outcome",
|
||||
description="""Record the outcome of a task execution.
|
||||
|
||||
Use this tool after completing a task to:
|
||||
- Update procedure success/failure rates
|
||||
- Store lessons learned for future reference
|
||||
- Improve procedure recommendations
|
||||
|
||||
This helps the memory system learn from experience.
|
||||
|
||||
Examples:
|
||||
- {"task_type": "code_review", "outcome": "success", "lessons_learned": "Breaking changes caught early"}
|
||||
- {"task_type": "deployment", "outcome": "failure", "error_details": "Database migration timeout", "lessons_learned": "Need to test migrations locally first"}
|
||||
""",
|
||||
args_schema=RecordOutcomeArgs,
|
||||
)
|
||||
|
||||
|
||||
# All tool definitions in a dictionary for easy lookup
|
||||
MEMORY_TOOL_DEFINITIONS: dict[str, MemoryToolDefinition] = {
|
||||
"remember": REMEMBER_TOOL,
|
||||
"recall": RECALL_TOOL,
|
||||
"forget": FORGET_TOOL,
|
||||
"reflect": REFLECT_TOOL,
|
||||
"get_memory_stats": GET_MEMORY_STATS_TOOL,
|
||||
"search_procedures": SEARCH_PROCEDURES_TOOL,
|
||||
"record_outcome": RECORD_OUTCOME_TOOL,
|
||||
}
|
||||
|
||||
|
||||
def get_all_tool_schemas() -> list[dict[str, Any]]:
|
||||
"""Get MCP-formatted schemas for all memory tools."""
|
||||
return [tool.to_mcp_format() for tool in MEMORY_TOOL_DEFINITIONS.values()]
|
||||
|
||||
|
||||
def get_tool_definition(name: str) -> MemoryToolDefinition | None:
|
||||
"""Get a specific tool definition by name."""
|
||||
return MEMORY_TOOL_DEFINITIONS.get(name)
|
||||
18
backend/app/services/memory/metrics/__init__.py
Normal file
18
backend/app/services/memory/metrics/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# app/services/memory/metrics/__init__.py
|
||||
"""Memory Metrics module."""
|
||||
|
||||
from .collector import (
|
||||
MemoryMetrics,
|
||||
get_memory_metrics,
|
||||
record_memory_operation,
|
||||
record_retrieval,
|
||||
reset_memory_metrics,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"MemoryMetrics",
|
||||
"get_memory_metrics",
|
||||
"record_memory_operation",
|
||||
"record_retrieval",
|
||||
"reset_memory_metrics",
|
||||
]
|
||||
542
backend/app/services/memory/metrics/collector.py
Normal file
542
backend/app/services/memory/metrics/collector.py
Normal file
@@ -0,0 +1,542 @@
|
||||
# app/services/memory/metrics/collector.py
|
||||
"""
|
||||
Memory Metrics Collector
|
||||
|
||||
Collects and exposes metrics for the memory system.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections import Counter, defaultdict, deque
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MetricType(str, Enum):
|
||||
"""Types of metrics."""
|
||||
|
||||
COUNTER = "counter"
|
||||
GAUGE = "gauge"
|
||||
HISTOGRAM = "histogram"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MetricValue:
|
||||
"""A single metric value."""
|
||||
|
||||
name: str
|
||||
metric_type: MetricType
|
||||
value: float
|
||||
labels: dict[str, str] = field(default_factory=dict)
|
||||
timestamp: datetime = field(default_factory=lambda: datetime.now(UTC))
|
||||
|
||||
|
||||
@dataclass
|
||||
class HistogramBucket:
|
||||
"""Histogram bucket for distribution metrics."""
|
||||
|
||||
le: float # Less than or equal
|
||||
count: int = 0
|
||||
|
||||
|
||||
class MemoryMetrics:
|
||||
"""
|
||||
Collects memory system metrics.
|
||||
|
||||
Metrics tracked:
|
||||
- Memory operations (get/set/delete by type and scope)
|
||||
- Retrieval operations and latencies
|
||||
- Memory item counts by type
|
||||
- Consolidation operations and durations
|
||||
- Cache hit/miss rates
|
||||
- Procedure success rates
|
||||
- Embedding operations
|
||||
"""
|
||||
|
||||
# Maximum samples to keep in histogram (circular buffer)
|
||||
MAX_HISTOGRAM_SAMPLES = 10000
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize MemoryMetrics."""
|
||||
self._counters: dict[str, Counter[str]] = defaultdict(Counter)
|
||||
self._gauges: dict[str, dict[str, float]] = defaultdict(dict)
|
||||
# Use deque with maxlen for bounded memory (circular buffer)
|
||||
self._histograms: dict[str, deque[float]] = defaultdict(
|
||||
lambda: deque(maxlen=self.MAX_HISTOGRAM_SAMPLES)
|
||||
)
|
||||
self._histogram_buckets: dict[str, list[HistogramBucket]] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# Initialize histogram buckets
|
||||
self._init_histogram_buckets()
|
||||
|
||||
def _init_histogram_buckets(self) -> None:
|
||||
"""Initialize histogram buckets for latency metrics."""
|
||||
# Fast operations (working memory)
|
||||
fast_buckets = [0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, float("inf")]
|
||||
|
||||
# Normal operations (retrieval)
|
||||
normal_buckets = [0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, float("inf")]
|
||||
|
||||
# Slow operations (consolidation)
|
||||
slow_buckets = [0.1, 0.5, 1.0, 2.5, 5.0, 10.0, 30.0, 60.0, float("inf")]
|
||||
|
||||
self._histogram_buckets["memory_working_latency_seconds"] = [
|
||||
HistogramBucket(le=b) for b in fast_buckets
|
||||
]
|
||||
self._histogram_buckets["memory_retrieval_latency_seconds"] = [
|
||||
HistogramBucket(le=b) for b in normal_buckets
|
||||
]
|
||||
self._histogram_buckets["memory_consolidation_duration_seconds"] = [
|
||||
HistogramBucket(le=b) for b in slow_buckets
|
||||
]
|
||||
self._histogram_buckets["memory_embedding_latency_seconds"] = [
|
||||
HistogramBucket(le=b) for b in normal_buckets
|
||||
]
|
||||
|
||||
# Counter methods - Operations
|
||||
|
||||
async def inc_operations(
|
||||
self,
|
||||
operation: str,
|
||||
memory_type: str,
|
||||
scope: str | None = None,
|
||||
success: bool = True,
|
||||
) -> None:
|
||||
"""Increment memory operation counter."""
|
||||
async with self._lock:
|
||||
labels = f"operation={operation},memory_type={memory_type}"
|
||||
if scope:
|
||||
labels += f",scope={scope}"
|
||||
labels += f",success={str(success).lower()}"
|
||||
self._counters["memory_operations_total"][labels] += 1
|
||||
|
||||
async def inc_retrieval(
|
||||
self,
|
||||
memory_type: str,
|
||||
strategy: str,
|
||||
results_count: int,
|
||||
) -> None:
|
||||
"""Increment retrieval counter."""
|
||||
async with self._lock:
|
||||
labels = f"memory_type={memory_type},strategy={strategy}"
|
||||
self._counters["memory_retrievals_total"][labels] += 1
|
||||
|
||||
# Track result counts as a separate metric
|
||||
self._counters["memory_retrieval_results_total"][labels] += results_count
|
||||
|
||||
async def inc_cache_hit(self, cache_type: str) -> None:
|
||||
"""Increment cache hit counter."""
|
||||
async with self._lock:
|
||||
labels = f"cache_type={cache_type}"
|
||||
self._counters["memory_cache_hits_total"][labels] += 1
|
||||
|
||||
async def inc_cache_miss(self, cache_type: str) -> None:
|
||||
"""Increment cache miss counter."""
|
||||
async with self._lock:
|
||||
labels = f"cache_type={cache_type}"
|
||||
self._counters["memory_cache_misses_total"][labels] += 1
|
||||
|
||||
async def inc_consolidation(
|
||||
self,
|
||||
consolidation_type: str,
|
||||
success: bool = True,
|
||||
) -> None:
|
||||
"""Increment consolidation counter."""
|
||||
async with self._lock:
|
||||
labels = f"type={consolidation_type},success={str(success).lower()}"
|
||||
self._counters["memory_consolidations_total"][labels] += 1
|
||||
|
||||
async def inc_procedure_execution(
|
||||
self,
|
||||
procedure_id: str | None = None,
|
||||
success: bool = True,
|
||||
) -> None:
|
||||
"""Increment procedure execution counter."""
|
||||
async with self._lock:
|
||||
labels = f"success={str(success).lower()}"
|
||||
self._counters["memory_procedure_executions_total"][labels] += 1
|
||||
|
||||
async def inc_embeddings_generated(self, memory_type: str) -> None:
|
||||
"""Increment embeddings generated counter."""
|
||||
async with self._lock:
|
||||
labels = f"memory_type={memory_type}"
|
||||
self._counters["memory_embeddings_generated_total"][labels] += 1
|
||||
|
||||
async def inc_fact_reinforcements(self) -> None:
|
||||
"""Increment fact reinforcement counter."""
|
||||
async with self._lock:
|
||||
self._counters["memory_fact_reinforcements_total"][""] += 1
|
||||
|
||||
async def inc_episodes_recorded(self, outcome: str) -> None:
|
||||
"""Increment episodes recorded counter."""
|
||||
async with self._lock:
|
||||
labels = f"outcome={outcome}"
|
||||
self._counters["memory_episodes_recorded_total"][labels] += 1
|
||||
|
||||
async def inc_anomalies_detected(self, anomaly_type: str) -> None:
|
||||
"""Increment anomaly detection counter."""
|
||||
async with self._lock:
|
||||
labels = f"anomaly_type={anomaly_type}"
|
||||
self._counters["memory_anomalies_detected_total"][labels] += 1
|
||||
|
||||
async def inc_patterns_detected(self, pattern_type: str) -> None:
|
||||
"""Increment pattern detection counter."""
|
||||
async with self._lock:
|
||||
labels = f"pattern_type={pattern_type}"
|
||||
self._counters["memory_patterns_detected_total"][labels] += 1
|
||||
|
||||
async def inc_insights_generated(self, insight_type: str) -> None:
|
||||
"""Increment insight generation counter."""
|
||||
async with self._lock:
|
||||
labels = f"insight_type={insight_type}"
|
||||
self._counters["memory_insights_generated_total"][labels] += 1
|
||||
|
||||
# Gauge methods
|
||||
|
||||
async def set_memory_items_count(
|
||||
self,
|
||||
memory_type: str,
|
||||
scope: str,
|
||||
count: int,
|
||||
) -> None:
|
||||
"""Set memory item count gauge."""
|
||||
async with self._lock:
|
||||
labels = f"memory_type={memory_type},scope={scope}"
|
||||
self._gauges["memory_items_count"][labels] = float(count)
|
||||
|
||||
async def set_memory_size_bytes(
|
||||
self,
|
||||
memory_type: str,
|
||||
scope: str,
|
||||
size_bytes: int,
|
||||
) -> None:
|
||||
"""Set memory size gauge in bytes."""
|
||||
async with self._lock:
|
||||
labels = f"memory_type={memory_type},scope={scope}"
|
||||
self._gauges["memory_size_bytes"][labels] = float(size_bytes)
|
||||
|
||||
async def set_cache_size(self, cache_type: str, size: int) -> None:
|
||||
"""Set cache size gauge."""
|
||||
async with self._lock:
|
||||
labels = f"cache_type={cache_type}"
|
||||
self._gauges["memory_cache_size"][labels] = float(size)
|
||||
|
||||
async def set_procedure_success_rate(
|
||||
self,
|
||||
procedure_name: str,
|
||||
rate: float,
|
||||
) -> None:
|
||||
"""Set procedure success rate gauge (0-1)."""
|
||||
async with self._lock:
|
||||
labels = f"procedure_name={procedure_name}"
|
||||
self._gauges["memory_procedure_success_rate"][labels] = rate
|
||||
|
||||
async def set_active_sessions(self, count: int) -> None:
|
||||
"""Set active working memory sessions gauge."""
|
||||
async with self._lock:
|
||||
self._gauges["memory_active_sessions"][""] = float(count)
|
||||
|
||||
async def set_pending_consolidations(self, count: int) -> None:
|
||||
"""Set pending consolidations gauge."""
|
||||
async with self._lock:
|
||||
self._gauges["memory_pending_consolidations"][""] = float(count)
|
||||
|
||||
# Histogram methods
|
||||
|
||||
async def observe_working_latency(self, latency_seconds: float) -> None:
|
||||
"""Observe working memory operation latency."""
|
||||
async with self._lock:
|
||||
self._observe_histogram("memory_working_latency_seconds", latency_seconds)
|
||||
|
||||
async def observe_retrieval_latency(self, latency_seconds: float) -> None:
|
||||
"""Observe retrieval latency."""
|
||||
async with self._lock:
|
||||
self._observe_histogram("memory_retrieval_latency_seconds", latency_seconds)
|
||||
|
||||
async def observe_consolidation_duration(self, duration_seconds: float) -> None:
|
||||
"""Observe consolidation duration."""
|
||||
async with self._lock:
|
||||
self._observe_histogram(
|
||||
"memory_consolidation_duration_seconds", duration_seconds
|
||||
)
|
||||
|
||||
async def observe_embedding_latency(self, latency_seconds: float) -> None:
|
||||
"""Observe embedding generation latency."""
|
||||
async with self._lock:
|
||||
self._observe_histogram("memory_embedding_latency_seconds", latency_seconds)
|
||||
|
||||
def _observe_histogram(self, name: str, value: float) -> None:
|
||||
"""Record a value in a histogram."""
|
||||
self._histograms[name].append(value)
|
||||
|
||||
# Update buckets
|
||||
if name in self._histogram_buckets:
|
||||
for bucket in self._histogram_buckets[name]:
|
||||
if value <= bucket.le:
|
||||
bucket.count += 1
|
||||
|
||||
# Export methods
|
||||
|
||||
async def get_all_metrics(self) -> list[MetricValue]:
|
||||
"""Get all metrics as MetricValue objects."""
|
||||
metrics: list[MetricValue] = []
|
||||
|
||||
async with self._lock:
|
||||
# Export counters
|
||||
for name, counter in self._counters.items():
|
||||
for labels_str, value in counter.items():
|
||||
labels = self._parse_labels(labels_str)
|
||||
metrics.append(
|
||||
MetricValue(
|
||||
name=name,
|
||||
metric_type=MetricType.COUNTER,
|
||||
value=float(value),
|
||||
labels=labels,
|
||||
)
|
||||
)
|
||||
|
||||
# Export gauges
|
||||
for name, gauge_dict in self._gauges.items():
|
||||
for labels_str, gauge_value in gauge_dict.items():
|
||||
gauge_labels = self._parse_labels(labels_str)
|
||||
metrics.append(
|
||||
MetricValue(
|
||||
name=name,
|
||||
metric_type=MetricType.GAUGE,
|
||||
value=gauge_value,
|
||||
labels=gauge_labels,
|
||||
)
|
||||
)
|
||||
|
||||
# Export histogram summaries
|
||||
for name, values in self._histograms.items():
|
||||
if values:
|
||||
metrics.append(
|
||||
MetricValue(
|
||||
name=f"{name}_count",
|
||||
metric_type=MetricType.COUNTER,
|
||||
value=float(len(values)),
|
||||
)
|
||||
)
|
||||
metrics.append(
|
||||
MetricValue(
|
||||
name=f"{name}_sum",
|
||||
metric_type=MetricType.COUNTER,
|
||||
value=sum(values),
|
||||
)
|
||||
)
|
||||
|
||||
return metrics
|
||||
|
||||
async def get_prometheus_format(self) -> str:
|
||||
"""Export metrics in Prometheus text format."""
|
||||
lines: list[str] = []
|
||||
|
||||
async with self._lock:
|
||||
# Export counters
|
||||
for name, counter in self._counters.items():
|
||||
lines.append(f"# TYPE {name} counter")
|
||||
for labels_str, value in counter.items():
|
||||
if labels_str:
|
||||
lines.append(f"{name}{{{labels_str}}} {value}")
|
||||
else:
|
||||
lines.append(f"{name} {value}")
|
||||
|
||||
# Export gauges
|
||||
for name, gauge_dict in self._gauges.items():
|
||||
lines.append(f"# TYPE {name} gauge")
|
||||
for labels_str, gauge_value in gauge_dict.items():
|
||||
if labels_str:
|
||||
lines.append(f"{name}{{{labels_str}}} {gauge_value}")
|
||||
else:
|
||||
lines.append(f"{name} {gauge_value}")
|
||||
|
||||
# Export histograms
|
||||
for name, buckets in self._histogram_buckets.items():
|
||||
lines.append(f"# TYPE {name} histogram")
|
||||
for bucket in buckets:
|
||||
le_str = "+Inf" if bucket.le == float("inf") else str(bucket.le)
|
||||
lines.append(f'{name}_bucket{{le="{le_str}"}} {bucket.count}')
|
||||
|
||||
if name in self._histograms:
|
||||
values = self._histograms[name]
|
||||
lines.append(f"{name}_count {len(values)}")
|
||||
lines.append(f"{name}_sum {sum(values)}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
async def get_summary(self) -> dict[str, Any]:
|
||||
"""Get a summary of key metrics."""
|
||||
async with self._lock:
|
||||
total_operations = sum(self._counters["memory_operations_total"].values())
|
||||
successful_operations = sum(
|
||||
v
|
||||
for k, v in self._counters["memory_operations_total"].items()
|
||||
if "success=true" in k
|
||||
)
|
||||
|
||||
total_retrievals = sum(self._counters["memory_retrievals_total"].values())
|
||||
|
||||
total_cache_hits = sum(self._counters["memory_cache_hits_total"].values())
|
||||
total_cache_misses = sum(
|
||||
self._counters["memory_cache_misses_total"].values()
|
||||
)
|
||||
cache_hit_rate = (
|
||||
total_cache_hits / (total_cache_hits + total_cache_misses)
|
||||
if (total_cache_hits + total_cache_misses) > 0
|
||||
else 0.0
|
||||
)
|
||||
|
||||
total_consolidations = sum(
|
||||
self._counters["memory_consolidations_total"].values()
|
||||
)
|
||||
|
||||
total_episodes = sum(
|
||||
self._counters["memory_episodes_recorded_total"].values()
|
||||
)
|
||||
|
||||
# Calculate average latencies
|
||||
retrieval_latencies = list(
|
||||
self._histograms.get("memory_retrieval_latency_seconds", deque())
|
||||
)
|
||||
avg_retrieval_latency = (
|
||||
sum(retrieval_latencies) / len(retrieval_latencies)
|
||||
if retrieval_latencies
|
||||
else 0.0
|
||||
)
|
||||
|
||||
return {
|
||||
"total_operations": total_operations,
|
||||
"successful_operations": successful_operations,
|
||||
"operation_success_rate": (
|
||||
successful_operations / total_operations
|
||||
if total_operations > 0
|
||||
else 1.0
|
||||
),
|
||||
"total_retrievals": total_retrievals,
|
||||
"cache_hit_rate": cache_hit_rate,
|
||||
"total_consolidations": total_consolidations,
|
||||
"total_episodes_recorded": total_episodes,
|
||||
"avg_retrieval_latency_ms": avg_retrieval_latency * 1000,
|
||||
"patterns_detected": sum(
|
||||
self._counters["memory_patterns_detected_total"].values()
|
||||
),
|
||||
"insights_generated": sum(
|
||||
self._counters["memory_insights_generated_total"].values()
|
||||
),
|
||||
"anomalies_detected": sum(
|
||||
self._counters["memory_anomalies_detected_total"].values()
|
||||
),
|
||||
"active_sessions": self._gauges.get("memory_active_sessions", {}).get(
|
||||
"", 0
|
||||
),
|
||||
"pending_consolidations": self._gauges.get(
|
||||
"memory_pending_consolidations", {}
|
||||
).get("", 0),
|
||||
}
|
||||
|
||||
async def get_cache_stats(self) -> dict[str, Any]:
|
||||
"""Get detailed cache statistics."""
|
||||
async with self._lock:
|
||||
stats: dict[str, Any] = {}
|
||||
|
||||
# Get hits/misses by cache type
|
||||
for labels_str, hits in self._counters["memory_cache_hits_total"].items():
|
||||
cache_type = self._parse_labels(labels_str).get("cache_type", "unknown")
|
||||
if cache_type not in stats:
|
||||
stats[cache_type] = {"hits": 0, "misses": 0}
|
||||
stats[cache_type]["hits"] = hits
|
||||
|
||||
for labels_str, misses in self._counters[
|
||||
"memory_cache_misses_total"
|
||||
].items():
|
||||
cache_type = self._parse_labels(labels_str).get("cache_type", "unknown")
|
||||
if cache_type not in stats:
|
||||
stats[cache_type] = {"hits": 0, "misses": 0}
|
||||
stats[cache_type]["misses"] = misses
|
||||
|
||||
# Calculate hit rates
|
||||
for data in stats.values():
|
||||
total = data["hits"] + data["misses"]
|
||||
data["hit_rate"] = data["hits"] / total if total > 0 else 0.0
|
||||
data["total"] = total
|
||||
|
||||
return stats
|
||||
|
||||
async def reset(self) -> None:
|
||||
"""Reset all metrics."""
|
||||
async with self._lock:
|
||||
self._counters.clear()
|
||||
self._gauges.clear()
|
||||
self._histograms.clear()
|
||||
self._init_histogram_buckets()
|
||||
|
||||
def _parse_labels(self, labels_str: str) -> dict[str, str]:
|
||||
"""Parse labels string into dictionary."""
|
||||
if not labels_str:
|
||||
return {}
|
||||
|
||||
labels = {}
|
||||
for pair in labels_str.split(","):
|
||||
if "=" in pair:
|
||||
key, value = pair.split("=", 1)
|
||||
labels[key.strip()] = value.strip()
|
||||
|
||||
return labels
|
||||
|
||||
|
||||
# Singleton instance
|
||||
_metrics: MemoryMetrics | None = None
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def get_memory_metrics() -> MemoryMetrics:
|
||||
"""Get the singleton MemoryMetrics instance."""
|
||||
global _metrics
|
||||
|
||||
async with _lock:
|
||||
if _metrics is None:
|
||||
_metrics = MemoryMetrics()
|
||||
return _metrics
|
||||
|
||||
|
||||
async def reset_memory_metrics() -> None:
|
||||
"""Reset the singleton instance (for testing)."""
|
||||
global _metrics
|
||||
async with _lock:
|
||||
_metrics = None
|
||||
|
||||
|
||||
# Convenience functions
|
||||
|
||||
|
||||
async def record_memory_operation(
|
||||
operation: str,
|
||||
memory_type: str,
|
||||
scope: str | None = None,
|
||||
success: bool = True,
|
||||
latency_ms: float | None = None,
|
||||
) -> None:
|
||||
"""Record a memory operation."""
|
||||
metrics = await get_memory_metrics()
|
||||
await metrics.inc_operations(operation, memory_type, scope, success)
|
||||
|
||||
if latency_ms is not None and memory_type == "working":
|
||||
await metrics.observe_working_latency(latency_ms / 1000)
|
||||
|
||||
|
||||
async def record_retrieval(
|
||||
memory_type: str,
|
||||
strategy: str,
|
||||
results_count: int,
|
||||
latency_ms: float,
|
||||
) -> None:
|
||||
"""Record a retrieval operation."""
|
||||
metrics = await get_memory_metrics()
|
||||
await metrics.inc_retrieval(memory_type, strategy, results_count)
|
||||
await metrics.observe_retrieval_latency(latency_ms / 1000)
|
||||
22
backend/app/services/memory/procedural/__init__.py
Normal file
22
backend/app/services/memory/procedural/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
# app/services/memory/procedural/__init__.py
|
||||
"""
|
||||
Procedural Memory
|
||||
|
||||
Learned skills and procedures from successful task patterns.
|
||||
"""
|
||||
|
||||
from .matching import (
|
||||
MatchContext,
|
||||
MatchResult,
|
||||
ProcedureMatcher,
|
||||
get_procedure_matcher,
|
||||
)
|
||||
from .memory import ProceduralMemory
|
||||
|
||||
__all__ = [
|
||||
"MatchContext",
|
||||
"MatchResult",
|
||||
"ProceduralMemory",
|
||||
"ProcedureMatcher",
|
||||
"get_procedure_matcher",
|
||||
]
|
||||
291
backend/app/services/memory/procedural/matching.py
Normal file
291
backend/app/services/memory/procedural/matching.py
Normal file
@@ -0,0 +1,291 @@
|
||||
# app/services/memory/procedural/matching.py
|
||||
"""
|
||||
Procedure Matching.
|
||||
|
||||
Provides utilities for matching procedures to contexts,
|
||||
ranking procedures by relevance, and suggesting procedures.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from app.services.memory.types import Procedure
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MatchResult:
|
||||
"""Result of a procedure match."""
|
||||
|
||||
procedure: Procedure
|
||||
score: float
|
||||
matched_terms: list[str] = field(default_factory=list)
|
||||
match_type: str = "keyword" # keyword, semantic, pattern
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"procedure_id": str(self.procedure.id),
|
||||
"procedure_name": self.procedure.name,
|
||||
"score": self.score,
|
||||
"matched_terms": self.matched_terms,
|
||||
"match_type": self.match_type,
|
||||
"success_rate": self.procedure.success_rate,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class MatchContext:
|
||||
"""Context for procedure matching."""
|
||||
|
||||
query: str
|
||||
task_type: str | None = None
|
||||
project_id: Any | None = None
|
||||
agent_type_id: Any | None = None
|
||||
max_results: int = 5
|
||||
min_score: float = 0.3
|
||||
require_success_rate: float | None = None
|
||||
|
||||
|
||||
class ProcedureMatcher:
|
||||
"""
|
||||
Matches procedures to contexts using multiple strategies.
|
||||
|
||||
Matching strategies:
|
||||
- Keyword matching on trigger pattern and name
|
||||
- Pattern-based matching using regex
|
||||
- Success rate weighting
|
||||
|
||||
In production, this would be augmented with vector similarity search.
|
||||
"""
|
||||
|
||||
# Common task-related keywords for boosting
|
||||
TASK_KEYWORDS: ClassVar[set[str]] = {
|
||||
"create",
|
||||
"update",
|
||||
"delete",
|
||||
"fix",
|
||||
"implement",
|
||||
"add",
|
||||
"remove",
|
||||
"refactor",
|
||||
"test",
|
||||
"deploy",
|
||||
"configure",
|
||||
"setup",
|
||||
"build",
|
||||
"debug",
|
||||
"optimize",
|
||||
}
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the matcher."""
|
||||
self._compiled_patterns: dict[str, re.Pattern[str]] = {}
|
||||
|
||||
def match(
|
||||
self,
|
||||
procedures: list[Procedure],
|
||||
context: MatchContext,
|
||||
) -> list[MatchResult]:
|
||||
"""
|
||||
Match procedures against a context.
|
||||
|
||||
Args:
|
||||
procedures: List of procedures to match
|
||||
context: Matching context
|
||||
|
||||
Returns:
|
||||
List of match results, sorted by score (highest first)
|
||||
"""
|
||||
results: list[MatchResult] = []
|
||||
|
||||
query_terms = self._extract_terms(context.query)
|
||||
query_lower = context.query.lower()
|
||||
|
||||
for procedure in procedures:
|
||||
score, matched = self._calculate_match_score(
|
||||
procedure=procedure,
|
||||
query_terms=query_terms,
|
||||
query_lower=query_lower,
|
||||
context=context,
|
||||
)
|
||||
|
||||
if score >= context.min_score:
|
||||
# Apply success rate boost
|
||||
if context.require_success_rate is not None:
|
||||
if procedure.success_rate < context.require_success_rate:
|
||||
continue
|
||||
|
||||
# Boost score based on success rate
|
||||
success_boost = procedure.success_rate * 0.2
|
||||
final_score = min(1.0, score + success_boost)
|
||||
|
||||
results.append(
|
||||
MatchResult(
|
||||
procedure=procedure,
|
||||
score=final_score,
|
||||
matched_terms=matched,
|
||||
match_type="keyword",
|
||||
)
|
||||
)
|
||||
|
||||
# Sort by score descending
|
||||
results.sort(key=lambda r: r.score, reverse=True)
|
||||
|
||||
return results[: context.max_results]
|
||||
|
||||
def _extract_terms(self, text: str) -> list[str]:
|
||||
"""Extract searchable terms from text."""
|
||||
# Remove special characters and split
|
||||
clean = re.sub(r"[^\w\s-]", " ", text.lower())
|
||||
terms = clean.split()
|
||||
|
||||
# Filter out very short terms
|
||||
return [t for t in terms if len(t) >= 2]
|
||||
|
||||
def _calculate_match_score(
|
||||
self,
|
||||
procedure: Procedure,
|
||||
query_terms: list[str],
|
||||
query_lower: str,
|
||||
context: MatchContext,
|
||||
) -> tuple[float, list[str]]:
|
||||
"""
|
||||
Calculate match score between procedure and query.
|
||||
|
||||
Returns:
|
||||
Tuple of (score, matched_terms)
|
||||
"""
|
||||
score = 0.0
|
||||
matched: list[str] = []
|
||||
|
||||
trigger_lower = procedure.trigger_pattern.lower()
|
||||
name_lower = procedure.name.lower()
|
||||
|
||||
# Exact name match - high score
|
||||
if name_lower in query_lower or query_lower in name_lower:
|
||||
score += 0.5
|
||||
matched.append(f"name:{procedure.name}")
|
||||
|
||||
# Trigger pattern match
|
||||
if trigger_lower in query_lower or query_lower in trigger_lower:
|
||||
score += 0.4
|
||||
matched.append(f"trigger:{procedure.trigger_pattern[:30]}")
|
||||
|
||||
# Term-by-term matching
|
||||
for term in query_terms:
|
||||
if term in trigger_lower:
|
||||
score += 0.1
|
||||
matched.append(term)
|
||||
elif term in name_lower:
|
||||
score += 0.08
|
||||
matched.append(term)
|
||||
|
||||
# Boost for task keywords
|
||||
if term in self.TASK_KEYWORDS:
|
||||
if term in trigger_lower or term in name_lower:
|
||||
score += 0.05
|
||||
|
||||
# Task type match if provided
|
||||
if context.task_type:
|
||||
task_type_lower = context.task_type.lower()
|
||||
if task_type_lower in trigger_lower or task_type_lower in name_lower:
|
||||
score += 0.3
|
||||
matched.append(f"task_type:{context.task_type}")
|
||||
|
||||
# Regex pattern matching on trigger
|
||||
try:
|
||||
pattern = self._get_or_compile_pattern(trigger_lower)
|
||||
if pattern and pattern.search(query_lower):
|
||||
score += 0.25
|
||||
matched.append("pattern_match")
|
||||
except re.error:
|
||||
pass # Invalid regex, skip pattern matching
|
||||
|
||||
return min(1.0, score), matched
|
||||
|
||||
def _get_or_compile_pattern(self, pattern: str) -> re.Pattern[str] | None:
|
||||
"""Get or compile a regex pattern with caching."""
|
||||
if pattern in self._compiled_patterns:
|
||||
return self._compiled_patterns[pattern]
|
||||
|
||||
# Only compile if it looks like a regex pattern
|
||||
if not any(c in pattern for c in r"\.*+?[]{}|()^$"):
|
||||
return None
|
||||
|
||||
try:
|
||||
compiled = re.compile(pattern, re.IGNORECASE)
|
||||
self._compiled_patterns[pattern] = compiled
|
||||
return compiled
|
||||
except re.error:
|
||||
return None
|
||||
|
||||
def rank_by_relevance(
|
||||
self,
|
||||
procedures: list[Procedure],
|
||||
task_type: str,
|
||||
) -> list[Procedure]:
|
||||
"""
|
||||
Rank procedures by relevance to a task type.
|
||||
|
||||
Args:
|
||||
procedures: Procedures to rank
|
||||
task_type: Task type for relevance
|
||||
|
||||
Returns:
|
||||
Procedures sorted by relevance
|
||||
"""
|
||||
context = MatchContext(
|
||||
query=task_type,
|
||||
task_type=task_type,
|
||||
min_score=0.0,
|
||||
max_results=len(procedures),
|
||||
)
|
||||
|
||||
results = self.match(procedures, context)
|
||||
return [r.procedure for r in results]
|
||||
|
||||
def suggest_procedures(
|
||||
self,
|
||||
procedures: list[Procedure],
|
||||
query: str,
|
||||
min_success_rate: float = 0.5,
|
||||
max_suggestions: int = 3,
|
||||
) -> list[MatchResult]:
|
||||
"""
|
||||
Suggest the best procedures for a query.
|
||||
|
||||
Only suggests procedures with sufficient success rate.
|
||||
|
||||
Args:
|
||||
procedures: Available procedures
|
||||
query: Query/context
|
||||
min_success_rate: Minimum success rate to suggest
|
||||
max_suggestions: Maximum suggestions
|
||||
|
||||
Returns:
|
||||
List of procedure suggestions
|
||||
"""
|
||||
context = MatchContext(
|
||||
query=query,
|
||||
max_results=max_suggestions,
|
||||
min_score=0.2,
|
||||
require_success_rate=min_success_rate,
|
||||
)
|
||||
|
||||
return self.match(procedures, context)
|
||||
|
||||
|
||||
# Singleton matcher instance
|
||||
_matcher: ProcedureMatcher | None = None
|
||||
|
||||
|
||||
def get_procedure_matcher() -> ProcedureMatcher:
|
||||
"""Get the singleton procedure matcher instance."""
|
||||
global _matcher
|
||||
if _matcher is None:
|
||||
_matcher = ProcedureMatcher()
|
||||
return _matcher
|
||||
749
backend/app/services/memory/procedural/memory.py
Normal file
749
backend/app/services/memory/procedural/memory.py
Normal file
@@ -0,0 +1,749 @@
|
||||
# app/services/memory/procedural/memory.py
|
||||
"""
|
||||
Procedural Memory Implementation.
|
||||
|
||||
Provides storage and retrieval for learned procedures (skills)
|
||||
derived from successful task execution patterns.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_, desc, or_, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.memory.procedure import Procedure as ProcedureModel
|
||||
from app.services.memory.config import get_memory_settings
|
||||
from app.services.memory.types import Procedure, ProcedureCreate, RetrievalResult, Step
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _escape_like_pattern(pattern: str) -> str:
|
||||
"""
|
||||
Escape SQL LIKE/ILIKE special characters to prevent pattern injection.
|
||||
|
||||
Characters escaped:
|
||||
- % (matches zero or more characters)
|
||||
- _ (matches exactly one character)
|
||||
- \\ (escape character itself)
|
||||
|
||||
Args:
|
||||
pattern: Raw search pattern from user input
|
||||
|
||||
Returns:
|
||||
Escaped pattern safe for use in LIKE/ILIKE queries
|
||||
"""
|
||||
# Escape backslash first, then the wildcards
|
||||
return pattern.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
|
||||
|
||||
|
||||
def _model_to_procedure(model: ProcedureModel) -> Procedure:
|
||||
"""Convert SQLAlchemy model to Procedure dataclass."""
|
||||
return Procedure(
|
||||
id=model.id, # type: ignore[arg-type]
|
||||
project_id=model.project_id, # type: ignore[arg-type]
|
||||
agent_type_id=model.agent_type_id, # type: ignore[arg-type]
|
||||
name=model.name, # type: ignore[arg-type]
|
||||
trigger_pattern=model.trigger_pattern, # type: ignore[arg-type]
|
||||
steps=model.steps or [], # type: ignore[arg-type]
|
||||
success_count=model.success_count, # type: ignore[arg-type]
|
||||
failure_count=model.failure_count, # type: ignore[arg-type]
|
||||
last_used=model.last_used, # type: ignore[arg-type]
|
||||
embedding=None, # Don't expose raw embedding
|
||||
created_at=model.created_at, # type: ignore[arg-type]
|
||||
updated_at=model.updated_at, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
|
||||
class ProceduralMemory:
|
||||
"""
|
||||
Procedural Memory Service.
|
||||
|
||||
Provides procedure storage and retrieval:
|
||||
- Record procedures from successful task patterns
|
||||
- Find matching procedures by trigger pattern
|
||||
- Track success/failure rates
|
||||
- Get best procedure for a task type
|
||||
- Update procedure steps
|
||||
|
||||
Performance target: <50ms P95 for matching
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
embedding_generator: Any | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize procedural memory.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
embedding_generator: Optional embedding generator for semantic matching
|
||||
"""
|
||||
self._session = session
|
||||
self._embedding_generator = embedding_generator
|
||||
self._settings = get_memory_settings()
|
||||
|
||||
@classmethod
|
||||
async def create(
|
||||
cls,
|
||||
session: AsyncSession,
|
||||
embedding_generator: Any | None = None,
|
||||
) -> "ProceduralMemory":
|
||||
"""
|
||||
Factory method to create ProceduralMemory.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
embedding_generator: Optional embedding generator
|
||||
|
||||
Returns:
|
||||
Configured ProceduralMemory instance
|
||||
"""
|
||||
return cls(session=session, embedding_generator=embedding_generator)
|
||||
|
||||
# =========================================================================
|
||||
# Procedure Recording
|
||||
# =========================================================================
|
||||
|
||||
async def record_procedure(self, procedure: ProcedureCreate) -> Procedure:
|
||||
"""
|
||||
Record a new procedure or update an existing one.
|
||||
|
||||
If a procedure with the same name exists in the same scope,
|
||||
its steps will be updated and success count incremented.
|
||||
|
||||
Args:
|
||||
procedure: Procedure data to record
|
||||
|
||||
Returns:
|
||||
The created or updated procedure
|
||||
"""
|
||||
# Check for existing procedure with same name
|
||||
existing = await self._find_existing_procedure(
|
||||
project_id=procedure.project_id,
|
||||
agent_type_id=procedure.agent_type_id,
|
||||
name=procedure.name,
|
||||
)
|
||||
|
||||
if existing is not None:
|
||||
# Update existing procedure
|
||||
return await self._update_existing_procedure(
|
||||
existing=existing,
|
||||
new_steps=procedure.steps,
|
||||
new_trigger=procedure.trigger_pattern,
|
||||
)
|
||||
|
||||
# Create new procedure
|
||||
now = datetime.now(UTC)
|
||||
|
||||
# Generate embedding if possible
|
||||
embedding = None
|
||||
if self._embedding_generator is not None:
|
||||
embedding_text = self._create_embedding_text(procedure)
|
||||
embedding = await self._embedding_generator.generate(embedding_text)
|
||||
|
||||
model = ProcedureModel(
|
||||
project_id=procedure.project_id,
|
||||
agent_type_id=procedure.agent_type_id,
|
||||
name=procedure.name,
|
||||
trigger_pattern=procedure.trigger_pattern,
|
||||
steps=procedure.steps,
|
||||
success_count=1, # New procedures start with 1 success (they worked)
|
||||
failure_count=0,
|
||||
last_used=now,
|
||||
embedding=embedding,
|
||||
)
|
||||
|
||||
self._session.add(model)
|
||||
await self._session.flush()
|
||||
await self._session.refresh(model)
|
||||
|
||||
logger.info(
|
||||
f"Recorded new procedure: {procedure.name} with {len(procedure.steps)} steps"
|
||||
)
|
||||
|
||||
return _model_to_procedure(model)
|
||||
|
||||
async def _find_existing_procedure(
|
||||
self,
|
||||
project_id: UUID | None,
|
||||
agent_type_id: UUID | None,
|
||||
name: str,
|
||||
) -> ProcedureModel | None:
|
||||
"""Find an existing procedure with the same name in the same scope."""
|
||||
query = select(ProcedureModel).where(ProcedureModel.name == name)
|
||||
|
||||
if project_id is not None:
|
||||
query = query.where(ProcedureModel.project_id == project_id)
|
||||
else:
|
||||
query = query.where(ProcedureModel.project_id.is_(None))
|
||||
|
||||
if agent_type_id is not None:
|
||||
query = query.where(ProcedureModel.agent_type_id == agent_type_id)
|
||||
else:
|
||||
query = query.where(ProcedureModel.agent_type_id.is_(None))
|
||||
|
||||
result = await self._session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def _update_existing_procedure(
|
||||
self,
|
||||
existing: ProcedureModel,
|
||||
new_steps: list[dict[str, Any]],
|
||||
new_trigger: str,
|
||||
) -> Procedure:
|
||||
"""Update an existing procedure with new steps."""
|
||||
now = datetime.now(UTC)
|
||||
|
||||
# Merge steps intelligently - keep existing order, add new steps
|
||||
merged_steps = self._merge_steps(
|
||||
existing.steps or [], # type: ignore[arg-type]
|
||||
new_steps,
|
||||
)
|
||||
|
||||
stmt = (
|
||||
update(ProcedureModel)
|
||||
.where(ProcedureModel.id == existing.id)
|
||||
.values(
|
||||
steps=merged_steps,
|
||||
trigger_pattern=new_trigger,
|
||||
success_count=ProcedureModel.success_count + 1,
|
||||
last_used=now,
|
||||
updated_at=now,
|
||||
)
|
||||
.returning(ProcedureModel)
|
||||
)
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
updated_model = result.scalar_one()
|
||||
await self._session.flush()
|
||||
|
||||
logger.info(f"Updated existing procedure: {existing.name}")
|
||||
|
||||
return _model_to_procedure(updated_model)
|
||||
|
||||
def _merge_steps(
|
||||
self,
|
||||
existing_steps: list[dict[str, Any]],
|
||||
new_steps: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Merge steps from a new execution with existing steps."""
|
||||
if not existing_steps:
|
||||
return new_steps
|
||||
if not new_steps:
|
||||
return existing_steps
|
||||
|
||||
# For now, use the new steps if they differ significantly
|
||||
# In production, this could use more sophisticated merging
|
||||
if len(new_steps) != len(existing_steps):
|
||||
# If structure changed, prefer newer steps
|
||||
return new_steps
|
||||
|
||||
# Merge step-by-step, preferring new data where available
|
||||
merged = []
|
||||
for i, new_step in enumerate(new_steps):
|
||||
if i < len(existing_steps):
|
||||
# Merge with existing step
|
||||
step = {**existing_steps[i], **new_step}
|
||||
else:
|
||||
step = new_step
|
||||
merged.append(step)
|
||||
|
||||
return merged
|
||||
|
||||
def _create_embedding_text(self, procedure: ProcedureCreate) -> str:
|
||||
"""Create text for embedding from procedure data."""
|
||||
steps_text = " ".join(step.get("action", "") for step in procedure.steps)
|
||||
return f"{procedure.name} {procedure.trigger_pattern} {steps_text}"
|
||||
|
||||
# =========================================================================
|
||||
# Procedure Retrieval
|
||||
# =========================================================================
|
||||
|
||||
async def find_matching(
|
||||
self,
|
||||
context: str,
|
||||
project_id: UUID | None = None,
|
||||
agent_type_id: UUID | None = None,
|
||||
limit: int = 5,
|
||||
) -> list[Procedure]:
|
||||
"""
|
||||
Find procedures matching the given context.
|
||||
|
||||
Args:
|
||||
context: Context/trigger to match against
|
||||
project_id: Optional project to search within
|
||||
agent_type_id: Optional agent type filter
|
||||
limit: Maximum results
|
||||
|
||||
Returns:
|
||||
List of matching procedures
|
||||
"""
|
||||
result = await self._find_matching_with_metadata(
|
||||
context=context,
|
||||
project_id=project_id,
|
||||
agent_type_id=agent_type_id,
|
||||
limit=limit,
|
||||
)
|
||||
return result.items
|
||||
|
||||
async def _find_matching_with_metadata(
|
||||
self,
|
||||
context: str,
|
||||
project_id: UUID | None = None,
|
||||
agent_type_id: UUID | None = None,
|
||||
limit: int = 5,
|
||||
) -> RetrievalResult[Procedure]:
|
||||
"""Find matching procedures with full result metadata."""
|
||||
start_time = time.perf_counter()
|
||||
|
||||
# Build base query - prioritize by success rate
|
||||
stmt = (
|
||||
select(ProcedureModel)
|
||||
.order_by(
|
||||
desc(
|
||||
ProcedureModel.success_count
|
||||
/ (ProcedureModel.success_count + ProcedureModel.failure_count + 1)
|
||||
),
|
||||
desc(ProcedureModel.last_used),
|
||||
)
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
# Apply scope filters
|
||||
if project_id is not None:
|
||||
stmt = stmt.where(
|
||||
or_(
|
||||
ProcedureModel.project_id == project_id,
|
||||
ProcedureModel.project_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
if agent_type_id is not None:
|
||||
stmt = stmt.where(
|
||||
or_(
|
||||
ProcedureModel.agent_type_id == agent_type_id,
|
||||
ProcedureModel.agent_type_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
# Text-based matching on trigger pattern and name
|
||||
# TODO: Implement proper vector similarity search when pgvector is integrated
|
||||
search_terms = context.lower().split()[:5] # Limit to 5 terms
|
||||
if search_terms:
|
||||
conditions = []
|
||||
for term in search_terms:
|
||||
# Escape SQL wildcards to prevent pattern injection
|
||||
escaped_term = _escape_like_pattern(term)
|
||||
term_pattern = f"%{escaped_term}%"
|
||||
conditions.append(
|
||||
or_(
|
||||
ProcedureModel.trigger_pattern.ilike(term_pattern),
|
||||
ProcedureModel.name.ilike(term_pattern),
|
||||
)
|
||||
)
|
||||
if conditions:
|
||||
stmt = stmt.where(or_(*conditions))
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
models = list(result.scalars().all())
|
||||
|
||||
latency_ms = (time.perf_counter() - start_time) * 1000
|
||||
|
||||
return RetrievalResult(
|
||||
items=[_model_to_procedure(m) for m in models],
|
||||
total_count=len(models),
|
||||
query=context,
|
||||
retrieval_type="procedural",
|
||||
latency_ms=latency_ms,
|
||||
metadata={"project_id": str(project_id) if project_id else None},
|
||||
)
|
||||
|
||||
async def get_best_procedure(
|
||||
self,
|
||||
task_type: str,
|
||||
project_id: UUID | None = None,
|
||||
agent_type_id: UUID | None = None,
|
||||
min_success_rate: float = 0.5,
|
||||
min_uses: int = 1,
|
||||
) -> Procedure | None:
|
||||
"""
|
||||
Get the best procedure for a given task type.
|
||||
|
||||
Returns the procedure with the highest success rate that
|
||||
meets the minimum thresholds.
|
||||
|
||||
Args:
|
||||
task_type: Task type to find procedure for
|
||||
project_id: Optional project scope
|
||||
agent_type_id: Optional agent type scope
|
||||
min_success_rate: Minimum required success rate
|
||||
min_uses: Minimum number of uses required
|
||||
|
||||
Returns:
|
||||
Best matching procedure or None
|
||||
"""
|
||||
# Escape SQL wildcards to prevent pattern injection
|
||||
escaped_task_type = _escape_like_pattern(task_type)
|
||||
task_type_pattern = f"%{escaped_task_type}%"
|
||||
|
||||
# Build query for procedures matching task type
|
||||
stmt = (
|
||||
select(ProcedureModel)
|
||||
.where(
|
||||
and_(
|
||||
(ProcedureModel.success_count + ProcedureModel.failure_count)
|
||||
>= min_uses,
|
||||
or_(
|
||||
ProcedureModel.trigger_pattern.ilike(task_type_pattern),
|
||||
ProcedureModel.name.ilike(task_type_pattern),
|
||||
),
|
||||
)
|
||||
)
|
||||
.order_by(
|
||||
desc(
|
||||
ProcedureModel.success_count
|
||||
/ (ProcedureModel.success_count + ProcedureModel.failure_count + 1)
|
||||
),
|
||||
desc(ProcedureModel.last_used),
|
||||
)
|
||||
.limit(10)
|
||||
)
|
||||
|
||||
# Apply scope filters
|
||||
if project_id is not None:
|
||||
stmt = stmt.where(
|
||||
or_(
|
||||
ProcedureModel.project_id == project_id,
|
||||
ProcedureModel.project_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
if agent_type_id is not None:
|
||||
stmt = stmt.where(
|
||||
or_(
|
||||
ProcedureModel.agent_type_id == agent_type_id,
|
||||
ProcedureModel.agent_type_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
models = list(result.scalars().all())
|
||||
|
||||
# Filter by success rate in Python (SQLAlchemy division in WHERE is complex)
|
||||
for model in models:
|
||||
success = float(model.success_count)
|
||||
failure = float(model.failure_count)
|
||||
total = success + failure
|
||||
if total > 0 and (success / total) >= min_success_rate:
|
||||
logger.debug(
|
||||
f"Found best procedure for '{task_type}': {model.name} "
|
||||
f"(success_rate={success / total:.2%})"
|
||||
)
|
||||
return _model_to_procedure(model)
|
||||
|
||||
return None
|
||||
|
||||
async def get_by_id(self, procedure_id: UUID) -> Procedure | None:
|
||||
"""Get a procedure by ID."""
|
||||
query = select(ProcedureModel).where(ProcedureModel.id == procedure_id)
|
||||
result = await self._session.execute(query)
|
||||
model = result.scalar_one_or_none()
|
||||
return _model_to_procedure(model) if model else None
|
||||
|
||||
# =========================================================================
|
||||
# Outcome Recording
|
||||
# =========================================================================
|
||||
|
||||
async def record_outcome(
|
||||
self,
|
||||
procedure_id: UUID,
|
||||
success: bool,
|
||||
) -> Procedure:
|
||||
"""
|
||||
Record the outcome of using a procedure.
|
||||
|
||||
Updates the success or failure count and last_used timestamp.
|
||||
|
||||
Args:
|
||||
procedure_id: Procedure that was used
|
||||
success: Whether the procedure succeeded
|
||||
|
||||
Returns:
|
||||
Updated procedure
|
||||
|
||||
Raises:
|
||||
ValueError: If procedure not found
|
||||
"""
|
||||
query = select(ProcedureModel).where(ProcedureModel.id == procedure_id)
|
||||
result = await self._session.execute(query)
|
||||
model = result.scalar_one_or_none()
|
||||
|
||||
if model is None:
|
||||
raise ValueError(f"Procedure not found: {procedure_id}")
|
||||
|
||||
now = datetime.now(UTC)
|
||||
|
||||
if success:
|
||||
stmt = (
|
||||
update(ProcedureModel)
|
||||
.where(ProcedureModel.id == procedure_id)
|
||||
.values(
|
||||
success_count=ProcedureModel.success_count + 1,
|
||||
last_used=now,
|
||||
updated_at=now,
|
||||
)
|
||||
.returning(ProcedureModel)
|
||||
)
|
||||
else:
|
||||
stmt = (
|
||||
update(ProcedureModel)
|
||||
.where(ProcedureModel.id == procedure_id)
|
||||
.values(
|
||||
failure_count=ProcedureModel.failure_count + 1,
|
||||
last_used=now,
|
||||
updated_at=now,
|
||||
)
|
||||
.returning(ProcedureModel)
|
||||
)
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
updated_model = result.scalar_one()
|
||||
await self._session.flush()
|
||||
|
||||
outcome = "success" if success else "failure"
|
||||
logger.info(
|
||||
f"Recorded {outcome} for procedure {procedure_id}: "
|
||||
f"success_rate={updated_model.success_rate:.2%}"
|
||||
)
|
||||
|
||||
return _model_to_procedure(updated_model)
|
||||
|
||||
# =========================================================================
|
||||
# Step Management
|
||||
# =========================================================================
|
||||
|
||||
async def update_steps(
|
||||
self,
|
||||
procedure_id: UUID,
|
||||
steps: list[Step],
|
||||
) -> Procedure:
|
||||
"""
|
||||
Update the steps of a procedure.
|
||||
|
||||
Args:
|
||||
procedure_id: Procedure to update
|
||||
steps: New steps
|
||||
|
||||
Returns:
|
||||
Updated procedure
|
||||
|
||||
Raises:
|
||||
ValueError: If procedure not found
|
||||
"""
|
||||
query = select(ProcedureModel).where(ProcedureModel.id == procedure_id)
|
||||
result = await self._session.execute(query)
|
||||
model = result.scalar_one_or_none()
|
||||
|
||||
if model is None:
|
||||
raise ValueError(f"Procedure not found: {procedure_id}")
|
||||
|
||||
# Convert Step objects to dictionaries
|
||||
steps_dict = [
|
||||
{
|
||||
"order": step.order,
|
||||
"action": step.action,
|
||||
"parameters": step.parameters,
|
||||
"expected_outcome": step.expected_outcome,
|
||||
"fallback_action": step.fallback_action,
|
||||
}
|
||||
for step in steps
|
||||
]
|
||||
|
||||
now = datetime.now(UTC)
|
||||
stmt = (
|
||||
update(ProcedureModel)
|
||||
.where(ProcedureModel.id == procedure_id)
|
||||
.values(
|
||||
steps=steps_dict,
|
||||
updated_at=now,
|
||||
)
|
||||
.returning(ProcedureModel)
|
||||
)
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
updated_model = result.scalar_one()
|
||||
await self._session.flush()
|
||||
|
||||
logger.info(f"Updated steps for procedure {procedure_id}: {len(steps)} steps")
|
||||
|
||||
return _model_to_procedure(updated_model)
|
||||
|
||||
# =========================================================================
|
||||
# Statistics & Management
|
||||
# =========================================================================
|
||||
|
||||
async def get_stats(
|
||||
self,
|
||||
project_id: UUID | None = None,
|
||||
agent_type_id: UUID | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Get statistics about procedural memory.
|
||||
|
||||
Args:
|
||||
project_id: Optional project to get stats for
|
||||
agent_type_id: Optional agent type filter
|
||||
|
||||
Returns:
|
||||
Dictionary with statistics
|
||||
"""
|
||||
query = select(ProcedureModel)
|
||||
|
||||
if project_id is not None:
|
||||
query = query.where(
|
||||
or_(
|
||||
ProcedureModel.project_id == project_id,
|
||||
ProcedureModel.project_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
if agent_type_id is not None:
|
||||
query = query.where(
|
||||
or_(
|
||||
ProcedureModel.agent_type_id == agent_type_id,
|
||||
ProcedureModel.agent_type_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
result = await self._session.execute(query)
|
||||
models = list(result.scalars().all())
|
||||
|
||||
if not models:
|
||||
return {
|
||||
"total_procedures": 0,
|
||||
"avg_success_rate": 0.0,
|
||||
"avg_steps_count": 0.0,
|
||||
"total_uses": 0,
|
||||
"high_success_count": 0,
|
||||
"low_success_count": 0,
|
||||
}
|
||||
|
||||
success_rates = [m.success_rate for m in models]
|
||||
step_counts = [len(m.steps or []) for m in models]
|
||||
total_uses = sum(m.total_uses for m in models)
|
||||
|
||||
return {
|
||||
"total_procedures": len(models),
|
||||
"avg_success_rate": sum(success_rates) / len(success_rates),
|
||||
"avg_steps_count": sum(step_counts) / len(step_counts),
|
||||
"total_uses": total_uses,
|
||||
"high_success_count": sum(1 for r in success_rates if r >= 0.8),
|
||||
"low_success_count": sum(1 for r in success_rates if r < 0.5),
|
||||
}
|
||||
|
||||
async def count(
|
||||
self,
|
||||
project_id: UUID | None = None,
|
||||
agent_type_id: UUID | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Count procedures in scope.
|
||||
|
||||
Args:
|
||||
project_id: Optional project to count for
|
||||
agent_type_id: Optional agent type filter
|
||||
|
||||
Returns:
|
||||
Number of procedures
|
||||
"""
|
||||
query = select(ProcedureModel)
|
||||
|
||||
if project_id is not None:
|
||||
query = query.where(
|
||||
or_(
|
||||
ProcedureModel.project_id == project_id,
|
||||
ProcedureModel.project_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
if agent_type_id is not None:
|
||||
query = query.where(
|
||||
or_(
|
||||
ProcedureModel.agent_type_id == agent_type_id,
|
||||
ProcedureModel.agent_type_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
result = await self._session.execute(query)
|
||||
return len(list(result.scalars().all()))
|
||||
|
||||
async def delete(self, procedure_id: UUID) -> bool:
|
||||
"""
|
||||
Delete a procedure.
|
||||
|
||||
Args:
|
||||
procedure_id: Procedure to delete
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
query = select(ProcedureModel).where(ProcedureModel.id == procedure_id)
|
||||
result = await self._session.execute(query)
|
||||
model = result.scalar_one_or_none()
|
||||
|
||||
if model is None:
|
||||
return False
|
||||
|
||||
await self._session.delete(model)
|
||||
await self._session.flush()
|
||||
|
||||
logger.info(f"Deleted procedure {procedure_id}")
|
||||
return True
|
||||
|
||||
async def get_procedures_by_success_rate(
|
||||
self,
|
||||
min_rate: float = 0.0,
|
||||
max_rate: float = 1.0,
|
||||
project_id: UUID | None = None,
|
||||
limit: int = 20,
|
||||
) -> list[Procedure]:
|
||||
"""
|
||||
Get procedures within a success rate range.
|
||||
|
||||
Args:
|
||||
min_rate: Minimum success rate
|
||||
max_rate: Maximum success rate
|
||||
project_id: Optional project scope
|
||||
limit: Maximum results
|
||||
|
||||
Returns:
|
||||
List of procedures
|
||||
"""
|
||||
query = (
|
||||
select(ProcedureModel)
|
||||
.order_by(desc(ProcedureModel.last_used))
|
||||
.limit(limit * 2) # Fetch more since we filter in Python
|
||||
)
|
||||
|
||||
if project_id is not None:
|
||||
query = query.where(
|
||||
or_(
|
||||
ProcedureModel.project_id == project_id,
|
||||
ProcedureModel.project_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
result = await self._session.execute(query)
|
||||
models = list(result.scalars().all())
|
||||
|
||||
# Filter by success rate in Python
|
||||
filtered = [m for m in models if min_rate <= m.success_rate <= max_rate][:limit]
|
||||
|
||||
return [_model_to_procedure(m) for m in filtered]
|
||||
38
backend/app/services/memory/reflection/__init__.py
Normal file
38
backend/app/services/memory/reflection/__init__.py
Normal file
@@ -0,0 +1,38 @@
|
||||
# app/services/memory/reflection/__init__.py
|
||||
"""
|
||||
Memory Reflection Layer.
|
||||
|
||||
Analyzes patterns in agent experiences to generate actionable insights.
|
||||
"""
|
||||
|
||||
from .service import (
|
||||
MemoryReflection,
|
||||
ReflectionConfig,
|
||||
get_memory_reflection,
|
||||
)
|
||||
from .types import (
|
||||
Anomaly,
|
||||
AnomalyType,
|
||||
Factor,
|
||||
FactorType,
|
||||
Insight,
|
||||
InsightType,
|
||||
Pattern,
|
||||
PatternType,
|
||||
TimeRange,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Anomaly",
|
||||
"AnomalyType",
|
||||
"Factor",
|
||||
"FactorType",
|
||||
"Insight",
|
||||
"InsightType",
|
||||
"MemoryReflection",
|
||||
"Pattern",
|
||||
"PatternType",
|
||||
"ReflectionConfig",
|
||||
"TimeRange",
|
||||
"get_memory_reflection",
|
||||
]
|
||||
1451
backend/app/services/memory/reflection/service.py
Normal file
1451
backend/app/services/memory/reflection/service.py
Normal file
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user