forked from cardosofelipe/fast-next-template
Compare commits
58 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4ad3d20cf2 | ||
|
|
8623eb56f5 | ||
|
|
3cb6c8d13b | ||
|
|
8e16e2645e | ||
|
|
82c3a6ba47 | ||
|
|
b6c38cac88 | ||
|
|
51404216ae | ||
|
|
3f23bc3db3 | ||
|
|
a0ec5fa2cc | ||
|
|
f262d08be2 | ||
|
|
b3f371e0a3 | ||
|
|
93cc37224c | ||
|
|
5717bffd63 | ||
|
|
9339ea30a1 | ||
|
|
79cb6bfd7b | ||
|
|
45025bb2f1 | ||
|
|
3c6b14d2bf | ||
|
|
6b21a6fadd | ||
|
|
600657adc4 | ||
|
|
c9d0d079b3 | ||
|
|
4c8f81368c | ||
|
|
efbe91ce14 | ||
|
|
5d646779c9 | ||
|
|
5a4d93df26 | ||
|
|
7ef217be39 | ||
|
|
20159c5865 | ||
|
|
f9a72fcb34 | ||
|
|
fcb0a5f86a | ||
|
|
92782bcb05 | ||
|
|
1dcf99ee38 | ||
|
|
70009676a3 | ||
|
|
192237e69b | ||
|
|
3edce9cd26 | ||
|
|
35aea2d73a | ||
|
|
d0f32d04f7 | ||
|
|
da85a8aba8 | ||
|
|
f8bd1011e9 | ||
|
|
f057c2f0b6 | ||
|
|
33ec889fc4 | ||
|
|
74b8c65741 | ||
|
|
b232298c61 | ||
|
|
cf6291ac8e | ||
|
|
e3fe0439fd | ||
|
|
57680c3772 | ||
|
|
997cfaa03a | ||
|
|
6954774e36 | ||
|
|
30e5c68304 | ||
|
|
0b24d4c6cc | ||
|
|
1670e05e0d | ||
|
|
999b7ac03f | ||
|
|
48ecb40f18 | ||
|
|
b818f17418 | ||
|
|
e946787a61 | ||
|
|
3554efe66a | ||
|
|
bd988f76b0 | ||
|
|
4974233169 | ||
|
|
c9d8c0835c | ||
|
|
085a748929 |
24
Makefile
24
Makefile
@@ -1,5 +1,5 @@
|
||||
.PHONY: help dev dev-full prod down logs logs-dev clean clean-slate drop-db reset-db push-images deploy
|
||||
.PHONY: test test-backend test-mcp test-frontend test-all test-cov test-integration validate validate-all
|
||||
.PHONY: test test-backend test-mcp test-frontend test-all test-cov test-integration validate validate-all format-all
|
||||
|
||||
VERSION ?= latest
|
||||
REGISTRY ?= ghcr.io/cardosofelipe/pragma-stack
|
||||
@@ -22,6 +22,9 @@ help:
|
||||
@echo " make test-cov - Run all tests with coverage reports"
|
||||
@echo " make test-integration - Run MCP integration tests (requires running stack)"
|
||||
@echo ""
|
||||
@echo "Formatting:"
|
||||
@echo " make format-all - Format code in backend + MCP servers + frontend"
|
||||
@echo ""
|
||||
@echo "Validation:"
|
||||
@echo " make validate - Validate backend + MCP servers (lint, type-check, test)"
|
||||
@echo " make validate-all - Validate everything including frontend"
|
||||
@@ -161,6 +164,25 @@ test-integration:
|
||||
@echo "Note: Requires running stack (make dev first)"
|
||||
@cd backend && RUN_INTEGRATION_TESTS=true IS_TEST=True uv run pytest tests/integration/ -v
|
||||
|
||||
# ============================================================================
|
||||
# Formatting
|
||||
# ============================================================================
|
||||
|
||||
format-all:
|
||||
@echo "Formatting backend..."
|
||||
@cd backend && make format
|
||||
@echo ""
|
||||
@echo "Formatting LLM Gateway..."
|
||||
@cd mcp-servers/llm-gateway && make format
|
||||
@echo ""
|
||||
@echo "Formatting Knowledge Base..."
|
||||
@cd mcp-servers/knowledge-base && make format
|
||||
@echo ""
|
||||
@echo "Formatting frontend..."
|
||||
@cd frontend && npm run format
|
||||
@echo ""
|
||||
@echo "All code formatted!"
|
||||
|
||||
# ============================================================================
|
||||
# Validation (lint + type-check + test)
|
||||
# ============================================================================
|
||||
|
||||
@@ -80,7 +80,7 @@ test:
|
||||
|
||||
test-cov:
|
||||
@echo "🧪 Running tests with coverage..."
|
||||
@IS_TEST=True PYTHONPATH=. uv run pytest --cov=app --cov-report=term-missing --cov-report=html -n 16
|
||||
@IS_TEST=True PYTHONPATH=. uv run pytest --cov=app --cov-report=term-missing --cov-report=html -n 20
|
||||
@echo "📊 Coverage report generated in htmlcov/index.html"
|
||||
|
||||
# ============================================================================
|
||||
|
||||
512
backend/app/alembic/versions/0005_add_memory_system_tables.py
Normal file
512
backend/app/alembic/versions/0005_add_memory_system_tables.py
Normal file
@@ -0,0 +1,512 @@
|
||||
"""Add Agent Memory System tables
|
||||
|
||||
Revision ID: 0005
|
||||
Revises: 0004
|
||||
Create Date: 2025-01-05
|
||||
|
||||
This migration creates the Agent Memory System tables:
|
||||
- working_memory: Key-value storage with TTL for active sessions
|
||||
- episodes: Experiential memories from task executions
|
||||
- facts: Semantic knowledge triples with confidence scores
|
||||
- procedures: Learned skills and procedures
|
||||
- memory_consolidation_log: Tracks consolidation jobs
|
||||
|
||||
See Issue #88: Database Schema & Storage Layer
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "0005"
|
||||
down_revision: str | None = "0004"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Create Agent Memory System tables."""
|
||||
|
||||
# =========================================================================
|
||||
# Create ENUM types for memory system
|
||||
# =========================================================================
|
||||
|
||||
# Scope type enum
|
||||
scope_type_enum = postgresql.ENUM(
|
||||
"global",
|
||||
"project",
|
||||
"agent_type",
|
||||
"agent_instance",
|
||||
"session",
|
||||
name="scope_type",
|
||||
create_type=False,
|
||||
)
|
||||
scope_type_enum.create(op.get_bind(), checkfirst=True)
|
||||
|
||||
# Episode outcome enum
|
||||
episode_outcome_enum = postgresql.ENUM(
|
||||
"success",
|
||||
"failure",
|
||||
"partial",
|
||||
name="episode_outcome",
|
||||
create_type=False,
|
||||
)
|
||||
episode_outcome_enum.create(op.get_bind(), checkfirst=True)
|
||||
|
||||
# Consolidation type enum
|
||||
consolidation_type_enum = postgresql.ENUM(
|
||||
"working_to_episodic",
|
||||
"episodic_to_semantic",
|
||||
"episodic_to_procedural",
|
||||
"pruning",
|
||||
name="consolidation_type",
|
||||
create_type=False,
|
||||
)
|
||||
consolidation_type_enum.create(op.get_bind(), checkfirst=True)
|
||||
|
||||
# Consolidation status enum
|
||||
consolidation_status_enum = postgresql.ENUM(
|
||||
"pending",
|
||||
"running",
|
||||
"completed",
|
||||
"failed",
|
||||
name="consolidation_status",
|
||||
create_type=False,
|
||||
)
|
||||
consolidation_status_enum.create(op.get_bind(), checkfirst=True)
|
||||
|
||||
# =========================================================================
|
||||
# Create working_memory table
|
||||
# Key-value storage with TTL for active sessions
|
||||
# =========================================================================
|
||||
op.create_table(
|
||||
"working_memory",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column(
|
||||
"scope_type",
|
||||
scope_type_enum,
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("scope_id", sa.String(255), nullable=False),
|
||||
sa.Column("key", sa.String(255), nullable=False),
|
||||
sa.Column("value", postgresql.JSONB(astext_type=sa.Text()), nullable=False),
|
||||
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
# Working memory indexes
|
||||
op.create_index(
|
||||
"ix_working_memory_scope_type",
|
||||
"working_memory",
|
||||
["scope_type"],
|
||||
)
|
||||
op.create_index(
|
||||
"ix_working_memory_scope_id",
|
||||
"working_memory",
|
||||
["scope_id"],
|
||||
)
|
||||
op.create_index(
|
||||
"ix_working_memory_scope_key",
|
||||
"working_memory",
|
||||
["scope_type", "scope_id", "key"],
|
||||
unique=True,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_working_memory_expires",
|
||||
"working_memory",
|
||||
["expires_at"],
|
||||
)
|
||||
op.create_index(
|
||||
"ix_working_memory_scope_list",
|
||||
"working_memory",
|
||||
["scope_type", "scope_id"],
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Create episodes table
|
||||
# Experiential memories from task executions
|
||||
# =========================================================================
|
||||
op.create_table(
|
||||
"episodes",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("project_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("agent_instance_id", postgresql.UUID(as_uuid=True), nullable=True),
|
||||
sa.Column("agent_type_id", postgresql.UUID(as_uuid=True), nullable=True),
|
||||
sa.Column("session_id", sa.String(255), nullable=False),
|
||||
sa.Column("task_type", sa.String(100), nullable=False),
|
||||
sa.Column("task_description", sa.Text(), nullable=False),
|
||||
sa.Column(
|
||||
"actions",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=False,
|
||||
server_default="[]",
|
||||
),
|
||||
sa.Column("context_summary", sa.Text(), nullable=False),
|
||||
sa.Column(
|
||||
"outcome",
|
||||
episode_outcome_enum,
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("outcome_details", sa.Text(), nullable=True),
|
||||
sa.Column("duration_seconds", sa.Float(), nullable=False, server_default="0.0"),
|
||||
sa.Column("tokens_used", sa.BigInteger(), nullable=False, server_default="0"),
|
||||
sa.Column(
|
||||
"lessons_learned",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=False,
|
||||
server_default="[]",
|
||||
),
|
||||
sa.Column("importance_score", sa.Float(), nullable=False, server_default="0.5"),
|
||||
# Vector embedding - using TEXT as fallback, will be VECTOR(1536) when pgvector is available
|
||||
sa.Column("embedding", sa.Text(), nullable=True),
|
||||
sa.Column("occurred_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["project_id"],
|
||||
["projects.id"],
|
||||
name="fk_episodes_project",
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["agent_instance_id"],
|
||||
["agent_instances.id"],
|
||||
name="fk_episodes_agent_instance",
|
||||
ondelete="SET NULL",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["agent_type_id"],
|
||||
["agent_types.id"],
|
||||
name="fk_episodes_agent_type",
|
||||
ondelete="SET NULL",
|
||||
),
|
||||
)
|
||||
|
||||
# Episode indexes
|
||||
op.create_index("ix_episodes_project_id", "episodes", ["project_id"])
|
||||
op.create_index("ix_episodes_agent_instance_id", "episodes", ["agent_instance_id"])
|
||||
op.create_index("ix_episodes_agent_type_id", "episodes", ["agent_type_id"])
|
||||
op.create_index("ix_episodes_session_id", "episodes", ["session_id"])
|
||||
op.create_index("ix_episodes_task_type", "episodes", ["task_type"])
|
||||
op.create_index("ix_episodes_outcome", "episodes", ["outcome"])
|
||||
op.create_index("ix_episodes_importance_score", "episodes", ["importance_score"])
|
||||
op.create_index("ix_episodes_occurred_at", "episodes", ["occurred_at"])
|
||||
op.create_index("ix_episodes_project_task", "episodes", ["project_id", "task_type"])
|
||||
op.create_index(
|
||||
"ix_episodes_project_outcome", "episodes", ["project_id", "outcome"]
|
||||
)
|
||||
op.create_index(
|
||||
"ix_episodes_agent_task", "episodes", ["agent_instance_id", "task_type"]
|
||||
)
|
||||
op.create_index(
|
||||
"ix_episodes_project_time", "episodes", ["project_id", "occurred_at"]
|
||||
)
|
||||
op.create_index(
|
||||
"ix_episodes_importance_time",
|
||||
"episodes",
|
||||
["importance_score", "occurred_at"],
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Create facts table
|
||||
# Semantic knowledge triples with confidence scores
|
||||
# =========================================================================
|
||||
op.create_table(
|
||||
"facts",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column(
|
||||
"project_id", postgresql.UUID(as_uuid=True), nullable=True
|
||||
), # NULL for global facts
|
||||
sa.Column("subject", sa.String(500), nullable=False),
|
||||
sa.Column("predicate", sa.String(255), nullable=False),
|
||||
sa.Column("object", sa.Text(), nullable=False),
|
||||
sa.Column("confidence", sa.Float(), nullable=False, server_default="0.8"),
|
||||
# Source episode IDs stored as JSON array of UUID strings for cross-db compatibility
|
||||
sa.Column(
|
||||
"source_episode_ids",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=False,
|
||||
server_default="[]",
|
||||
),
|
||||
sa.Column("first_learned", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("last_reinforced", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column(
|
||||
"reinforcement_count", sa.Integer(), nullable=False, server_default="1"
|
||||
),
|
||||
# Vector embedding
|
||||
sa.Column("embedding", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["project_id"],
|
||||
["projects.id"],
|
||||
name="fk_facts_project",
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
)
|
||||
|
||||
# Fact indexes
|
||||
op.create_index("ix_facts_project_id", "facts", ["project_id"])
|
||||
op.create_index("ix_facts_subject", "facts", ["subject"])
|
||||
op.create_index("ix_facts_predicate", "facts", ["predicate"])
|
||||
op.create_index("ix_facts_confidence", "facts", ["confidence"])
|
||||
op.create_index("ix_facts_subject_predicate", "facts", ["subject", "predicate"])
|
||||
op.create_index("ix_facts_project_subject", "facts", ["project_id", "subject"])
|
||||
op.create_index(
|
||||
"ix_facts_confidence_time", "facts", ["confidence", "last_reinforced"]
|
||||
)
|
||||
# Unique constraint for triples within project scope
|
||||
op.create_index(
|
||||
"ix_facts_unique_triple",
|
||||
"facts",
|
||||
["project_id", "subject", "predicate", "object"],
|
||||
unique=True,
|
||||
postgresql_where=sa.text("project_id IS NOT NULL"),
|
||||
)
|
||||
# Unique constraint for global facts (project_id IS NULL)
|
||||
op.create_index(
|
||||
"ix_facts_unique_triple_global",
|
||||
"facts",
|
||||
["subject", "predicate", "object"],
|
||||
unique=True,
|
||||
postgresql_where=sa.text("project_id IS NULL"),
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Create procedures table
|
||||
# Learned skills and procedures
|
||||
# =========================================================================
|
||||
op.create_table(
|
||||
"procedures",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("project_id", postgresql.UUID(as_uuid=True), nullable=True),
|
||||
sa.Column("agent_type_id", postgresql.UUID(as_uuid=True), nullable=True),
|
||||
sa.Column("name", sa.String(255), nullable=False),
|
||||
sa.Column("trigger_pattern", sa.Text(), nullable=False),
|
||||
sa.Column(
|
||||
"steps",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=False,
|
||||
server_default="[]",
|
||||
),
|
||||
sa.Column("success_count", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.Column("failure_count", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.Column("last_used", sa.DateTime(timezone=True), nullable=True),
|
||||
# Vector embedding
|
||||
sa.Column("embedding", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["project_id"],
|
||||
["projects.id"],
|
||||
name="fk_procedures_project",
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["agent_type_id"],
|
||||
["agent_types.id"],
|
||||
name="fk_procedures_agent_type",
|
||||
ondelete="SET NULL",
|
||||
),
|
||||
)
|
||||
|
||||
# Procedure indexes
|
||||
op.create_index("ix_procedures_project_id", "procedures", ["project_id"])
|
||||
op.create_index("ix_procedures_agent_type_id", "procedures", ["agent_type_id"])
|
||||
op.create_index("ix_procedures_name", "procedures", ["name"])
|
||||
op.create_index("ix_procedures_last_used", "procedures", ["last_used"])
|
||||
op.create_index(
|
||||
"ix_procedures_unique_name",
|
||||
"procedures",
|
||||
["project_id", "agent_type_id", "name"],
|
||||
unique=True,
|
||||
)
|
||||
op.create_index("ix_procedures_project_name", "procedures", ["project_id", "name"])
|
||||
# Note: agent_type_id already indexed via ix_procedures_agent_type_id (line 354)
|
||||
op.create_index(
|
||||
"ix_procedures_success_rate",
|
||||
"procedures",
|
||||
["success_count", "failure_count"],
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Add check constraints for data integrity
|
||||
# =========================================================================
|
||||
|
||||
# Episode constraints
|
||||
op.create_check_constraint(
|
||||
"ck_episodes_importance_range",
|
||||
"episodes",
|
||||
"importance_score >= 0.0 AND importance_score <= 1.0",
|
||||
)
|
||||
op.create_check_constraint(
|
||||
"ck_episodes_duration_positive",
|
||||
"episodes",
|
||||
"duration_seconds >= 0.0",
|
||||
)
|
||||
op.create_check_constraint(
|
||||
"ck_episodes_tokens_positive",
|
||||
"episodes",
|
||||
"tokens_used >= 0",
|
||||
)
|
||||
|
||||
# Fact constraints
|
||||
op.create_check_constraint(
|
||||
"ck_facts_confidence_range",
|
||||
"facts",
|
||||
"confidence >= 0.0 AND confidence <= 1.0",
|
||||
)
|
||||
op.create_check_constraint(
|
||||
"ck_facts_reinforcement_positive",
|
||||
"facts",
|
||||
"reinforcement_count >= 1",
|
||||
)
|
||||
|
||||
# Procedure constraints
|
||||
op.create_check_constraint(
|
||||
"ck_procedures_success_positive",
|
||||
"procedures",
|
||||
"success_count >= 0",
|
||||
)
|
||||
op.create_check_constraint(
|
||||
"ck_procedures_failure_positive",
|
||||
"procedures",
|
||||
"failure_count >= 0",
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Create memory_consolidation_log table
|
||||
# Tracks consolidation jobs
|
||||
# =========================================================================
|
||||
op.create_table(
|
||||
"memory_consolidation_log",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column(
|
||||
"consolidation_type",
|
||||
consolidation_type_enum,
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("source_count", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.Column("result_count", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.Column("started_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
"status",
|
||||
consolidation_status_enum,
|
||||
nullable=False,
|
||||
server_default="pending",
|
||||
),
|
||||
sa.Column("error", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
# Consolidation log indexes
|
||||
op.create_index(
|
||||
"ix_consolidation_type",
|
||||
"memory_consolidation_log",
|
||||
["consolidation_type"],
|
||||
)
|
||||
op.create_index(
|
||||
"ix_consolidation_status",
|
||||
"memory_consolidation_log",
|
||||
["status"],
|
||||
)
|
||||
op.create_index(
|
||||
"ix_consolidation_type_status",
|
||||
"memory_consolidation_log",
|
||||
["consolidation_type", "status"],
|
||||
)
|
||||
op.create_index(
|
||||
"ix_consolidation_started",
|
||||
"memory_consolidation_log",
|
||||
["started_at"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Drop Agent Memory System tables."""
|
||||
|
||||
# Drop check constraints first
|
||||
op.drop_constraint("ck_procedures_failure_positive", "procedures", type_="check")
|
||||
op.drop_constraint("ck_procedures_success_positive", "procedures", type_="check")
|
||||
op.drop_constraint("ck_facts_reinforcement_positive", "facts", type_="check")
|
||||
op.drop_constraint("ck_facts_confidence_range", "facts", type_="check")
|
||||
op.drop_constraint("ck_episodes_tokens_positive", "episodes", type_="check")
|
||||
op.drop_constraint("ck_episodes_duration_positive", "episodes", type_="check")
|
||||
op.drop_constraint("ck_episodes_importance_range", "episodes", type_="check")
|
||||
|
||||
# Drop unique indexes for global facts
|
||||
op.drop_index("ix_facts_unique_triple_global", "facts")
|
||||
|
||||
# Drop tables in reverse order (dependencies first)
|
||||
op.drop_table("memory_consolidation_log")
|
||||
op.drop_table("procedures")
|
||||
op.drop_table("facts")
|
||||
op.drop_table("episodes")
|
||||
op.drop_table("working_memory")
|
||||
|
||||
# Drop ENUM types
|
||||
op.execute("DROP TYPE IF EXISTS consolidation_status")
|
||||
op.execute("DROP TYPE IF EXISTS consolidation_type")
|
||||
op.execute("DROP TYPE IF EXISTS episode_outcome")
|
||||
op.execute("DROP TYPE IF EXISTS scope_type")
|
||||
52
backend/app/alembic/versions/0006_add_abandoned_outcome.py
Normal file
52
backend/app/alembic/versions/0006_add_abandoned_outcome.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""Add ABANDONED to episode_outcome enum
|
||||
|
||||
Revision ID: 0006
|
||||
Revises: 0005
|
||||
Create Date: 2025-01-06
|
||||
|
||||
This migration adds the 'abandoned' value to the episode_outcome enum type.
|
||||
This allows episodes to track when a task was abandoned (not completed,
|
||||
but not necessarily a failure either - e.g., user cancelled, session timeout).
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "0006"
|
||||
down_revision: str | None = "0005"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Add 'abandoned' value to episode_outcome enum."""
|
||||
# PostgreSQL ALTER TYPE ADD VALUE is safe and non-blocking
|
||||
op.execute("ALTER TYPE episode_outcome ADD VALUE IF NOT EXISTS 'abandoned'")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Remove 'abandoned' from episode_outcome enum.
|
||||
|
||||
Note: PostgreSQL doesn't support removing values from enums directly.
|
||||
This downgrade converts any 'abandoned' episodes to 'failure' and
|
||||
recreates the enum without 'abandoned'.
|
||||
"""
|
||||
# Convert any abandoned episodes to failure first
|
||||
op.execute("""
|
||||
UPDATE episodes
|
||||
SET outcome = 'failure'
|
||||
WHERE outcome = 'abandoned'
|
||||
""")
|
||||
|
||||
# Recreate the enum without abandoned
|
||||
# This is complex in PostgreSQL - requires creating new type, updating columns, dropping old
|
||||
op.execute("ALTER TYPE episode_outcome RENAME TO episode_outcome_old")
|
||||
op.execute("CREATE TYPE episode_outcome AS ENUM ('success', 'failure', 'partial')")
|
||||
op.execute("""
|
||||
ALTER TABLE episodes
|
||||
ALTER COLUMN outcome TYPE episode_outcome
|
||||
USING outcome::text::episode_outcome
|
||||
""")
|
||||
op.execute("DROP TYPE episode_outcome_old")
|
||||
@@ -0,0 +1,90 @@
|
||||
"""Add category and display fields to agent_types table
|
||||
|
||||
Revision ID: 0007
|
||||
Revises: 0006
|
||||
Create Date: 2026-01-06
|
||||
|
||||
This migration adds:
|
||||
- category: String(50) for grouping agents by role type
|
||||
- icon: String(50) for Lucide icon identifier
|
||||
- color: String(7) for hex color code
|
||||
- sort_order: Integer for display ordering within categories
|
||||
- typical_tasks: JSONB list of tasks this agent excels at
|
||||
- collaboration_hints: JSONB list of agent slugs that work well together
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "0007"
|
||||
down_revision: str | None = "0006"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Add category and display fields to agent_types table."""
|
||||
# Add new columns
|
||||
op.add_column(
|
||||
"agent_types",
|
||||
sa.Column("category", sa.String(length=50), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"agent_types",
|
||||
sa.Column("icon", sa.String(length=50), nullable=True, server_default="bot"),
|
||||
)
|
||||
op.add_column(
|
||||
"agent_types",
|
||||
sa.Column(
|
||||
"color", sa.String(length=7), nullable=True, server_default="#3B82F6"
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"agent_types",
|
||||
sa.Column("sort_order", sa.Integer(), nullable=False, server_default="0"),
|
||||
)
|
||||
op.add_column(
|
||||
"agent_types",
|
||||
sa.Column(
|
||||
"typical_tasks",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=False,
|
||||
server_default="[]",
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"agent_types",
|
||||
sa.Column(
|
||||
"collaboration_hints",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=False,
|
||||
server_default="[]",
|
||||
),
|
||||
)
|
||||
|
||||
# Add indexes for category and sort_order
|
||||
op.create_index("ix_agent_types_category", "agent_types", ["category"])
|
||||
op.create_index("ix_agent_types_sort_order", "agent_types", ["sort_order"])
|
||||
op.create_index(
|
||||
"ix_agent_types_category_sort", "agent_types", ["category", "sort_order"]
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Remove category and display fields from agent_types table."""
|
||||
# Drop indexes
|
||||
op.drop_index("ix_agent_types_category_sort", table_name="agent_types")
|
||||
op.drop_index("ix_agent_types_sort_order", table_name="agent_types")
|
||||
op.drop_index("ix_agent_types_category", table_name="agent_types")
|
||||
|
||||
# Drop columns
|
||||
op.drop_column("agent_types", "collaboration_hints")
|
||||
op.drop_column("agent_types", "typical_tasks")
|
||||
op.drop_column("agent_types", "sort_order")
|
||||
op.drop_column("agent_types", "color")
|
||||
op.drop_column("agent_types", "icon")
|
||||
op.drop_column("agent_types", "category")
|
||||
@@ -81,6 +81,13 @@ def _build_agent_type_response(
|
||||
mcp_servers=agent_type.mcp_servers,
|
||||
tool_permissions=agent_type.tool_permissions,
|
||||
is_active=agent_type.is_active,
|
||||
# Category and display fields
|
||||
category=agent_type.category,
|
||||
icon=agent_type.icon,
|
||||
color=agent_type.color,
|
||||
sort_order=agent_type.sort_order,
|
||||
typical_tasks=agent_type.typical_tasks or [],
|
||||
collaboration_hints=agent_type.collaboration_hints or [],
|
||||
created_at=agent_type.created_at,
|
||||
updated_at=agent_type.updated_at,
|
||||
instance_count=instance_count,
|
||||
@@ -300,6 +307,7 @@ async def list_agent_types(
|
||||
request: Request,
|
||||
pagination: PaginationParams = Depends(),
|
||||
is_active: bool = Query(True, description="Filter by active status"),
|
||||
category: str | None = Query(None, description="Filter by category"),
|
||||
search: str | None = Query(None, description="Search by name, slug, description"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
@@ -314,6 +322,7 @@ async def list_agent_types(
|
||||
request: FastAPI request object
|
||||
pagination: Pagination parameters (page, limit)
|
||||
is_active: Filter by active status (default: True)
|
||||
category: Filter by category (e.g., "development", "design")
|
||||
search: Optional search term for name, slug, description
|
||||
current_user: Authenticated user
|
||||
db: Database session
|
||||
@@ -328,6 +337,7 @@ async def list_agent_types(
|
||||
skip=pagination.offset,
|
||||
limit=pagination.limit,
|
||||
is_active=is_active,
|
||||
category=category,
|
||||
search=search,
|
||||
)
|
||||
|
||||
@@ -354,6 +364,51 @@ async def list_agent_types(
|
||||
raise
|
||||
|
||||
|
||||
@router.get(
|
||||
"/grouped",
|
||||
response_model=dict[str, list[AgentTypeResponse]],
|
||||
summary="List Agent Types Grouped by Category",
|
||||
description="Get all agent types organized by category",
|
||||
operation_id="list_agent_types_grouped",
|
||||
)
|
||||
@limiter.limit(f"{60 * RATE_MULTIPLIER}/minute")
|
||||
async def list_agent_types_grouped(
|
||||
request: Request,
|
||||
is_active: bool = Query(True, description="Filter by active status"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Get agent types grouped by category.
|
||||
|
||||
Returns a dictionary where keys are category names and values
|
||||
are lists of agent types, sorted by sort_order within each category.
|
||||
|
||||
Args:
|
||||
request: FastAPI request object
|
||||
is_active: Filter by active status (default: True)
|
||||
current_user: Authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Dictionary mapping category to list of agent types
|
||||
"""
|
||||
try:
|
||||
grouped = await agent_type_crud.get_grouped_by_category(db, is_active=is_active)
|
||||
|
||||
# Transform to response objects
|
||||
result: dict[str, list[AgentTypeResponse]] = {}
|
||||
for category, types in grouped.items():
|
||||
result[category] = [
|
||||
_build_agent_type_response(t, instance_count=0) for t in types
|
||||
]
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting grouped agent types: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{agent_type_id}",
|
||||
response_model=AgentTypeResponse,
|
||||
|
||||
@@ -1,366 +0,0 @@
|
||||
{
|
||||
"organizations": [
|
||||
{
|
||||
"name": "Acme Corp",
|
||||
"slug": "acme-corp",
|
||||
"description": "A leading provider of coyote-catching equipment."
|
||||
},
|
||||
{
|
||||
"name": "Globex Corporation",
|
||||
"slug": "globex",
|
||||
"description": "We own the East Coast."
|
||||
},
|
||||
{
|
||||
"name": "Soylent Corp",
|
||||
"slug": "soylent",
|
||||
"description": "Making food for the future."
|
||||
},
|
||||
{
|
||||
"name": "Initech",
|
||||
"slug": "initech",
|
||||
"description": "Software for the soul."
|
||||
},
|
||||
{
|
||||
"name": "Umbrella Corporation",
|
||||
"slug": "umbrella",
|
||||
"description": "Our business is life itself."
|
||||
},
|
||||
{
|
||||
"name": "Massive Dynamic",
|
||||
"slug": "massive-dynamic",
|
||||
"description": "What don't we do?"
|
||||
}
|
||||
],
|
||||
"users": [
|
||||
{
|
||||
"email": "demo@example.com",
|
||||
"password": "DemoPass1234!",
|
||||
"first_name": "Demo",
|
||||
"last_name": "User",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "acme-corp",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "alice@acme.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Alice",
|
||||
"last_name": "Smith",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "acme-corp",
|
||||
"role": "admin",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "bob@acme.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Bob",
|
||||
"last_name": "Jones",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "acme-corp",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "charlie@acme.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Charlie",
|
||||
"last_name": "Brown",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "acme-corp",
|
||||
"role": "member",
|
||||
"is_active": false
|
||||
},
|
||||
{
|
||||
"email": "diana@acme.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Diana",
|
||||
"last_name": "Prince",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "acme-corp",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "carol@globex.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Carol",
|
||||
"last_name": "Williams",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "globex",
|
||||
"role": "owner",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "dan@globex.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Dan",
|
||||
"last_name": "Miller",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "globex",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "ellen@globex.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Ellen",
|
||||
"last_name": "Ripley",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "globex",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "fred@globex.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Fred",
|
||||
"last_name": "Flintstone",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "globex",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "dave@soylent.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Dave",
|
||||
"last_name": "Brown",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "soylent",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "gina@soylent.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Gina",
|
||||
"last_name": "Torres",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "soylent",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "harry@soylent.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Harry",
|
||||
"last_name": "Potter",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "soylent",
|
||||
"role": "admin",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "eve@initech.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Eve",
|
||||
"last_name": "Davis",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "initech",
|
||||
"role": "admin",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "iris@initech.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Iris",
|
||||
"last_name": "West",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "initech",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "jack@initech.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Jack",
|
||||
"last_name": "Sparrow",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "initech",
|
||||
"role": "member",
|
||||
"is_active": false
|
||||
},
|
||||
{
|
||||
"email": "frank@umbrella.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Frank",
|
||||
"last_name": "Miller",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "umbrella",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "george@umbrella.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "George",
|
||||
"last_name": "Costanza",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "umbrella",
|
||||
"role": "member",
|
||||
"is_active": false
|
||||
},
|
||||
{
|
||||
"email": "kate@umbrella.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Kate",
|
||||
"last_name": "Bishop",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "umbrella",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "leo@massive.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Leo",
|
||||
"last_name": "Messi",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "massive-dynamic",
|
||||
"role": "owner",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "mary@massive.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Mary",
|
||||
"last_name": "Jane",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "massive-dynamic",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "nathan@massive.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Nathan",
|
||||
"last_name": "Drake",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "massive-dynamic",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "olivia@massive.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Olivia",
|
||||
"last_name": "Dunham",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "massive-dynamic",
|
||||
"role": "admin",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "peter@massive.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Peter",
|
||||
"last_name": "Parker",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "massive-dynamic",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "quinn@massive.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Quinn",
|
||||
"last_name": "Mallory",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "massive-dynamic",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "grace@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Grace",
|
||||
"last_name": "Hopper",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "heidi@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Heidi",
|
||||
"last_name": "Klum",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "ivan@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Ivan",
|
||||
"last_name": "Drago",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": false
|
||||
},
|
||||
{
|
||||
"email": "rachel@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Rachel",
|
||||
"last_name": "Green",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "sam@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Sam",
|
||||
"last_name": "Wilson",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "tony@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Tony",
|
||||
"last_name": "Stark",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "una@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Una",
|
||||
"last_name": "Chin-Riley",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": false
|
||||
},
|
||||
{
|
||||
"email": "victor@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Victor",
|
||||
"last_name": "Von Doom",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "wanda@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Wanda",
|
||||
"last_name": "Maximoff",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": true
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -43,6 +43,13 @@ class CRUDAgentType(CRUDBase[AgentType, AgentTypeCreate, AgentTypeUpdate]):
|
||||
mcp_servers=obj_in.mcp_servers,
|
||||
tool_permissions=obj_in.tool_permissions,
|
||||
is_active=obj_in.is_active,
|
||||
# Category and display fields
|
||||
category=obj_in.category.value if obj_in.category else None,
|
||||
icon=obj_in.icon,
|
||||
color=obj_in.color,
|
||||
sort_order=obj_in.sort_order,
|
||||
typical_tasks=obj_in.typical_tasks,
|
||||
collaboration_hints=obj_in.collaboration_hints,
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
@@ -68,6 +75,7 @@ class CRUDAgentType(CRUDBase[AgentType, AgentTypeCreate, AgentTypeUpdate]):
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
is_active: bool | None = None,
|
||||
category: str | None = None,
|
||||
search: str | None = None,
|
||||
sort_by: str = "created_at",
|
||||
sort_order: str = "desc",
|
||||
@@ -85,6 +93,9 @@ class CRUDAgentType(CRUDBase[AgentType, AgentTypeCreate, AgentTypeUpdate]):
|
||||
if is_active is not None:
|
||||
query = query.where(AgentType.is_active == is_active)
|
||||
|
||||
if category:
|
||||
query = query.where(AgentType.category == category)
|
||||
|
||||
if search:
|
||||
search_filter = or_(
|
||||
AgentType.name.ilike(f"%{search}%"),
|
||||
@@ -162,6 +173,7 @@ class CRUDAgentType(CRUDBase[AgentType, AgentTypeCreate, AgentTypeUpdate]):
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
is_active: bool | None = None,
|
||||
category: str | None = None,
|
||||
search: str | None = None,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""
|
||||
@@ -177,6 +189,7 @@ class CRUDAgentType(CRUDBase[AgentType, AgentTypeCreate, AgentTypeUpdate]):
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
is_active=is_active,
|
||||
category=category,
|
||||
search=search,
|
||||
)
|
||||
|
||||
@@ -260,6 +273,44 @@ class CRUDAgentType(CRUDBase[AgentType, AgentTypeCreate, AgentTypeUpdate]):
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_grouped_by_category(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
is_active: bool = True,
|
||||
) -> dict[str, list[AgentType]]:
|
||||
"""
|
||||
Get agent types grouped by category, sorted by sort_order within each group.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
is_active: Filter by active status (default: True)
|
||||
|
||||
Returns:
|
||||
Dictionary mapping category to list of agent types
|
||||
"""
|
||||
try:
|
||||
query = (
|
||||
select(AgentType)
|
||||
.where(AgentType.is_active == is_active)
|
||||
.order_by(AgentType.category, AgentType.sort_order, AgentType.name)
|
||||
)
|
||||
result = await db.execute(query)
|
||||
agent_types = list(result.scalars().all())
|
||||
|
||||
# Group by category
|
||||
grouped: dict[str, list[AgentType]] = {}
|
||||
for at in agent_types:
|
||||
cat: str = str(at.category) if at.category else "uncategorized"
|
||||
if cat not in grouped:
|
||||
grouped[cat] = []
|
||||
grouped[cat].append(at)
|
||||
|
||||
return grouped
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting grouped agent types: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
# Create a singleton instance for use across the application
|
||||
agent_type = CRUDAgentType(AgentType)
|
||||
|
||||
@@ -3,27 +3,48 @@
|
||||
Async database initialization script.
|
||||
|
||||
Creates the first superuser if configured and doesn't already exist.
|
||||
Seeds default agent types (production data) and demo data (when DEMO_MODE is enabled).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from datetime import UTC, date, datetime, timedelta
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy import select, text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.database import SessionLocal, engine
|
||||
from app.crud.syndarix.agent_type import agent_type as agent_type_crud
|
||||
from app.crud.user import user as user_crud
|
||||
from app.models.organization import Organization
|
||||
from app.models.syndarix import AgentInstance, AgentType, Issue, Project, Sprint
|
||||
from app.models.syndarix.enums import (
|
||||
AgentStatus,
|
||||
AutonomyLevel,
|
||||
ClientMode,
|
||||
IssuePriority,
|
||||
IssueStatus,
|
||||
IssueType,
|
||||
ProjectComplexity,
|
||||
ProjectStatus,
|
||||
SprintStatus,
|
||||
)
|
||||
from app.models.user import User
|
||||
from app.models.user_organization import UserOrganization
|
||||
from app.schemas.syndarix import AgentTypeCreate
|
||||
from app.schemas.users import UserCreate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Data file paths
|
||||
DATA_DIR = Path(__file__).parent.parent / "data"
|
||||
DEFAULT_AGENT_TYPES_PATH = DATA_DIR / "default_agent_types.json"
|
||||
DEMO_DATA_PATH = DATA_DIR / "demo_data.json"
|
||||
|
||||
|
||||
async def init_db() -> User | None:
|
||||
"""
|
||||
@@ -54,28 +75,29 @@ async def init_db() -> User | None:
|
||||
|
||||
if existing_user:
|
||||
logger.info(f"Superuser already exists: {existing_user.email}")
|
||||
return existing_user
|
||||
else:
|
||||
# Create superuser if doesn't exist
|
||||
user_in = UserCreate(
|
||||
email=superuser_email,
|
||||
password=superuser_password,
|
||||
first_name="Admin",
|
||||
last_name="User",
|
||||
is_superuser=True,
|
||||
)
|
||||
|
||||
# Create superuser if doesn't exist
|
||||
user_in = UserCreate(
|
||||
email=superuser_email,
|
||||
password=superuser_password,
|
||||
first_name="Admin",
|
||||
last_name="User",
|
||||
is_superuser=True,
|
||||
)
|
||||
existing_user = await user_crud.create(session, obj_in=user_in)
|
||||
await session.commit()
|
||||
await session.refresh(existing_user)
|
||||
logger.info(f"Created first superuser: {existing_user.email}")
|
||||
|
||||
user = await user_crud.create(session, obj_in=user_in)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
# ALWAYS load default agent types (production data)
|
||||
await load_default_agent_types(session)
|
||||
|
||||
logger.info(f"Created first superuser: {user.email}")
|
||||
|
||||
# Create demo data if in demo mode
|
||||
# Only load demo data if in demo mode
|
||||
if settings.DEMO_MODE:
|
||||
await load_demo_data(session)
|
||||
|
||||
return user
|
||||
return existing_user
|
||||
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
@@ -88,26 +110,96 @@ def _load_json_file(path: Path):
|
||||
return json.load(f)
|
||||
|
||||
|
||||
async def load_demo_data(session):
|
||||
"""Load demo data from JSON file."""
|
||||
demo_data_path = Path(__file__).parent / "core" / "demo_data.json"
|
||||
if not demo_data_path.exists():
|
||||
logger.warning(f"Demo data file not found: {demo_data_path}")
|
||||
async def load_default_agent_types(session: AsyncSession) -> None:
|
||||
"""
|
||||
Load default agent types from JSON file.
|
||||
|
||||
These are production defaults - created only if they don't exist, never overwritten.
|
||||
This allows users to customize agent types without worrying about server restarts.
|
||||
"""
|
||||
if not DEFAULT_AGENT_TYPES_PATH.exists():
|
||||
logger.warning(
|
||||
f"Default agent types file not found: {DEFAULT_AGENT_TYPES_PATH}"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
# Use asyncio.to_thread to avoid blocking the event loop
|
||||
data = await asyncio.to_thread(_load_json_file, demo_data_path)
|
||||
data = await asyncio.to_thread(_load_json_file, DEFAULT_AGENT_TYPES_PATH)
|
||||
|
||||
# Create Organizations
|
||||
org_map = {}
|
||||
for org_data in data.get("organizations", []):
|
||||
# Check if org exists
|
||||
result = await session.execute(
|
||||
text("SELECT * FROM organizations WHERE slug = :slug"),
|
||||
{"slug": org_data["slug"]},
|
||||
for agent_type_data in data:
|
||||
slug = agent_type_data["slug"]
|
||||
|
||||
# Check if agent type already exists
|
||||
existing = await agent_type_crud.get_by_slug(session, slug=slug)
|
||||
|
||||
if existing:
|
||||
logger.debug(f"Agent type already exists: {agent_type_data['name']}")
|
||||
continue
|
||||
|
||||
# Create the agent type
|
||||
agent_type_in = AgentTypeCreate(
|
||||
name=agent_type_data["name"],
|
||||
slug=slug,
|
||||
description=agent_type_data.get("description"),
|
||||
expertise=agent_type_data.get("expertise", []),
|
||||
personality_prompt=agent_type_data["personality_prompt"],
|
||||
primary_model=agent_type_data["primary_model"],
|
||||
fallback_models=agent_type_data.get("fallback_models", []),
|
||||
model_params=agent_type_data.get("model_params", {}),
|
||||
mcp_servers=agent_type_data.get("mcp_servers", []),
|
||||
tool_permissions=agent_type_data.get("tool_permissions", {}),
|
||||
is_active=agent_type_data.get("is_active", True),
|
||||
# Category and display fields
|
||||
category=agent_type_data.get("category"),
|
||||
icon=agent_type_data.get("icon", "bot"),
|
||||
color=agent_type_data.get("color", "#3B82F6"),
|
||||
sort_order=agent_type_data.get("sort_order", 0),
|
||||
typical_tasks=agent_type_data.get("typical_tasks", []),
|
||||
collaboration_hints=agent_type_data.get("collaboration_hints", []),
|
||||
)
|
||||
existing_org = result.first()
|
||||
|
||||
await agent_type_crud.create(session, obj_in=agent_type_in)
|
||||
logger.info(f"Created default agent type: {agent_type_data['name']}")
|
||||
|
||||
logger.info("Default agent types loaded successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading default agent types: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def load_demo_data(session: AsyncSession) -> None:
|
||||
"""
|
||||
Load demo data from JSON file.
|
||||
|
||||
Only runs when DEMO_MODE is enabled. Creates demo organizations, users,
|
||||
projects, sprints, agent instances, and issues.
|
||||
"""
|
||||
if not DEMO_DATA_PATH.exists():
|
||||
logger.warning(f"Demo data file not found: {DEMO_DATA_PATH}")
|
||||
return
|
||||
|
||||
try:
|
||||
data = await asyncio.to_thread(_load_json_file, DEMO_DATA_PATH)
|
||||
|
||||
# Build lookup maps for FK resolution
|
||||
org_map: dict[str, Organization] = {}
|
||||
user_map: dict[str, User] = {}
|
||||
project_map: dict[str, Project] = {}
|
||||
sprint_map: dict[str, Sprint] = {} # key: "project_slug:sprint_number"
|
||||
agent_type_map: dict[str, AgentType] = {}
|
||||
agent_instance_map: dict[
|
||||
str, AgentInstance
|
||||
] = {} # key: "project_slug:agent_name"
|
||||
|
||||
# ========================
|
||||
# 1. Create Organizations
|
||||
# ========================
|
||||
for org_data in data.get("organizations", []):
|
||||
org_result = await session.execute(
|
||||
select(Organization).where(Organization.slug == org_data["slug"])
|
||||
)
|
||||
existing_org = org_result.scalar_one_or_none()
|
||||
|
||||
if not existing_org:
|
||||
org = Organization(
|
||||
@@ -117,29 +209,20 @@ async def load_demo_data(session):
|
||||
is_active=True,
|
||||
)
|
||||
session.add(org)
|
||||
await session.flush() # Flush to get ID
|
||||
org_map[org.slug] = org
|
||||
await session.flush()
|
||||
org_map[str(org.slug)] = org
|
||||
logger.info(f"Created demo organization: {org.name}")
|
||||
else:
|
||||
# We can't easily get the ORM object from raw SQL result for map without querying again or mapping
|
||||
# So let's just query it properly if we need it for relationships
|
||||
# But for simplicity in this script, let's just assume we created it or it exists.
|
||||
# To properly map for users, we need the ID.
|
||||
# Let's use a simpler approach: just try to create, if slug conflict, skip.
|
||||
pass
|
||||
org_map[str(existing_org.slug)] = existing_org
|
||||
|
||||
# Re-query all orgs to build map for users
|
||||
result = await session.execute(select(Organization))
|
||||
orgs = result.scalars().all()
|
||||
org_map = {org.slug: org for org in orgs}
|
||||
|
||||
# Create Users
|
||||
# ========================
|
||||
# 2. Create Users
|
||||
# ========================
|
||||
for user_data in data.get("users", []):
|
||||
existing_user = await user_crud.get_by_email(
|
||||
session, email=user_data["email"]
|
||||
)
|
||||
if not existing_user:
|
||||
# Create user
|
||||
user_in = UserCreate(
|
||||
email=user_data["email"],
|
||||
password=user_data["password"],
|
||||
@@ -151,17 +234,13 @@ async def load_demo_data(session):
|
||||
user = await user_crud.create(session, obj_in=user_in)
|
||||
|
||||
# Randomize created_at for demo data (last 30 days)
|
||||
# This makes the charts look more realistic
|
||||
days_ago = random.randint(0, 30) # noqa: S311
|
||||
random_time = datetime.now(UTC) - timedelta(days=days_ago)
|
||||
# Add some random hours/minutes variation
|
||||
random_time = random_time.replace(
|
||||
hour=random.randint(0, 23), # noqa: S311
|
||||
minute=random.randint(0, 59), # noqa: S311
|
||||
)
|
||||
|
||||
# Update the timestamp and is_active directly in the database
|
||||
# We do this to ensure the values are persisted correctly
|
||||
await session.execute(
|
||||
text(
|
||||
"UPDATE users SET created_at = :created_at, is_active = :is_active WHERE id = :user_id"
|
||||
@@ -174,7 +253,7 @@ async def load_demo_data(session):
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Created demo user: {user.email} (created {days_ago} days ago, active={user_data.get('is_active', True)})"
|
||||
f"Created demo user: {user.email} (created {days_ago} days ago)"
|
||||
)
|
||||
|
||||
# Add to organization if specified
|
||||
@@ -182,19 +261,228 @@ async def load_demo_data(session):
|
||||
role = user_data.get("role")
|
||||
if org_slug and org_slug in org_map and role:
|
||||
org = org_map[org_slug]
|
||||
# Check if membership exists (it shouldn't for new user)
|
||||
member = UserOrganization(
|
||||
user_id=user.id, organization_id=org.id, role=role
|
||||
)
|
||||
session.add(member)
|
||||
logger.info(f"Added {user.email} to {org.name} as {role}")
|
||||
|
||||
user_map[str(user.email)] = user
|
||||
else:
|
||||
logger.info(f"Demo user already exists: {existing_user.email}")
|
||||
user_map[str(existing_user.email)] = existing_user
|
||||
logger.debug(f"Demo user already exists: {existing_user.email}")
|
||||
|
||||
await session.flush()
|
||||
|
||||
# Add admin user to map with special "__admin__" key
|
||||
# This allows demo data to reference the admin user as owner
|
||||
superuser_email = settings.FIRST_SUPERUSER_EMAIL or "admin@example.com"
|
||||
admin_user = await user_crud.get_by_email(session, email=superuser_email)
|
||||
if admin_user:
|
||||
user_map["__admin__"] = admin_user
|
||||
user_map[str(admin_user.email)] = admin_user
|
||||
logger.debug(f"Added admin user to map: {admin_user.email}")
|
||||
|
||||
# ========================
|
||||
# 3. Load Agent Types Map (for FK resolution)
|
||||
# ========================
|
||||
agent_types_result = await session.execute(select(AgentType))
|
||||
for at in agent_types_result.scalars().all():
|
||||
agent_type_map[str(at.slug)] = at
|
||||
|
||||
# ========================
|
||||
# 4. Create Projects
|
||||
# ========================
|
||||
for project_data in data.get("projects", []):
|
||||
project_result = await session.execute(
|
||||
select(Project).where(Project.slug == project_data["slug"])
|
||||
)
|
||||
existing_project = project_result.scalar_one_or_none()
|
||||
|
||||
if not existing_project:
|
||||
# Resolve owner email to user ID
|
||||
owner_id = None
|
||||
owner_email = project_data.get("owner_email")
|
||||
if owner_email and owner_email in user_map:
|
||||
owner_id = user_map[owner_email].id
|
||||
|
||||
project = Project(
|
||||
name=project_data["name"],
|
||||
slug=project_data["slug"],
|
||||
description=project_data.get("description"),
|
||||
owner_id=owner_id,
|
||||
autonomy_level=AutonomyLevel(
|
||||
project_data.get("autonomy_level", "milestone")
|
||||
),
|
||||
status=ProjectStatus(project_data.get("status", "active")),
|
||||
complexity=ProjectComplexity(
|
||||
project_data.get("complexity", "medium")
|
||||
),
|
||||
client_mode=ClientMode(project_data.get("client_mode", "auto")),
|
||||
settings=project_data.get("settings", {}),
|
||||
)
|
||||
session.add(project)
|
||||
await session.flush()
|
||||
project_map[str(project.slug)] = project
|
||||
logger.info(f"Created demo project: {project.name}")
|
||||
else:
|
||||
project_map[str(existing_project.slug)] = existing_project
|
||||
logger.debug(f"Demo project already exists: {existing_project.name}")
|
||||
|
||||
# ========================
|
||||
# 5. Create Sprints
|
||||
# ========================
|
||||
for sprint_data in data.get("sprints", []):
|
||||
project_slug = sprint_data["project_slug"]
|
||||
sprint_number = sprint_data["number"]
|
||||
sprint_key = f"{project_slug}:{sprint_number}"
|
||||
|
||||
if project_slug not in project_map:
|
||||
logger.warning(f"Project not found for sprint: {project_slug}")
|
||||
continue
|
||||
|
||||
sprint_project = project_map[project_slug]
|
||||
|
||||
# Check if sprint exists
|
||||
sprint_result = await session.execute(
|
||||
select(Sprint).where(
|
||||
Sprint.project_id == sprint_project.id,
|
||||
Sprint.number == sprint_number,
|
||||
)
|
||||
)
|
||||
existing_sprint = sprint_result.scalar_one_or_none()
|
||||
|
||||
if not existing_sprint:
|
||||
sprint = Sprint(
|
||||
project_id=sprint_project.id,
|
||||
name=sprint_data["name"],
|
||||
number=sprint_number,
|
||||
goal=sprint_data.get("goal"),
|
||||
start_date=date.fromisoformat(sprint_data["start_date"]),
|
||||
end_date=date.fromisoformat(sprint_data["end_date"]),
|
||||
status=SprintStatus(sprint_data.get("status", "planned")),
|
||||
planned_points=sprint_data.get("planned_points"),
|
||||
)
|
||||
session.add(sprint)
|
||||
await session.flush()
|
||||
sprint_map[sprint_key] = sprint
|
||||
logger.info(
|
||||
f"Created demo sprint: {sprint.name} for {sprint_project.name}"
|
||||
)
|
||||
else:
|
||||
sprint_map[sprint_key] = existing_sprint
|
||||
logger.debug(f"Demo sprint already exists: {existing_sprint.name}")
|
||||
|
||||
# ========================
|
||||
# 6. Create Agent Instances
|
||||
# ========================
|
||||
for agent_data in data.get("agent_instances", []):
|
||||
project_slug = agent_data["project_slug"]
|
||||
agent_type_slug = agent_data["agent_type_slug"]
|
||||
agent_name = agent_data["name"]
|
||||
agent_key = f"{project_slug}:{agent_name}"
|
||||
|
||||
if project_slug not in project_map:
|
||||
logger.warning(f"Project not found for agent: {project_slug}")
|
||||
continue
|
||||
|
||||
if agent_type_slug not in agent_type_map:
|
||||
logger.warning(f"Agent type not found: {agent_type_slug}")
|
||||
continue
|
||||
|
||||
agent_project = project_map[project_slug]
|
||||
agent_type = agent_type_map[agent_type_slug]
|
||||
|
||||
# Check if agent instance exists (by name within project)
|
||||
agent_result = await session.execute(
|
||||
select(AgentInstance).where(
|
||||
AgentInstance.project_id == agent_project.id,
|
||||
AgentInstance.name == agent_name,
|
||||
)
|
||||
)
|
||||
existing_agent = agent_result.scalar_one_or_none()
|
||||
|
||||
if not existing_agent:
|
||||
agent_instance = AgentInstance(
|
||||
project_id=agent_project.id,
|
||||
agent_type_id=agent_type.id,
|
||||
name=agent_name,
|
||||
status=AgentStatus(agent_data.get("status", "idle")),
|
||||
current_task=agent_data.get("current_task"),
|
||||
)
|
||||
session.add(agent_instance)
|
||||
await session.flush()
|
||||
agent_instance_map[agent_key] = agent_instance
|
||||
logger.info(
|
||||
f"Created demo agent: {agent_name} ({agent_type.name}) "
|
||||
f"for {agent_project.name}"
|
||||
)
|
||||
else:
|
||||
agent_instance_map[agent_key] = existing_agent
|
||||
logger.debug(f"Demo agent already exists: {existing_agent.name}")
|
||||
|
||||
# ========================
|
||||
# 7. Create Issues
|
||||
# ========================
|
||||
for issue_data in data.get("issues", []):
|
||||
project_slug = issue_data["project_slug"]
|
||||
|
||||
if project_slug not in project_map:
|
||||
logger.warning(f"Project not found for issue: {project_slug}")
|
||||
continue
|
||||
|
||||
issue_project = project_map[project_slug]
|
||||
|
||||
# Check if issue exists (by title within project - simple heuristic)
|
||||
issue_result = await session.execute(
|
||||
select(Issue).where(
|
||||
Issue.project_id == issue_project.id,
|
||||
Issue.title == issue_data["title"],
|
||||
)
|
||||
)
|
||||
existing_issue = issue_result.scalar_one_or_none()
|
||||
|
||||
if not existing_issue:
|
||||
# Resolve sprint
|
||||
sprint_id = None
|
||||
sprint_number = issue_data.get("sprint_number")
|
||||
if sprint_number:
|
||||
sprint_key = f"{project_slug}:{sprint_number}"
|
||||
if sprint_key in sprint_map:
|
||||
sprint_id = sprint_map[sprint_key].id
|
||||
|
||||
# Resolve assigned agent
|
||||
assigned_agent_id = None
|
||||
assigned_agent_name = issue_data.get("assigned_agent_name")
|
||||
if assigned_agent_name:
|
||||
agent_key = f"{project_slug}:{assigned_agent_name}"
|
||||
if agent_key in agent_instance_map:
|
||||
assigned_agent_id = agent_instance_map[agent_key].id
|
||||
|
||||
issue = Issue(
|
||||
project_id=issue_project.id,
|
||||
sprint_id=sprint_id,
|
||||
type=IssueType(issue_data.get("type", "task")),
|
||||
title=issue_data["title"],
|
||||
body=issue_data.get("body", ""),
|
||||
status=IssueStatus(issue_data.get("status", "open")),
|
||||
priority=IssuePriority(issue_data.get("priority", "medium")),
|
||||
labels=issue_data.get("labels", []),
|
||||
story_points=issue_data.get("story_points"),
|
||||
assigned_agent_id=assigned_agent_id,
|
||||
)
|
||||
session.add(issue)
|
||||
logger.info(f"Created demo issue: {issue.title[:50]}...")
|
||||
else:
|
||||
logger.debug(
|
||||
f"Demo issue already exists: {existing_issue.title[:50]}..."
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
logger.info("Demo data loaded successfully")
|
||||
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.error(f"Error loading demo data: {e}")
|
||||
raise
|
||||
|
||||
@@ -210,12 +498,12 @@ async def main():
|
||||
try:
|
||||
user = await init_db()
|
||||
if user:
|
||||
print("✓ Database initialized successfully")
|
||||
print(f"✓ Superuser: {user.email}")
|
||||
print("Database initialized successfully")
|
||||
print(f"Superuser: {user.email}")
|
||||
else:
|
||||
print("✗ Failed to initialize database")
|
||||
print("Failed to initialize database")
|
||||
except Exception as e:
|
||||
print(f"✗ Error initializing database: {e}")
|
||||
print(f"Error initializing database: {e}")
|
||||
raise
|
||||
finally:
|
||||
# Close the engine
|
||||
|
||||
@@ -8,6 +8,19 @@ from app.core.database import Base
|
||||
|
||||
from .base import TimestampMixin, UUIDMixin
|
||||
|
||||
# Memory system models
|
||||
from .memory import (
|
||||
ConsolidationStatus,
|
||||
ConsolidationType,
|
||||
Episode,
|
||||
EpisodeOutcome,
|
||||
Fact,
|
||||
MemoryConsolidationLog,
|
||||
Procedure,
|
||||
ScopeType,
|
||||
WorkingMemory,
|
||||
)
|
||||
|
||||
# OAuth models (client mode - authenticate via Google/GitHub)
|
||||
from .oauth_account import OAuthAccount
|
||||
|
||||
@@ -37,7 +50,14 @@ __all__ = [
|
||||
"AgentInstance",
|
||||
"AgentType",
|
||||
"Base",
|
||||
# Memory models
|
||||
"ConsolidationStatus",
|
||||
"ConsolidationType",
|
||||
"Episode",
|
||||
"EpisodeOutcome",
|
||||
"Fact",
|
||||
"Issue",
|
||||
"MemoryConsolidationLog",
|
||||
"OAuthAccount",
|
||||
"OAuthAuthorizationCode",
|
||||
"OAuthClient",
|
||||
@@ -46,11 +66,14 @@ __all__ = [
|
||||
"OAuthState",
|
||||
"Organization",
|
||||
"OrganizationRole",
|
||||
"Procedure",
|
||||
"Project",
|
||||
"ScopeType",
|
||||
"Sprint",
|
||||
"TimestampMixin",
|
||||
"UUIDMixin",
|
||||
"User",
|
||||
"UserOrganization",
|
||||
"UserSession",
|
||||
"WorkingMemory",
|
||||
]
|
||||
|
||||
32
backend/app/models/memory/__init__.py
Normal file
32
backend/app/models/memory/__init__.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# app/models/memory/__init__.py
|
||||
"""
|
||||
Memory System Database Models.
|
||||
|
||||
Provides SQLAlchemy models for the Agent Memory System:
|
||||
- WorkingMemory: Key-value storage with TTL
|
||||
- Episode: Experiential memories
|
||||
- Fact: Semantic knowledge triples
|
||||
- Procedure: Learned skills
|
||||
- MemoryConsolidationLog: Consolidation job tracking
|
||||
"""
|
||||
|
||||
from .consolidation import MemoryConsolidationLog
|
||||
from .enums import ConsolidationStatus, ConsolidationType, EpisodeOutcome, ScopeType
|
||||
from .episode import Episode
|
||||
from .fact import Fact
|
||||
from .procedure import Procedure
|
||||
from .working_memory import WorkingMemory
|
||||
|
||||
__all__ = [
|
||||
# Enums
|
||||
"ConsolidationStatus",
|
||||
"ConsolidationType",
|
||||
# Models
|
||||
"Episode",
|
||||
"EpisodeOutcome",
|
||||
"Fact",
|
||||
"MemoryConsolidationLog",
|
||||
"Procedure",
|
||||
"ScopeType",
|
||||
"WorkingMemory",
|
||||
]
|
||||
72
backend/app/models/memory/consolidation.py
Normal file
72
backend/app/models/memory/consolidation.py
Normal file
@@ -0,0 +1,72 @@
|
||||
# app/models/memory/consolidation.py
|
||||
"""
|
||||
Memory Consolidation Log database model.
|
||||
|
||||
Tracks memory consolidation jobs that transfer knowledge
|
||||
between memory tiers.
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, DateTime, Enum, Index, Integer, Text
|
||||
|
||||
from app.models.base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
from .enums import ConsolidationStatus, ConsolidationType
|
||||
|
||||
|
||||
class MemoryConsolidationLog(Base, UUIDMixin, TimestampMixin):
|
||||
"""
|
||||
Memory consolidation job log.
|
||||
|
||||
Tracks consolidation operations:
|
||||
- Working -> Episodic (session end)
|
||||
- Episodic -> Semantic (fact extraction)
|
||||
- Episodic -> Procedural (procedure learning)
|
||||
- Pruning (removing low-value memories)
|
||||
"""
|
||||
|
||||
__tablename__ = "memory_consolidation_log"
|
||||
|
||||
# Consolidation type
|
||||
consolidation_type: Column[ConsolidationType] = Column(
|
||||
Enum(ConsolidationType),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Counts
|
||||
source_count = Column(Integer, nullable=False, default=0)
|
||||
result_count = Column(Integer, nullable=False, default=0)
|
||||
|
||||
# Timing
|
||||
started_at = Column(DateTime(timezone=True), nullable=False)
|
||||
completed_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Status
|
||||
status: Column[ConsolidationStatus] = Column(
|
||||
Enum(ConsolidationStatus),
|
||||
nullable=False,
|
||||
default=ConsolidationStatus.PENDING,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Error details if failed
|
||||
error = Column(Text, nullable=True)
|
||||
|
||||
__table_args__ = (
|
||||
# Query patterns
|
||||
Index("ix_consolidation_type_status", "consolidation_type", "status"),
|
||||
Index("ix_consolidation_started", "started_at"),
|
||||
)
|
||||
|
||||
@property
|
||||
def duration_seconds(self) -> float | None:
|
||||
"""Calculate duration of the consolidation job."""
|
||||
if self.completed_at is None or self.started_at is None:
|
||||
return None
|
||||
return (self.completed_at - self.started_at).total_seconds()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<MemoryConsolidationLog {self.id} "
|
||||
f"type={self.consolidation_type.value} status={self.status.value}>"
|
||||
)
|
||||
73
backend/app/models/memory/enums.py
Normal file
73
backend/app/models/memory/enums.py
Normal file
@@ -0,0 +1,73 @@
|
||||
# app/models/memory/enums.py
|
||||
"""
|
||||
Enums for Memory System database models.
|
||||
|
||||
These enums define the database-level constraints for memory types
|
||||
and scoping levels.
|
||||
"""
|
||||
|
||||
from enum import Enum as PyEnum
|
||||
|
||||
|
||||
class ScopeType(str, PyEnum):
|
||||
"""
|
||||
Memory scope levels matching the memory service types.
|
||||
|
||||
GLOBAL: System-wide memories accessible by all
|
||||
PROJECT: Project-scoped memories
|
||||
AGENT_TYPE: Type-specific memories (shared by instances of same type)
|
||||
AGENT_INSTANCE: Instance-specific memories
|
||||
SESSION: Session-scoped ephemeral memories
|
||||
"""
|
||||
|
||||
GLOBAL = "global"
|
||||
PROJECT = "project"
|
||||
AGENT_TYPE = "agent_type"
|
||||
AGENT_INSTANCE = "agent_instance"
|
||||
SESSION = "session"
|
||||
|
||||
|
||||
class EpisodeOutcome(str, PyEnum):
|
||||
"""
|
||||
Outcome of an episode (task execution).
|
||||
|
||||
SUCCESS: Task completed successfully
|
||||
FAILURE: Task failed
|
||||
PARTIAL: Task partially completed
|
||||
"""
|
||||
|
||||
SUCCESS = "success"
|
||||
FAILURE = "failure"
|
||||
PARTIAL = "partial"
|
||||
|
||||
|
||||
class ConsolidationType(str, PyEnum):
|
||||
"""
|
||||
Types of memory consolidation operations.
|
||||
|
||||
WORKING_TO_EPISODIC: Transfer session state to episodic
|
||||
EPISODIC_TO_SEMANTIC: Extract facts from episodes
|
||||
EPISODIC_TO_PROCEDURAL: Extract procedures from episodes
|
||||
PRUNING: Remove low-value memories
|
||||
"""
|
||||
|
||||
WORKING_TO_EPISODIC = "working_to_episodic"
|
||||
EPISODIC_TO_SEMANTIC = "episodic_to_semantic"
|
||||
EPISODIC_TO_PROCEDURAL = "episodic_to_procedural"
|
||||
PRUNING = "pruning"
|
||||
|
||||
|
||||
class ConsolidationStatus(str, PyEnum):
|
||||
"""
|
||||
Status of a consolidation job.
|
||||
|
||||
PENDING: Job is queued
|
||||
RUNNING: Job is currently executing
|
||||
COMPLETED: Job finished successfully
|
||||
FAILED: Job failed with errors
|
||||
"""
|
||||
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
139
backend/app/models/memory/episode.py
Normal file
139
backend/app/models/memory/episode.py
Normal file
@@ -0,0 +1,139 @@
|
||||
# app/models/memory/episode.py
|
||||
"""
|
||||
Episode database model.
|
||||
|
||||
Stores experiential memories - records of past task executions
|
||||
with context, actions, outcomes, and lessons learned.
|
||||
"""
|
||||
|
||||
from sqlalchemy import (
|
||||
BigInteger,
|
||||
CheckConstraint,
|
||||
Column,
|
||||
DateTime,
|
||||
Enum,
|
||||
Float,
|
||||
ForeignKey,
|
||||
Index,
|
||||
String,
|
||||
Text,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import (
|
||||
JSONB,
|
||||
UUID as PGUUID,
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.models.base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
from .enums import EpisodeOutcome
|
||||
|
||||
# Import pgvector type - will be available after migration enables extension
|
||||
try:
|
||||
from pgvector.sqlalchemy import Vector # type: ignore[import-not-found]
|
||||
except ImportError:
|
||||
# Fallback for environments without pgvector
|
||||
Vector = None
|
||||
|
||||
|
||||
class Episode(Base, UUIDMixin, TimestampMixin):
|
||||
"""
|
||||
Episodic memory model.
|
||||
|
||||
Records experiential memories from agent task execution:
|
||||
- What task was performed
|
||||
- What actions were taken
|
||||
- What was the outcome
|
||||
- What lessons were learned
|
||||
"""
|
||||
|
||||
__tablename__ = "episodes"
|
||||
|
||||
# Foreign keys
|
||||
project_id = Column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("projects.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
agent_instance_id = Column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("agent_instances.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
|
||||
agent_type_id = Column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("agent_types.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Session reference
|
||||
session_id = Column(String(255), nullable=False, index=True)
|
||||
|
||||
# Task information
|
||||
task_type = Column(String(100), nullable=False, index=True)
|
||||
task_description = Column(Text, nullable=False)
|
||||
|
||||
# Actions taken (list of action dictionaries)
|
||||
actions = Column(JSONB, default=list, nullable=False)
|
||||
|
||||
# Context summary
|
||||
context_summary = Column(Text, nullable=False)
|
||||
|
||||
# Outcome
|
||||
outcome: Column[EpisodeOutcome] = Column(
|
||||
Enum(EpisodeOutcome),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
outcome_details = Column(Text, nullable=True)
|
||||
|
||||
# Metrics
|
||||
duration_seconds = Column(Float, nullable=False, default=0.0)
|
||||
tokens_used = Column(BigInteger, nullable=False, default=0)
|
||||
|
||||
# Learning
|
||||
lessons_learned = Column(JSONB, default=list, nullable=False)
|
||||
importance_score = Column(Float, nullable=False, default=0.5, index=True)
|
||||
|
||||
# Vector embedding for semantic search
|
||||
# Using 1536 dimensions for OpenAI text-embedding-3-small
|
||||
embedding = Column(Vector(1536) if Vector else Text, nullable=True)
|
||||
|
||||
# When the episode occurred
|
||||
occurred_at = Column(DateTime(timezone=True), nullable=False, index=True)
|
||||
|
||||
# Relationships
|
||||
project = relationship("Project", foreign_keys=[project_id])
|
||||
agent_instance = relationship("AgentInstance", foreign_keys=[agent_instance_id])
|
||||
agent_type = relationship("AgentType", foreign_keys=[agent_type_id])
|
||||
|
||||
__table_args__ = (
|
||||
# Primary query patterns
|
||||
Index("ix_episodes_project_task", "project_id", "task_type"),
|
||||
Index("ix_episodes_project_outcome", "project_id", "outcome"),
|
||||
Index("ix_episodes_agent_task", "agent_instance_id", "task_type"),
|
||||
Index("ix_episodes_project_time", "project_id", "occurred_at"),
|
||||
# For importance-based pruning
|
||||
Index("ix_episodes_importance_time", "importance_score", "occurred_at"),
|
||||
# Data integrity constraints
|
||||
CheckConstraint(
|
||||
"importance_score >= 0.0 AND importance_score <= 1.0",
|
||||
name="ck_episodes_importance_range",
|
||||
),
|
||||
CheckConstraint(
|
||||
"duration_seconds >= 0.0",
|
||||
name="ck_episodes_duration_positive",
|
||||
),
|
||||
CheckConstraint(
|
||||
"tokens_used >= 0",
|
||||
name="ck_episodes_tokens_positive",
|
||||
),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Episode {self.id} task={self.task_type} outcome={self.outcome.value}>"
|
||||
120
backend/app/models/memory/fact.py
Normal file
120
backend/app/models/memory/fact.py
Normal file
@@ -0,0 +1,120 @@
|
||||
# app/models/memory/fact.py
|
||||
"""
|
||||
Fact database model.
|
||||
|
||||
Stores semantic memories - learned facts in subject-predicate-object
|
||||
triple format with confidence scores and source tracking.
|
||||
"""
|
||||
|
||||
from sqlalchemy import (
|
||||
CheckConstraint,
|
||||
Column,
|
||||
DateTime,
|
||||
Float,
|
||||
ForeignKey,
|
||||
Index,
|
||||
Integer,
|
||||
String,
|
||||
Text,
|
||||
text,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import (
|
||||
JSONB,
|
||||
UUID as PGUUID,
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.models.base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
# Import pgvector type
|
||||
try:
|
||||
from pgvector.sqlalchemy import Vector # type: ignore[import-not-found]
|
||||
except ImportError:
|
||||
Vector = None
|
||||
|
||||
|
||||
class Fact(Base, UUIDMixin, TimestampMixin):
|
||||
"""
|
||||
Semantic memory model.
|
||||
|
||||
Stores learned facts as subject-predicate-object triples:
|
||||
- "FastAPI" - "uses" - "Starlette framework"
|
||||
- "Project Alpha" - "requires" - "OAuth authentication"
|
||||
|
||||
Facts have confidence scores that decay over time and can be
|
||||
reinforced when the same fact is learned again.
|
||||
"""
|
||||
|
||||
__tablename__ = "facts"
|
||||
|
||||
# Scoping: project_id is NULL for global facts
|
||||
project_id = Column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("projects.id", ondelete="CASCADE"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Triple format
|
||||
subject = Column(String(500), nullable=False, index=True)
|
||||
predicate = Column(String(255), nullable=False, index=True)
|
||||
object = Column(Text, nullable=False)
|
||||
|
||||
# Confidence score (0.0 to 1.0)
|
||||
confidence = Column(Float, nullable=False, default=0.8, index=True)
|
||||
|
||||
# Source tracking: which episodes contributed to this fact (stored as JSONB array of UUID strings)
|
||||
source_episode_ids: Column[list] = Column(JSONB, default=list, nullable=False)
|
||||
|
||||
# Learning history
|
||||
first_learned = Column(DateTime(timezone=True), nullable=False)
|
||||
last_reinforced = Column(DateTime(timezone=True), nullable=False)
|
||||
reinforcement_count = Column(Integer, nullable=False, default=1)
|
||||
|
||||
# Vector embedding for semantic search
|
||||
embedding = Column(Vector(1536) if Vector else Text, nullable=True)
|
||||
|
||||
# Relationships
|
||||
project = relationship("Project", foreign_keys=[project_id])
|
||||
|
||||
__table_args__ = (
|
||||
# Unique constraint on triple within project scope
|
||||
Index(
|
||||
"ix_facts_unique_triple",
|
||||
"project_id",
|
||||
"subject",
|
||||
"predicate",
|
||||
"object",
|
||||
unique=True,
|
||||
postgresql_where=text("project_id IS NOT NULL"),
|
||||
),
|
||||
# Unique constraint on triple for global facts (project_id IS NULL)
|
||||
Index(
|
||||
"ix_facts_unique_triple_global",
|
||||
"subject",
|
||||
"predicate",
|
||||
"object",
|
||||
unique=True,
|
||||
postgresql_where=text("project_id IS NULL"),
|
||||
),
|
||||
# Query patterns
|
||||
Index("ix_facts_subject_predicate", "subject", "predicate"),
|
||||
Index("ix_facts_project_subject", "project_id", "subject"),
|
||||
Index("ix_facts_confidence_time", "confidence", "last_reinforced"),
|
||||
# Note: subject already has index=True on Column definition, no need for explicit index
|
||||
# Data integrity constraints
|
||||
CheckConstraint(
|
||||
"confidence >= 0.0 AND confidence <= 1.0",
|
||||
name="ck_facts_confidence_range",
|
||||
),
|
||||
CheckConstraint(
|
||||
"reinforcement_count >= 1",
|
||||
name="ck_facts_reinforcement_positive",
|
||||
),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<Fact {self.id} '{self.subject}' - '{self.predicate}' - "
|
||||
f"'{self.object[:50]}...' conf={self.confidence:.2f}>"
|
||||
)
|
||||
129
backend/app/models/memory/procedure.py
Normal file
129
backend/app/models/memory/procedure.py
Normal file
@@ -0,0 +1,129 @@
|
||||
# app/models/memory/procedure.py
|
||||
"""
|
||||
Procedure database model.
|
||||
|
||||
Stores procedural memories - learned skills and procedures
|
||||
derived from successful task execution patterns.
|
||||
"""
|
||||
|
||||
from sqlalchemy import (
|
||||
CheckConstraint,
|
||||
Column,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Index,
|
||||
Integer,
|
||||
String,
|
||||
Text,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import (
|
||||
JSONB,
|
||||
UUID as PGUUID,
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.models.base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
# Import pgvector type
|
||||
try:
|
||||
from pgvector.sqlalchemy import Vector # type: ignore[import-not-found]
|
||||
except ImportError:
|
||||
Vector = None
|
||||
|
||||
|
||||
class Procedure(Base, UUIDMixin, TimestampMixin):
|
||||
"""
|
||||
Procedural memory model.
|
||||
|
||||
Stores learned procedures (skills) extracted from successful
|
||||
task execution patterns:
|
||||
- Name and trigger pattern for matching
|
||||
- Step-by-step actions
|
||||
- Success/failure tracking
|
||||
"""
|
||||
|
||||
__tablename__ = "procedures"
|
||||
|
||||
# Scoping
|
||||
project_id = Column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("projects.id", ondelete="CASCADE"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
|
||||
agent_type_id = Column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("agent_types.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Procedure identification
|
||||
name = Column(String(255), nullable=False, index=True)
|
||||
trigger_pattern = Column(Text, nullable=False)
|
||||
|
||||
# Steps as JSON array of step objects
|
||||
# Each step: {order, action, parameters, expected_outcome, fallback_action}
|
||||
steps = Column(JSONB, default=list, nullable=False)
|
||||
|
||||
# Success tracking
|
||||
success_count = Column(Integer, nullable=False, default=0)
|
||||
failure_count = Column(Integer, nullable=False, default=0)
|
||||
|
||||
# Usage tracking
|
||||
last_used = Column(DateTime(timezone=True), nullable=True, index=True)
|
||||
|
||||
# Vector embedding for semantic matching
|
||||
embedding = Column(Vector(1536) if Vector else Text, nullable=True)
|
||||
|
||||
# Relationships
|
||||
project = relationship("Project", foreign_keys=[project_id])
|
||||
agent_type = relationship("AgentType", foreign_keys=[agent_type_id])
|
||||
|
||||
__table_args__ = (
|
||||
# Unique procedure name within scope
|
||||
Index(
|
||||
"ix_procedures_unique_name",
|
||||
"project_id",
|
||||
"agent_type_id",
|
||||
"name",
|
||||
unique=True,
|
||||
),
|
||||
# Query patterns
|
||||
Index("ix_procedures_project_name", "project_id", "name"),
|
||||
# Note: agent_type_id already has index=True on Column definition
|
||||
# For finding best procedures
|
||||
Index("ix_procedures_success_rate", "success_count", "failure_count"),
|
||||
# Data integrity constraints
|
||||
CheckConstraint(
|
||||
"success_count >= 0",
|
||||
name="ck_procedures_success_positive",
|
||||
),
|
||||
CheckConstraint(
|
||||
"failure_count >= 0",
|
||||
name="ck_procedures_failure_positive",
|
||||
),
|
||||
)
|
||||
|
||||
@property
|
||||
def success_rate(self) -> float:
|
||||
"""Calculate the success rate of this procedure."""
|
||||
# Snapshot values to avoid race conditions in concurrent access
|
||||
success = self.success_count
|
||||
failure = self.failure_count
|
||||
total = success + failure
|
||||
if total == 0:
|
||||
return 0.0
|
||||
return success / total
|
||||
|
||||
@property
|
||||
def total_uses(self) -> int:
|
||||
"""Get total number of times this procedure was used."""
|
||||
# Snapshot values for consistency
|
||||
return self.success_count + self.failure_count
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<Procedure {self.name} ({self.id}) success_rate={self.success_rate:.2%}>"
|
||||
)
|
||||
58
backend/app/models/memory/working_memory.py
Normal file
58
backend/app/models/memory/working_memory.py
Normal file
@@ -0,0 +1,58 @@
|
||||
# app/models/memory/working_memory.py
|
||||
"""
|
||||
Working Memory database model.
|
||||
|
||||
Stores ephemeral key-value data for active sessions with TTL support.
|
||||
Used as database backup when Redis is unavailable.
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, DateTime, Enum, Index, String
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
|
||||
from app.models.base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
from .enums import ScopeType
|
||||
|
||||
|
||||
class WorkingMemory(Base, UUIDMixin, TimestampMixin):
|
||||
"""
|
||||
Working memory storage table.
|
||||
|
||||
Provides database-backed working memory as fallback when
|
||||
Redis is unavailable. Supports TTL-based expiration.
|
||||
"""
|
||||
|
||||
__tablename__ = "working_memory"
|
||||
|
||||
# Scoping
|
||||
scope_type: Column[ScopeType] = Column(
|
||||
Enum(ScopeType),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
scope_id = Column(String(255), nullable=False, index=True)
|
||||
|
||||
# Key-value storage
|
||||
key = Column(String(255), nullable=False)
|
||||
value = Column(JSONB, nullable=False)
|
||||
|
||||
# TTL support
|
||||
expires_at = Column(DateTime(timezone=True), nullable=True, index=True)
|
||||
|
||||
__table_args__ = (
|
||||
# Primary lookup: scope + key
|
||||
Index(
|
||||
"ix_working_memory_scope_key",
|
||||
"scope_type",
|
||||
"scope_id",
|
||||
"key",
|
||||
unique=True,
|
||||
),
|
||||
# For cleanup of expired entries
|
||||
Index("ix_working_memory_expires", "expires_at"),
|
||||
# For listing all keys in a scope
|
||||
Index("ix_working_memory_scope_list", "scope_type", "scope_id"),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<WorkingMemory {self.scope_type.value}:{self.scope_id}:{self.key}>"
|
||||
@@ -62,7 +62,11 @@ class AgentInstance(Base, UUIDMixin, TimestampMixin):
|
||||
|
||||
# Status tracking
|
||||
status: Column[AgentStatus] = Column(
|
||||
Enum(AgentStatus),
|
||||
Enum(
|
||||
AgentStatus,
|
||||
name="agent_status",
|
||||
values_callable=lambda x: [e.value for e in x],
|
||||
),
|
||||
default=AgentStatus.IDLE,
|
||||
nullable=False,
|
||||
index=True,
|
||||
|
||||
@@ -6,7 +6,7 @@ An AgentType is a template that defines the capabilities, personality,
|
||||
and model configuration for agent instances.
|
||||
"""
|
||||
|
||||
from sqlalchemy import Boolean, Column, Index, String, Text
|
||||
from sqlalchemy import Boolean, Column, Index, Integer, String, Text
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
@@ -56,6 +56,24 @@ class AgentType(Base, UUIDMixin, TimestampMixin):
|
||||
# Whether this agent type is available for new instances
|
||||
is_active = Column(Boolean, default=True, nullable=False, index=True)
|
||||
|
||||
# Category for grouping agents (development, design, quality, etc.)
|
||||
category = Column(String(50), nullable=True, index=True)
|
||||
|
||||
# Lucide icon identifier for UI display (e.g., "code", "palette", "shield")
|
||||
icon = Column(String(50), nullable=True, default="bot")
|
||||
|
||||
# Hex color code for visual distinction (e.g., "#3B82F6")
|
||||
color = Column(String(7), nullable=True, default="#3B82F6")
|
||||
|
||||
# Display ordering within category (lower = first)
|
||||
sort_order = Column(Integer, nullable=False, default=0, index=True)
|
||||
|
||||
# List of typical tasks this agent excels at
|
||||
typical_tasks = Column(JSONB, default=list, nullable=False)
|
||||
|
||||
# List of agent slugs that collaborate well with this type
|
||||
collaboration_hints = Column(JSONB, default=list, nullable=False)
|
||||
|
||||
# Relationships
|
||||
instances = relationship(
|
||||
"AgentInstance",
|
||||
@@ -66,6 +84,7 @@ class AgentType(Base, UUIDMixin, TimestampMixin):
|
||||
__table_args__ = (
|
||||
Index("ix_agent_types_slug_active", "slug", "is_active"),
|
||||
Index("ix_agent_types_name_active", "name", "is_active"),
|
||||
Index("ix_agent_types_category_sort", "category", "sort_order"),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
|
||||
@@ -167,3 +167,29 @@ class SprintStatus(str, PyEnum):
|
||||
IN_REVIEW = "in_review"
|
||||
COMPLETED = "completed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class AgentTypeCategory(str, PyEnum):
|
||||
"""
|
||||
Category classification for agent types.
|
||||
|
||||
Used for grouping and filtering agents in the UI.
|
||||
|
||||
DEVELOPMENT: Product, project, and engineering roles
|
||||
DESIGN: UI/UX and design research roles
|
||||
QUALITY: QA and security engineering
|
||||
OPERATIONS: DevOps and MLOps
|
||||
AI_ML: Machine learning and AI specialists
|
||||
DATA: Data science and engineering
|
||||
LEADERSHIP: Technical leadership roles
|
||||
DOMAIN_EXPERT: Industry and domain specialists
|
||||
"""
|
||||
|
||||
DEVELOPMENT = "development"
|
||||
DESIGN = "design"
|
||||
QUALITY = "quality"
|
||||
OPERATIONS = "operations"
|
||||
AI_ML = "ai_ml"
|
||||
DATA = "data"
|
||||
LEADERSHIP = "leadership"
|
||||
DOMAIN_EXPERT = "domain_expert"
|
||||
|
||||
@@ -59,7 +59,9 @@ class Issue(Base, UUIDMixin, TimestampMixin):
|
||||
|
||||
# Issue type (Epic, Story, Task, Bug)
|
||||
type: Column[IssueType] = Column(
|
||||
Enum(IssueType),
|
||||
Enum(
|
||||
IssueType, name="issue_type", values_callable=lambda x: [e.value for e in x]
|
||||
),
|
||||
default=IssueType.TASK,
|
||||
nullable=False,
|
||||
index=True,
|
||||
@@ -78,14 +80,22 @@ class Issue(Base, UUIDMixin, TimestampMixin):
|
||||
|
||||
# Status and priority
|
||||
status: Column[IssueStatus] = Column(
|
||||
Enum(IssueStatus),
|
||||
Enum(
|
||||
IssueStatus,
|
||||
name="issue_status",
|
||||
values_callable=lambda x: [e.value for e in x],
|
||||
),
|
||||
default=IssueStatus.OPEN,
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
priority: Column[IssuePriority] = Column(
|
||||
Enum(IssuePriority),
|
||||
Enum(
|
||||
IssuePriority,
|
||||
name="issue_priority",
|
||||
values_callable=lambda x: [e.value for e in x],
|
||||
),
|
||||
default=IssuePriority.MEDIUM,
|
||||
nullable=False,
|
||||
index=True,
|
||||
@@ -132,7 +142,11 @@ class Issue(Base, UUIDMixin, TimestampMixin):
|
||||
|
||||
# Sync status with external tracker
|
||||
sync_status: Column[SyncStatus] = Column(
|
||||
Enum(SyncStatus),
|
||||
Enum(
|
||||
SyncStatus,
|
||||
name="sync_status",
|
||||
values_callable=lambda x: [e.value for e in x],
|
||||
),
|
||||
default=SyncStatus.SYNCED,
|
||||
nullable=False,
|
||||
# Note: Index defined in __table_args__ as ix_issues_sync_status
|
||||
|
||||
@@ -35,28 +35,44 @@ class Project(Base, UUIDMixin, TimestampMixin):
|
||||
description = Column(Text, nullable=True)
|
||||
|
||||
autonomy_level: Column[AutonomyLevel] = Column(
|
||||
Enum(AutonomyLevel),
|
||||
Enum(
|
||||
AutonomyLevel,
|
||||
name="autonomy_level",
|
||||
values_callable=lambda x: [e.value for e in x],
|
||||
),
|
||||
default=AutonomyLevel.MILESTONE,
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
status: Column[ProjectStatus] = Column(
|
||||
Enum(ProjectStatus),
|
||||
Enum(
|
||||
ProjectStatus,
|
||||
name="project_status",
|
||||
values_callable=lambda x: [e.value for e in x],
|
||||
),
|
||||
default=ProjectStatus.ACTIVE,
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
complexity: Column[ProjectComplexity] = Column(
|
||||
Enum(ProjectComplexity),
|
||||
Enum(
|
||||
ProjectComplexity,
|
||||
name="project_complexity",
|
||||
values_callable=lambda x: [e.value for e in x],
|
||||
),
|
||||
default=ProjectComplexity.MEDIUM,
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
client_mode: Column[ClientMode] = Column(
|
||||
Enum(ClientMode),
|
||||
Enum(
|
||||
ClientMode,
|
||||
name="client_mode",
|
||||
values_callable=lambda x: [e.value for e in x],
|
||||
),
|
||||
default=ClientMode.AUTO,
|
||||
nullable=False,
|
||||
index=True,
|
||||
|
||||
@@ -57,7 +57,11 @@ class Sprint(Base, UUIDMixin, TimestampMixin):
|
||||
|
||||
# Status
|
||||
status: Column[SprintStatus] = Column(
|
||||
Enum(SprintStatus),
|
||||
Enum(
|
||||
SprintStatus,
|
||||
name="sprint_status",
|
||||
values_callable=lambda x: [e.value for e in x],
|
||||
),
|
||||
default=SprintStatus.PLANNED,
|
||||
nullable=False,
|
||||
index=True,
|
||||
|
||||
@@ -10,6 +10,8 @@ from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from app.models.syndarix.enums import AgentTypeCategory
|
||||
|
||||
|
||||
class AgentTypeBase(BaseModel):
|
||||
"""Base agent type schema with common fields."""
|
||||
@@ -26,6 +28,14 @@ class AgentTypeBase(BaseModel):
|
||||
tool_permissions: dict[str, Any] = Field(default_factory=dict)
|
||||
is_active: bool = True
|
||||
|
||||
# Category and display fields
|
||||
category: AgentTypeCategory | None = None
|
||||
icon: str | None = Field(None, max_length=50)
|
||||
color: str | None = Field(None, pattern=r"^#[0-9A-Fa-f]{6}$")
|
||||
sort_order: int = Field(default=0, ge=0, le=1000)
|
||||
typical_tasks: list[str] = Field(default_factory=list)
|
||||
collaboration_hints: list[str] = Field(default_factory=list)
|
||||
|
||||
@field_validator("slug")
|
||||
@classmethod
|
||||
def validate_slug(cls, v: str | None) -> str | None:
|
||||
@@ -62,6 +72,18 @@ class AgentTypeBase(BaseModel):
|
||||
"""Validate MCP server list."""
|
||||
return [s.strip() for s in v if s.strip()]
|
||||
|
||||
@field_validator("typical_tasks")
|
||||
@classmethod
|
||||
def validate_typical_tasks(cls, v: list[str]) -> list[str]:
|
||||
"""Validate and normalize typical tasks list."""
|
||||
return [t.strip() for t in v if t.strip()]
|
||||
|
||||
@field_validator("collaboration_hints")
|
||||
@classmethod
|
||||
def validate_collaboration_hints(cls, v: list[str]) -> list[str]:
|
||||
"""Validate and normalize collaboration hints (agent slugs)."""
|
||||
return [h.strip().lower() for h in v if h.strip()]
|
||||
|
||||
|
||||
class AgentTypeCreate(AgentTypeBase):
|
||||
"""Schema for creating a new agent type."""
|
||||
@@ -87,6 +109,14 @@ class AgentTypeUpdate(BaseModel):
|
||||
tool_permissions: dict[str, Any] | None = None
|
||||
is_active: bool | None = None
|
||||
|
||||
# Category and display fields (all optional for updates)
|
||||
category: AgentTypeCategory | None = None
|
||||
icon: str | None = Field(None, max_length=50)
|
||||
color: str | None = Field(None, pattern=r"^#[0-9A-Fa-f]{6}$")
|
||||
sort_order: int | None = Field(None, ge=0, le=1000)
|
||||
typical_tasks: list[str] | None = None
|
||||
collaboration_hints: list[str] | None = None
|
||||
|
||||
@field_validator("slug")
|
||||
@classmethod
|
||||
def validate_slug(cls, v: str | None) -> str | None:
|
||||
@@ -119,6 +149,22 @@ class AgentTypeUpdate(BaseModel):
|
||||
return v
|
||||
return [e.strip().lower() for e in v if e.strip()]
|
||||
|
||||
@field_validator("typical_tasks")
|
||||
@classmethod
|
||||
def validate_typical_tasks(cls, v: list[str] | None) -> list[str] | None:
|
||||
"""Validate and normalize typical tasks list."""
|
||||
if v is None:
|
||||
return v
|
||||
return [t.strip() for t in v if t.strip()]
|
||||
|
||||
@field_validator("collaboration_hints")
|
||||
@classmethod
|
||||
def validate_collaboration_hints(cls, v: list[str] | None) -> list[str] | None:
|
||||
"""Validate and normalize collaboration hints (agent slugs)."""
|
||||
if v is None:
|
||||
return v
|
||||
return [h.strip().lower() for h in v if h.strip()]
|
||||
|
||||
|
||||
class AgentTypeInDB(AgentTypeBase):
|
||||
"""Schema for agent type in database."""
|
||||
|
||||
@@ -114,6 +114,8 @@ from .types import (
|
||||
ContextType,
|
||||
ConversationContext,
|
||||
KnowledgeContext,
|
||||
MemoryContext,
|
||||
MemorySubtype,
|
||||
MessageRole,
|
||||
SystemContext,
|
||||
TaskComplexity,
|
||||
@@ -149,6 +151,8 @@ __all__ = [
|
||||
"FormattingError",
|
||||
"InvalidContextError",
|
||||
"KnowledgeContext",
|
||||
"MemoryContext",
|
||||
"MemorySubtype",
|
||||
"MessageRole",
|
||||
"ModelAdapter",
|
||||
"OpenAIAdapter",
|
||||
|
||||
@@ -30,6 +30,7 @@ class TokenBudget:
|
||||
knowledge: int = 0
|
||||
conversation: int = 0
|
||||
tools: int = 0
|
||||
memory: int = 0 # Agent memory (working, episodic, semantic, procedural)
|
||||
response_reserve: int = 0
|
||||
buffer: int = 0
|
||||
|
||||
@@ -60,6 +61,7 @@ class TokenBudget:
|
||||
"knowledge": self.knowledge,
|
||||
"conversation": self.conversation,
|
||||
"tool": self.tools,
|
||||
"memory": self.memory,
|
||||
}
|
||||
return allocation_map.get(context_type, 0)
|
||||
|
||||
@@ -211,6 +213,7 @@ class TokenBudget:
|
||||
"knowledge": self.knowledge,
|
||||
"conversation": self.conversation,
|
||||
"tools": self.tools,
|
||||
"memory": self.memory,
|
||||
"response_reserve": self.response_reserve,
|
||||
"buffer": self.buffer,
|
||||
},
|
||||
@@ -264,9 +267,10 @@ class BudgetAllocator:
|
||||
total=total_tokens,
|
||||
system=int(total_tokens * alloc.get("system", 0.05)),
|
||||
task=int(total_tokens * alloc.get("task", 0.10)),
|
||||
knowledge=int(total_tokens * alloc.get("knowledge", 0.40)),
|
||||
conversation=int(total_tokens * alloc.get("conversation", 0.20)),
|
||||
knowledge=int(total_tokens * alloc.get("knowledge", 0.30)),
|
||||
conversation=int(total_tokens * alloc.get("conversation", 0.15)),
|
||||
tools=int(total_tokens * alloc.get("tools", 0.05)),
|
||||
memory=int(total_tokens * alloc.get("memory", 0.15)),
|
||||
response_reserve=int(total_tokens * alloc.get("response", 0.15)),
|
||||
buffer=int(total_tokens * alloc.get("buffer", 0.05)),
|
||||
)
|
||||
@@ -317,6 +321,8 @@ class BudgetAllocator:
|
||||
budget.conversation = max(0, budget.conversation + actual_adjustment)
|
||||
elif context_type == "tool":
|
||||
budget.tools = max(0, budget.tools + actual_adjustment)
|
||||
elif context_type == "memory":
|
||||
budget.memory = max(0, budget.memory + actual_adjustment)
|
||||
|
||||
return budget
|
||||
|
||||
@@ -338,7 +344,12 @@ class BudgetAllocator:
|
||||
Rebalanced budget
|
||||
"""
|
||||
if prioritize is None:
|
||||
prioritize = [ContextType.KNOWLEDGE, ContextType.TASK, ContextType.SYSTEM]
|
||||
prioritize = [
|
||||
ContextType.KNOWLEDGE,
|
||||
ContextType.MEMORY,
|
||||
ContextType.TASK,
|
||||
ContextType.SYSTEM,
|
||||
]
|
||||
|
||||
# Calculate unused tokens per type
|
||||
unused: dict[str, int] = {}
|
||||
|
||||
@@ -7,6 +7,7 @@ Provides a high-level API for assembling optimized context for LLM requests.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from uuid import UUID
|
||||
|
||||
from .assembly import ContextPipeline
|
||||
from .budget import BudgetAllocator, TokenBudget, TokenCalculator
|
||||
@@ -20,6 +21,7 @@ from .types import (
|
||||
BaseContext,
|
||||
ConversationContext,
|
||||
KnowledgeContext,
|
||||
MemoryContext,
|
||||
MessageRole,
|
||||
SystemContext,
|
||||
TaskContext,
|
||||
@@ -30,6 +32,7 @@ if TYPE_CHECKING:
|
||||
from redis.asyncio import Redis
|
||||
|
||||
from app.services.mcp.client_manager import MCPClientManager
|
||||
from app.services.memory.integration import MemoryContextSource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -64,6 +67,7 @@ class ContextEngine:
|
||||
mcp_manager: "MCPClientManager | None" = None,
|
||||
redis: "Redis | None" = None,
|
||||
settings: ContextSettings | None = None,
|
||||
memory_source: "MemoryContextSource | None" = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the context engine.
|
||||
@@ -72,9 +76,11 @@ class ContextEngine:
|
||||
mcp_manager: MCP client manager for LLM Gateway/Knowledge Base
|
||||
redis: Redis connection for caching
|
||||
settings: Context settings
|
||||
memory_source: Optional memory context source for agent memory
|
||||
"""
|
||||
self._mcp = mcp_manager
|
||||
self._settings = settings or get_context_settings()
|
||||
self._memory_source = memory_source
|
||||
|
||||
# Initialize components
|
||||
self._calculator = TokenCalculator(mcp_manager=mcp_manager)
|
||||
@@ -115,6 +121,15 @@ class ContextEngine:
|
||||
"""
|
||||
self._cache.set_redis(redis)
|
||||
|
||||
def set_memory_source(self, memory_source: "MemoryContextSource") -> None:
|
||||
"""
|
||||
Set memory context source for agent memory integration.
|
||||
|
||||
Args:
|
||||
memory_source: Memory context source
|
||||
"""
|
||||
self._memory_source = memory_source
|
||||
|
||||
async def assemble_context(
|
||||
self,
|
||||
project_id: str,
|
||||
@@ -126,6 +141,10 @@ class ContextEngine:
|
||||
task_description: str | None = None,
|
||||
knowledge_query: str | None = None,
|
||||
knowledge_limit: int = 10,
|
||||
memory_query: str | None = None,
|
||||
memory_limit: int = 20,
|
||||
session_id: str | None = None,
|
||||
agent_type_id: str | None = None,
|
||||
conversation_history: list[dict[str, str]] | None = None,
|
||||
tool_results: list[dict[str, Any]] | None = None,
|
||||
custom_contexts: list[BaseContext] | None = None,
|
||||
@@ -151,6 +170,10 @@ class ContextEngine:
|
||||
task_description: Current task description
|
||||
knowledge_query: Query for knowledge base search
|
||||
knowledge_limit: Max number of knowledge results
|
||||
memory_query: Query for agent memory search
|
||||
memory_limit: Max number of memory results
|
||||
session_id: Session ID for working memory access
|
||||
agent_type_id: Agent type ID for procedural memory
|
||||
conversation_history: List of {"role": str, "content": str}
|
||||
tool_results: List of tool results to include
|
||||
custom_contexts: Additional custom contexts
|
||||
@@ -197,15 +220,27 @@ class ContextEngine:
|
||||
)
|
||||
contexts.extend(knowledge_contexts)
|
||||
|
||||
# 4. Conversation history
|
||||
# 4. Memory context from Agent Memory System
|
||||
if memory_query and self._memory_source:
|
||||
memory_contexts = await self._fetch_memory(
|
||||
project_id=project_id,
|
||||
agent_id=agent_id,
|
||||
query=memory_query,
|
||||
limit=memory_limit,
|
||||
session_id=session_id,
|
||||
agent_type_id=agent_type_id,
|
||||
)
|
||||
contexts.extend(memory_contexts)
|
||||
|
||||
# 5. Conversation history
|
||||
if conversation_history:
|
||||
contexts.extend(self._convert_conversation(conversation_history))
|
||||
|
||||
# 5. Tool results
|
||||
# 6. Tool results
|
||||
if tool_results:
|
||||
contexts.extend(self._convert_tool_results(tool_results))
|
||||
|
||||
# 6. Custom contexts
|
||||
# 7. Custom contexts
|
||||
if custom_contexts:
|
||||
contexts.extend(custom_contexts)
|
||||
|
||||
@@ -308,6 +343,65 @@ class ContextEngine:
|
||||
logger.warning(f"Failed to fetch knowledge: {e}")
|
||||
return []
|
||||
|
||||
async def _fetch_memory(
|
||||
self,
|
||||
project_id: str,
|
||||
agent_id: str,
|
||||
query: str,
|
||||
limit: int = 20,
|
||||
session_id: str | None = None,
|
||||
agent_type_id: str | None = None,
|
||||
) -> list[MemoryContext]:
|
||||
"""
|
||||
Fetch relevant memories from Agent Memory System.
|
||||
|
||||
Args:
|
||||
project_id: Project identifier
|
||||
agent_id: Agent identifier
|
||||
query: Search query
|
||||
limit: Maximum results
|
||||
session_id: Session ID for working memory
|
||||
agent_type_id: Agent type ID for procedural memory
|
||||
|
||||
Returns:
|
||||
List of MemoryContext instances
|
||||
"""
|
||||
if not self._memory_source:
|
||||
return []
|
||||
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
|
||||
# Configure fetch limits
|
||||
from app.services.memory.integration.context_source import MemoryFetchConfig
|
||||
|
||||
config = MemoryFetchConfig(
|
||||
working_limit=min(limit // 4, 5),
|
||||
episodic_limit=min(limit // 2, 10),
|
||||
semantic_limit=min(limit // 2, 10),
|
||||
procedural_limit=min(limit // 4, 5),
|
||||
include_working=session_id is not None,
|
||||
)
|
||||
|
||||
result = await self._memory_source.fetch_context(
|
||||
query=query,
|
||||
project_id=UUID(project_id),
|
||||
agent_instance_id=UUID(agent_id) if agent_id else None,
|
||||
agent_type_id=UUID(agent_type_id) if agent_type_id else None,
|
||||
session_id=session_id,
|
||||
config=config,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Fetched {len(result.contexts)} memory contexts for query: {query}, "
|
||||
f"by_type: {result.by_type}"
|
||||
)
|
||||
return result.contexts[:limit]
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch memory: {e}")
|
||||
return []
|
||||
|
||||
def _convert_conversation(
|
||||
self,
|
||||
history: list[dict[str, str]],
|
||||
@@ -466,6 +560,7 @@ def create_context_engine(
|
||||
mcp_manager: "MCPClientManager | None" = None,
|
||||
redis: "Redis | None" = None,
|
||||
settings: ContextSettings | None = None,
|
||||
memory_source: "MemoryContextSource | None" = None,
|
||||
) -> ContextEngine:
|
||||
"""
|
||||
Create a context engine instance.
|
||||
@@ -474,6 +569,7 @@ def create_context_engine(
|
||||
mcp_manager: MCP client manager
|
||||
redis: Redis connection
|
||||
settings: Context settings
|
||||
memory_source: Optional memory context source
|
||||
|
||||
Returns:
|
||||
Configured ContextEngine instance
|
||||
@@ -482,4 +578,5 @@ def create_context_engine(
|
||||
mcp_manager=mcp_manager,
|
||||
redis=redis,
|
||||
settings=settings,
|
||||
memory_source=memory_source,
|
||||
)
|
||||
|
||||
@@ -15,6 +15,10 @@ from .conversation import (
|
||||
MessageRole,
|
||||
)
|
||||
from .knowledge import KnowledgeContext
|
||||
from .memory import (
|
||||
MemoryContext,
|
||||
MemorySubtype,
|
||||
)
|
||||
from .system import SystemContext
|
||||
from .task import (
|
||||
TaskComplexity,
|
||||
@@ -33,6 +37,8 @@ __all__ = [
|
||||
"ContextType",
|
||||
"ConversationContext",
|
||||
"KnowledgeContext",
|
||||
"MemoryContext",
|
||||
"MemorySubtype",
|
||||
"MessageRole",
|
||||
"SystemContext",
|
||||
"TaskComplexity",
|
||||
|
||||
@@ -26,6 +26,7 @@ class ContextType(str, Enum):
|
||||
KNOWLEDGE = "knowledge"
|
||||
CONVERSATION = "conversation"
|
||||
TOOL = "tool"
|
||||
MEMORY = "memory" # Agent memory (working, episodic, semantic, procedural)
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, value: str) -> "ContextType":
|
||||
|
||||
282
backend/app/services/context/types/memory.py
Normal file
282
backend/app/services/context/types/memory.py
Normal file
@@ -0,0 +1,282 @@
|
||||
"""
|
||||
Memory Context Type.
|
||||
|
||||
Represents agent memory as context for LLM requests.
|
||||
Includes working, episodic, semantic, and procedural memories.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from .base import BaseContext, ContextPriority, ContextType
|
||||
|
||||
|
||||
class MemorySubtype(str, Enum):
|
||||
"""Types of agent memory."""
|
||||
|
||||
WORKING = "working" # Session-scoped temporary data
|
||||
EPISODIC = "episodic" # Task history and outcomes
|
||||
SEMANTIC = "semantic" # Facts and knowledge
|
||||
PROCEDURAL = "procedural" # Learned procedures
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class MemoryContext(BaseContext):
|
||||
"""
|
||||
Context from agent memory system.
|
||||
|
||||
Memory context represents data retrieved from the agent
|
||||
memory system, including:
|
||||
- Working memory: Current session state
|
||||
- Episodic memory: Past task experiences
|
||||
- Semantic memory: Learned facts and knowledge
|
||||
- Procedural memory: Known procedures and workflows
|
||||
|
||||
Each memory item includes relevance scoring from search.
|
||||
"""
|
||||
|
||||
# Memory-specific fields
|
||||
memory_subtype: MemorySubtype = field(default=MemorySubtype.EPISODIC)
|
||||
memory_id: str | None = field(default=None)
|
||||
relevance_score: float = field(default=0.0)
|
||||
importance: float = field(default=0.5)
|
||||
search_query: str = field(default="")
|
||||
|
||||
# Type-specific fields (populated based on memory_subtype)
|
||||
key: str | None = field(default=None) # For working memory
|
||||
task_type: str | None = field(default=None) # For episodic
|
||||
outcome: str | None = field(default=None) # For episodic
|
||||
subject: str | None = field(default=None) # For semantic
|
||||
predicate: str | None = field(default=None) # For semantic
|
||||
object_value: str | None = field(default=None) # For semantic
|
||||
trigger: str | None = field(default=None) # For procedural
|
||||
success_rate: float | None = field(default=None) # For procedural
|
||||
|
||||
def get_type(self) -> ContextType:
|
||||
"""Return MEMORY context type."""
|
||||
return ContextType.MEMORY
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary with memory-specific fields."""
|
||||
base = super().to_dict()
|
||||
base.update(
|
||||
{
|
||||
"memory_subtype": self.memory_subtype.value,
|
||||
"memory_id": self.memory_id,
|
||||
"relevance_score": self.relevance_score,
|
||||
"importance": self.importance,
|
||||
"search_query": self.search_query,
|
||||
"key": self.key,
|
||||
"task_type": self.task_type,
|
||||
"outcome": self.outcome,
|
||||
"subject": self.subject,
|
||||
"predicate": self.predicate,
|
||||
"object_value": self.object_value,
|
||||
"trigger": self.trigger,
|
||||
"success_rate": self.success_rate,
|
||||
}
|
||||
)
|
||||
return base
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "MemoryContext":
|
||||
"""Create MemoryContext from dictionary."""
|
||||
return cls(
|
||||
id=data.get("id", ""),
|
||||
content=data["content"],
|
||||
source=data["source"],
|
||||
timestamp=datetime.fromisoformat(data["timestamp"])
|
||||
if isinstance(data.get("timestamp"), str)
|
||||
else data.get("timestamp", datetime.now(UTC)),
|
||||
priority=data.get("priority", ContextPriority.NORMAL.value),
|
||||
metadata=data.get("metadata", {}),
|
||||
memory_subtype=MemorySubtype(data.get("memory_subtype", "episodic")),
|
||||
memory_id=data.get("memory_id"),
|
||||
relevance_score=data.get("relevance_score", 0.0),
|
||||
importance=data.get("importance", 0.5),
|
||||
search_query=data.get("search_query", ""),
|
||||
key=data.get("key"),
|
||||
task_type=data.get("task_type"),
|
||||
outcome=data.get("outcome"),
|
||||
subject=data.get("subject"),
|
||||
predicate=data.get("predicate"),
|
||||
object_value=data.get("object_value"),
|
||||
trigger=data.get("trigger"),
|
||||
success_rate=data.get("success_rate"),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_working_memory(
|
||||
cls,
|
||||
key: str,
|
||||
value: Any,
|
||||
source: str = "working_memory",
|
||||
query: str = "",
|
||||
) -> "MemoryContext":
|
||||
"""
|
||||
Create MemoryContext from working memory entry.
|
||||
|
||||
Args:
|
||||
key: Working memory key
|
||||
value: Value stored at key
|
||||
source: Source identifier
|
||||
query: Search query used
|
||||
|
||||
Returns:
|
||||
MemoryContext instance
|
||||
"""
|
||||
return cls(
|
||||
content=str(value),
|
||||
source=source,
|
||||
memory_subtype=MemorySubtype.WORKING,
|
||||
key=key,
|
||||
relevance_score=1.0, # Working memory is always relevant
|
||||
importance=0.8, # Higher importance for current session state
|
||||
search_query=query,
|
||||
priority=ContextPriority.HIGH.value,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_episodic_memory(
|
||||
cls,
|
||||
episode: Any,
|
||||
query: str = "",
|
||||
) -> "MemoryContext":
|
||||
"""
|
||||
Create MemoryContext from episodic memory episode.
|
||||
|
||||
Args:
|
||||
episode: Episode object from episodic memory
|
||||
query: Search query used
|
||||
|
||||
Returns:
|
||||
MemoryContext instance
|
||||
"""
|
||||
outcome_val = None
|
||||
if hasattr(episode, "outcome") and episode.outcome:
|
||||
outcome_val = (
|
||||
episode.outcome.value
|
||||
if hasattr(episode.outcome, "value")
|
||||
else str(episode.outcome)
|
||||
)
|
||||
|
||||
return cls(
|
||||
content=episode.task_description,
|
||||
source=f"episodic:{episode.id}",
|
||||
memory_subtype=MemorySubtype.EPISODIC,
|
||||
memory_id=str(episode.id),
|
||||
relevance_score=getattr(episode, "importance_score", 0.5),
|
||||
importance=getattr(episode, "importance_score", 0.5),
|
||||
search_query=query,
|
||||
task_type=getattr(episode, "task_type", None),
|
||||
outcome=outcome_val,
|
||||
metadata={
|
||||
"session_id": getattr(episode, "session_id", None),
|
||||
"occurred_at": episode.occurred_at.isoformat()
|
||||
if hasattr(episode, "occurred_at") and episode.occurred_at
|
||||
else None,
|
||||
"lessons_learned": getattr(episode, "lessons_learned", []),
|
||||
},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_semantic_memory(
|
||||
cls,
|
||||
fact: Any,
|
||||
query: str = "",
|
||||
) -> "MemoryContext":
|
||||
"""
|
||||
Create MemoryContext from semantic memory fact.
|
||||
|
||||
Args:
|
||||
fact: Fact object from semantic memory
|
||||
query: Search query used
|
||||
|
||||
Returns:
|
||||
MemoryContext instance
|
||||
"""
|
||||
triple = f"{fact.subject} {fact.predicate} {fact.object}"
|
||||
return cls(
|
||||
content=triple,
|
||||
source=f"semantic:{fact.id}",
|
||||
memory_subtype=MemorySubtype.SEMANTIC,
|
||||
memory_id=str(fact.id),
|
||||
relevance_score=getattr(fact, "confidence", 0.5),
|
||||
importance=getattr(fact, "confidence", 0.5),
|
||||
search_query=query,
|
||||
subject=fact.subject,
|
||||
predicate=fact.predicate,
|
||||
object_value=fact.object,
|
||||
priority=ContextPriority.NORMAL.value,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_procedural_memory(
|
||||
cls,
|
||||
procedure: Any,
|
||||
query: str = "",
|
||||
) -> "MemoryContext":
|
||||
"""
|
||||
Create MemoryContext from procedural memory procedure.
|
||||
|
||||
Args:
|
||||
procedure: Procedure object from procedural memory
|
||||
query: Search query used
|
||||
|
||||
Returns:
|
||||
MemoryContext instance
|
||||
"""
|
||||
# Format steps as content
|
||||
steps = getattr(procedure, "steps", [])
|
||||
steps_content = "\n".join(
|
||||
f" {i + 1}. {step.get('action', step) if isinstance(step, dict) else step}"
|
||||
for i, step in enumerate(steps)
|
||||
)
|
||||
content = f"Procedure: {procedure.name}\nTrigger: {procedure.trigger_pattern}\nSteps:\n{steps_content}"
|
||||
|
||||
return cls(
|
||||
content=content,
|
||||
source=f"procedural:{procedure.id}",
|
||||
memory_subtype=MemorySubtype.PROCEDURAL,
|
||||
memory_id=str(procedure.id),
|
||||
relevance_score=getattr(procedure, "success_rate", 0.5),
|
||||
importance=0.7, # Procedures are moderately important
|
||||
search_query=query,
|
||||
trigger=procedure.trigger_pattern,
|
||||
success_rate=getattr(procedure, "success_rate", None),
|
||||
metadata={
|
||||
"steps_count": len(steps),
|
||||
"execution_count": getattr(procedure, "success_count", 0)
|
||||
+ getattr(procedure, "failure_count", 0),
|
||||
},
|
||||
)
|
||||
|
||||
def is_working_memory(self) -> bool:
|
||||
"""Check if this is working memory."""
|
||||
return self.memory_subtype == MemorySubtype.WORKING
|
||||
|
||||
def is_episodic_memory(self) -> bool:
|
||||
"""Check if this is episodic memory."""
|
||||
return self.memory_subtype == MemorySubtype.EPISODIC
|
||||
|
||||
def is_semantic_memory(self) -> bool:
|
||||
"""Check if this is semantic memory."""
|
||||
return self.memory_subtype == MemorySubtype.SEMANTIC
|
||||
|
||||
def is_procedural_memory(self) -> bool:
|
||||
"""Check if this is procedural memory."""
|
||||
return self.memory_subtype == MemorySubtype.PROCEDURAL
|
||||
|
||||
def get_formatted_source(self) -> str:
|
||||
"""
|
||||
Get a formatted source string for display.
|
||||
|
||||
Returns:
|
||||
Formatted source string
|
||||
"""
|
||||
parts = [f"[{self.memory_subtype.value}]", self.source]
|
||||
if self.memory_id:
|
||||
parts.append(f"({self.memory_id[:8]}...)")
|
||||
return " ".join(parts)
|
||||
@@ -122,16 +122,24 @@ class MCPClientManager:
|
||||
)
|
||||
|
||||
async def _connect_all_servers(self) -> None:
|
||||
"""Connect to all enabled MCP servers."""
|
||||
"""Connect to all enabled MCP servers concurrently."""
|
||||
import asyncio
|
||||
|
||||
enabled_servers = self._registry.get_enabled_configs()
|
||||
|
||||
for name, config in enabled_servers.items():
|
||||
async def connect_server(name: str, config: "MCPServerConfig") -> None:
|
||||
try:
|
||||
await self._pool.get_connection(name, config)
|
||||
logger.info("Connected to MCP server: %s", name)
|
||||
except Exception as e:
|
||||
logger.error("Failed to connect to MCP server %s: %s", name, e)
|
||||
|
||||
# Connect to all servers concurrently for faster startup
|
||||
await asyncio.gather(
|
||||
*(connect_server(name, config) for name, config in enabled_servers.items()),
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""
|
||||
Shutdown the MCP client manager.
|
||||
|
||||
@@ -179,6 +179,8 @@ def load_mcp_config(path: str | Path | None = None) -> MCPConfig:
|
||||
2. MCP_CONFIG_PATH environment variable
|
||||
3. Default path (backend/mcp_servers.yaml)
|
||||
4. Empty config if no file exists
|
||||
|
||||
In test mode (IS_TEST=True), retry settings are reduced for faster tests.
|
||||
"""
|
||||
if path is None:
|
||||
path = os.environ.get("MCP_CONFIG_PATH", str(DEFAULT_CONFIG_PATH))
|
||||
@@ -189,7 +191,18 @@ def load_mcp_config(path: str | Path | None = None) -> MCPConfig:
|
||||
# Return empty config if no file exists (allows runtime registration)
|
||||
return MCPConfig()
|
||||
|
||||
return MCPConfig.from_yaml(path)
|
||||
config = MCPConfig.from_yaml(path)
|
||||
|
||||
# In test mode, reduce retry settings to speed up tests
|
||||
is_test = os.environ.get("IS_TEST", "").lower() in ("true", "1", "yes")
|
||||
if is_test:
|
||||
for server_config in config.mcp_servers.values():
|
||||
server_config.retry_attempts = 1 # Single attempt
|
||||
server_config.retry_delay = 0.1 # 100ms instead of 1s
|
||||
server_config.retry_max_delay = 0.5 # 500ms max
|
||||
server_config.timeout = 2 # 2s timeout instead of 30-120s
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def create_default_config() -> MCPConfig:
|
||||
|
||||
141
backend/app/services/memory/__init__.py
Normal file
141
backend/app/services/memory/__init__.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""
|
||||
Agent Memory System
|
||||
|
||||
Multi-tier cognitive memory for AI agents, providing:
|
||||
- Working Memory: Session-scoped ephemeral state (Redis/In-memory)
|
||||
- Episodic Memory: Experiential records of past tasks (PostgreSQL)
|
||||
- Semantic Memory: Learned facts and knowledge (PostgreSQL + pgvector)
|
||||
- Procedural Memory: Learned skills and procedures (PostgreSQL)
|
||||
|
||||
Usage:
|
||||
from app.services.memory import (
|
||||
MemoryManager,
|
||||
MemorySettings,
|
||||
get_memory_settings,
|
||||
MemoryType,
|
||||
ScopeLevel,
|
||||
)
|
||||
|
||||
# Create a manager for a session
|
||||
manager = MemoryManager.for_session(
|
||||
session_id="sess-123",
|
||||
project_id=uuid,
|
||||
)
|
||||
|
||||
async with manager:
|
||||
# Working memory
|
||||
await manager.set_working("key", {"data": "value"})
|
||||
value = await manager.get_working("key")
|
||||
|
||||
# Episodic memory
|
||||
episode = await manager.record_episode(episode_data)
|
||||
similar = await manager.search_episodes("query")
|
||||
|
||||
# Semantic memory
|
||||
fact = await manager.store_fact(fact_data)
|
||||
facts = await manager.search_facts("query")
|
||||
|
||||
# Procedural memory
|
||||
procedure = await manager.record_procedure(procedure_data)
|
||||
procedures = await manager.find_procedures("context")
|
||||
"""
|
||||
|
||||
# Configuration
|
||||
from .config import (
|
||||
MemorySettings,
|
||||
get_default_settings,
|
||||
get_memory_settings,
|
||||
reset_memory_settings,
|
||||
)
|
||||
|
||||
# Exceptions
|
||||
from .exceptions import (
|
||||
CheckpointError,
|
||||
EmbeddingError,
|
||||
MemoryCapacityError,
|
||||
MemoryConflictError,
|
||||
MemoryConsolidationError,
|
||||
MemoryError,
|
||||
MemoryExpiredError,
|
||||
MemoryNotFoundError,
|
||||
MemoryRetrievalError,
|
||||
MemoryScopeError,
|
||||
MemorySerializationError,
|
||||
MemoryStorageError,
|
||||
)
|
||||
|
||||
# Manager
|
||||
from .manager import MemoryManager
|
||||
|
||||
# Types
|
||||
from .types import (
|
||||
ConsolidationStatus,
|
||||
ConsolidationType,
|
||||
Episode,
|
||||
EpisodeCreate,
|
||||
Fact,
|
||||
FactCreate,
|
||||
MemoryItem,
|
||||
MemoryStats,
|
||||
MemoryStore,
|
||||
MemoryType,
|
||||
Outcome,
|
||||
Procedure,
|
||||
ProcedureCreate,
|
||||
RetrievalResult,
|
||||
ScopeContext,
|
||||
ScopeLevel,
|
||||
Step,
|
||||
TaskState,
|
||||
WorkingMemoryItem,
|
||||
)
|
||||
|
||||
# Reflection (lazy import available)
|
||||
# Import directly: from app.services.memory.reflection import MemoryReflection
|
||||
|
||||
__all__ = [
|
||||
"CheckpointError",
|
||||
"ConsolidationStatus",
|
||||
"ConsolidationType",
|
||||
"EmbeddingError",
|
||||
"Episode",
|
||||
"EpisodeCreate",
|
||||
"Fact",
|
||||
"FactCreate",
|
||||
"MemoryCapacityError",
|
||||
"MemoryConflictError",
|
||||
"MemoryConsolidationError",
|
||||
# Exceptions
|
||||
"MemoryError",
|
||||
"MemoryExpiredError",
|
||||
"MemoryItem",
|
||||
# Manager
|
||||
"MemoryManager",
|
||||
"MemoryNotFoundError",
|
||||
"MemoryRetrievalError",
|
||||
"MemoryScopeError",
|
||||
"MemorySerializationError",
|
||||
# Configuration
|
||||
"MemorySettings",
|
||||
"MemoryStats",
|
||||
"MemoryStorageError",
|
||||
# Types - Abstract
|
||||
"MemoryStore",
|
||||
# Types - Enums
|
||||
"MemoryType",
|
||||
"Outcome",
|
||||
"Procedure",
|
||||
"ProcedureCreate",
|
||||
"RetrievalResult",
|
||||
# Types - Data Classes
|
||||
"ScopeContext",
|
||||
"ScopeLevel",
|
||||
"Step",
|
||||
"TaskState",
|
||||
"WorkingMemoryItem",
|
||||
"get_default_settings",
|
||||
"get_memory_settings",
|
||||
"reset_memory_settings",
|
||||
# MCP Tools - lazy import to avoid circular dependencies
|
||||
# Import directly: from app.services.memory.mcp import MemoryToolService
|
||||
]
|
||||
21
backend/app/services/memory/cache/__init__.py
vendored
Normal file
21
backend/app/services/memory/cache/__init__.py
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
# app/services/memory/cache/__init__.py
|
||||
"""
|
||||
Memory Caching Layer.
|
||||
|
||||
Provides caching for memory operations:
|
||||
- Hot Memory Cache: LRU cache for frequently accessed memories
|
||||
- Embedding Cache: Cache embeddings by content hash
|
||||
- Cache Manager: Unified cache management with invalidation
|
||||
"""
|
||||
|
||||
from .cache_manager import CacheManager, CacheStats, get_cache_manager
|
||||
from .embedding_cache import EmbeddingCache
|
||||
from .hot_cache import HotMemoryCache
|
||||
|
||||
__all__ = [
|
||||
"CacheManager",
|
||||
"CacheStats",
|
||||
"EmbeddingCache",
|
||||
"HotMemoryCache",
|
||||
"get_cache_manager",
|
||||
]
|
||||
505
backend/app/services/memory/cache/cache_manager.py
vendored
Normal file
505
backend/app/services/memory/cache/cache_manager.py
vendored
Normal file
@@ -0,0 +1,505 @@
|
||||
# app/services/memory/cache/cache_manager.py
|
||||
"""
|
||||
Cache Manager.
|
||||
|
||||
Unified cache management for memory operations.
|
||||
Coordinates hot cache, embedding cache, and retrieval cache.
|
||||
Provides centralized invalidation and statistics.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from uuid import UUID
|
||||
|
||||
from app.services.memory.config import get_memory_settings
|
||||
|
||||
from .embedding_cache import EmbeddingCache, create_embedding_cache
|
||||
from .hot_cache import CacheKey, HotMemoryCache, create_hot_cache
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis.asyncio import Redis
|
||||
|
||||
from app.services.memory.indexing.retrieval import RetrievalCache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _utcnow() -> datetime:
|
||||
"""Get current UTC time as timezone-aware datetime."""
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheStats:
|
||||
"""Aggregated cache statistics."""
|
||||
|
||||
hot_cache: dict[str, Any] = field(default_factory=dict)
|
||||
embedding_cache: dict[str, Any] = field(default_factory=dict)
|
||||
retrieval_cache: dict[str, Any] = field(default_factory=dict)
|
||||
overall_hit_rate: float = 0.0
|
||||
last_cleanup: datetime | None = None
|
||||
cleanup_count: int = 0
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"hot_cache": self.hot_cache,
|
||||
"embedding_cache": self.embedding_cache,
|
||||
"retrieval_cache": self.retrieval_cache,
|
||||
"overall_hit_rate": self.overall_hit_rate,
|
||||
"last_cleanup": self.last_cleanup.isoformat()
|
||||
if self.last_cleanup
|
||||
else None,
|
||||
"cleanup_count": self.cleanup_count,
|
||||
}
|
||||
|
||||
|
||||
class CacheManager:
|
||||
"""
|
||||
Unified cache manager for memory operations.
|
||||
|
||||
Provides:
|
||||
- Centralized cache configuration
|
||||
- Coordinated invalidation across caches
|
||||
- Aggregated statistics
|
||||
- Automatic cleanup scheduling
|
||||
|
||||
Performance targets:
|
||||
- Overall cache hit rate > 80%
|
||||
- Cache operations < 1ms (memory), < 5ms (Redis)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hot_cache: HotMemoryCache[Any] | None = None,
|
||||
embedding_cache: EmbeddingCache | None = None,
|
||||
retrieval_cache: "RetrievalCache | None" = None,
|
||||
redis: "Redis | None" = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the cache manager.
|
||||
|
||||
Args:
|
||||
hot_cache: Optional pre-configured hot cache
|
||||
embedding_cache: Optional pre-configured embedding cache
|
||||
retrieval_cache: Optional pre-configured retrieval cache
|
||||
redis: Optional Redis connection for persistence
|
||||
"""
|
||||
self._settings = get_memory_settings()
|
||||
self._redis = redis
|
||||
self._enabled = self._settings.cache_enabled
|
||||
|
||||
# Initialize caches
|
||||
if hot_cache:
|
||||
self._hot_cache = hot_cache
|
||||
else:
|
||||
self._hot_cache = create_hot_cache(
|
||||
max_size=self._settings.cache_max_items,
|
||||
default_ttl_seconds=self._settings.cache_ttl_seconds,
|
||||
)
|
||||
|
||||
if embedding_cache:
|
||||
self._embedding_cache = embedding_cache
|
||||
else:
|
||||
self._embedding_cache = create_embedding_cache(
|
||||
max_size=self._settings.cache_max_items,
|
||||
default_ttl_seconds=self._settings.cache_ttl_seconds
|
||||
* 12, # 1hr for embeddings
|
||||
redis=redis,
|
||||
)
|
||||
|
||||
self._retrieval_cache = retrieval_cache
|
||||
|
||||
# Stats tracking
|
||||
self._last_cleanup: datetime | None = None
|
||||
self._cleanup_count = 0
|
||||
self._lock = threading.RLock()
|
||||
|
||||
logger.info(
|
||||
f"Initialized CacheManager: enabled={self._enabled}, "
|
||||
f"redis={'connected' if redis else 'disabled'}"
|
||||
)
|
||||
|
||||
def set_redis(self, redis: "Redis") -> None:
|
||||
"""Set Redis connection for all caches."""
|
||||
self._redis = redis
|
||||
self._embedding_cache.set_redis(redis)
|
||||
|
||||
def set_retrieval_cache(self, cache: "RetrievalCache") -> None:
|
||||
"""Set retrieval cache instance."""
|
||||
self._retrieval_cache = cache
|
||||
|
||||
@property
|
||||
def is_enabled(self) -> bool:
|
||||
"""Check if caching is enabled."""
|
||||
return self._enabled
|
||||
|
||||
@property
|
||||
def hot_cache(self) -> HotMemoryCache[Any]:
|
||||
"""Get the hot memory cache."""
|
||||
return self._hot_cache
|
||||
|
||||
@property
|
||||
def embedding_cache(self) -> EmbeddingCache:
|
||||
"""Get the embedding cache."""
|
||||
return self._embedding_cache
|
||||
|
||||
@property
|
||||
def retrieval_cache(self) -> "RetrievalCache | None":
|
||||
"""Get the retrieval cache."""
|
||||
return self._retrieval_cache
|
||||
|
||||
# =========================================================================
|
||||
# Hot Memory Cache Operations
|
||||
# =========================================================================
|
||||
|
||||
def get_memory(
|
||||
self,
|
||||
memory_type: str,
|
||||
memory_id: UUID | str,
|
||||
scope: str | None = None,
|
||||
) -> Any | None:
|
||||
"""
|
||||
Get a memory from hot cache.
|
||||
|
||||
Args:
|
||||
memory_type: Type of memory
|
||||
memory_id: Memory ID
|
||||
scope: Optional scope
|
||||
|
||||
Returns:
|
||||
Cached memory or None
|
||||
"""
|
||||
if not self._enabled:
|
||||
return None
|
||||
return self._hot_cache.get_by_id(memory_type, memory_id, scope)
|
||||
|
||||
def cache_memory(
|
||||
self,
|
||||
memory_type: str,
|
||||
memory_id: UUID | str,
|
||||
memory: Any,
|
||||
scope: str | None = None,
|
||||
ttl_seconds: float | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Cache a memory in hot cache.
|
||||
|
||||
Args:
|
||||
memory_type: Type of memory
|
||||
memory_id: Memory ID
|
||||
memory: Memory object
|
||||
scope: Optional scope
|
||||
ttl_seconds: Optional TTL override
|
||||
"""
|
||||
if not self._enabled:
|
||||
return
|
||||
self._hot_cache.put_by_id(memory_type, memory_id, memory, scope, ttl_seconds)
|
||||
|
||||
# =========================================================================
|
||||
# Embedding Cache Operations
|
||||
# =========================================================================
|
||||
|
||||
async def get_embedding(
|
||||
self,
|
||||
content: str,
|
||||
model: str = "default",
|
||||
) -> list[float] | None:
|
||||
"""
|
||||
Get a cached embedding.
|
||||
|
||||
Args:
|
||||
content: Content text
|
||||
model: Model name
|
||||
|
||||
Returns:
|
||||
Cached embedding or None
|
||||
"""
|
||||
if not self._enabled:
|
||||
return None
|
||||
return await self._embedding_cache.get(content, model)
|
||||
|
||||
async def cache_embedding(
|
||||
self,
|
||||
content: str,
|
||||
embedding: list[float],
|
||||
model: str = "default",
|
||||
ttl_seconds: float | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Cache an embedding.
|
||||
|
||||
Args:
|
||||
content: Content text
|
||||
embedding: Embedding vector
|
||||
model: Model name
|
||||
ttl_seconds: Optional TTL override
|
||||
|
||||
Returns:
|
||||
Content hash
|
||||
"""
|
||||
if not self._enabled:
|
||||
return EmbeddingCache.hash_content(content)
|
||||
return await self._embedding_cache.put(content, embedding, model, ttl_seconds)
|
||||
|
||||
# =========================================================================
|
||||
# Invalidation
|
||||
# =========================================================================
|
||||
|
||||
async def invalidate_memory(
|
||||
self,
|
||||
memory_type: str,
|
||||
memory_id: UUID | str,
|
||||
scope: str | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Invalidate a memory across all caches.
|
||||
|
||||
Args:
|
||||
memory_type: Type of memory
|
||||
memory_id: Memory ID
|
||||
scope: Optional scope
|
||||
|
||||
Returns:
|
||||
Number of entries invalidated
|
||||
"""
|
||||
count = 0
|
||||
|
||||
# Invalidate hot cache
|
||||
if self._hot_cache.invalidate_by_id(memory_type, memory_id, scope):
|
||||
count += 1
|
||||
|
||||
# Invalidate retrieval cache
|
||||
if self._retrieval_cache:
|
||||
uuid_id = (
|
||||
UUID(str(memory_id)) if not isinstance(memory_id, UUID) else memory_id
|
||||
)
|
||||
count += self._retrieval_cache.invalidate_by_memory(uuid_id)
|
||||
|
||||
logger.debug(f"Invalidated {count} cache entries for {memory_type}:{memory_id}")
|
||||
return count
|
||||
|
||||
async def invalidate_by_type(self, memory_type: str) -> int:
|
||||
"""
|
||||
Invalidate all entries of a memory type.
|
||||
|
||||
Args:
|
||||
memory_type: Type of memory
|
||||
|
||||
Returns:
|
||||
Number of entries invalidated
|
||||
"""
|
||||
count = self._hot_cache.invalidate_by_type(memory_type)
|
||||
|
||||
if self._retrieval_cache:
|
||||
count += self._retrieval_cache.clear()
|
||||
|
||||
logger.info(f"Invalidated {count} cache entries for type {memory_type}")
|
||||
return count
|
||||
|
||||
async def invalidate_by_scope(self, scope: str) -> int:
|
||||
"""
|
||||
Invalidate all entries in a scope.
|
||||
|
||||
Args:
|
||||
scope: Scope to invalidate (e.g., project_id)
|
||||
|
||||
Returns:
|
||||
Number of entries invalidated
|
||||
"""
|
||||
count = self._hot_cache.invalidate_by_scope(scope)
|
||||
|
||||
# Retrieval cache doesn't support scope-based invalidation
|
||||
# so we clear it entirely for safety
|
||||
if self._retrieval_cache:
|
||||
count += self._retrieval_cache.clear()
|
||||
|
||||
logger.info(f"Invalidated {count} cache entries for scope {scope}")
|
||||
return count
|
||||
|
||||
async def invalidate_embedding(
|
||||
self,
|
||||
content: str,
|
||||
model: str = "default",
|
||||
) -> bool:
|
||||
"""
|
||||
Invalidate a cached embedding.
|
||||
|
||||
Args:
|
||||
content: Content text
|
||||
model: Model name
|
||||
|
||||
Returns:
|
||||
True if entry was found and removed
|
||||
"""
|
||||
return await self._embedding_cache.invalidate(content, model)
|
||||
|
||||
async def clear_all(self) -> int:
|
||||
"""
|
||||
Clear all caches.
|
||||
|
||||
Returns:
|
||||
Total number of entries cleared
|
||||
"""
|
||||
count = 0
|
||||
|
||||
count += self._hot_cache.clear()
|
||||
count += await self._embedding_cache.clear()
|
||||
|
||||
if self._retrieval_cache:
|
||||
count += self._retrieval_cache.clear()
|
||||
|
||||
logger.info(f"Cleared {count} entries from all caches")
|
||||
return count
|
||||
|
||||
# =========================================================================
|
||||
# Cleanup
|
||||
# =========================================================================
|
||||
|
||||
async def cleanup_expired(self) -> int:
|
||||
"""
|
||||
Clean up expired entries from all caches.
|
||||
|
||||
Returns:
|
||||
Number of entries cleaned up
|
||||
"""
|
||||
with self._lock:
|
||||
count = 0
|
||||
|
||||
count += self._hot_cache.cleanup_expired()
|
||||
count += self._embedding_cache.cleanup_expired()
|
||||
|
||||
# Retrieval cache doesn't have a cleanup method,
|
||||
# but entries expire on access
|
||||
|
||||
self._last_cleanup = _utcnow()
|
||||
self._cleanup_count += 1
|
||||
|
||||
if count > 0:
|
||||
logger.info(f"Cleaned up {count} expired cache entries")
|
||||
|
||||
return count
|
||||
|
||||
# =========================================================================
|
||||
# Statistics
|
||||
# =========================================================================
|
||||
|
||||
def get_stats(self) -> CacheStats:
|
||||
"""
|
||||
Get aggregated cache statistics.
|
||||
|
||||
Returns:
|
||||
CacheStats with all cache metrics
|
||||
"""
|
||||
hot_stats = self._hot_cache.get_stats().to_dict()
|
||||
emb_stats = self._embedding_cache.get_stats().to_dict()
|
||||
|
||||
retrieval_stats: dict[str, Any] = {}
|
||||
if self._retrieval_cache:
|
||||
retrieval_stats = self._retrieval_cache.get_stats()
|
||||
|
||||
# Calculate overall hit rate
|
||||
total_hits = hot_stats.get("hits", 0) + emb_stats.get("hits", 0)
|
||||
total_misses = hot_stats.get("misses", 0) + emb_stats.get("misses", 0)
|
||||
|
||||
if retrieval_stats:
|
||||
# Retrieval cache doesn't track hits/misses the same way
|
||||
pass
|
||||
|
||||
total_requests = total_hits + total_misses
|
||||
overall_hit_rate = total_hits / total_requests if total_requests > 0 else 0.0
|
||||
|
||||
return CacheStats(
|
||||
hot_cache=hot_stats,
|
||||
embedding_cache=emb_stats,
|
||||
retrieval_cache=retrieval_stats,
|
||||
overall_hit_rate=overall_hit_rate,
|
||||
last_cleanup=self._last_cleanup,
|
||||
cleanup_count=self._cleanup_count,
|
||||
)
|
||||
|
||||
def get_hot_memories(self, limit: int = 10) -> list[tuple[CacheKey, int]]:
|
||||
"""
|
||||
Get the most frequently accessed memories.
|
||||
|
||||
Args:
|
||||
limit: Maximum number to return
|
||||
|
||||
Returns:
|
||||
List of (key, access_count) tuples
|
||||
"""
|
||||
return self._hot_cache.get_hot_memories(limit)
|
||||
|
||||
def reset_stats(self) -> None:
|
||||
"""Reset all cache statistics."""
|
||||
self._hot_cache.reset_stats()
|
||||
self._embedding_cache.reset_stats()
|
||||
|
||||
# =========================================================================
|
||||
# Warmup
|
||||
# =========================================================================
|
||||
|
||||
async def warmup(
|
||||
self,
|
||||
memories: list[tuple[str, UUID | str, Any]],
|
||||
scope: str | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Warm up the hot cache with memories.
|
||||
|
||||
Args:
|
||||
memories: List of (memory_type, memory_id, memory) tuples
|
||||
scope: Optional scope for all memories
|
||||
|
||||
Returns:
|
||||
Number of memories cached
|
||||
"""
|
||||
if not self._enabled:
|
||||
return 0
|
||||
|
||||
for memory_type, memory_id, memory in memories:
|
||||
self._hot_cache.put_by_id(memory_type, memory_id, memory, scope)
|
||||
|
||||
logger.info(f"Warmed up cache with {len(memories)} memories")
|
||||
return len(memories)
|
||||
|
||||
|
||||
# Singleton instance
|
||||
_cache_manager: CacheManager | None = None
|
||||
_cache_manager_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_cache_manager(
|
||||
redis: "Redis | None" = None,
|
||||
reset: bool = False,
|
||||
) -> CacheManager:
|
||||
"""
|
||||
Get the global CacheManager instance.
|
||||
|
||||
Thread-safe with double-checked locking pattern.
|
||||
|
||||
Args:
|
||||
redis: Optional Redis connection
|
||||
reset: Force create a new instance
|
||||
|
||||
Returns:
|
||||
CacheManager instance
|
||||
"""
|
||||
global _cache_manager
|
||||
|
||||
if reset or _cache_manager is None:
|
||||
with _cache_manager_lock:
|
||||
if reset or _cache_manager is None:
|
||||
_cache_manager = CacheManager(redis=redis)
|
||||
|
||||
return _cache_manager
|
||||
|
||||
|
||||
def reset_cache_manager() -> None:
|
||||
"""Reset the global cache manager instance."""
|
||||
global _cache_manager
|
||||
with _cache_manager_lock:
|
||||
_cache_manager = None
|
||||
623
backend/app/services/memory/cache/embedding_cache.py
vendored
Normal file
623
backend/app/services/memory/cache/embedding_cache.py
vendored
Normal file
@@ -0,0 +1,623 @@
|
||||
# app/services/memory/cache/embedding_cache.py
|
||||
"""
|
||||
Embedding Cache.
|
||||
|
||||
Caches embeddings by content hash to avoid recomputing.
|
||||
Provides significant performance improvement for repeated content.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import threading
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis.asyncio import Redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _utcnow() -> datetime:
|
||||
"""Get current UTC time as timezone-aware datetime."""
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingEntry:
|
||||
"""A cached embedding entry."""
|
||||
|
||||
embedding: list[float]
|
||||
content_hash: str
|
||||
model: str
|
||||
created_at: datetime
|
||||
ttl_seconds: float = 3600.0 # 1 hour default
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if this entry has expired."""
|
||||
age = (_utcnow() - self.created_at).total_seconds()
|
||||
return age > self.ttl_seconds
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingCacheStats:
|
||||
"""Statistics for the embedding cache."""
|
||||
|
||||
hits: int = 0
|
||||
misses: int = 0
|
||||
evictions: int = 0
|
||||
expirations: int = 0
|
||||
current_size: int = 0
|
||||
max_size: int = 0
|
||||
bytes_saved: int = 0 # Estimated bytes saved by caching
|
||||
|
||||
@property
|
||||
def hit_rate(self) -> float:
|
||||
"""Calculate cache hit rate."""
|
||||
total = self.hits + self.misses
|
||||
if total == 0:
|
||||
return 0.0
|
||||
return self.hits / total
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"hits": self.hits,
|
||||
"misses": self.misses,
|
||||
"evictions": self.evictions,
|
||||
"expirations": self.expirations,
|
||||
"current_size": self.current_size,
|
||||
"max_size": self.max_size,
|
||||
"hit_rate": self.hit_rate,
|
||||
"bytes_saved": self.bytes_saved,
|
||||
}
|
||||
|
||||
|
||||
class EmbeddingCache:
|
||||
"""
|
||||
Cache for embeddings by content hash.
|
||||
|
||||
Features:
|
||||
- Content-hash based deduplication
|
||||
- LRU eviction
|
||||
- TTL-based expiration
|
||||
- Optional Redis backing for persistence
|
||||
- Thread-safe operations
|
||||
|
||||
Performance targets:
|
||||
- Cache hit rate > 90% for repeated content
|
||||
- Get/put operations < 1ms (memory), < 5ms (Redis)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_size: int = 50000,
|
||||
default_ttl_seconds: float = 3600.0,
|
||||
redis: "Redis | None" = None,
|
||||
redis_prefix: str = "mem:emb",
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the embedding cache.
|
||||
|
||||
Args:
|
||||
max_size: Maximum number of entries in memory cache
|
||||
default_ttl_seconds: Default TTL for entries (1 hour)
|
||||
redis: Optional Redis connection for persistence
|
||||
redis_prefix: Prefix for Redis keys
|
||||
"""
|
||||
self._max_size = max_size
|
||||
self._default_ttl = default_ttl_seconds
|
||||
self._cache: OrderedDict[str, EmbeddingEntry] = OrderedDict()
|
||||
self._lock = threading.RLock()
|
||||
self._stats = EmbeddingCacheStats(max_size=max_size)
|
||||
self._redis = redis
|
||||
self._redis_prefix = redis_prefix
|
||||
|
||||
logger.info(
|
||||
f"Initialized EmbeddingCache with max_size={max_size}, "
|
||||
f"ttl={default_ttl_seconds}s, redis={'enabled' if redis else 'disabled'}"
|
||||
)
|
||||
|
||||
def set_redis(self, redis: "Redis") -> None:
|
||||
"""Set Redis connection for persistence."""
|
||||
self._redis = redis
|
||||
|
||||
@staticmethod
|
||||
def hash_content(content: str) -> str:
|
||||
"""
|
||||
Compute hash of content for cache key.
|
||||
|
||||
Args:
|
||||
content: Content to hash
|
||||
|
||||
Returns:
|
||||
32-character hex hash
|
||||
"""
|
||||
return hashlib.sha256(content.encode()).hexdigest()[:32]
|
||||
|
||||
def _cache_key(self, content_hash: str, model: str) -> str:
|
||||
"""Build cache key from content hash and model."""
|
||||
return f"{content_hash}:{model}"
|
||||
|
||||
def _redis_key(self, content_hash: str, model: str) -> str:
|
||||
"""Build Redis key from content hash and model."""
|
||||
return f"{self._redis_prefix}:{content_hash}:{model}"
|
||||
|
||||
async def get(
|
||||
self,
|
||||
content: str,
|
||||
model: str = "default",
|
||||
) -> list[float] | None:
|
||||
"""
|
||||
Get a cached embedding.
|
||||
|
||||
Args:
|
||||
content: Content text
|
||||
model: Model name
|
||||
|
||||
Returns:
|
||||
Cached embedding or None if not found/expired
|
||||
"""
|
||||
content_hash = self.hash_content(content)
|
||||
cache_key = self._cache_key(content_hash, model)
|
||||
|
||||
# Check memory cache first
|
||||
with self._lock:
|
||||
if cache_key in self._cache:
|
||||
entry = self._cache[cache_key]
|
||||
if entry.is_expired():
|
||||
del self._cache[cache_key]
|
||||
self._stats.expirations += 1
|
||||
self._stats.current_size = len(self._cache)
|
||||
else:
|
||||
# Move to end (most recently used)
|
||||
self._cache.move_to_end(cache_key)
|
||||
self._stats.hits += 1
|
||||
return entry.embedding
|
||||
|
||||
# Check Redis if available
|
||||
if self._redis:
|
||||
try:
|
||||
redis_key = self._redis_key(content_hash, model)
|
||||
data = await self._redis.get(redis_key)
|
||||
if data:
|
||||
import json
|
||||
|
||||
embedding = json.loads(data)
|
||||
# Store in memory cache for faster access
|
||||
self._put_memory(content_hash, model, embedding)
|
||||
self._stats.hits += 1
|
||||
return embedding
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis get error: {e}")
|
||||
|
||||
self._stats.misses += 1
|
||||
return None
|
||||
|
||||
async def get_by_hash(
|
||||
self,
|
||||
content_hash: str,
|
||||
model: str = "default",
|
||||
) -> list[float] | None:
|
||||
"""
|
||||
Get a cached embedding by hash.
|
||||
|
||||
Args:
|
||||
content_hash: Content hash
|
||||
model: Model name
|
||||
|
||||
Returns:
|
||||
Cached embedding or None if not found/expired
|
||||
"""
|
||||
cache_key = self._cache_key(content_hash, model)
|
||||
|
||||
with self._lock:
|
||||
if cache_key in self._cache:
|
||||
entry = self._cache[cache_key]
|
||||
if entry.is_expired():
|
||||
del self._cache[cache_key]
|
||||
self._stats.expirations += 1
|
||||
self._stats.current_size = len(self._cache)
|
||||
else:
|
||||
self._cache.move_to_end(cache_key)
|
||||
self._stats.hits += 1
|
||||
return entry.embedding
|
||||
|
||||
# Check Redis
|
||||
if self._redis:
|
||||
try:
|
||||
redis_key = self._redis_key(content_hash, model)
|
||||
data = await self._redis.get(redis_key)
|
||||
if data:
|
||||
import json
|
||||
|
||||
embedding = json.loads(data)
|
||||
self._put_memory(content_hash, model, embedding)
|
||||
self._stats.hits += 1
|
||||
return embedding
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis get error: {e}")
|
||||
|
||||
self._stats.misses += 1
|
||||
return None
|
||||
|
||||
async def put(
|
||||
self,
|
||||
content: str,
|
||||
embedding: list[float],
|
||||
model: str = "default",
|
||||
ttl_seconds: float | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Cache an embedding.
|
||||
|
||||
Args:
|
||||
content: Content text
|
||||
embedding: Embedding vector
|
||||
model: Model name
|
||||
ttl_seconds: Optional TTL override
|
||||
|
||||
Returns:
|
||||
Content hash
|
||||
"""
|
||||
content_hash = self.hash_content(content)
|
||||
ttl = ttl_seconds or self._default_ttl
|
||||
|
||||
# Store in memory
|
||||
self._put_memory(content_hash, model, embedding, ttl)
|
||||
|
||||
# Store in Redis if available
|
||||
if self._redis:
|
||||
try:
|
||||
import json
|
||||
|
||||
redis_key = self._redis_key(content_hash, model)
|
||||
await self._redis.setex(
|
||||
redis_key,
|
||||
int(ttl),
|
||||
json.dumps(embedding),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis put error: {e}")
|
||||
|
||||
return content_hash
|
||||
|
||||
def _put_memory(
|
||||
self,
|
||||
content_hash: str,
|
||||
model: str,
|
||||
embedding: list[float],
|
||||
ttl_seconds: float | None = None,
|
||||
) -> None:
|
||||
"""Store in memory cache."""
|
||||
with self._lock:
|
||||
# Evict if at capacity
|
||||
self._evict_if_needed()
|
||||
|
||||
cache_key = self._cache_key(content_hash, model)
|
||||
entry = EmbeddingEntry(
|
||||
embedding=embedding,
|
||||
content_hash=content_hash,
|
||||
model=model,
|
||||
created_at=_utcnow(),
|
||||
ttl_seconds=ttl_seconds or self._default_ttl,
|
||||
)
|
||||
|
||||
self._cache[cache_key] = entry
|
||||
self._cache.move_to_end(cache_key)
|
||||
self._stats.current_size = len(self._cache)
|
||||
|
||||
def _evict_if_needed(self) -> None:
|
||||
"""Evict entries if cache is at capacity."""
|
||||
while len(self._cache) >= self._max_size:
|
||||
if self._cache:
|
||||
self._cache.popitem(last=False)
|
||||
self._stats.evictions += 1
|
||||
|
||||
async def put_batch(
|
||||
self,
|
||||
items: list[tuple[str, list[float]]],
|
||||
model: str = "default",
|
||||
ttl_seconds: float | None = None,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Cache multiple embeddings.
|
||||
|
||||
Args:
|
||||
items: List of (content, embedding) tuples
|
||||
model: Model name
|
||||
ttl_seconds: Optional TTL override
|
||||
|
||||
Returns:
|
||||
List of content hashes
|
||||
"""
|
||||
hashes = []
|
||||
for content, embedding in items:
|
||||
content_hash = await self.put(content, embedding, model, ttl_seconds)
|
||||
hashes.append(content_hash)
|
||||
return hashes
|
||||
|
||||
async def invalidate(
|
||||
self,
|
||||
content: str,
|
||||
model: str = "default",
|
||||
) -> bool:
|
||||
"""
|
||||
Invalidate a cached embedding.
|
||||
|
||||
Args:
|
||||
content: Content text
|
||||
model: Model name
|
||||
|
||||
Returns:
|
||||
True if entry was found and removed
|
||||
"""
|
||||
content_hash = self.hash_content(content)
|
||||
return await self.invalidate_by_hash(content_hash, model)
|
||||
|
||||
async def invalidate_by_hash(
|
||||
self,
|
||||
content_hash: str,
|
||||
model: str = "default",
|
||||
) -> bool:
|
||||
"""
|
||||
Invalidate a cached embedding by hash.
|
||||
|
||||
Args:
|
||||
content_hash: Content hash
|
||||
model: Model name
|
||||
|
||||
Returns:
|
||||
True if entry was found and removed
|
||||
"""
|
||||
cache_key = self._cache_key(content_hash, model)
|
||||
removed = False
|
||||
|
||||
with self._lock:
|
||||
if cache_key in self._cache:
|
||||
del self._cache[cache_key]
|
||||
self._stats.current_size = len(self._cache)
|
||||
removed = True
|
||||
|
||||
# Remove from Redis
|
||||
if self._redis:
|
||||
try:
|
||||
redis_key = self._redis_key(content_hash, model)
|
||||
await self._redis.delete(redis_key)
|
||||
removed = True
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis delete error: {e}")
|
||||
|
||||
return removed
|
||||
|
||||
async def invalidate_by_model(self, model: str) -> int:
|
||||
"""
|
||||
Invalidate all embeddings for a model.
|
||||
|
||||
Args:
|
||||
model: Model name
|
||||
|
||||
Returns:
|
||||
Number of entries invalidated
|
||||
"""
|
||||
count = 0
|
||||
|
||||
with self._lock:
|
||||
keys_to_remove = [k for k, v in self._cache.items() if v.model == model]
|
||||
for key in keys_to_remove:
|
||||
del self._cache[key]
|
||||
count += 1
|
||||
|
||||
self._stats.current_size = len(self._cache)
|
||||
|
||||
# Note: Redis pattern deletion would require SCAN which is expensive
|
||||
# For now, we only clear memory cache for model-based invalidation
|
||||
|
||||
return count
|
||||
|
||||
async def clear(self) -> int:
|
||||
"""
|
||||
Clear all cache entries.
|
||||
|
||||
Returns:
|
||||
Number of entries cleared
|
||||
"""
|
||||
with self._lock:
|
||||
count = len(self._cache)
|
||||
self._cache.clear()
|
||||
self._stats.current_size = 0
|
||||
|
||||
# Clear Redis entries
|
||||
if self._redis:
|
||||
try:
|
||||
pattern = f"{self._redis_prefix}:*"
|
||||
deleted = 0
|
||||
async for key in self._redis.scan_iter(match=pattern):
|
||||
await self._redis.delete(key)
|
||||
deleted += 1
|
||||
count = max(count, deleted)
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis clear error: {e}")
|
||||
|
||||
logger.info(f"Cleared {count} entries from embedding cache")
|
||||
return count
|
||||
|
||||
def cleanup_expired(self) -> int:
|
||||
"""
|
||||
Remove all expired entries from memory cache.
|
||||
|
||||
Returns:
|
||||
Number of entries removed
|
||||
"""
|
||||
with self._lock:
|
||||
keys_to_remove = [k for k, v in self._cache.items() if v.is_expired()]
|
||||
for key in keys_to_remove:
|
||||
del self._cache[key]
|
||||
self._stats.expirations += 1
|
||||
|
||||
self._stats.current_size = len(self._cache)
|
||||
|
||||
if keys_to_remove:
|
||||
logger.debug(f"Cleaned up {len(keys_to_remove)} expired embeddings")
|
||||
|
||||
return len(keys_to_remove)
|
||||
|
||||
def get_stats(self) -> EmbeddingCacheStats:
|
||||
"""Get cache statistics."""
|
||||
with self._lock:
|
||||
self._stats.current_size = len(self._cache)
|
||||
return self._stats
|
||||
|
||||
def reset_stats(self) -> None:
|
||||
"""Reset cache statistics."""
|
||||
with self._lock:
|
||||
self._stats = EmbeddingCacheStats(
|
||||
max_size=self._max_size,
|
||||
current_size=len(self._cache),
|
||||
)
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
"""Get current cache size."""
|
||||
return len(self._cache)
|
||||
|
||||
@property
|
||||
def max_size(self) -> int:
|
||||
"""Get maximum cache size."""
|
||||
return self._max_size
|
||||
|
||||
|
||||
class CachedEmbeddingGenerator:
|
||||
"""
|
||||
Wrapper for embedding generators with caching.
|
||||
|
||||
Wraps an embedding generator to cache results.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
generator: Any,
|
||||
cache: EmbeddingCache,
|
||||
model: str = "default",
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the cached embedding generator.
|
||||
|
||||
Args:
|
||||
generator: Underlying embedding generator
|
||||
cache: Embedding cache
|
||||
model: Model name for cache keys
|
||||
"""
|
||||
self._generator = generator
|
||||
self._cache = cache
|
||||
self._model = model
|
||||
self._call_count = 0
|
||||
self._cache_hit_count = 0
|
||||
|
||||
async def generate(self, text: str) -> list[float]:
|
||||
"""
|
||||
Generate embedding with caching.
|
||||
|
||||
Args:
|
||||
text: Text to embed
|
||||
|
||||
Returns:
|
||||
Embedding vector
|
||||
"""
|
||||
self._call_count += 1
|
||||
|
||||
# Check cache first
|
||||
cached = await self._cache.get(text, self._model)
|
||||
if cached is not None:
|
||||
self._cache_hit_count += 1
|
||||
return cached
|
||||
|
||||
# Generate and cache
|
||||
embedding = await self._generator.generate(text)
|
||||
await self._cache.put(text, embedding, self._model)
|
||||
|
||||
return embedding
|
||||
|
||||
async def generate_batch(
|
||||
self,
|
||||
texts: list[str],
|
||||
) -> list[list[float]]:
|
||||
"""
|
||||
Generate embeddings for multiple texts with caching.
|
||||
|
||||
Args:
|
||||
texts: Texts to embed
|
||||
|
||||
Returns:
|
||||
List of embedding vectors
|
||||
"""
|
||||
results: list[list[float] | None] = [None] * len(texts)
|
||||
to_generate: list[tuple[int, str]] = []
|
||||
|
||||
# Check cache for each text
|
||||
for i, text in enumerate(texts):
|
||||
cached = await self._cache.get(text, self._model)
|
||||
if cached is not None:
|
||||
results[i] = cached
|
||||
self._cache_hit_count += 1
|
||||
else:
|
||||
to_generate.append((i, text))
|
||||
|
||||
self._call_count += len(texts)
|
||||
|
||||
# Generate missing embeddings
|
||||
if to_generate:
|
||||
if hasattr(self._generator, "generate_batch"):
|
||||
texts_to_gen = [t for _, t in to_generate]
|
||||
embeddings = await self._generator.generate_batch(texts_to_gen)
|
||||
|
||||
for (idx, text), embedding in zip(to_generate, embeddings, strict=True):
|
||||
results[idx] = embedding
|
||||
await self._cache.put(text, embedding, self._model)
|
||||
else:
|
||||
# Fallback to individual generation
|
||||
for idx, text in to_generate:
|
||||
embedding = await self._generator.generate(text)
|
||||
results[idx] = embedding
|
||||
await self._cache.put(text, embedding, self._model)
|
||||
|
||||
return results # type: ignore[return-value]
|
||||
|
||||
def get_stats(self) -> dict[str, Any]:
|
||||
"""Get generator statistics."""
|
||||
return {
|
||||
"call_count": self._call_count,
|
||||
"cache_hit_count": self._cache_hit_count,
|
||||
"cache_hit_rate": (
|
||||
self._cache_hit_count / self._call_count
|
||||
if self._call_count > 0
|
||||
else 0.0
|
||||
),
|
||||
"cache_stats": self._cache.get_stats().to_dict(),
|
||||
}
|
||||
|
||||
|
||||
# Factory function
|
||||
def create_embedding_cache(
|
||||
max_size: int = 50000,
|
||||
default_ttl_seconds: float = 3600.0,
|
||||
redis: "Redis | None" = None,
|
||||
) -> EmbeddingCache:
|
||||
"""
|
||||
Create an embedding cache.
|
||||
|
||||
Args:
|
||||
max_size: Maximum number of entries
|
||||
default_ttl_seconds: Default TTL for entries
|
||||
redis: Optional Redis connection
|
||||
|
||||
Returns:
|
||||
Configured EmbeddingCache instance
|
||||
"""
|
||||
return EmbeddingCache(
|
||||
max_size=max_size,
|
||||
default_ttl_seconds=default_ttl_seconds,
|
||||
redis=redis,
|
||||
)
|
||||
461
backend/app/services/memory/cache/hot_cache.py
vendored
Normal file
461
backend/app/services/memory/cache/hot_cache.py
vendored
Normal file
@@ -0,0 +1,461 @@
|
||||
# app/services/memory/cache/hot_cache.py
|
||||
"""
|
||||
Hot Memory Cache.
|
||||
|
||||
LRU cache for frequently accessed memories.
|
||||
Provides fast access to recently used memories without database queries.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _utcnow() -> datetime:
|
||||
"""Get current UTC time as timezone-aware datetime."""
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheEntry[T]:
|
||||
"""A cached memory entry with metadata."""
|
||||
|
||||
value: T
|
||||
created_at: datetime
|
||||
last_accessed_at: datetime
|
||||
access_count: int = 1
|
||||
ttl_seconds: float = 300.0
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if this entry has expired."""
|
||||
age = (_utcnow() - self.created_at).total_seconds()
|
||||
return age > self.ttl_seconds
|
||||
|
||||
def touch(self) -> None:
|
||||
"""Update access time and count."""
|
||||
self.last_accessed_at = _utcnow()
|
||||
self.access_count += 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheKey:
|
||||
"""A structured cache key with components."""
|
||||
|
||||
memory_type: str
|
||||
memory_id: str
|
||||
scope: str | None = None
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.memory_type, self.memory_id, self.scope))
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, CacheKey):
|
||||
return False
|
||||
return (
|
||||
self.memory_type == other.memory_type
|
||||
and self.memory_id == other.memory_id
|
||||
and self.scope == other.scope
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
if self.scope:
|
||||
return f"{self.memory_type}:{self.scope}:{self.memory_id}"
|
||||
return f"{self.memory_type}:{self.memory_id}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class HotCacheStats:
|
||||
"""Statistics for the hot memory cache."""
|
||||
|
||||
hits: int = 0
|
||||
misses: int = 0
|
||||
evictions: int = 0
|
||||
expirations: int = 0
|
||||
current_size: int = 0
|
||||
max_size: int = 0
|
||||
|
||||
@property
|
||||
def hit_rate(self) -> float:
|
||||
"""Calculate cache hit rate."""
|
||||
total = self.hits + self.misses
|
||||
if total == 0:
|
||||
return 0.0
|
||||
return self.hits / total
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"hits": self.hits,
|
||||
"misses": self.misses,
|
||||
"evictions": self.evictions,
|
||||
"expirations": self.expirations,
|
||||
"current_size": self.current_size,
|
||||
"max_size": self.max_size,
|
||||
"hit_rate": self.hit_rate,
|
||||
}
|
||||
|
||||
|
||||
class HotMemoryCache[T]:
|
||||
"""
|
||||
LRU cache for frequently accessed memories.
|
||||
|
||||
Features:
|
||||
- LRU eviction when capacity is reached
|
||||
- TTL-based expiration
|
||||
- Access count tracking for hot memory identification
|
||||
- Thread-safe operations
|
||||
- Scoped invalidation
|
||||
|
||||
Performance targets:
|
||||
- Cache hit rate > 80% for hot memories
|
||||
- Get/put operations < 1ms
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_size: int = 10000,
|
||||
default_ttl_seconds: float = 300.0,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the hot memory cache.
|
||||
|
||||
Args:
|
||||
max_size: Maximum number of entries
|
||||
default_ttl_seconds: Default TTL for entries (5 minutes)
|
||||
"""
|
||||
self._max_size = max_size
|
||||
self._default_ttl = default_ttl_seconds
|
||||
self._cache: OrderedDict[CacheKey, CacheEntry[T]] = OrderedDict()
|
||||
self._lock = threading.RLock()
|
||||
self._stats = HotCacheStats(max_size=max_size)
|
||||
logger.info(
|
||||
f"Initialized HotMemoryCache with max_size={max_size}, "
|
||||
f"ttl={default_ttl_seconds}s"
|
||||
)
|
||||
|
||||
def get(self, key: CacheKey) -> T | None:
|
||||
"""
|
||||
Get a memory from cache.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
|
||||
Returns:
|
||||
Cached value or None if not found/expired
|
||||
"""
|
||||
with self._lock:
|
||||
if key not in self._cache:
|
||||
self._stats.misses += 1
|
||||
return None
|
||||
|
||||
entry = self._cache[key]
|
||||
|
||||
# Check expiration
|
||||
if entry.is_expired():
|
||||
del self._cache[key]
|
||||
self._stats.expirations += 1
|
||||
self._stats.misses += 1
|
||||
self._stats.current_size = len(self._cache)
|
||||
return None
|
||||
|
||||
# Move to end (most recently used)
|
||||
self._cache.move_to_end(key)
|
||||
entry.touch()
|
||||
|
||||
self._stats.hits += 1
|
||||
return entry.value
|
||||
|
||||
def get_by_id(
|
||||
self,
|
||||
memory_type: str,
|
||||
memory_id: UUID | str,
|
||||
scope: str | None = None,
|
||||
) -> T | None:
|
||||
"""
|
||||
Get a memory by type and ID.
|
||||
|
||||
Args:
|
||||
memory_type: Type of memory (episodic, semantic, procedural)
|
||||
memory_id: Memory ID
|
||||
scope: Optional scope (project_id, agent_id)
|
||||
|
||||
Returns:
|
||||
Cached value or None if not found/expired
|
||||
"""
|
||||
key = CacheKey(
|
||||
memory_type=memory_type,
|
||||
memory_id=str(memory_id),
|
||||
scope=scope,
|
||||
)
|
||||
return self.get(key)
|
||||
|
||||
def put(
|
||||
self,
|
||||
key: CacheKey,
|
||||
value: T,
|
||||
ttl_seconds: float | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Put a memory into cache.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
value: Value to cache
|
||||
ttl_seconds: Optional TTL override
|
||||
"""
|
||||
with self._lock:
|
||||
# Evict if at capacity
|
||||
self._evict_if_needed()
|
||||
|
||||
now = _utcnow()
|
||||
entry = CacheEntry(
|
||||
value=value,
|
||||
created_at=now,
|
||||
last_accessed_at=now,
|
||||
access_count=1,
|
||||
ttl_seconds=ttl_seconds or self._default_ttl,
|
||||
)
|
||||
|
||||
self._cache[key] = entry
|
||||
self._cache.move_to_end(key)
|
||||
self._stats.current_size = len(self._cache)
|
||||
|
||||
def put_by_id(
|
||||
self,
|
||||
memory_type: str,
|
||||
memory_id: UUID | str,
|
||||
value: T,
|
||||
scope: str | None = None,
|
||||
ttl_seconds: float | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Put a memory by type and ID.
|
||||
|
||||
Args:
|
||||
memory_type: Type of memory
|
||||
memory_id: Memory ID
|
||||
value: Value to cache
|
||||
scope: Optional scope
|
||||
ttl_seconds: Optional TTL override
|
||||
"""
|
||||
key = CacheKey(
|
||||
memory_type=memory_type,
|
||||
memory_id=str(memory_id),
|
||||
scope=scope,
|
||||
)
|
||||
self.put(key, value, ttl_seconds)
|
||||
|
||||
def _evict_if_needed(self) -> None:
|
||||
"""Evict entries if cache is at capacity."""
|
||||
while len(self._cache) >= self._max_size:
|
||||
# Remove least recently used (first item)
|
||||
if self._cache:
|
||||
self._cache.popitem(last=False)
|
||||
self._stats.evictions += 1
|
||||
|
||||
def invalidate(self, key: CacheKey) -> bool:
|
||||
"""
|
||||
Invalidate a specific cache entry.
|
||||
|
||||
Args:
|
||||
key: Cache key to invalidate
|
||||
|
||||
Returns:
|
||||
True if entry was found and removed
|
||||
"""
|
||||
with self._lock:
|
||||
if key in self._cache:
|
||||
del self._cache[key]
|
||||
self._stats.current_size = len(self._cache)
|
||||
return True
|
||||
return False
|
||||
|
||||
def invalidate_by_id(
|
||||
self,
|
||||
memory_type: str,
|
||||
memory_id: UUID | str,
|
||||
scope: str | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Invalidate a memory by type and ID.
|
||||
|
||||
Args:
|
||||
memory_type: Type of memory
|
||||
memory_id: Memory ID
|
||||
scope: Optional scope
|
||||
|
||||
Returns:
|
||||
True if entry was found and removed
|
||||
"""
|
||||
key = CacheKey(
|
||||
memory_type=memory_type,
|
||||
memory_id=str(memory_id),
|
||||
scope=scope,
|
||||
)
|
||||
return self.invalidate(key)
|
||||
|
||||
def invalidate_by_type(self, memory_type: str) -> int:
|
||||
"""
|
||||
Invalidate all entries of a memory type.
|
||||
|
||||
Args:
|
||||
memory_type: Type of memory to invalidate
|
||||
|
||||
Returns:
|
||||
Number of entries invalidated
|
||||
"""
|
||||
with self._lock:
|
||||
keys_to_remove = [
|
||||
k for k in self._cache.keys() if k.memory_type == memory_type
|
||||
]
|
||||
for key in keys_to_remove:
|
||||
del self._cache[key]
|
||||
|
||||
self._stats.current_size = len(self._cache)
|
||||
return len(keys_to_remove)
|
||||
|
||||
def invalidate_by_scope(self, scope: str) -> int:
|
||||
"""
|
||||
Invalidate all entries in a scope.
|
||||
|
||||
Args:
|
||||
scope: Scope to invalidate (e.g., project_id)
|
||||
|
||||
Returns:
|
||||
Number of entries invalidated
|
||||
"""
|
||||
with self._lock:
|
||||
keys_to_remove = [k for k in self._cache.keys() if k.scope == scope]
|
||||
for key in keys_to_remove:
|
||||
del self._cache[key]
|
||||
|
||||
self._stats.current_size = len(self._cache)
|
||||
return len(keys_to_remove)
|
||||
|
||||
def invalidate_pattern(self, pattern: str) -> int:
|
||||
"""
|
||||
Invalidate entries matching a pattern.
|
||||
|
||||
Pattern can include * as wildcard.
|
||||
|
||||
Args:
|
||||
pattern: Pattern to match (e.g., "episodic:*")
|
||||
|
||||
Returns:
|
||||
Number of entries invalidated
|
||||
"""
|
||||
import fnmatch
|
||||
|
||||
with self._lock:
|
||||
keys_to_remove = [
|
||||
k for k in self._cache.keys() if fnmatch.fnmatch(str(k), pattern)
|
||||
]
|
||||
for key in keys_to_remove:
|
||||
del self._cache[key]
|
||||
|
||||
self._stats.current_size = len(self._cache)
|
||||
return len(keys_to_remove)
|
||||
|
||||
def clear(self) -> int:
|
||||
"""
|
||||
Clear all cache entries.
|
||||
|
||||
Returns:
|
||||
Number of entries cleared
|
||||
"""
|
||||
with self._lock:
|
||||
count = len(self._cache)
|
||||
self._cache.clear()
|
||||
self._stats.current_size = 0
|
||||
logger.info(f"Cleared {count} entries from hot cache")
|
||||
return count
|
||||
|
||||
def cleanup_expired(self) -> int:
|
||||
"""
|
||||
Remove all expired entries.
|
||||
|
||||
Returns:
|
||||
Number of entries removed
|
||||
"""
|
||||
with self._lock:
|
||||
keys_to_remove = [k for k, v in self._cache.items() if v.is_expired()]
|
||||
for key in keys_to_remove:
|
||||
del self._cache[key]
|
||||
self._stats.expirations += 1
|
||||
|
||||
self._stats.current_size = len(self._cache)
|
||||
|
||||
if keys_to_remove:
|
||||
logger.debug(f"Cleaned up {len(keys_to_remove)} expired entries")
|
||||
|
||||
return len(keys_to_remove)
|
||||
|
||||
def get_hot_memories(self, limit: int = 10) -> list[tuple[CacheKey, int]]:
|
||||
"""
|
||||
Get the most frequently accessed memories.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of memories to return
|
||||
|
||||
Returns:
|
||||
List of (key, access_count) tuples sorted by access count
|
||||
"""
|
||||
with self._lock:
|
||||
entries = [
|
||||
(k, v.access_count)
|
||||
for k, v in self._cache.items()
|
||||
if not v.is_expired()
|
||||
]
|
||||
entries.sort(key=lambda x: x[1], reverse=True)
|
||||
return entries[:limit]
|
||||
|
||||
def get_stats(self) -> HotCacheStats:
|
||||
"""Get cache statistics."""
|
||||
with self._lock:
|
||||
self._stats.current_size = len(self._cache)
|
||||
return self._stats
|
||||
|
||||
def reset_stats(self) -> None:
|
||||
"""Reset cache statistics."""
|
||||
with self._lock:
|
||||
self._stats = HotCacheStats(
|
||||
max_size=self._max_size,
|
||||
current_size=len(self._cache),
|
||||
)
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
"""Get current cache size."""
|
||||
return len(self._cache)
|
||||
|
||||
@property
|
||||
def max_size(self) -> int:
|
||||
"""Get maximum cache size."""
|
||||
return self._max_size
|
||||
|
||||
|
||||
# Factory function for typed caches
|
||||
def create_hot_cache(
|
||||
max_size: int = 10000,
|
||||
default_ttl_seconds: float = 300.0,
|
||||
) -> HotMemoryCache[Any]:
|
||||
"""
|
||||
Create a hot memory cache.
|
||||
|
||||
Args:
|
||||
max_size: Maximum number of entries
|
||||
default_ttl_seconds: Default TTL for entries
|
||||
|
||||
Returns:
|
||||
Configured HotMemoryCache instance
|
||||
"""
|
||||
return HotMemoryCache(
|
||||
max_size=max_size,
|
||||
default_ttl_seconds=default_ttl_seconds,
|
||||
)
|
||||
410
backend/app/services/memory/config.py
Normal file
410
backend/app/services/memory/config.py
Normal file
@@ -0,0 +1,410 @@
|
||||
"""
|
||||
Memory System Configuration.
|
||||
|
||||
Provides Pydantic settings for the Agent Memory System,
|
||||
including storage backends, capacity limits, and consolidation policies.
|
||||
"""
|
||||
|
||||
import threading
|
||||
from functools import lru_cache
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, field_validator, model_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class MemorySettings(BaseSettings):
|
||||
"""
|
||||
Configuration for the Agent Memory System.
|
||||
|
||||
All settings can be overridden via environment variables
|
||||
with the MEM_ prefix.
|
||||
"""
|
||||
|
||||
# Working Memory Settings
|
||||
working_memory_backend: str = Field(
|
||||
default="redis",
|
||||
description="Backend for working memory: 'redis' or 'memory'",
|
||||
)
|
||||
working_memory_default_ttl_seconds: int = Field(
|
||||
default=3600,
|
||||
ge=60,
|
||||
le=86400,
|
||||
description="Default TTL for working memory items (1 hour default)",
|
||||
)
|
||||
working_memory_max_items_per_session: int = Field(
|
||||
default=1000,
|
||||
ge=100,
|
||||
le=100000,
|
||||
description="Maximum items per session in working memory",
|
||||
)
|
||||
working_memory_max_value_size_bytes: int = Field(
|
||||
default=1048576, # 1MB
|
||||
ge=1024,
|
||||
le=104857600, # 100MB
|
||||
description="Maximum size of a single value in working memory",
|
||||
)
|
||||
working_memory_checkpoint_enabled: bool = Field(
|
||||
default=True,
|
||||
description="Enable checkpointing for working memory recovery",
|
||||
)
|
||||
|
||||
# Redis Settings (for working memory)
|
||||
redis_url: str = Field(
|
||||
default="redis://localhost:6379/0",
|
||||
description="Redis connection URL",
|
||||
)
|
||||
redis_prefix: str = Field(
|
||||
default="mem",
|
||||
description="Redis key prefix for memory items",
|
||||
)
|
||||
redis_connection_timeout_seconds: int = Field(
|
||||
default=5,
|
||||
ge=1,
|
||||
le=60,
|
||||
description="Redis connection timeout",
|
||||
)
|
||||
|
||||
# Episodic Memory Settings
|
||||
episodic_max_episodes_per_project: int = Field(
|
||||
default=10000,
|
||||
ge=100,
|
||||
le=1000000,
|
||||
description="Maximum episodes to retain per project",
|
||||
)
|
||||
episodic_default_importance: float = Field(
|
||||
default=0.5,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Default importance score for new episodes",
|
||||
)
|
||||
episodic_retention_days: int = Field(
|
||||
default=365,
|
||||
ge=7,
|
||||
le=3650,
|
||||
description="Days to retain episodes before archival",
|
||||
)
|
||||
|
||||
# Semantic Memory Settings
|
||||
semantic_max_facts_per_project: int = Field(
|
||||
default=50000,
|
||||
ge=1000,
|
||||
le=10000000,
|
||||
description="Maximum facts to retain per project",
|
||||
)
|
||||
semantic_confidence_decay_days: int = Field(
|
||||
default=90,
|
||||
ge=7,
|
||||
le=365,
|
||||
description="Days until confidence decays by 50%",
|
||||
)
|
||||
semantic_min_confidence: float = Field(
|
||||
default=0.1,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Minimum confidence before fact is pruned",
|
||||
)
|
||||
|
||||
# Procedural Memory Settings
|
||||
procedural_max_procedures_per_project: int = Field(
|
||||
default=1000,
|
||||
ge=10,
|
||||
le=100000,
|
||||
description="Maximum procedures per project",
|
||||
)
|
||||
procedural_min_success_rate: float = Field(
|
||||
default=0.3,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Minimum success rate before procedure is pruned",
|
||||
)
|
||||
procedural_min_uses_before_suggest: int = Field(
|
||||
default=3,
|
||||
ge=1,
|
||||
le=100,
|
||||
description="Minimum uses before procedure is suggested",
|
||||
)
|
||||
|
||||
# Embedding Settings
|
||||
embedding_model: str = Field(
|
||||
default="text-embedding-3-small",
|
||||
description="Model to use for embeddings",
|
||||
)
|
||||
embedding_dimensions: int = Field(
|
||||
default=1536,
|
||||
ge=256,
|
||||
le=4096,
|
||||
description="Embedding vector dimensions",
|
||||
)
|
||||
embedding_batch_size: int = Field(
|
||||
default=100,
|
||||
ge=1,
|
||||
le=1000,
|
||||
description="Batch size for embedding generation",
|
||||
)
|
||||
embedding_cache_enabled: bool = Field(
|
||||
default=True,
|
||||
description="Enable caching of embeddings",
|
||||
)
|
||||
|
||||
# Retrieval Settings
|
||||
retrieval_default_limit: int = Field(
|
||||
default=10,
|
||||
ge=1,
|
||||
le=100,
|
||||
description="Default limit for retrieval queries",
|
||||
)
|
||||
retrieval_max_limit: int = Field(
|
||||
default=100,
|
||||
ge=10,
|
||||
le=1000,
|
||||
description="Maximum limit for retrieval queries",
|
||||
)
|
||||
retrieval_min_similarity: float = Field(
|
||||
default=0.5,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Minimum similarity score for retrieval",
|
||||
)
|
||||
|
||||
# Consolidation Settings
|
||||
consolidation_enabled: bool = Field(
|
||||
default=True,
|
||||
description="Enable automatic memory consolidation",
|
||||
)
|
||||
consolidation_batch_size: int = Field(
|
||||
default=100,
|
||||
ge=10,
|
||||
le=1000,
|
||||
description="Batch size for consolidation jobs",
|
||||
)
|
||||
consolidation_schedule_cron: str = Field(
|
||||
default="0 3 * * *",
|
||||
description="Cron expression for nightly consolidation (3 AM)",
|
||||
)
|
||||
consolidation_working_to_episodic_delay_minutes: int = Field(
|
||||
default=30,
|
||||
ge=5,
|
||||
le=1440,
|
||||
description="Minutes after session end before consolidating to episodic",
|
||||
)
|
||||
|
||||
# Pruning Settings
|
||||
pruning_enabled: bool = Field(
|
||||
default=True,
|
||||
description="Enable automatic memory pruning",
|
||||
)
|
||||
pruning_min_age_days: int = Field(
|
||||
default=7,
|
||||
ge=1,
|
||||
le=365,
|
||||
description="Minimum age before memory can be pruned",
|
||||
)
|
||||
pruning_importance_threshold: float = Field(
|
||||
default=0.2,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Importance threshold below which memory can be pruned",
|
||||
)
|
||||
|
||||
# Caching Settings
|
||||
cache_enabled: bool = Field(
|
||||
default=True,
|
||||
description="Enable caching for memory retrieval",
|
||||
)
|
||||
cache_ttl_seconds: int = Field(
|
||||
default=300,
|
||||
ge=10,
|
||||
le=3600,
|
||||
description="Cache TTL for retrieval results",
|
||||
)
|
||||
cache_max_items: int = Field(
|
||||
default=10000,
|
||||
ge=100,
|
||||
le=1000000,
|
||||
description="Maximum items in memory cache",
|
||||
)
|
||||
|
||||
# Performance Settings
|
||||
max_retrieval_time_ms: int = Field(
|
||||
default=100,
|
||||
ge=10,
|
||||
le=5000,
|
||||
description="Target maximum retrieval time in milliseconds",
|
||||
)
|
||||
parallel_retrieval: bool = Field(
|
||||
default=True,
|
||||
description="Enable parallel retrieval from multiple memory types",
|
||||
)
|
||||
max_parallel_retrievals: int = Field(
|
||||
default=4,
|
||||
ge=1,
|
||||
le=10,
|
||||
description="Maximum concurrent retrieval operations",
|
||||
)
|
||||
|
||||
@field_validator("working_memory_backend")
|
||||
@classmethod
|
||||
def validate_backend(cls, v: str) -> str:
|
||||
"""Validate working memory backend."""
|
||||
valid_backends = {"redis", "memory"}
|
||||
if v not in valid_backends:
|
||||
raise ValueError(f"backend must be one of: {valid_backends}")
|
||||
return v
|
||||
|
||||
@field_validator("embedding_model")
|
||||
@classmethod
|
||||
def validate_embedding_model(cls, v: str) -> str:
|
||||
"""Validate embedding model name."""
|
||||
valid_models = {
|
||||
"text-embedding-3-small",
|
||||
"text-embedding-3-large",
|
||||
"text-embedding-ada-002",
|
||||
}
|
||||
if v not in valid_models:
|
||||
raise ValueError(f"embedding_model must be one of: {valid_models}")
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_limits(self) -> "MemorySettings":
|
||||
"""Validate that limits are consistent."""
|
||||
if self.retrieval_default_limit > self.retrieval_max_limit:
|
||||
raise ValueError(
|
||||
f"retrieval_default_limit ({self.retrieval_default_limit}) "
|
||||
f"cannot exceed retrieval_max_limit ({self.retrieval_max_limit})"
|
||||
)
|
||||
return self
|
||||
|
||||
def get_working_memory_config(self) -> dict[str, Any]:
|
||||
"""Get working memory configuration as a dictionary."""
|
||||
return {
|
||||
"backend": self.working_memory_backend,
|
||||
"default_ttl_seconds": self.working_memory_default_ttl_seconds,
|
||||
"max_items_per_session": self.working_memory_max_items_per_session,
|
||||
"max_value_size_bytes": self.working_memory_max_value_size_bytes,
|
||||
"checkpoint_enabled": self.working_memory_checkpoint_enabled,
|
||||
}
|
||||
|
||||
def get_redis_config(self) -> dict[str, Any]:
|
||||
"""Get Redis configuration as a dictionary."""
|
||||
return {
|
||||
"url": self.redis_url,
|
||||
"prefix": self.redis_prefix,
|
||||
"connection_timeout_seconds": self.redis_connection_timeout_seconds,
|
||||
}
|
||||
|
||||
def get_embedding_config(self) -> dict[str, Any]:
|
||||
"""Get embedding configuration as a dictionary."""
|
||||
return {
|
||||
"model": self.embedding_model,
|
||||
"dimensions": self.embedding_dimensions,
|
||||
"batch_size": self.embedding_batch_size,
|
||||
"cache_enabled": self.embedding_cache_enabled,
|
||||
}
|
||||
|
||||
def get_consolidation_config(self) -> dict[str, Any]:
|
||||
"""Get consolidation configuration as a dictionary."""
|
||||
return {
|
||||
"enabled": self.consolidation_enabled,
|
||||
"batch_size": self.consolidation_batch_size,
|
||||
"schedule_cron": self.consolidation_schedule_cron,
|
||||
"working_to_episodic_delay_minutes": (
|
||||
self.consolidation_working_to_episodic_delay_minutes
|
||||
),
|
||||
}
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert settings to dictionary for logging/debugging."""
|
||||
return {
|
||||
"working_memory": self.get_working_memory_config(),
|
||||
"redis": self.get_redis_config(),
|
||||
"episodic": {
|
||||
"max_episodes_per_project": self.episodic_max_episodes_per_project,
|
||||
"default_importance": self.episodic_default_importance,
|
||||
"retention_days": self.episodic_retention_days,
|
||||
},
|
||||
"semantic": {
|
||||
"max_facts_per_project": self.semantic_max_facts_per_project,
|
||||
"confidence_decay_days": self.semantic_confidence_decay_days,
|
||||
"min_confidence": self.semantic_min_confidence,
|
||||
},
|
||||
"procedural": {
|
||||
"max_procedures_per_project": self.procedural_max_procedures_per_project,
|
||||
"min_success_rate": self.procedural_min_success_rate,
|
||||
"min_uses_before_suggest": self.procedural_min_uses_before_suggest,
|
||||
},
|
||||
"embedding": self.get_embedding_config(),
|
||||
"retrieval": {
|
||||
"default_limit": self.retrieval_default_limit,
|
||||
"max_limit": self.retrieval_max_limit,
|
||||
"min_similarity": self.retrieval_min_similarity,
|
||||
},
|
||||
"consolidation": self.get_consolidation_config(),
|
||||
"pruning": {
|
||||
"enabled": self.pruning_enabled,
|
||||
"min_age_days": self.pruning_min_age_days,
|
||||
"importance_threshold": self.pruning_importance_threshold,
|
||||
},
|
||||
"cache": {
|
||||
"enabled": self.cache_enabled,
|
||||
"ttl_seconds": self.cache_ttl_seconds,
|
||||
"max_items": self.cache_max_items,
|
||||
},
|
||||
"performance": {
|
||||
"max_retrieval_time_ms": self.max_retrieval_time_ms,
|
||||
"parallel_retrieval": self.parallel_retrieval,
|
||||
"max_parallel_retrievals": self.max_parallel_retrievals,
|
||||
},
|
||||
}
|
||||
|
||||
model_config = {
|
||||
"env_prefix": "MEM_",
|
||||
"env_file": ".env",
|
||||
"env_file_encoding": "utf-8",
|
||||
"case_sensitive": False,
|
||||
"extra": "ignore",
|
||||
}
|
||||
|
||||
|
||||
# Thread-safe singleton pattern
|
||||
_settings: MemorySettings | None = None
|
||||
_settings_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_memory_settings() -> MemorySettings:
|
||||
"""
|
||||
Get the global MemorySettings instance.
|
||||
|
||||
Thread-safe with double-checked locking pattern.
|
||||
|
||||
Returns:
|
||||
MemorySettings instance
|
||||
"""
|
||||
global _settings
|
||||
if _settings is None:
|
||||
with _settings_lock:
|
||||
if _settings is None:
|
||||
_settings = MemorySettings()
|
||||
return _settings
|
||||
|
||||
|
||||
def reset_memory_settings() -> None:
|
||||
"""
|
||||
Reset the global settings instance.
|
||||
|
||||
Primarily used for testing.
|
||||
"""
|
||||
global _settings
|
||||
with _settings_lock:
|
||||
_settings = None
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_default_settings() -> MemorySettings:
|
||||
"""
|
||||
Get default settings (cached).
|
||||
|
||||
Use this for read-only access to defaults.
|
||||
For mutable access, use get_memory_settings().
|
||||
"""
|
||||
return MemorySettings()
|
||||
29
backend/app/services/memory/consolidation/__init__.py
Normal file
29
backend/app/services/memory/consolidation/__init__.py
Normal file
@@ -0,0 +1,29 @@
|
||||
# app/services/memory/consolidation/__init__.py
|
||||
"""
|
||||
Memory Consolidation.
|
||||
|
||||
Transfers and extracts knowledge between memory tiers:
|
||||
- Working -> Episodic (session end)
|
||||
- Episodic -> Semantic (learn facts)
|
||||
- Episodic -> Procedural (learn procedures)
|
||||
|
||||
Also handles memory pruning and importance-based retention.
|
||||
"""
|
||||
|
||||
from .service import (
|
||||
ConsolidationConfig,
|
||||
ConsolidationResult,
|
||||
MemoryConsolidationService,
|
||||
NightlyConsolidationResult,
|
||||
SessionConsolidationResult,
|
||||
get_consolidation_service,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ConsolidationConfig",
|
||||
"ConsolidationResult",
|
||||
"MemoryConsolidationService",
|
||||
"NightlyConsolidationResult",
|
||||
"SessionConsolidationResult",
|
||||
"get_consolidation_service",
|
||||
]
|
||||
913
backend/app/services/memory/consolidation/service.py
Normal file
913
backend/app/services/memory/consolidation/service.py
Normal file
@@ -0,0 +1,913 @@
|
||||
# app/services/memory/consolidation/service.py
|
||||
"""
|
||||
Memory Consolidation Service.
|
||||
|
||||
Transfers and extracts knowledge between memory tiers:
|
||||
- Working -> Episodic (session end)
|
||||
- Episodic -> Semantic (learn facts)
|
||||
- Episodic -> Procedural (learn procedures)
|
||||
|
||||
Also handles memory pruning and importance-based retention.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.memory.episodic.memory import EpisodicMemory
|
||||
from app.services.memory.procedural.memory import ProceduralMemory
|
||||
from app.services.memory.semantic.extraction import FactExtractor, get_fact_extractor
|
||||
from app.services.memory.semantic.memory import SemanticMemory
|
||||
from app.services.memory.types import (
|
||||
Episode,
|
||||
EpisodeCreate,
|
||||
Outcome,
|
||||
ProcedureCreate,
|
||||
TaskState,
|
||||
)
|
||||
from app.services.memory.working.memory import WorkingMemory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConsolidationConfig:
|
||||
"""Configuration for memory consolidation."""
|
||||
|
||||
# Working -> Episodic thresholds
|
||||
min_steps_for_episode: int = 2
|
||||
min_duration_seconds: float = 5.0
|
||||
|
||||
# Episodic -> Semantic thresholds
|
||||
min_confidence_for_fact: float = 0.6
|
||||
max_facts_per_episode: int = 10
|
||||
reinforce_existing_facts: bool = True
|
||||
|
||||
# Episodic -> Procedural thresholds
|
||||
min_episodes_for_procedure: int = 3
|
||||
min_success_rate_for_procedure: float = 0.7
|
||||
min_steps_for_procedure: int = 2
|
||||
|
||||
# Pruning thresholds
|
||||
max_episode_age_days: int = 90
|
||||
min_importance_to_keep: float = 0.2
|
||||
keep_all_failures: bool = True
|
||||
keep_all_with_lessons: bool = True
|
||||
|
||||
# Batch sizes
|
||||
batch_size: int = 100
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConsolidationResult:
|
||||
"""Result of a consolidation operation."""
|
||||
|
||||
source_type: str
|
||||
target_type: str
|
||||
items_processed: int = 0
|
||||
items_created: int = 0
|
||||
items_updated: int = 0
|
||||
items_skipped: int = 0
|
||||
items_pruned: int = 0
|
||||
errors: list[str] = field(default_factory=list)
|
||||
duration_seconds: float = 0.0
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"source_type": self.source_type,
|
||||
"target_type": self.target_type,
|
||||
"items_processed": self.items_processed,
|
||||
"items_created": self.items_created,
|
||||
"items_updated": self.items_updated,
|
||||
"items_skipped": self.items_skipped,
|
||||
"items_pruned": self.items_pruned,
|
||||
"errors": self.errors,
|
||||
"duration_seconds": self.duration_seconds,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionConsolidationResult:
|
||||
"""Result of consolidating a session's working memory to episodic."""
|
||||
|
||||
session_id: str
|
||||
episode_created: bool = False
|
||||
episode_id: UUID | None = None
|
||||
scratchpad_entries: int = 0
|
||||
variables_captured: int = 0
|
||||
error: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class NightlyConsolidationResult:
|
||||
"""Result of nightly consolidation run."""
|
||||
|
||||
started_at: datetime
|
||||
completed_at: datetime | None = None
|
||||
episodic_to_semantic: ConsolidationResult | None = None
|
||||
episodic_to_procedural: ConsolidationResult | None = None
|
||||
pruning: ConsolidationResult | None = None
|
||||
total_episodes_processed: int = 0
|
||||
total_facts_created: int = 0
|
||||
total_procedures_created: int = 0
|
||||
total_pruned: int = 0
|
||||
errors: list[str] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"started_at": self.started_at.isoformat(),
|
||||
"completed_at": self.completed_at.isoformat()
|
||||
if self.completed_at
|
||||
else None,
|
||||
"episodic_to_semantic": (
|
||||
self.episodic_to_semantic.to_dict()
|
||||
if self.episodic_to_semantic
|
||||
else None
|
||||
),
|
||||
"episodic_to_procedural": (
|
||||
self.episodic_to_procedural.to_dict()
|
||||
if self.episodic_to_procedural
|
||||
else None
|
||||
),
|
||||
"pruning": self.pruning.to_dict() if self.pruning else None,
|
||||
"total_episodes_processed": self.total_episodes_processed,
|
||||
"total_facts_created": self.total_facts_created,
|
||||
"total_procedures_created": self.total_procedures_created,
|
||||
"total_pruned": self.total_pruned,
|
||||
"errors": self.errors,
|
||||
}
|
||||
|
||||
|
||||
class MemoryConsolidationService:
|
||||
"""
|
||||
Service for consolidating memories between tiers.
|
||||
|
||||
Responsibilities:
|
||||
- Transfer working memory to episodic at session end
|
||||
- Extract facts from episodes to semantic memory
|
||||
- Learn procedures from successful episode patterns
|
||||
- Prune old/low-value memories
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
config: ConsolidationConfig | None = None,
|
||||
embedding_generator: Any | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize consolidation service.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
config: Consolidation configuration
|
||||
embedding_generator: Optional embedding generator
|
||||
"""
|
||||
self._session = session
|
||||
self._config = config or ConsolidationConfig()
|
||||
self._embedding_generator = embedding_generator
|
||||
self._fact_extractor: FactExtractor = get_fact_extractor()
|
||||
|
||||
# Memory services (lazy initialized)
|
||||
self._episodic: EpisodicMemory | None = None
|
||||
self._semantic: SemanticMemory | None = None
|
||||
self._procedural: ProceduralMemory | None = None
|
||||
|
||||
async def _get_episodic(self) -> EpisodicMemory:
|
||||
"""Get or create episodic memory service."""
|
||||
if self._episodic is None:
|
||||
self._episodic = await EpisodicMemory.create(
|
||||
self._session, self._embedding_generator
|
||||
)
|
||||
return self._episodic
|
||||
|
||||
async def _get_semantic(self) -> SemanticMemory:
|
||||
"""Get or create semantic memory service."""
|
||||
if self._semantic is None:
|
||||
self._semantic = await SemanticMemory.create(
|
||||
self._session, self._embedding_generator
|
||||
)
|
||||
return self._semantic
|
||||
|
||||
async def _get_procedural(self) -> ProceduralMemory:
|
||||
"""Get or create procedural memory service."""
|
||||
if self._procedural is None:
|
||||
self._procedural = await ProceduralMemory.create(
|
||||
self._session, self._embedding_generator
|
||||
)
|
||||
return self._procedural
|
||||
|
||||
# =========================================================================
|
||||
# Working -> Episodic Consolidation
|
||||
# =========================================================================
|
||||
|
||||
async def consolidate_session(
|
||||
self,
|
||||
working_memory: WorkingMemory,
|
||||
project_id: UUID,
|
||||
session_id: str,
|
||||
task_type: str = "session_task",
|
||||
agent_instance_id: UUID | None = None,
|
||||
agent_type_id: UUID | None = None,
|
||||
) -> SessionConsolidationResult:
|
||||
"""
|
||||
Consolidate a session's working memory to episodic memory.
|
||||
|
||||
Called at session end to transfer relevant session data
|
||||
into a persistent episode.
|
||||
|
||||
Args:
|
||||
working_memory: The session's working memory
|
||||
project_id: Project ID
|
||||
session_id: Session ID
|
||||
task_type: Type of task performed
|
||||
agent_instance_id: Optional agent instance
|
||||
agent_type_id: Optional agent type
|
||||
|
||||
Returns:
|
||||
SessionConsolidationResult with outcome details
|
||||
"""
|
||||
result = SessionConsolidationResult(session_id=session_id)
|
||||
|
||||
try:
|
||||
# Get task state
|
||||
task_state = await working_memory.get_task_state()
|
||||
|
||||
# Check if there's enough content to consolidate
|
||||
if not self._should_consolidate_session(task_state):
|
||||
logger.debug(
|
||||
f"Skipping consolidation for session {session_id}: insufficient content"
|
||||
)
|
||||
return result
|
||||
|
||||
# Gather scratchpad entries
|
||||
scratchpad = await working_memory.get_scratchpad()
|
||||
result.scratchpad_entries = len(scratchpad)
|
||||
|
||||
# Gather user variables
|
||||
all_data = await working_memory.get_all()
|
||||
result.variables_captured = len(all_data)
|
||||
|
||||
# Determine outcome
|
||||
outcome = self._determine_session_outcome(task_state)
|
||||
|
||||
# Build actions from scratchpad and variables
|
||||
actions = self._build_actions_from_session(scratchpad, all_data, task_state)
|
||||
|
||||
# Build context summary
|
||||
context_summary = self._build_context_summary(task_state, all_data)
|
||||
|
||||
# Extract lessons learned
|
||||
lessons = self._extract_session_lessons(task_state, outcome)
|
||||
|
||||
# Calculate importance
|
||||
importance = self._calculate_session_importance(
|
||||
task_state, outcome, actions
|
||||
)
|
||||
|
||||
# Create episode
|
||||
episode_data = EpisodeCreate(
|
||||
project_id=project_id,
|
||||
session_id=session_id,
|
||||
task_type=task_type,
|
||||
task_description=task_state.description
|
||||
if task_state
|
||||
else "Session task",
|
||||
actions=actions,
|
||||
context_summary=context_summary,
|
||||
outcome=outcome,
|
||||
outcome_details=task_state.status if task_state else "",
|
||||
duration_seconds=self._calculate_duration(task_state),
|
||||
tokens_used=0, # Would need to track this in working memory
|
||||
lessons_learned=lessons,
|
||||
importance_score=importance,
|
||||
agent_instance_id=agent_instance_id,
|
||||
agent_type_id=agent_type_id,
|
||||
)
|
||||
|
||||
episodic = await self._get_episodic()
|
||||
episode = await episodic.record_episode(episode_data)
|
||||
|
||||
result.episode_created = True
|
||||
result.episode_id = episode.id
|
||||
|
||||
logger.info(
|
||||
f"Consolidated session {session_id} to episode {episode.id} "
|
||||
f"({len(actions)} actions, outcome={outcome.value})"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
result.error = str(e)
|
||||
logger.exception(f"Failed to consolidate session {session_id}")
|
||||
|
||||
return result
|
||||
|
||||
def _should_consolidate_session(self, task_state: TaskState | None) -> bool:
|
||||
"""Check if session has enough content to consolidate."""
|
||||
if task_state is None:
|
||||
return False
|
||||
|
||||
# Check minimum steps
|
||||
if task_state.current_step < self._config.min_steps_for_episode:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _determine_session_outcome(self, task_state: TaskState | None) -> Outcome:
|
||||
"""Determine outcome from task state."""
|
||||
if task_state is None:
|
||||
return Outcome.PARTIAL
|
||||
|
||||
status = task_state.status.lower() if task_state.status else ""
|
||||
progress = task_state.progress_percent
|
||||
|
||||
if "success" in status or "complete" in status or progress >= 100:
|
||||
return Outcome.SUCCESS
|
||||
if "fail" in status or "error" in status:
|
||||
return Outcome.FAILURE
|
||||
if progress >= 50:
|
||||
return Outcome.PARTIAL
|
||||
return Outcome.FAILURE
|
||||
|
||||
def _build_actions_from_session(
|
||||
self,
|
||||
scratchpad: list[str],
|
||||
variables: dict[str, Any],
|
||||
task_state: TaskState | None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Build action list from session data."""
|
||||
actions: list[dict[str, Any]] = []
|
||||
|
||||
# Add scratchpad entries as actions
|
||||
for i, entry in enumerate(scratchpad):
|
||||
actions.append(
|
||||
{
|
||||
"step": i + 1,
|
||||
"type": "reasoning",
|
||||
"content": entry[:500], # Truncate long entries
|
||||
}
|
||||
)
|
||||
|
||||
# Add final state
|
||||
if task_state:
|
||||
actions.append(
|
||||
{
|
||||
"step": len(scratchpad) + 1,
|
||||
"type": "final_state",
|
||||
"current_step": task_state.current_step,
|
||||
"total_steps": task_state.total_steps,
|
||||
"progress": task_state.progress_percent,
|
||||
"status": task_state.status,
|
||||
}
|
||||
)
|
||||
|
||||
return actions
|
||||
|
||||
def _build_context_summary(
|
||||
self,
|
||||
task_state: TaskState | None,
|
||||
variables: dict[str, Any],
|
||||
) -> str:
|
||||
"""Build context summary from session data."""
|
||||
parts = []
|
||||
|
||||
if task_state:
|
||||
parts.append(f"Task: {task_state.description}")
|
||||
parts.append(f"Progress: {task_state.progress_percent:.1f}%")
|
||||
parts.append(f"Steps: {task_state.current_step}/{task_state.total_steps}")
|
||||
|
||||
# Include key variables
|
||||
key_vars = {k: v for k, v in variables.items() if len(str(v)) < 100}
|
||||
if key_vars:
|
||||
var_str = ", ".join(f"{k}={v}" for k, v in list(key_vars.items())[:5])
|
||||
parts.append(f"Variables: {var_str}")
|
||||
|
||||
return "; ".join(parts) if parts else "Session completed"
|
||||
|
||||
def _extract_session_lessons(
|
||||
self,
|
||||
task_state: TaskState | None,
|
||||
outcome: Outcome,
|
||||
) -> list[str]:
|
||||
"""Extract lessons from session."""
|
||||
lessons: list[str] = []
|
||||
|
||||
if task_state and task_state.status:
|
||||
if outcome == Outcome.FAILURE:
|
||||
lessons.append(
|
||||
f"Task failed at step {task_state.current_step}: {task_state.status}"
|
||||
)
|
||||
elif outcome == Outcome.SUCCESS:
|
||||
lessons.append(
|
||||
f"Successfully completed in {task_state.current_step} steps"
|
||||
)
|
||||
|
||||
return lessons
|
||||
|
||||
def _calculate_session_importance(
|
||||
self,
|
||||
task_state: TaskState | None,
|
||||
outcome: Outcome,
|
||||
actions: list[dict[str, Any]],
|
||||
) -> float:
|
||||
"""Calculate importance score for session."""
|
||||
score = 0.5 # Base score
|
||||
|
||||
# Failures are important to learn from
|
||||
if outcome == Outcome.FAILURE:
|
||||
score += 0.3
|
||||
|
||||
# Many steps means complex task
|
||||
if task_state and task_state.total_steps >= 5:
|
||||
score += 0.1
|
||||
|
||||
# Many actions means detailed reasoning
|
||||
if len(actions) >= 5:
|
||||
score += 0.1
|
||||
|
||||
return min(1.0, score)
|
||||
|
||||
def _calculate_duration(self, task_state: TaskState | None) -> float:
|
||||
"""Calculate session duration."""
|
||||
if task_state is None:
|
||||
return 0.0
|
||||
|
||||
if task_state.started_at and task_state.updated_at:
|
||||
delta = task_state.updated_at - task_state.started_at
|
||||
return delta.total_seconds()
|
||||
|
||||
return 0.0
|
||||
|
||||
# =========================================================================
|
||||
# Episodic -> Semantic Consolidation
|
||||
# =========================================================================
|
||||
|
||||
async def consolidate_episodes_to_facts(
|
||||
self,
|
||||
project_id: UUID,
|
||||
since: datetime | None = None,
|
||||
limit: int | None = None,
|
||||
) -> ConsolidationResult:
|
||||
"""
|
||||
Extract facts from episodic memories to semantic memory.
|
||||
|
||||
Args:
|
||||
project_id: Project to consolidate
|
||||
since: Only process episodes since this time
|
||||
limit: Maximum episodes to process
|
||||
|
||||
Returns:
|
||||
ConsolidationResult with extraction statistics
|
||||
"""
|
||||
start_time = datetime.now(UTC)
|
||||
result = ConsolidationResult(
|
||||
source_type="episodic",
|
||||
target_type="semantic",
|
||||
)
|
||||
|
||||
try:
|
||||
episodic = await self._get_episodic()
|
||||
semantic = await self._get_semantic()
|
||||
|
||||
# Get episodes to process
|
||||
since_time = since or datetime.now(UTC) - timedelta(days=1)
|
||||
episodes = await episodic.get_recent(
|
||||
project_id,
|
||||
limit=limit or self._config.batch_size,
|
||||
since=since_time,
|
||||
)
|
||||
|
||||
for episode in episodes:
|
||||
result.items_processed += 1
|
||||
|
||||
try:
|
||||
# Extract facts using the extractor
|
||||
extracted_facts = self._fact_extractor.extract_from_episode(episode)
|
||||
|
||||
for extracted_fact in extracted_facts:
|
||||
if (
|
||||
extracted_fact.confidence
|
||||
< self._config.min_confidence_for_fact
|
||||
):
|
||||
result.items_skipped += 1
|
||||
continue
|
||||
|
||||
# Create fact (store_fact handles deduplication/reinforcement)
|
||||
fact_create = extracted_fact.to_fact_create(
|
||||
project_id=project_id,
|
||||
source_episode_ids=[episode.id],
|
||||
)
|
||||
|
||||
# store_fact automatically reinforces if fact already exists
|
||||
fact = await semantic.store_fact(fact_create)
|
||||
|
||||
# Check if this was a new fact or reinforced existing
|
||||
if fact.reinforcement_count == 1:
|
||||
result.items_created += 1
|
||||
else:
|
||||
result.items_updated += 1
|
||||
|
||||
except Exception as e:
|
||||
result.errors.append(f"Episode {episode.id}: {e}")
|
||||
logger.warning(
|
||||
f"Failed to extract facts from episode {episode.id}: {e}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
result.errors.append(f"Consolidation failed: {e}")
|
||||
logger.exception("Failed episodic -> semantic consolidation")
|
||||
|
||||
result.duration_seconds = (datetime.now(UTC) - start_time).total_seconds()
|
||||
|
||||
logger.info(
|
||||
f"Episodic -> Semantic consolidation: "
|
||||
f"{result.items_processed} processed, "
|
||||
f"{result.items_created} created, "
|
||||
f"{result.items_updated} reinforced"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
# =========================================================================
|
||||
# Episodic -> Procedural Consolidation
|
||||
# =========================================================================
|
||||
|
||||
async def consolidate_episodes_to_procedures(
|
||||
self,
|
||||
project_id: UUID,
|
||||
agent_type_id: UUID | None = None,
|
||||
since: datetime | None = None,
|
||||
) -> ConsolidationResult:
|
||||
"""
|
||||
Learn procedures from patterns in episodic memories.
|
||||
|
||||
Identifies recurring successful patterns and creates/updates
|
||||
procedures to capture them.
|
||||
|
||||
Args:
|
||||
project_id: Project to consolidate
|
||||
agent_type_id: Optional filter by agent type
|
||||
since: Only process episodes since this time
|
||||
|
||||
Returns:
|
||||
ConsolidationResult with procedure statistics
|
||||
"""
|
||||
start_time = datetime.now(UTC)
|
||||
result = ConsolidationResult(
|
||||
source_type="episodic",
|
||||
target_type="procedural",
|
||||
)
|
||||
|
||||
try:
|
||||
episodic = await self._get_episodic()
|
||||
procedural = await self._get_procedural()
|
||||
|
||||
# Get successful episodes
|
||||
since_time = since or datetime.now(UTC) - timedelta(days=7)
|
||||
episodes = await episodic.get_by_outcome(
|
||||
project_id,
|
||||
outcome=Outcome.SUCCESS,
|
||||
limit=self._config.batch_size,
|
||||
agent_instance_id=None, # Get all agent instances
|
||||
)
|
||||
|
||||
# Group by task type
|
||||
task_groups: dict[str, list[Episode]] = {}
|
||||
for episode in episodes:
|
||||
if episode.occurred_at >= since_time:
|
||||
if episode.task_type not in task_groups:
|
||||
task_groups[episode.task_type] = []
|
||||
task_groups[episode.task_type].append(episode)
|
||||
|
||||
result.items_processed = len(episodes)
|
||||
|
||||
# Process each task type group
|
||||
for task_type, group in task_groups.items():
|
||||
if len(group) < self._config.min_episodes_for_procedure:
|
||||
result.items_skipped += len(group)
|
||||
continue
|
||||
|
||||
try:
|
||||
procedure_result = await self._learn_procedure_from_episodes(
|
||||
procedural,
|
||||
project_id,
|
||||
agent_type_id,
|
||||
task_type,
|
||||
group,
|
||||
)
|
||||
|
||||
if procedure_result == "created":
|
||||
result.items_created += 1
|
||||
elif procedure_result == "updated":
|
||||
result.items_updated += 1
|
||||
else:
|
||||
result.items_skipped += 1
|
||||
|
||||
except Exception as e:
|
||||
result.errors.append(f"Task type '{task_type}': {e}")
|
||||
logger.warning(f"Failed to learn procedure for '{task_type}': {e}")
|
||||
|
||||
except Exception as e:
|
||||
result.errors.append(f"Consolidation failed: {e}")
|
||||
logger.exception("Failed episodic -> procedural consolidation")
|
||||
|
||||
result.duration_seconds = (datetime.now(UTC) - start_time).total_seconds()
|
||||
|
||||
logger.info(
|
||||
f"Episodic -> Procedural consolidation: "
|
||||
f"{result.items_processed} processed, "
|
||||
f"{result.items_created} created, "
|
||||
f"{result.items_updated} updated"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def _learn_procedure_from_episodes(
|
||||
self,
|
||||
procedural: ProceduralMemory,
|
||||
project_id: UUID,
|
||||
agent_type_id: UUID | None,
|
||||
task_type: str,
|
||||
episodes: list[Episode],
|
||||
) -> str:
|
||||
"""Learn or update a procedure from a set of episodes."""
|
||||
# Calculate success rate for this pattern
|
||||
success_count = sum(1 for e in episodes if e.outcome == Outcome.SUCCESS)
|
||||
total_count = len(episodes)
|
||||
success_rate = success_count / total_count if total_count > 0 else 0
|
||||
|
||||
if success_rate < self._config.min_success_rate_for_procedure:
|
||||
return "skipped"
|
||||
|
||||
# Extract common steps from episodes
|
||||
steps = self._extract_common_steps(episodes)
|
||||
|
||||
if len(steps) < self._config.min_steps_for_procedure:
|
||||
return "skipped"
|
||||
|
||||
# Check for existing procedure
|
||||
matching = await procedural.find_matching(
|
||||
context=task_type,
|
||||
project_id=project_id,
|
||||
agent_type_id=agent_type_id,
|
||||
limit=1,
|
||||
)
|
||||
|
||||
if matching:
|
||||
# Update existing procedure with new success
|
||||
await procedural.record_outcome(
|
||||
matching[0].id,
|
||||
success=True,
|
||||
)
|
||||
return "updated"
|
||||
else:
|
||||
# Create new procedure
|
||||
# Note: success_count starts at 1 in record_procedure
|
||||
procedure_data = ProcedureCreate(
|
||||
project_id=project_id,
|
||||
agent_type_id=agent_type_id,
|
||||
name=f"Procedure for {task_type}",
|
||||
trigger_pattern=task_type,
|
||||
steps=steps,
|
||||
)
|
||||
await procedural.record_procedure(procedure_data)
|
||||
return "created"
|
||||
|
||||
def _extract_common_steps(self, episodes: list[Episode]) -> list[dict[str, Any]]:
|
||||
"""Extract common action steps from multiple episodes."""
|
||||
# Simple heuristic: take the steps from the most successful episode
|
||||
# with the most detailed actions
|
||||
|
||||
best_episode = max(
|
||||
episodes,
|
||||
key=lambda e: (
|
||||
e.outcome == Outcome.SUCCESS,
|
||||
len(e.actions),
|
||||
e.importance_score,
|
||||
),
|
||||
)
|
||||
|
||||
steps: list[dict[str, Any]] = []
|
||||
for i, action in enumerate(best_episode.actions):
|
||||
step = {
|
||||
"order": i + 1,
|
||||
"action": action.get("type", "action"),
|
||||
"description": action.get("content", str(action))[:500],
|
||||
"parameters": action,
|
||||
}
|
||||
steps.append(step)
|
||||
|
||||
return steps
|
||||
|
||||
# =========================================================================
|
||||
# Memory Pruning
|
||||
# =========================================================================
|
||||
|
||||
async def prune_old_episodes(
|
||||
self,
|
||||
project_id: UUID,
|
||||
max_age_days: int | None = None,
|
||||
min_importance: float | None = None,
|
||||
) -> ConsolidationResult:
|
||||
"""
|
||||
Prune old, low-value episodes.
|
||||
|
||||
Args:
|
||||
project_id: Project to prune
|
||||
max_age_days: Maximum age in days (default from config)
|
||||
min_importance: Minimum importance to keep (default from config)
|
||||
|
||||
Returns:
|
||||
ConsolidationResult with pruning statistics
|
||||
"""
|
||||
start_time = datetime.now(UTC)
|
||||
result = ConsolidationResult(
|
||||
source_type="episodic",
|
||||
target_type="pruned",
|
||||
)
|
||||
|
||||
max_age = max_age_days or self._config.max_episode_age_days
|
||||
min_imp = min_importance or self._config.min_importance_to_keep
|
||||
cutoff_date = datetime.now(UTC) - timedelta(days=max_age)
|
||||
|
||||
try:
|
||||
episodic = await self._get_episodic()
|
||||
|
||||
# Get old episodes
|
||||
# Note: In production, this would use a more efficient query
|
||||
all_episodes = await episodic.get_recent(
|
||||
project_id,
|
||||
limit=self._config.batch_size * 10,
|
||||
since=cutoff_date - timedelta(days=365), # Search past year
|
||||
)
|
||||
|
||||
for episode in all_episodes:
|
||||
result.items_processed += 1
|
||||
|
||||
# Check if should be pruned
|
||||
if not self._should_prune_episode(episode, cutoff_date, min_imp):
|
||||
result.items_skipped += 1
|
||||
continue
|
||||
|
||||
try:
|
||||
deleted = await episodic.delete(episode.id)
|
||||
if deleted:
|
||||
result.items_pruned += 1
|
||||
else:
|
||||
result.items_skipped += 1
|
||||
except Exception as e:
|
||||
result.errors.append(f"Episode {episode.id}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
result.errors.append(f"Pruning failed: {e}")
|
||||
logger.exception("Failed episode pruning")
|
||||
|
||||
result.duration_seconds = (datetime.now(UTC) - start_time).total_seconds()
|
||||
|
||||
logger.info(
|
||||
f"Episode pruning: {result.items_processed} processed, "
|
||||
f"{result.items_pruned} pruned"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _should_prune_episode(
|
||||
self,
|
||||
episode: Episode,
|
||||
cutoff_date: datetime,
|
||||
min_importance: float,
|
||||
) -> bool:
|
||||
"""Determine if an episode should be pruned."""
|
||||
# Keep recent episodes
|
||||
if episode.occurred_at >= cutoff_date:
|
||||
return False
|
||||
|
||||
# Keep failures if configured
|
||||
if self._config.keep_all_failures and episode.outcome == Outcome.FAILURE:
|
||||
return False
|
||||
|
||||
# Keep episodes with lessons if configured
|
||||
if self._config.keep_all_with_lessons and episode.lessons_learned:
|
||||
return False
|
||||
|
||||
# Keep high-importance episodes
|
||||
if episode.importance_score >= min_importance:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
# =========================================================================
|
||||
# Nightly Consolidation
|
||||
# =========================================================================
|
||||
|
||||
async def run_nightly_consolidation(
|
||||
self,
|
||||
project_id: UUID,
|
||||
agent_type_id: UUID | None = None,
|
||||
) -> NightlyConsolidationResult:
|
||||
"""
|
||||
Run full nightly consolidation workflow.
|
||||
|
||||
This includes:
|
||||
1. Extract facts from recent episodes
|
||||
2. Learn procedures from successful patterns
|
||||
3. Prune old, low-value memories
|
||||
|
||||
Args:
|
||||
project_id: Project to consolidate
|
||||
agent_type_id: Optional agent type filter
|
||||
|
||||
Returns:
|
||||
NightlyConsolidationResult with all outcomes
|
||||
"""
|
||||
result = NightlyConsolidationResult(started_at=datetime.now(UTC))
|
||||
|
||||
logger.info(f"Starting nightly consolidation for project {project_id}")
|
||||
|
||||
try:
|
||||
# Step 1: Episodic -> Semantic (last 24 hours)
|
||||
since_yesterday = datetime.now(UTC) - timedelta(days=1)
|
||||
result.episodic_to_semantic = await self.consolidate_episodes_to_facts(
|
||||
project_id=project_id,
|
||||
since=since_yesterday,
|
||||
)
|
||||
result.total_facts_created = result.episodic_to_semantic.items_created
|
||||
|
||||
# Step 2: Episodic -> Procedural (last 7 days)
|
||||
since_week = datetime.now(UTC) - timedelta(days=7)
|
||||
result.episodic_to_procedural = (
|
||||
await self.consolidate_episodes_to_procedures(
|
||||
project_id=project_id,
|
||||
agent_type_id=agent_type_id,
|
||||
since=since_week,
|
||||
)
|
||||
)
|
||||
result.total_procedures_created = (
|
||||
result.episodic_to_procedural.items_created
|
||||
)
|
||||
|
||||
# Step 3: Prune old memories
|
||||
result.pruning = await self.prune_old_episodes(project_id=project_id)
|
||||
result.total_pruned = result.pruning.items_pruned
|
||||
|
||||
# Calculate totals
|
||||
result.total_episodes_processed = (
|
||||
result.episodic_to_semantic.items_processed
|
||||
if result.episodic_to_semantic
|
||||
else 0
|
||||
) + (
|
||||
result.episodic_to_procedural.items_processed
|
||||
if result.episodic_to_procedural
|
||||
else 0
|
||||
)
|
||||
|
||||
# Collect all errors
|
||||
if result.episodic_to_semantic and result.episodic_to_semantic.errors:
|
||||
result.errors.extend(result.episodic_to_semantic.errors)
|
||||
if result.episodic_to_procedural and result.episodic_to_procedural.errors:
|
||||
result.errors.extend(result.episodic_to_procedural.errors)
|
||||
if result.pruning and result.pruning.errors:
|
||||
result.errors.extend(result.pruning.errors)
|
||||
|
||||
except Exception as e:
|
||||
result.errors.append(f"Nightly consolidation failed: {e}")
|
||||
logger.exception("Nightly consolidation failed")
|
||||
|
||||
result.completed_at = datetime.now(UTC)
|
||||
|
||||
duration = (result.completed_at - result.started_at).total_seconds()
|
||||
logger.info(
|
||||
f"Nightly consolidation completed in {duration:.1f}s: "
|
||||
f"{result.total_facts_created} facts, "
|
||||
f"{result.total_procedures_created} procedures, "
|
||||
f"{result.total_pruned} pruned"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# Factory function - no singleton to avoid stale session issues
|
||||
async def get_consolidation_service(
|
||||
session: AsyncSession,
|
||||
config: ConsolidationConfig | None = None,
|
||||
) -> MemoryConsolidationService:
|
||||
"""
|
||||
Create a memory consolidation service for the given session.
|
||||
|
||||
Note: This creates a new instance each time to avoid stale session issues.
|
||||
The service is lightweight and safe to recreate per-request.
|
||||
|
||||
Args:
|
||||
session: Database session (must be active)
|
||||
config: Optional configuration
|
||||
|
||||
Returns:
|
||||
MemoryConsolidationService instance
|
||||
"""
|
||||
return MemoryConsolidationService(session=session, config=config)
|
||||
17
backend/app/services/memory/episodic/__init__.py
Normal file
17
backend/app/services/memory/episodic/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
# app/services/memory/episodic/__init__.py
|
||||
"""
|
||||
Episodic Memory Package.
|
||||
|
||||
Provides experiential memory storage and retrieval for agent learning.
|
||||
"""
|
||||
|
||||
from .memory import EpisodicMemory
|
||||
from .recorder import EpisodeRecorder
|
||||
from .retrieval import EpisodeRetriever, RetrievalStrategy
|
||||
|
||||
__all__ = [
|
||||
"EpisodeRecorder",
|
||||
"EpisodeRetriever",
|
||||
"EpisodicMemory",
|
||||
"RetrievalStrategy",
|
||||
]
|
||||
490
backend/app/services/memory/episodic/memory.py
Normal file
490
backend/app/services/memory/episodic/memory.py
Normal file
@@ -0,0 +1,490 @@
|
||||
# app/services/memory/episodic/memory.py
|
||||
"""
|
||||
Episodic Memory Implementation.
|
||||
|
||||
Provides experiential memory storage and retrieval for agent learning.
|
||||
Combines episode recording and retrieval into a unified interface.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.memory.types import Episode, EpisodeCreate, Outcome, RetrievalResult
|
||||
|
||||
from .recorder import EpisodeRecorder
|
||||
from .retrieval import EpisodeRetriever, RetrievalStrategy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EpisodicMemory:
|
||||
"""
|
||||
Episodic Memory Service.
|
||||
|
||||
Provides experiential memory for agent learning:
|
||||
- Record task completions with context
|
||||
- Store failures with error context
|
||||
- Retrieve by semantic similarity
|
||||
- Retrieve by recency, outcome, task type
|
||||
- Track importance scores
|
||||
- Extract lessons learned
|
||||
|
||||
Performance target: <100ms P95 for retrieval
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
embedding_generator: Any | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize episodic memory.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
embedding_generator: Optional embedding generator for semantic search
|
||||
"""
|
||||
self._session = session
|
||||
self._embedding_generator = embedding_generator
|
||||
self._recorder = EpisodeRecorder(session, embedding_generator)
|
||||
self._retriever = EpisodeRetriever(session, embedding_generator)
|
||||
|
||||
@classmethod
|
||||
async def create(
|
||||
cls,
|
||||
session: AsyncSession,
|
||||
embedding_generator: Any | None = None,
|
||||
) -> "EpisodicMemory":
|
||||
"""
|
||||
Factory method to create EpisodicMemory.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
embedding_generator: Optional embedding generator
|
||||
|
||||
Returns:
|
||||
Configured EpisodicMemory instance
|
||||
"""
|
||||
return cls(session=session, embedding_generator=embedding_generator)
|
||||
|
||||
# =========================================================================
|
||||
# Recording Operations
|
||||
# =========================================================================
|
||||
|
||||
async def record_episode(self, episode: EpisodeCreate) -> Episode:
|
||||
"""
|
||||
Record a new episode.
|
||||
|
||||
Args:
|
||||
episode: Episode data to record
|
||||
|
||||
Returns:
|
||||
The created episode with assigned ID
|
||||
"""
|
||||
return await self._recorder.record(episode)
|
||||
|
||||
async def record_success(
|
||||
self,
|
||||
project_id: UUID,
|
||||
session_id: str,
|
||||
task_type: str,
|
||||
task_description: str,
|
||||
actions: list[dict[str, Any]],
|
||||
context_summary: str,
|
||||
outcome_details: str = "",
|
||||
duration_seconds: float = 0.0,
|
||||
tokens_used: int = 0,
|
||||
lessons_learned: list[str] | None = None,
|
||||
agent_instance_id: UUID | None = None,
|
||||
agent_type_id: UUID | None = None,
|
||||
) -> Episode:
|
||||
"""
|
||||
Convenience method to record a successful episode.
|
||||
|
||||
Args:
|
||||
project_id: Project ID
|
||||
session_id: Session ID
|
||||
task_type: Type of task
|
||||
task_description: Task description
|
||||
actions: Actions taken
|
||||
context_summary: Context summary
|
||||
outcome_details: Optional outcome details
|
||||
duration_seconds: Task duration
|
||||
tokens_used: Tokens consumed
|
||||
lessons_learned: Optional lessons
|
||||
agent_instance_id: Optional agent instance
|
||||
agent_type_id: Optional agent type
|
||||
|
||||
Returns:
|
||||
The created episode
|
||||
"""
|
||||
episode_data = EpisodeCreate(
|
||||
project_id=project_id,
|
||||
session_id=session_id,
|
||||
task_type=task_type,
|
||||
task_description=task_description,
|
||||
actions=actions,
|
||||
context_summary=context_summary,
|
||||
outcome=Outcome.SUCCESS,
|
||||
outcome_details=outcome_details,
|
||||
duration_seconds=duration_seconds,
|
||||
tokens_used=tokens_used,
|
||||
lessons_learned=lessons_learned or [],
|
||||
agent_instance_id=agent_instance_id,
|
||||
agent_type_id=agent_type_id,
|
||||
)
|
||||
return await self.record_episode(episode_data)
|
||||
|
||||
async def record_failure(
|
||||
self,
|
||||
project_id: UUID,
|
||||
session_id: str,
|
||||
task_type: str,
|
||||
task_description: str,
|
||||
actions: list[dict[str, Any]],
|
||||
context_summary: str,
|
||||
error_details: str,
|
||||
duration_seconds: float = 0.0,
|
||||
tokens_used: int = 0,
|
||||
lessons_learned: list[str] | None = None,
|
||||
agent_instance_id: UUID | None = None,
|
||||
agent_type_id: UUID | None = None,
|
||||
) -> Episode:
|
||||
"""
|
||||
Convenience method to record a failed episode.
|
||||
|
||||
Args:
|
||||
project_id: Project ID
|
||||
session_id: Session ID
|
||||
task_type: Type of task
|
||||
task_description: Task description
|
||||
actions: Actions taken before failure
|
||||
context_summary: Context summary
|
||||
error_details: Error details
|
||||
duration_seconds: Task duration
|
||||
tokens_used: Tokens consumed
|
||||
lessons_learned: Optional lessons from failure
|
||||
agent_instance_id: Optional agent instance
|
||||
agent_type_id: Optional agent type
|
||||
|
||||
Returns:
|
||||
The created episode
|
||||
"""
|
||||
episode_data = EpisodeCreate(
|
||||
project_id=project_id,
|
||||
session_id=session_id,
|
||||
task_type=task_type,
|
||||
task_description=task_description,
|
||||
actions=actions,
|
||||
context_summary=context_summary,
|
||||
outcome=Outcome.FAILURE,
|
||||
outcome_details=error_details,
|
||||
duration_seconds=duration_seconds,
|
||||
tokens_used=tokens_used,
|
||||
lessons_learned=lessons_learned or [],
|
||||
agent_instance_id=agent_instance_id,
|
||||
agent_type_id=agent_type_id,
|
||||
)
|
||||
return await self.record_episode(episode_data)
|
||||
|
||||
# =========================================================================
|
||||
# Retrieval Operations
|
||||
# =========================================================================
|
||||
|
||||
async def search_similar(
|
||||
self,
|
||||
project_id: UUID,
|
||||
query: str,
|
||||
limit: int = 10,
|
||||
agent_instance_id: UUID | None = None,
|
||||
) -> list[Episode]:
|
||||
"""
|
||||
Search for semantically similar episodes.
|
||||
|
||||
Args:
|
||||
project_id: Project to search within
|
||||
query: Search query
|
||||
limit: Maximum results
|
||||
agent_instance_id: Optional filter by agent instance
|
||||
|
||||
Returns:
|
||||
List of similar episodes
|
||||
"""
|
||||
result = await self._retriever.search_similar(
|
||||
project_id, query, limit, agent_instance_id
|
||||
)
|
||||
return result.items
|
||||
|
||||
async def get_recent(
|
||||
self,
|
||||
project_id: UUID,
|
||||
limit: int = 10,
|
||||
since: datetime | None = None,
|
||||
agent_instance_id: UUID | None = None,
|
||||
) -> list[Episode]:
|
||||
"""
|
||||
Get recent episodes.
|
||||
|
||||
Args:
|
||||
project_id: Project to search within
|
||||
limit: Maximum results
|
||||
since: Optional time filter
|
||||
agent_instance_id: Optional filter by agent instance
|
||||
|
||||
Returns:
|
||||
List of recent episodes
|
||||
"""
|
||||
result = await self._retriever.get_recent(
|
||||
project_id, limit, since, agent_instance_id
|
||||
)
|
||||
return result.items
|
||||
|
||||
async def get_by_outcome(
|
||||
self,
|
||||
project_id: UUID,
|
||||
outcome: Outcome,
|
||||
limit: int = 10,
|
||||
agent_instance_id: UUID | None = None,
|
||||
) -> list[Episode]:
|
||||
"""
|
||||
Get episodes by outcome.
|
||||
|
||||
Args:
|
||||
project_id: Project to search within
|
||||
outcome: Outcome to filter by
|
||||
limit: Maximum results
|
||||
agent_instance_id: Optional filter by agent instance
|
||||
|
||||
Returns:
|
||||
List of episodes with specified outcome
|
||||
"""
|
||||
result = await self._retriever.get_by_outcome(
|
||||
project_id, outcome, limit, agent_instance_id
|
||||
)
|
||||
return result.items
|
||||
|
||||
async def get_by_task_type(
|
||||
self,
|
||||
project_id: UUID,
|
||||
task_type: str,
|
||||
limit: int = 10,
|
||||
agent_instance_id: UUID | None = None,
|
||||
) -> list[Episode]:
|
||||
"""
|
||||
Get episodes by task type.
|
||||
|
||||
Args:
|
||||
project_id: Project to search within
|
||||
task_type: Task type to filter by
|
||||
limit: Maximum results
|
||||
agent_instance_id: Optional filter by agent instance
|
||||
|
||||
Returns:
|
||||
List of episodes with specified task type
|
||||
"""
|
||||
result = await self._retriever.get_by_task_type(
|
||||
project_id, task_type, limit, agent_instance_id
|
||||
)
|
||||
return result.items
|
||||
|
||||
async def get_important(
|
||||
self,
|
||||
project_id: UUID,
|
||||
limit: int = 10,
|
||||
min_importance: float = 0.7,
|
||||
agent_instance_id: UUID | None = None,
|
||||
) -> list[Episode]:
|
||||
"""
|
||||
Get high-importance episodes.
|
||||
|
||||
Args:
|
||||
project_id: Project to search within
|
||||
limit: Maximum results
|
||||
min_importance: Minimum importance score
|
||||
agent_instance_id: Optional filter by agent instance
|
||||
|
||||
Returns:
|
||||
List of important episodes
|
||||
"""
|
||||
result = await self._retriever.get_important(
|
||||
project_id, limit, min_importance, agent_instance_id
|
||||
)
|
||||
return result.items
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
project_id: UUID,
|
||||
strategy: RetrievalStrategy = RetrievalStrategy.RECENCY,
|
||||
limit: int = 10,
|
||||
**kwargs: Any,
|
||||
) -> RetrievalResult[Episode]:
|
||||
"""
|
||||
Retrieve episodes with full result metadata.
|
||||
|
||||
Args:
|
||||
project_id: Project to search within
|
||||
strategy: Retrieval strategy
|
||||
limit: Maximum results
|
||||
**kwargs: Strategy-specific parameters
|
||||
|
||||
Returns:
|
||||
RetrievalResult with episodes and metadata
|
||||
"""
|
||||
return await self._retriever.retrieve(project_id, strategy, limit, **kwargs)
|
||||
|
||||
# =========================================================================
|
||||
# Modification Operations
|
||||
# =========================================================================
|
||||
|
||||
async def get_by_id(self, episode_id: UUID) -> Episode | None:
|
||||
"""Get an episode by ID."""
|
||||
return await self._recorder.get_by_id(episode_id)
|
||||
|
||||
async def update_importance(
|
||||
self,
|
||||
episode_id: UUID,
|
||||
importance_score: float,
|
||||
) -> Episode | None:
|
||||
"""
|
||||
Update an episode's importance score.
|
||||
|
||||
Args:
|
||||
episode_id: Episode to update
|
||||
importance_score: New importance score (0.0 to 1.0)
|
||||
|
||||
Returns:
|
||||
Updated episode or None if not found
|
||||
"""
|
||||
return await self._recorder.update_importance(episode_id, importance_score)
|
||||
|
||||
async def add_lessons(
|
||||
self,
|
||||
episode_id: UUID,
|
||||
lessons: list[str],
|
||||
) -> Episode | None:
|
||||
"""
|
||||
Add lessons learned to an episode.
|
||||
|
||||
Args:
|
||||
episode_id: Episode to update
|
||||
lessons: Lessons to add
|
||||
|
||||
Returns:
|
||||
Updated episode or None if not found
|
||||
"""
|
||||
return await self._recorder.add_lessons(episode_id, lessons)
|
||||
|
||||
async def delete(self, episode_id: UUID) -> bool:
|
||||
"""
|
||||
Delete an episode.
|
||||
|
||||
Args:
|
||||
episode_id: Episode to delete
|
||||
|
||||
Returns:
|
||||
True if deleted
|
||||
"""
|
||||
return await self._recorder.delete(episode_id)
|
||||
|
||||
# =========================================================================
|
||||
# Summarization
|
||||
# =========================================================================
|
||||
|
||||
async def summarize_episodes(
|
||||
self,
|
||||
episode_ids: list[UUID],
|
||||
) -> str:
|
||||
"""
|
||||
Summarize multiple episodes into a consolidated view.
|
||||
|
||||
Args:
|
||||
episode_ids: Episodes to summarize
|
||||
|
||||
Returns:
|
||||
Summary text
|
||||
"""
|
||||
if not episode_ids:
|
||||
return "No episodes to summarize."
|
||||
|
||||
episodes: list[Episode] = []
|
||||
for episode_id in episode_ids:
|
||||
episode = await self.get_by_id(episode_id)
|
||||
if episode:
|
||||
episodes.append(episode)
|
||||
|
||||
if not episodes:
|
||||
return "No episodes found."
|
||||
|
||||
# Build summary
|
||||
lines = [f"Summary of {len(episodes)} episodes:", ""]
|
||||
|
||||
# Outcome breakdown
|
||||
success = sum(1 for e in episodes if e.outcome == Outcome.SUCCESS)
|
||||
failure = sum(1 for e in episodes if e.outcome == Outcome.FAILURE)
|
||||
partial = sum(1 for e in episodes if e.outcome == Outcome.PARTIAL)
|
||||
lines.append(
|
||||
f"Outcomes: {success} success, {failure} failure, {partial} partial"
|
||||
)
|
||||
|
||||
# Task types
|
||||
task_types = {e.task_type for e in episodes}
|
||||
lines.append(f"Task types: {', '.join(sorted(task_types))}")
|
||||
|
||||
# Aggregate lessons
|
||||
all_lessons: list[str] = []
|
||||
for e in episodes:
|
||||
all_lessons.extend(e.lessons_learned)
|
||||
|
||||
if all_lessons:
|
||||
lines.append("")
|
||||
lines.append("Key lessons learned:")
|
||||
# Deduplicate lessons
|
||||
unique_lessons = list(dict.fromkeys(all_lessons))
|
||||
for lesson in unique_lessons[:10]: # Top 10
|
||||
lines.append(f" - {lesson}")
|
||||
|
||||
# Duration and tokens
|
||||
total_duration = sum(e.duration_seconds for e in episodes)
|
||||
total_tokens = sum(e.tokens_used for e in episodes)
|
||||
lines.append("")
|
||||
lines.append(f"Total duration: {total_duration:.1f}s")
|
||||
lines.append(f"Total tokens: {total_tokens:,}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
# =========================================================================
|
||||
# Statistics
|
||||
# =========================================================================
|
||||
|
||||
async def get_stats(self, project_id: UUID) -> dict[str, Any]:
|
||||
"""
|
||||
Get episode statistics for a project.
|
||||
|
||||
Args:
|
||||
project_id: Project to get stats for
|
||||
|
||||
Returns:
|
||||
Dictionary with episode statistics
|
||||
"""
|
||||
return await self._recorder.get_stats(project_id)
|
||||
|
||||
async def count(
|
||||
self,
|
||||
project_id: UUID,
|
||||
since: datetime | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Count episodes for a project.
|
||||
|
||||
Args:
|
||||
project_id: Project to count for
|
||||
since: Optional time filter
|
||||
|
||||
Returns:
|
||||
Number of episodes
|
||||
"""
|
||||
return await self._recorder.count_by_project(project_id, since)
|
||||
357
backend/app/services/memory/episodic/recorder.py
Normal file
357
backend/app/services/memory/episodic/recorder.py
Normal file
@@ -0,0 +1,357 @@
|
||||
# app/services/memory/episodic/recorder.py
|
||||
"""
|
||||
Episode Recording.
|
||||
|
||||
Handles the creation and storage of episodic memories
|
||||
during agent task execution.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.memory.enums import EpisodeOutcome
|
||||
from app.models.memory.episode import Episode as EpisodeModel
|
||||
from app.services.memory.config import get_memory_settings
|
||||
from app.services.memory.types import Episode, EpisodeCreate, Outcome
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _outcome_to_db(outcome: Outcome) -> EpisodeOutcome:
|
||||
"""Convert service Outcome to database EpisodeOutcome."""
|
||||
return EpisodeOutcome(outcome.value)
|
||||
|
||||
|
||||
def _db_to_outcome(db_outcome: EpisodeOutcome) -> Outcome:
|
||||
"""Convert database EpisodeOutcome to service Outcome."""
|
||||
return Outcome(db_outcome.value)
|
||||
|
||||
|
||||
def _model_to_episode(model: EpisodeModel) -> Episode:
|
||||
"""Convert SQLAlchemy model to Episode dataclass."""
|
||||
# SQLAlchemy Column types are inferred as Column[T] by mypy, but at runtime
|
||||
# they return actual values. We use type: ignore to handle this mismatch.
|
||||
return Episode(
|
||||
id=model.id, # type: ignore[arg-type]
|
||||
project_id=model.project_id, # type: ignore[arg-type]
|
||||
agent_instance_id=model.agent_instance_id, # type: ignore[arg-type]
|
||||
agent_type_id=model.agent_type_id, # type: ignore[arg-type]
|
||||
session_id=model.session_id, # type: ignore[arg-type]
|
||||
task_type=model.task_type, # type: ignore[arg-type]
|
||||
task_description=model.task_description, # type: ignore[arg-type]
|
||||
actions=model.actions or [], # type: ignore[arg-type]
|
||||
context_summary=model.context_summary, # type: ignore[arg-type]
|
||||
outcome=_db_to_outcome(model.outcome), # type: ignore[arg-type]
|
||||
outcome_details=model.outcome_details or "", # type: ignore[arg-type]
|
||||
duration_seconds=model.duration_seconds, # type: ignore[arg-type]
|
||||
tokens_used=model.tokens_used, # type: ignore[arg-type]
|
||||
lessons_learned=model.lessons_learned or [], # type: ignore[arg-type]
|
||||
importance_score=model.importance_score, # type: ignore[arg-type]
|
||||
embedding=None, # Don't expose raw embedding
|
||||
occurred_at=model.occurred_at, # type: ignore[arg-type]
|
||||
created_at=model.created_at, # type: ignore[arg-type]
|
||||
updated_at=model.updated_at, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
|
||||
class EpisodeRecorder:
|
||||
"""
|
||||
Records episodes to the database.
|
||||
|
||||
Handles episode creation, importance scoring,
|
||||
and lesson extraction.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
embedding_generator: Any | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize recorder.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
embedding_generator: Optional embedding generator for semantic indexing
|
||||
"""
|
||||
self._session = session
|
||||
self._embedding_generator = embedding_generator
|
||||
self._settings = get_memory_settings()
|
||||
|
||||
async def record(self, episode_data: EpisodeCreate) -> Episode:
|
||||
"""
|
||||
Record a new episode.
|
||||
|
||||
Args:
|
||||
episode_data: Episode data to record
|
||||
|
||||
Returns:
|
||||
The created episode
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
|
||||
# Calculate importance score if not provided
|
||||
importance = episode_data.importance_score
|
||||
if importance == 0.5: # Default value, calculate
|
||||
importance = self._calculate_importance(episode_data)
|
||||
|
||||
# Create the model
|
||||
model = EpisodeModel(
|
||||
id=uuid4(),
|
||||
project_id=episode_data.project_id,
|
||||
agent_instance_id=episode_data.agent_instance_id,
|
||||
agent_type_id=episode_data.agent_type_id,
|
||||
session_id=episode_data.session_id,
|
||||
task_type=episode_data.task_type,
|
||||
task_description=episode_data.task_description,
|
||||
actions=episode_data.actions,
|
||||
context_summary=episode_data.context_summary,
|
||||
outcome=_outcome_to_db(episode_data.outcome),
|
||||
outcome_details=episode_data.outcome_details,
|
||||
duration_seconds=episode_data.duration_seconds,
|
||||
tokens_used=episode_data.tokens_used,
|
||||
lessons_learned=episode_data.lessons_learned,
|
||||
importance_score=importance,
|
||||
occurred_at=now,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
|
||||
# Generate embedding if generator available
|
||||
if self._embedding_generator is not None:
|
||||
try:
|
||||
text_for_embedding = self._create_embedding_text(episode_data)
|
||||
embedding = await self._embedding_generator.generate(text_for_embedding)
|
||||
model.embedding = embedding
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to generate embedding: {e}")
|
||||
|
||||
self._session.add(model)
|
||||
await self._session.flush()
|
||||
await self._session.refresh(model)
|
||||
|
||||
logger.debug(f"Recorded episode {model.id} for task {model.task_type}")
|
||||
return _model_to_episode(model)
|
||||
|
||||
def _calculate_importance(self, episode_data: EpisodeCreate) -> float:
|
||||
"""
|
||||
Calculate importance score for an episode.
|
||||
|
||||
Factors:
|
||||
- Outcome: Failures are more important to learn from
|
||||
- Duration: Longer tasks may be more significant
|
||||
- Token usage: Higher usage may indicate complexity
|
||||
- Lessons learned: Episodes with lessons are more valuable
|
||||
"""
|
||||
score = 0.5 # Base score
|
||||
|
||||
# Outcome factor
|
||||
if episode_data.outcome == Outcome.FAILURE:
|
||||
score += 0.2 # Failures are important for learning
|
||||
elif episode_data.outcome == Outcome.PARTIAL:
|
||||
score += 0.1
|
||||
# Success is default, no adjustment
|
||||
|
||||
# Lessons learned factor
|
||||
if episode_data.lessons_learned:
|
||||
score += min(0.15, len(episode_data.lessons_learned) * 0.05)
|
||||
|
||||
# Duration factor (longer tasks may be more significant)
|
||||
if episode_data.duration_seconds > 60:
|
||||
score += 0.05
|
||||
if episode_data.duration_seconds > 300:
|
||||
score += 0.05
|
||||
|
||||
# Token usage factor (complex tasks)
|
||||
if episode_data.tokens_used > 1000:
|
||||
score += 0.05
|
||||
|
||||
# Clamp to valid range
|
||||
return min(1.0, max(0.0, score))
|
||||
|
||||
def _create_embedding_text(self, episode_data: EpisodeCreate) -> str:
|
||||
"""Create text representation for embedding generation."""
|
||||
parts = [
|
||||
f"Task: {episode_data.task_type}",
|
||||
f"Description: {episode_data.task_description}",
|
||||
f"Context: {episode_data.context_summary}",
|
||||
f"Outcome: {episode_data.outcome.value}",
|
||||
]
|
||||
|
||||
if episode_data.outcome_details:
|
||||
parts.append(f"Details: {episode_data.outcome_details}")
|
||||
|
||||
if episode_data.lessons_learned:
|
||||
parts.append(f"Lessons: {', '.join(episode_data.lessons_learned)}")
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
async def get_by_id(self, episode_id: UUID) -> Episode | None:
|
||||
"""Get an episode by ID."""
|
||||
query = select(EpisodeModel).where(EpisodeModel.id == episode_id)
|
||||
result = await self._session.execute(query)
|
||||
model = result.scalar_one_or_none()
|
||||
|
||||
if model is None:
|
||||
return None
|
||||
|
||||
return _model_to_episode(model)
|
||||
|
||||
async def update_importance(
|
||||
self,
|
||||
episode_id: UUID,
|
||||
importance_score: float,
|
||||
) -> Episode | None:
|
||||
"""
|
||||
Update the importance score of an episode.
|
||||
|
||||
Args:
|
||||
episode_id: Episode to update
|
||||
importance_score: New importance score (0.0 to 1.0)
|
||||
|
||||
Returns:
|
||||
Updated episode or None if not found
|
||||
"""
|
||||
# Validate score
|
||||
importance_score = min(1.0, max(0.0, importance_score))
|
||||
|
||||
stmt = (
|
||||
update(EpisodeModel)
|
||||
.where(EpisodeModel.id == episode_id)
|
||||
.values(
|
||||
importance_score=importance_score,
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
.returning(EpisodeModel)
|
||||
)
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
model = result.scalar_one_or_none()
|
||||
|
||||
if model is None:
|
||||
return None
|
||||
|
||||
await self._session.flush()
|
||||
return _model_to_episode(model)
|
||||
|
||||
async def add_lessons(
|
||||
self,
|
||||
episode_id: UUID,
|
||||
lessons: list[str],
|
||||
) -> Episode | None:
|
||||
"""
|
||||
Add lessons learned to an episode.
|
||||
|
||||
Args:
|
||||
episode_id: Episode to update
|
||||
lessons: New lessons to add
|
||||
|
||||
Returns:
|
||||
Updated episode or None if not found
|
||||
"""
|
||||
# Get current episode
|
||||
query = select(EpisodeModel).where(EpisodeModel.id == episode_id)
|
||||
result = await self._session.execute(query)
|
||||
model = result.scalar_one_or_none()
|
||||
|
||||
if model is None:
|
||||
return None
|
||||
|
||||
# Append lessons
|
||||
current_lessons: list[str] = model.lessons_learned or [] # type: ignore[assignment]
|
||||
updated_lessons = current_lessons + lessons
|
||||
|
||||
stmt = (
|
||||
update(EpisodeModel)
|
||||
.where(EpisodeModel.id == episode_id)
|
||||
.values(
|
||||
lessons_learned=updated_lessons,
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
.returning(EpisodeModel)
|
||||
)
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
model = result.scalar_one_or_none()
|
||||
await self._session.flush()
|
||||
|
||||
return _model_to_episode(model) if model else None
|
||||
|
||||
async def delete(self, episode_id: UUID) -> bool:
|
||||
"""
|
||||
Delete an episode.
|
||||
|
||||
Args:
|
||||
episode_id: Episode to delete
|
||||
|
||||
Returns:
|
||||
True if deleted
|
||||
"""
|
||||
query = select(EpisodeModel).where(EpisodeModel.id == episode_id)
|
||||
result = await self._session.execute(query)
|
||||
model = result.scalar_one_or_none()
|
||||
|
||||
if model is None:
|
||||
return False
|
||||
|
||||
await self._session.delete(model)
|
||||
await self._session.flush()
|
||||
return True
|
||||
|
||||
async def count_by_project(
|
||||
self,
|
||||
project_id: UUID,
|
||||
since: datetime | None = None,
|
||||
) -> int:
|
||||
"""Count episodes for a project."""
|
||||
query = select(EpisodeModel).where(EpisodeModel.project_id == project_id)
|
||||
if since is not None:
|
||||
query = query.where(EpisodeModel.occurred_at >= since)
|
||||
|
||||
result = await self._session.execute(query)
|
||||
return len(list(result.scalars().all()))
|
||||
|
||||
async def get_stats(self, project_id: UUID) -> dict[str, Any]:
|
||||
"""
|
||||
Get statistics for a project's episodes.
|
||||
|
||||
Returns:
|
||||
Dictionary with episode statistics
|
||||
"""
|
||||
query = select(EpisodeModel).where(EpisodeModel.project_id == project_id)
|
||||
result = await self._session.execute(query)
|
||||
episodes = list(result.scalars().all())
|
||||
|
||||
if not episodes:
|
||||
return {
|
||||
"total_count": 0,
|
||||
"success_count": 0,
|
||||
"failure_count": 0,
|
||||
"partial_count": 0,
|
||||
"avg_importance": 0.0,
|
||||
"avg_duration": 0.0,
|
||||
"total_tokens": 0,
|
||||
}
|
||||
|
||||
success_count = sum(1 for e in episodes if e.outcome == EpisodeOutcome.SUCCESS)
|
||||
failure_count = sum(1 for e in episodes if e.outcome == EpisodeOutcome.FAILURE)
|
||||
partial_count = sum(1 for e in episodes if e.outcome == EpisodeOutcome.PARTIAL)
|
||||
|
||||
avg_importance = sum(e.importance_score for e in episodes) / len(episodes)
|
||||
avg_duration = sum(e.duration_seconds for e in episodes) / len(episodes)
|
||||
total_tokens = sum(e.tokens_used for e in episodes)
|
||||
|
||||
return {
|
||||
"total_count": len(episodes),
|
||||
"success_count": success_count,
|
||||
"failure_count": failure_count,
|
||||
"partial_count": partial_count,
|
||||
"avg_importance": avg_importance,
|
||||
"avg_duration": avg_duration,
|
||||
"total_tokens": total_tokens,
|
||||
}
|
||||
503
backend/app/services/memory/episodic/retrieval.py
Normal file
503
backend/app/services/memory/episodic/retrieval.py
Normal file
@@ -0,0 +1,503 @@
|
||||
# app/services/memory/episodic/retrieval.py
|
||||
"""
|
||||
Episode Retrieval Strategies.
|
||||
|
||||
Provides different retrieval strategies for finding relevant episodes:
|
||||
- Semantic similarity (vector search)
|
||||
- Recency-based
|
||||
- Outcome-based filtering
|
||||
- Importance-based ranking
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_, desc, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.memory.enums import EpisodeOutcome
|
||||
from app.models.memory.episode import Episode as EpisodeModel
|
||||
from app.services.memory.types import Episode, Outcome, RetrievalResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RetrievalStrategy(str, Enum):
|
||||
"""Retrieval strategy types."""
|
||||
|
||||
SEMANTIC = "semantic"
|
||||
RECENCY = "recency"
|
||||
OUTCOME = "outcome"
|
||||
IMPORTANCE = "importance"
|
||||
HYBRID = "hybrid"
|
||||
|
||||
|
||||
def _model_to_episode(model: EpisodeModel) -> Episode:
|
||||
"""Convert SQLAlchemy model to Episode dataclass."""
|
||||
# SQLAlchemy Column types are inferred as Column[T] by mypy, but at runtime
|
||||
# they return actual values. We use type: ignore to handle this mismatch.
|
||||
return Episode(
|
||||
id=model.id, # type: ignore[arg-type]
|
||||
project_id=model.project_id, # type: ignore[arg-type]
|
||||
agent_instance_id=model.agent_instance_id, # type: ignore[arg-type]
|
||||
agent_type_id=model.agent_type_id, # type: ignore[arg-type]
|
||||
session_id=model.session_id, # type: ignore[arg-type]
|
||||
task_type=model.task_type, # type: ignore[arg-type]
|
||||
task_description=model.task_description, # type: ignore[arg-type]
|
||||
actions=model.actions or [], # type: ignore[arg-type]
|
||||
context_summary=model.context_summary, # type: ignore[arg-type]
|
||||
outcome=Outcome(model.outcome.value),
|
||||
outcome_details=model.outcome_details or "", # type: ignore[arg-type]
|
||||
duration_seconds=model.duration_seconds, # type: ignore[arg-type]
|
||||
tokens_used=model.tokens_used, # type: ignore[arg-type]
|
||||
lessons_learned=model.lessons_learned or [], # type: ignore[arg-type]
|
||||
importance_score=model.importance_score, # type: ignore[arg-type]
|
||||
embedding=None, # Don't expose raw embedding
|
||||
occurred_at=model.occurred_at, # type: ignore[arg-type]
|
||||
created_at=model.created_at, # type: ignore[arg-type]
|
||||
updated_at=model.updated_at, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
|
||||
class BaseRetriever(ABC):
|
||||
"""Abstract base class for episode retrieval strategies."""
|
||||
|
||||
@abstractmethod
|
||||
async def retrieve(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
project_id: UUID,
|
||||
limit: int = 10,
|
||||
**kwargs: Any,
|
||||
) -> RetrievalResult[Episode]:
|
||||
"""Retrieve episodes based on the strategy."""
|
||||
...
|
||||
|
||||
|
||||
class RecencyRetriever(BaseRetriever):
|
||||
"""Retrieves episodes by recency (most recent first)."""
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
project_id: UUID,
|
||||
limit: int = 10,
|
||||
*,
|
||||
since: datetime | None = None,
|
||||
agent_instance_id: UUID | None = None,
|
||||
**kwargs: Any,
|
||||
) -> RetrievalResult[Episode]:
|
||||
"""Retrieve most recent episodes."""
|
||||
start_time = time.perf_counter()
|
||||
|
||||
query = (
|
||||
select(EpisodeModel)
|
||||
.where(EpisodeModel.project_id == project_id)
|
||||
.order_by(desc(EpisodeModel.occurred_at))
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
if since is not None:
|
||||
query = query.where(EpisodeModel.occurred_at >= since)
|
||||
|
||||
if agent_instance_id is not None:
|
||||
query = query.where(EpisodeModel.agent_instance_id == agent_instance_id)
|
||||
|
||||
result = await session.execute(query)
|
||||
models = list(result.scalars().all())
|
||||
|
||||
# Get total count
|
||||
count_query = select(EpisodeModel).where(EpisodeModel.project_id == project_id)
|
||||
if since is not None:
|
||||
count_query = count_query.where(EpisodeModel.occurred_at >= since)
|
||||
count_result = await session.execute(count_query)
|
||||
total_count = len(list(count_result.scalars().all()))
|
||||
|
||||
latency_ms = (time.perf_counter() - start_time) * 1000
|
||||
|
||||
return RetrievalResult(
|
||||
items=[_model_to_episode(m) for m in models],
|
||||
total_count=total_count,
|
||||
query="recency",
|
||||
retrieval_type=RetrievalStrategy.RECENCY.value,
|
||||
latency_ms=latency_ms,
|
||||
metadata={"since": since.isoformat() if since else None},
|
||||
)
|
||||
|
||||
|
||||
class OutcomeRetriever(BaseRetriever):
|
||||
"""Retrieves episodes filtered by outcome."""
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
project_id: UUID,
|
||||
limit: int = 10,
|
||||
*,
|
||||
outcome: Outcome | None = None,
|
||||
agent_instance_id: UUID | None = None,
|
||||
**kwargs: Any,
|
||||
) -> RetrievalResult[Episode]:
|
||||
"""Retrieve episodes by outcome."""
|
||||
start_time = time.perf_counter()
|
||||
|
||||
query = (
|
||||
select(EpisodeModel)
|
||||
.where(EpisodeModel.project_id == project_id)
|
||||
.order_by(desc(EpisodeModel.occurred_at))
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
if outcome is not None:
|
||||
db_outcome = EpisodeOutcome(outcome.value)
|
||||
query = query.where(EpisodeModel.outcome == db_outcome)
|
||||
|
||||
if agent_instance_id is not None:
|
||||
query = query.where(EpisodeModel.agent_instance_id == agent_instance_id)
|
||||
|
||||
result = await session.execute(query)
|
||||
models = list(result.scalars().all())
|
||||
|
||||
# Get total count
|
||||
count_query = select(EpisodeModel).where(EpisodeModel.project_id == project_id)
|
||||
if outcome is not None:
|
||||
count_query = count_query.where(
|
||||
EpisodeModel.outcome == EpisodeOutcome(outcome.value)
|
||||
)
|
||||
count_result = await session.execute(count_query)
|
||||
total_count = len(list(count_result.scalars().all()))
|
||||
|
||||
latency_ms = (time.perf_counter() - start_time) * 1000
|
||||
|
||||
return RetrievalResult(
|
||||
items=[_model_to_episode(m) for m in models],
|
||||
total_count=total_count,
|
||||
query=f"outcome:{outcome.value if outcome else 'all'}",
|
||||
retrieval_type=RetrievalStrategy.OUTCOME.value,
|
||||
latency_ms=latency_ms,
|
||||
metadata={"outcome": outcome.value if outcome else None},
|
||||
)
|
||||
|
||||
|
||||
class TaskTypeRetriever(BaseRetriever):
|
||||
"""Retrieves episodes filtered by task type."""
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
project_id: UUID,
|
||||
limit: int = 10,
|
||||
*,
|
||||
task_type: str | None = None,
|
||||
agent_instance_id: UUID | None = None,
|
||||
**kwargs: Any,
|
||||
) -> RetrievalResult[Episode]:
|
||||
"""Retrieve episodes by task type."""
|
||||
start_time = time.perf_counter()
|
||||
|
||||
query = (
|
||||
select(EpisodeModel)
|
||||
.where(EpisodeModel.project_id == project_id)
|
||||
.order_by(desc(EpisodeModel.occurred_at))
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
if task_type is not None:
|
||||
query = query.where(EpisodeModel.task_type == task_type)
|
||||
|
||||
if agent_instance_id is not None:
|
||||
query = query.where(EpisodeModel.agent_instance_id == agent_instance_id)
|
||||
|
||||
result = await session.execute(query)
|
||||
models = list(result.scalars().all())
|
||||
|
||||
# Get total count
|
||||
count_query = select(EpisodeModel).where(EpisodeModel.project_id == project_id)
|
||||
if task_type is not None:
|
||||
count_query = count_query.where(EpisodeModel.task_type == task_type)
|
||||
count_result = await session.execute(count_query)
|
||||
total_count = len(list(count_result.scalars().all()))
|
||||
|
||||
latency_ms = (time.perf_counter() - start_time) * 1000
|
||||
|
||||
return RetrievalResult(
|
||||
items=[_model_to_episode(m) for m in models],
|
||||
total_count=total_count,
|
||||
query=f"task_type:{task_type or 'all'}",
|
||||
retrieval_type="task_type",
|
||||
latency_ms=latency_ms,
|
||||
metadata={"task_type": task_type},
|
||||
)
|
||||
|
||||
|
||||
class ImportanceRetriever(BaseRetriever):
|
||||
"""Retrieves episodes ranked by importance score."""
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
project_id: UUID,
|
||||
limit: int = 10,
|
||||
*,
|
||||
min_importance: float = 0.0,
|
||||
agent_instance_id: UUID | None = None,
|
||||
**kwargs: Any,
|
||||
) -> RetrievalResult[Episode]:
|
||||
"""Retrieve episodes by importance."""
|
||||
start_time = time.perf_counter()
|
||||
|
||||
query = (
|
||||
select(EpisodeModel)
|
||||
.where(
|
||||
and_(
|
||||
EpisodeModel.project_id == project_id,
|
||||
EpisodeModel.importance_score >= min_importance,
|
||||
)
|
||||
)
|
||||
.order_by(desc(EpisodeModel.importance_score))
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
if agent_instance_id is not None:
|
||||
query = query.where(EpisodeModel.agent_instance_id == agent_instance_id)
|
||||
|
||||
result = await session.execute(query)
|
||||
models = list(result.scalars().all())
|
||||
|
||||
# Get total count
|
||||
count_query = select(EpisodeModel).where(
|
||||
and_(
|
||||
EpisodeModel.project_id == project_id,
|
||||
EpisodeModel.importance_score >= min_importance,
|
||||
)
|
||||
)
|
||||
count_result = await session.execute(count_query)
|
||||
total_count = len(list(count_result.scalars().all()))
|
||||
|
||||
latency_ms = (time.perf_counter() - start_time) * 1000
|
||||
|
||||
return RetrievalResult(
|
||||
items=[_model_to_episode(m) for m in models],
|
||||
total_count=total_count,
|
||||
query=f"importance>={min_importance}",
|
||||
retrieval_type=RetrievalStrategy.IMPORTANCE.value,
|
||||
latency_ms=latency_ms,
|
||||
metadata={"min_importance": min_importance},
|
||||
)
|
||||
|
||||
|
||||
class SemanticRetriever(BaseRetriever):
|
||||
"""Retrieves episodes by semantic similarity using vector search."""
|
||||
|
||||
def __init__(self, embedding_generator: Any | None = None) -> None:
|
||||
"""Initialize with optional embedding generator."""
|
||||
self._embedding_generator = embedding_generator
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
project_id: UUID,
|
||||
limit: int = 10,
|
||||
*,
|
||||
query_text: str | None = None,
|
||||
query_embedding: list[float] | None = None,
|
||||
agent_instance_id: UUID | None = None,
|
||||
**kwargs: Any,
|
||||
) -> RetrievalResult[Episode]:
|
||||
"""Retrieve episodes by semantic similarity."""
|
||||
start_time = time.perf_counter()
|
||||
|
||||
# If no embedding provided, fall back to recency
|
||||
if query_embedding is None and query_text is None:
|
||||
logger.warning(
|
||||
"No query provided for semantic search, falling back to recency"
|
||||
)
|
||||
recency = RecencyRetriever()
|
||||
fallback_result = await recency.retrieve(
|
||||
session, project_id, limit, agent_instance_id=agent_instance_id
|
||||
)
|
||||
latency_ms = (time.perf_counter() - start_time) * 1000
|
||||
return RetrievalResult(
|
||||
items=fallback_result.items,
|
||||
total_count=fallback_result.total_count,
|
||||
query="no_query",
|
||||
retrieval_type=RetrievalStrategy.SEMANTIC.value,
|
||||
latency_ms=latency_ms,
|
||||
metadata={"fallback": "recency", "reason": "no_query"},
|
||||
)
|
||||
|
||||
# Generate embedding if needed
|
||||
embedding = query_embedding
|
||||
if embedding is None and query_text is not None:
|
||||
if self._embedding_generator is not None:
|
||||
embedding = await self._embedding_generator.generate(query_text)
|
||||
else:
|
||||
logger.warning("No embedding generator, falling back to recency")
|
||||
recency = RecencyRetriever()
|
||||
fallback_result = await recency.retrieve(
|
||||
session, project_id, limit, agent_instance_id=agent_instance_id
|
||||
)
|
||||
latency_ms = (time.perf_counter() - start_time) * 1000
|
||||
return RetrievalResult(
|
||||
items=fallback_result.items,
|
||||
total_count=fallback_result.total_count,
|
||||
query=query_text,
|
||||
retrieval_type=RetrievalStrategy.SEMANTIC.value,
|
||||
latency_ms=latency_ms,
|
||||
metadata={
|
||||
"fallback": "recency",
|
||||
"reason": "no_embedding_generator",
|
||||
},
|
||||
)
|
||||
|
||||
# For now, use recency if vector search not available
|
||||
# TODO: Implement proper pgvector similarity search when integrated
|
||||
logger.debug("Vector search not yet implemented, using recency fallback")
|
||||
recency = RecencyRetriever()
|
||||
result = await recency.retrieve(
|
||||
session, project_id, limit, agent_instance_id=agent_instance_id
|
||||
)
|
||||
|
||||
latency_ms = (time.perf_counter() - start_time) * 1000
|
||||
|
||||
return RetrievalResult(
|
||||
items=result.items,
|
||||
total_count=result.total_count,
|
||||
query=query_text or "embedding",
|
||||
retrieval_type=RetrievalStrategy.SEMANTIC.value,
|
||||
latency_ms=latency_ms,
|
||||
metadata={"fallback": "recency"},
|
||||
)
|
||||
|
||||
|
||||
class EpisodeRetriever:
|
||||
"""
|
||||
Unified episode retrieval service.
|
||||
|
||||
Provides a single interface for all retrieval strategies.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
embedding_generator: Any | None = None,
|
||||
) -> None:
|
||||
"""Initialize retriever with database session."""
|
||||
self._session = session
|
||||
self._retrievers: dict[RetrievalStrategy, BaseRetriever] = {
|
||||
RetrievalStrategy.RECENCY: RecencyRetriever(),
|
||||
RetrievalStrategy.OUTCOME: OutcomeRetriever(),
|
||||
RetrievalStrategy.IMPORTANCE: ImportanceRetriever(),
|
||||
RetrievalStrategy.SEMANTIC: SemanticRetriever(embedding_generator),
|
||||
}
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
project_id: UUID,
|
||||
strategy: RetrievalStrategy = RetrievalStrategy.RECENCY,
|
||||
limit: int = 10,
|
||||
**kwargs: Any,
|
||||
) -> RetrievalResult[Episode]:
|
||||
"""
|
||||
Retrieve episodes using the specified strategy.
|
||||
|
||||
Args:
|
||||
project_id: Project to search within
|
||||
strategy: Retrieval strategy to use
|
||||
limit: Maximum number of episodes to return
|
||||
**kwargs: Strategy-specific parameters
|
||||
|
||||
Returns:
|
||||
RetrievalResult containing matching episodes
|
||||
"""
|
||||
retriever = self._retrievers.get(strategy)
|
||||
if retriever is None:
|
||||
raise ValueError(f"Unknown retrieval strategy: {strategy}")
|
||||
|
||||
return await retriever.retrieve(self._session, project_id, limit, **kwargs)
|
||||
|
||||
async def get_recent(
|
||||
self,
|
||||
project_id: UUID,
|
||||
limit: int = 10,
|
||||
since: datetime | None = None,
|
||||
agent_instance_id: UUID | None = None,
|
||||
) -> RetrievalResult[Episode]:
|
||||
"""Get recent episodes."""
|
||||
return await self.retrieve(
|
||||
project_id,
|
||||
RetrievalStrategy.RECENCY,
|
||||
limit,
|
||||
since=since,
|
||||
agent_instance_id=agent_instance_id,
|
||||
)
|
||||
|
||||
async def get_by_outcome(
|
||||
self,
|
||||
project_id: UUID,
|
||||
outcome: Outcome,
|
||||
limit: int = 10,
|
||||
agent_instance_id: UUID | None = None,
|
||||
) -> RetrievalResult[Episode]:
|
||||
"""Get episodes by outcome."""
|
||||
return await self.retrieve(
|
||||
project_id,
|
||||
RetrievalStrategy.OUTCOME,
|
||||
limit,
|
||||
outcome=outcome,
|
||||
agent_instance_id=agent_instance_id,
|
||||
)
|
||||
|
||||
async def get_by_task_type(
|
||||
self,
|
||||
project_id: UUID,
|
||||
task_type: str,
|
||||
limit: int = 10,
|
||||
agent_instance_id: UUID | None = None,
|
||||
) -> RetrievalResult[Episode]:
|
||||
"""Get episodes by task type."""
|
||||
retriever = TaskTypeRetriever()
|
||||
return await retriever.retrieve(
|
||||
self._session,
|
||||
project_id,
|
||||
limit,
|
||||
task_type=task_type,
|
||||
agent_instance_id=agent_instance_id,
|
||||
)
|
||||
|
||||
async def get_important(
|
||||
self,
|
||||
project_id: UUID,
|
||||
limit: int = 10,
|
||||
min_importance: float = 0.7,
|
||||
agent_instance_id: UUID | None = None,
|
||||
) -> RetrievalResult[Episode]:
|
||||
"""Get high-importance episodes."""
|
||||
return await self.retrieve(
|
||||
project_id,
|
||||
RetrievalStrategy.IMPORTANCE,
|
||||
limit,
|
||||
min_importance=min_importance,
|
||||
agent_instance_id=agent_instance_id,
|
||||
)
|
||||
|
||||
async def search_similar(
|
||||
self,
|
||||
project_id: UUID,
|
||||
query: str,
|
||||
limit: int = 10,
|
||||
agent_instance_id: UUID | None = None,
|
||||
) -> RetrievalResult[Episode]:
|
||||
"""Search for semantically similar episodes."""
|
||||
return await self.retrieve(
|
||||
project_id,
|
||||
RetrievalStrategy.SEMANTIC,
|
||||
limit,
|
||||
query_text=query,
|
||||
agent_instance_id=agent_instance_id,
|
||||
)
|
||||
222
backend/app/services/memory/exceptions.py
Normal file
222
backend/app/services/memory/exceptions.py
Normal file
@@ -0,0 +1,222 @@
|
||||
"""
|
||||
Memory System Exceptions
|
||||
|
||||
Custom exception classes for the Agent Memory System.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
|
||||
class MemoryError(Exception):
|
||||
"""Base exception for all memory-related errors."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
*,
|
||||
memory_type: str | None = None,
|
||||
scope_type: str | None = None,
|
||||
scope_id: str | None = None,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.memory_type = memory_type
|
||||
self.scope_type = scope_type
|
||||
self.scope_id = scope_id
|
||||
self.details = details or {}
|
||||
|
||||
|
||||
class MemoryNotFoundError(MemoryError):
|
||||
"""Raised when a memory item is not found."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Memory not found",
|
||||
*,
|
||||
memory_id: UUID | str | None = None,
|
||||
key: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.memory_id = memory_id
|
||||
self.key = key
|
||||
|
||||
|
||||
class MemoryCapacityError(MemoryError):
|
||||
"""Raised when memory capacity limits are exceeded."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Memory capacity exceeded",
|
||||
*,
|
||||
current_size: int = 0,
|
||||
max_size: int = 0,
|
||||
item_count: int = 0,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.current_size = current_size
|
||||
self.max_size = max_size
|
||||
self.item_count = item_count
|
||||
|
||||
|
||||
class MemoryExpiredError(MemoryError):
|
||||
"""Raised when attempting to access expired memory."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Memory has expired",
|
||||
*,
|
||||
key: str | None = None,
|
||||
expired_at: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.key = key
|
||||
self.expired_at = expired_at
|
||||
|
||||
|
||||
class MemoryStorageError(MemoryError):
|
||||
"""Raised when memory storage operations fail."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Memory storage operation failed",
|
||||
*,
|
||||
operation: str | None = None,
|
||||
backend: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.operation = operation
|
||||
self.backend = backend
|
||||
|
||||
|
||||
class MemoryConnectionError(MemoryError):
|
||||
"""Raised when memory storage connection fails."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Memory connection failed",
|
||||
*,
|
||||
backend: str | None = None,
|
||||
host: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.backend = backend
|
||||
self.host = host
|
||||
|
||||
|
||||
class MemorySerializationError(MemoryError):
|
||||
"""Raised when memory serialization/deserialization fails."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Memory serialization failed",
|
||||
*,
|
||||
content_type: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.content_type = content_type
|
||||
|
||||
|
||||
class MemoryScopeError(MemoryError):
|
||||
"""Raised when memory scope operations fail."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Memory scope error",
|
||||
*,
|
||||
requested_scope: str | None = None,
|
||||
allowed_scopes: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.requested_scope = requested_scope
|
||||
self.allowed_scopes = allowed_scopes or []
|
||||
|
||||
|
||||
class MemoryConsolidationError(MemoryError):
|
||||
"""Raised when memory consolidation fails."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Memory consolidation failed",
|
||||
*,
|
||||
source_type: str | None = None,
|
||||
target_type: str | None = None,
|
||||
items_processed: int = 0,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.source_type = source_type
|
||||
self.target_type = target_type
|
||||
self.items_processed = items_processed
|
||||
|
||||
|
||||
class MemoryRetrievalError(MemoryError):
|
||||
"""Raised when memory retrieval fails."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Memory retrieval failed",
|
||||
*,
|
||||
query: str | None = None,
|
||||
retrieval_type: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.query = query
|
||||
self.retrieval_type = retrieval_type
|
||||
|
||||
|
||||
class EmbeddingError(MemoryError):
|
||||
"""Raised when embedding generation fails."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Embedding generation failed",
|
||||
*,
|
||||
content_length: int = 0,
|
||||
model: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.content_length = content_length
|
||||
self.model = model
|
||||
|
||||
|
||||
class CheckpointError(MemoryError):
|
||||
"""Raised when checkpoint operations fail."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Checkpoint operation failed",
|
||||
*,
|
||||
checkpoint_id: str | None = None,
|
||||
operation: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.checkpoint_id = checkpoint_id
|
||||
self.operation = operation
|
||||
|
||||
|
||||
class MemoryConflictError(MemoryError):
|
||||
"""Raised when there's a conflict in memory (e.g., contradictory facts)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Memory conflict detected",
|
||||
*,
|
||||
conflicting_ids: list[str | UUID] | None = None,
|
||||
conflict_type: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.conflicting_ids = conflicting_ids or []
|
||||
self.conflict_type = conflict_type
|
||||
56
backend/app/services/memory/indexing/__init__.py
Normal file
56
backend/app/services/memory/indexing/__init__.py
Normal file
@@ -0,0 +1,56 @@
|
||||
# app/services/memory/indexing/__init__.py
|
||||
"""
|
||||
Memory Indexing & Retrieval.
|
||||
|
||||
Provides vector embeddings and multiple index types for efficient memory search:
|
||||
- Vector index for semantic similarity
|
||||
- Temporal index for time-based queries
|
||||
- Entity index for entity lookups
|
||||
- Outcome index for success/failure filtering
|
||||
"""
|
||||
|
||||
from .index import (
|
||||
EntityIndex,
|
||||
EntityIndexEntry,
|
||||
IndexEntry,
|
||||
MemoryIndex,
|
||||
MemoryIndexer,
|
||||
OutcomeIndex,
|
||||
OutcomeIndexEntry,
|
||||
TemporalIndex,
|
||||
TemporalIndexEntry,
|
||||
VectorIndex,
|
||||
VectorIndexEntry,
|
||||
get_memory_indexer,
|
||||
)
|
||||
from .retrieval import (
|
||||
CacheEntry,
|
||||
RelevanceScorer,
|
||||
RetrievalCache,
|
||||
RetrievalEngine,
|
||||
RetrievalQuery,
|
||||
ScoredResult,
|
||||
get_retrieval_engine,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"CacheEntry",
|
||||
"EntityIndex",
|
||||
"EntityIndexEntry",
|
||||
"IndexEntry",
|
||||
"MemoryIndex",
|
||||
"MemoryIndexer",
|
||||
"OutcomeIndex",
|
||||
"OutcomeIndexEntry",
|
||||
"RelevanceScorer",
|
||||
"RetrievalCache",
|
||||
"RetrievalEngine",
|
||||
"RetrievalQuery",
|
||||
"ScoredResult",
|
||||
"TemporalIndex",
|
||||
"TemporalIndexEntry",
|
||||
"VectorIndex",
|
||||
"VectorIndexEntry",
|
||||
"get_memory_indexer",
|
||||
"get_retrieval_engine",
|
||||
]
|
||||
858
backend/app/services/memory/indexing/index.py
Normal file
858
backend/app/services/memory/indexing/index.py
Normal file
@@ -0,0 +1,858 @@
|
||||
# app/services/memory/indexing/index.py
|
||||
"""
|
||||
Memory Indexing.
|
||||
|
||||
Provides multiple indexing strategies for efficient memory retrieval:
|
||||
- Vector embeddings for semantic search
|
||||
- Temporal index for time-based queries
|
||||
- Entity index for entity-based lookups
|
||||
- Outcome index for success/failure filtering
|
||||
"""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any, TypeVar
|
||||
from uuid import UUID
|
||||
|
||||
from app.services.memory.types import Episode, Fact, MemoryType, Outcome, Procedure
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T", Episode, Fact, Procedure)
|
||||
|
||||
|
||||
def _utcnow() -> datetime:
|
||||
"""Get current UTC time as timezone-aware datetime."""
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexEntry:
|
||||
"""A single entry in an index."""
|
||||
|
||||
memory_id: UUID
|
||||
memory_type: MemoryType
|
||||
indexed_at: datetime = field(default_factory=_utcnow)
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VectorIndexEntry(IndexEntry):
|
||||
"""An entry with vector embedding."""
|
||||
|
||||
embedding: list[float] = field(default_factory=list)
|
||||
dimension: int = 0
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Set dimension from embedding."""
|
||||
if self.embedding:
|
||||
self.dimension = len(self.embedding)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TemporalIndexEntry(IndexEntry):
|
||||
"""An entry indexed by time."""
|
||||
|
||||
timestamp: datetime = field(default_factory=_utcnow)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EntityIndexEntry(IndexEntry):
|
||||
"""An entry indexed by entity."""
|
||||
|
||||
entity_type: str = ""
|
||||
entity_value: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutcomeIndexEntry(IndexEntry):
|
||||
"""An entry indexed by outcome."""
|
||||
|
||||
outcome: Outcome = Outcome.SUCCESS
|
||||
|
||||
|
||||
class MemoryIndex[T](ABC):
|
||||
"""Abstract base class for memory indices."""
|
||||
|
||||
@abstractmethod
|
||||
async def add(self, item: T) -> IndexEntry:
|
||||
"""Add an item to the index."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def remove(self, memory_id: UUID) -> bool:
|
||||
"""Remove an item from the index."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def search(
|
||||
self,
|
||||
query: Any,
|
||||
limit: int = 10,
|
||||
**kwargs: Any,
|
||||
) -> list[IndexEntry]:
|
||||
"""Search the index."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def clear(self) -> int:
|
||||
"""Clear all entries from the index."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def count(self) -> int:
|
||||
"""Get the number of entries in the index."""
|
||||
...
|
||||
|
||||
|
||||
class VectorIndex(MemoryIndex[T]):
|
||||
"""
|
||||
Vector-based index using embeddings for semantic similarity search.
|
||||
|
||||
Uses cosine similarity for matching.
|
||||
"""
|
||||
|
||||
def __init__(self, dimension: int = 1536) -> None:
|
||||
"""
|
||||
Initialize the vector index.
|
||||
|
||||
Args:
|
||||
dimension: Embedding dimension (default 1536 for OpenAI)
|
||||
"""
|
||||
self._dimension = dimension
|
||||
self._entries: dict[UUID, VectorIndexEntry] = {}
|
||||
logger.info(f"Initialized VectorIndex with dimension={dimension}")
|
||||
|
||||
async def add(self, item: T) -> VectorIndexEntry:
|
||||
"""
|
||||
Add an item to the vector index.
|
||||
|
||||
Args:
|
||||
item: Memory item with embedding
|
||||
|
||||
Returns:
|
||||
The created index entry
|
||||
"""
|
||||
embedding = getattr(item, "embedding", None) or []
|
||||
|
||||
entry = VectorIndexEntry(
|
||||
memory_id=item.id,
|
||||
memory_type=self._get_memory_type(item),
|
||||
embedding=embedding,
|
||||
dimension=len(embedding),
|
||||
)
|
||||
|
||||
self._entries[item.id] = entry
|
||||
logger.debug(f"Added {item.id} to vector index")
|
||||
return entry
|
||||
|
||||
async def remove(self, memory_id: UUID) -> bool:
|
||||
"""Remove an item from the vector index."""
|
||||
if memory_id in self._entries:
|
||||
del self._entries[memory_id]
|
||||
logger.debug(f"Removed {memory_id} from vector index")
|
||||
return True
|
||||
return False
|
||||
|
||||
async def search( # type: ignore[override]
|
||||
self,
|
||||
query: Any,
|
||||
limit: int = 10,
|
||||
min_similarity: float = 0.0,
|
||||
**kwargs: Any,
|
||||
) -> list[VectorIndexEntry]:
|
||||
"""
|
||||
Search for similar items using vector similarity.
|
||||
|
||||
Args:
|
||||
query: Query embedding vector
|
||||
limit: Maximum results to return
|
||||
min_similarity: Minimum similarity threshold (0-1)
|
||||
**kwargs: Additional filter parameters
|
||||
|
||||
Returns:
|
||||
List of matching entries sorted by similarity
|
||||
"""
|
||||
if not isinstance(query, list) or not query:
|
||||
return []
|
||||
|
||||
results: list[tuple[float, VectorIndexEntry]] = []
|
||||
|
||||
for entry in self._entries.values():
|
||||
if not entry.embedding:
|
||||
continue
|
||||
|
||||
similarity = self._cosine_similarity(query, entry.embedding)
|
||||
if similarity >= min_similarity:
|
||||
results.append((similarity, entry))
|
||||
|
||||
# Sort by similarity descending
|
||||
results.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
# Apply memory type filter if provided
|
||||
memory_type = kwargs.get("memory_type")
|
||||
if memory_type:
|
||||
results = [(s, e) for s, e in results if e.memory_type == memory_type]
|
||||
|
||||
# Store similarity in metadata for the returned entries
|
||||
# Use a copy of metadata to avoid mutating cached entries
|
||||
output = []
|
||||
for similarity, entry in results[:limit]:
|
||||
# Create a shallow copy of the entry with updated metadata
|
||||
entry_with_score = VectorIndexEntry(
|
||||
memory_id=entry.memory_id,
|
||||
memory_type=entry.memory_type,
|
||||
embedding=entry.embedding,
|
||||
metadata={**entry.metadata, "similarity": similarity},
|
||||
)
|
||||
output.append(entry_with_score)
|
||||
|
||||
logger.debug(f"Vector search returned {len(output)} results")
|
||||
return output
|
||||
|
||||
async def clear(self) -> int:
|
||||
"""Clear all entries from the index."""
|
||||
count = len(self._entries)
|
||||
self._entries.clear()
|
||||
logger.info(f"Cleared {count} entries from vector index")
|
||||
return count
|
||||
|
||||
async def count(self) -> int:
|
||||
"""Get the number of entries in the index."""
|
||||
return len(self._entries)
|
||||
|
||||
def _cosine_similarity(self, a: list[float], b: list[float]) -> float:
|
||||
"""Calculate cosine similarity between two vectors."""
|
||||
if len(a) != len(b) or len(a) == 0:
|
||||
return 0.0
|
||||
|
||||
dot_product = sum(x * y for x, y in zip(a, b, strict=True))
|
||||
norm_a = sum(x * x for x in a) ** 0.5
|
||||
norm_b = sum(x * x for x in b) ** 0.5
|
||||
|
||||
if norm_a == 0 or norm_b == 0:
|
||||
return 0.0
|
||||
|
||||
return dot_product / (norm_a * norm_b)
|
||||
|
||||
def _get_memory_type(self, item: T) -> MemoryType:
|
||||
"""Get the memory type for an item."""
|
||||
if isinstance(item, Episode):
|
||||
return MemoryType.EPISODIC
|
||||
elif isinstance(item, Fact):
|
||||
return MemoryType.SEMANTIC
|
||||
elif isinstance(item, Procedure):
|
||||
return MemoryType.PROCEDURAL
|
||||
return MemoryType.WORKING
|
||||
|
||||
|
||||
class TemporalIndex(MemoryIndex[T]):
|
||||
"""
|
||||
Time-based index for efficient temporal queries.
|
||||
|
||||
Supports:
|
||||
- Range queries (between timestamps)
|
||||
- Recent items (within last N seconds/hours/days)
|
||||
- Oldest/newest sorting
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the temporal index."""
|
||||
self._entries: dict[UUID, TemporalIndexEntry] = {}
|
||||
# Sorted list for efficient range queries
|
||||
self._sorted_entries: list[tuple[datetime, UUID]] = []
|
||||
logger.info("Initialized TemporalIndex")
|
||||
|
||||
async def add(self, item: T) -> TemporalIndexEntry:
|
||||
"""
|
||||
Add an item to the temporal index.
|
||||
|
||||
Args:
|
||||
item: Memory item with timestamp
|
||||
|
||||
Returns:
|
||||
The created index entry
|
||||
"""
|
||||
# Get timestamp from various possible fields
|
||||
timestamp = self._get_timestamp(item)
|
||||
|
||||
entry = TemporalIndexEntry(
|
||||
memory_id=item.id,
|
||||
memory_type=self._get_memory_type(item),
|
||||
timestamp=timestamp,
|
||||
)
|
||||
|
||||
self._entries[item.id] = entry
|
||||
self._insert_sorted(timestamp, item.id)
|
||||
|
||||
logger.debug(f"Added {item.id} to temporal index at {timestamp}")
|
||||
return entry
|
||||
|
||||
async def remove(self, memory_id: UUID) -> bool:
|
||||
"""Remove an item from the temporal index."""
|
||||
if memory_id not in self._entries:
|
||||
return False
|
||||
|
||||
self._entries.pop(memory_id)
|
||||
self._sorted_entries = [
|
||||
(ts, mid) for ts, mid in self._sorted_entries if mid != memory_id
|
||||
]
|
||||
|
||||
logger.debug(f"Removed {memory_id} from temporal index")
|
||||
return True
|
||||
|
||||
async def search( # type: ignore[override]
|
||||
self,
|
||||
query: Any,
|
||||
limit: int = 10,
|
||||
start_time: datetime | None = None,
|
||||
end_time: datetime | None = None,
|
||||
recent_seconds: float | None = None,
|
||||
order: str = "desc",
|
||||
**kwargs: Any,
|
||||
) -> list[TemporalIndexEntry]:
|
||||
"""
|
||||
Search for items by time.
|
||||
|
||||
Args:
|
||||
query: Ignored for temporal search
|
||||
limit: Maximum results to return
|
||||
start_time: Start of time range
|
||||
end_time: End of time range
|
||||
recent_seconds: Get items from last N seconds
|
||||
order: Sort order ("asc" or "desc")
|
||||
**kwargs: Additional filter parameters
|
||||
|
||||
Returns:
|
||||
List of matching entries sorted by time
|
||||
"""
|
||||
if recent_seconds is not None:
|
||||
start_time = _utcnow() - timedelta(seconds=recent_seconds)
|
||||
end_time = _utcnow()
|
||||
|
||||
# Filter by time range
|
||||
results: list[TemporalIndexEntry] = []
|
||||
for entry in self._entries.values():
|
||||
if start_time and entry.timestamp < start_time:
|
||||
continue
|
||||
if end_time and entry.timestamp > end_time:
|
||||
continue
|
||||
results.append(entry)
|
||||
|
||||
# Apply memory type filter if provided
|
||||
memory_type = kwargs.get("memory_type")
|
||||
if memory_type:
|
||||
results = [e for e in results if e.memory_type == memory_type]
|
||||
|
||||
# Sort by timestamp
|
||||
results.sort(key=lambda e: e.timestamp, reverse=(order == "desc"))
|
||||
|
||||
logger.debug(f"Temporal search returned {min(len(results), limit)} results")
|
||||
return results[:limit]
|
||||
|
||||
async def clear(self) -> int:
|
||||
"""Clear all entries from the index."""
|
||||
count = len(self._entries)
|
||||
self._entries.clear()
|
||||
self._sorted_entries.clear()
|
||||
logger.info(f"Cleared {count} entries from temporal index")
|
||||
return count
|
||||
|
||||
async def count(self) -> int:
|
||||
"""Get the number of entries in the index."""
|
||||
return len(self._entries)
|
||||
|
||||
def _insert_sorted(self, timestamp: datetime, memory_id: UUID) -> None:
|
||||
"""Insert entry maintaining sorted order."""
|
||||
# Binary search insert for efficiency
|
||||
low, high = 0, len(self._sorted_entries)
|
||||
while low < high:
|
||||
mid = (low + high) // 2
|
||||
if self._sorted_entries[mid][0] < timestamp:
|
||||
low = mid + 1
|
||||
else:
|
||||
high = mid
|
||||
self._sorted_entries.insert(low, (timestamp, memory_id))
|
||||
|
||||
def _get_timestamp(self, item: T) -> datetime:
|
||||
"""Get the relevant timestamp for an item."""
|
||||
if hasattr(item, "occurred_at"):
|
||||
return item.occurred_at
|
||||
if hasattr(item, "first_learned"):
|
||||
return item.first_learned
|
||||
if hasattr(item, "last_used") and item.last_used:
|
||||
return item.last_used
|
||||
if hasattr(item, "created_at"):
|
||||
return item.created_at
|
||||
return _utcnow()
|
||||
|
||||
def _get_memory_type(self, item: T) -> MemoryType:
|
||||
"""Get the memory type for an item."""
|
||||
if isinstance(item, Episode):
|
||||
return MemoryType.EPISODIC
|
||||
elif isinstance(item, Fact):
|
||||
return MemoryType.SEMANTIC
|
||||
elif isinstance(item, Procedure):
|
||||
return MemoryType.PROCEDURAL
|
||||
return MemoryType.WORKING
|
||||
|
||||
|
||||
class EntityIndex(MemoryIndex[T]):
|
||||
"""
|
||||
Entity-based index for lookups by entities mentioned in memories.
|
||||
|
||||
Supports:
|
||||
- Single entity lookup
|
||||
- Multi-entity intersection
|
||||
- Entity type filtering
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the entity index."""
|
||||
# Main storage
|
||||
self._entries: dict[UUID, EntityIndexEntry] = {}
|
||||
# Inverted index: entity -> set of memory IDs
|
||||
self._entity_to_memories: dict[str, set[UUID]] = {}
|
||||
# Memory to entities mapping
|
||||
self._memory_to_entities: dict[UUID, set[str]] = {}
|
||||
logger.info("Initialized EntityIndex")
|
||||
|
||||
async def add(self, item: T) -> EntityIndexEntry:
|
||||
"""
|
||||
Add an item to the entity index.
|
||||
|
||||
Args:
|
||||
item: Memory item with entity information
|
||||
|
||||
Returns:
|
||||
The created index entry
|
||||
"""
|
||||
entities = self._extract_entities(item)
|
||||
|
||||
# Create entry for the primary entity (or first one)
|
||||
primary_entity = entities[0] if entities else ("unknown", "unknown")
|
||||
|
||||
entry = EntityIndexEntry(
|
||||
memory_id=item.id,
|
||||
memory_type=self._get_memory_type(item),
|
||||
entity_type=primary_entity[0],
|
||||
entity_value=primary_entity[1],
|
||||
)
|
||||
|
||||
self._entries[item.id] = entry
|
||||
|
||||
# Update inverted indices
|
||||
entity_keys = {f"{etype}:{evalue}" for etype, evalue in entities}
|
||||
self._memory_to_entities[item.id] = entity_keys
|
||||
|
||||
for entity_key in entity_keys:
|
||||
if entity_key not in self._entity_to_memories:
|
||||
self._entity_to_memories[entity_key] = set()
|
||||
self._entity_to_memories[entity_key].add(item.id)
|
||||
|
||||
logger.debug(f"Added {item.id} to entity index with {len(entities)} entities")
|
||||
return entry
|
||||
|
||||
async def remove(self, memory_id: UUID) -> bool:
|
||||
"""Remove an item from the entity index."""
|
||||
if memory_id not in self._entries:
|
||||
return False
|
||||
|
||||
# Remove from inverted index
|
||||
if memory_id in self._memory_to_entities:
|
||||
for entity_key in self._memory_to_entities[memory_id]:
|
||||
if entity_key in self._entity_to_memories:
|
||||
self._entity_to_memories[entity_key].discard(memory_id)
|
||||
if not self._entity_to_memories[entity_key]:
|
||||
del self._entity_to_memories[entity_key]
|
||||
del self._memory_to_entities[memory_id]
|
||||
|
||||
del self._entries[memory_id]
|
||||
logger.debug(f"Removed {memory_id} from entity index")
|
||||
return True
|
||||
|
||||
async def search( # type: ignore[override]
|
||||
self,
|
||||
query: Any,
|
||||
limit: int = 10,
|
||||
entity_type: str | None = None,
|
||||
entity_value: str | None = None,
|
||||
entities: list[tuple[str, str]] | None = None,
|
||||
match_all: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> list[EntityIndexEntry]:
|
||||
"""
|
||||
Search for items by entity.
|
||||
|
||||
Args:
|
||||
query: Entity value to search (if entity_type not specified)
|
||||
limit: Maximum results to return
|
||||
entity_type: Type of entity to filter
|
||||
entity_value: Specific entity value
|
||||
entities: List of (type, value) tuples to match
|
||||
match_all: If True, require all entities to match
|
||||
**kwargs: Additional filter parameters
|
||||
|
||||
Returns:
|
||||
List of matching entries
|
||||
"""
|
||||
matching_ids: set[UUID] | None = None
|
||||
|
||||
# Handle single entity query
|
||||
if entity_type and entity_value:
|
||||
entities = [(entity_type, entity_value)]
|
||||
elif entity_value is None and isinstance(query, str):
|
||||
# Search across all entity types
|
||||
entity_value = query
|
||||
|
||||
if entities:
|
||||
for etype, evalue in entities:
|
||||
entity_key = f"{etype}:{evalue}"
|
||||
if entity_key in self._entity_to_memories:
|
||||
ids = self._entity_to_memories[entity_key]
|
||||
if matching_ids is None:
|
||||
matching_ids = ids.copy()
|
||||
elif match_all:
|
||||
matching_ids &= ids
|
||||
else:
|
||||
matching_ids |= ids
|
||||
elif match_all:
|
||||
# Required entity not found
|
||||
matching_ids = set()
|
||||
break
|
||||
elif entity_value:
|
||||
# Search for value across all types
|
||||
matching_ids = set()
|
||||
for entity_key, ids in self._entity_to_memories.items():
|
||||
if entity_value.lower() in entity_key.lower():
|
||||
matching_ids |= ids
|
||||
|
||||
if matching_ids is None:
|
||||
matching_ids = set(self._entries.keys())
|
||||
|
||||
# Apply memory type filter if provided
|
||||
memory_type = kwargs.get("memory_type")
|
||||
results = []
|
||||
for mid in matching_ids:
|
||||
if mid in self._entries:
|
||||
entry = self._entries[mid]
|
||||
if memory_type and entry.memory_type != memory_type:
|
||||
continue
|
||||
results.append(entry)
|
||||
|
||||
logger.debug(f"Entity search returned {min(len(results), limit)} results")
|
||||
return results[:limit]
|
||||
|
||||
async def clear(self) -> int:
|
||||
"""Clear all entries from the index."""
|
||||
count = len(self._entries)
|
||||
self._entries.clear()
|
||||
self._entity_to_memories.clear()
|
||||
self._memory_to_entities.clear()
|
||||
logger.info(f"Cleared {count} entries from entity index")
|
||||
return count
|
||||
|
||||
async def count(self) -> int:
|
||||
"""Get the number of entries in the index."""
|
||||
return len(self._entries)
|
||||
|
||||
async def get_entities(self, memory_id: UUID) -> list[tuple[str, str]]:
|
||||
"""Get all entities for a memory item."""
|
||||
if memory_id not in self._memory_to_entities:
|
||||
return []
|
||||
|
||||
entities = []
|
||||
for entity_key in self._memory_to_entities[memory_id]:
|
||||
if ":" in entity_key:
|
||||
etype, evalue = entity_key.split(":", 1)
|
||||
entities.append((etype, evalue))
|
||||
return entities
|
||||
|
||||
def _extract_entities(self, item: T) -> list[tuple[str, str]]:
|
||||
"""Extract entities from a memory item."""
|
||||
entities: list[tuple[str, str]] = []
|
||||
|
||||
if isinstance(item, Episode):
|
||||
# Extract from task type and context
|
||||
entities.append(("task_type", item.task_type))
|
||||
if item.project_id:
|
||||
entities.append(("project", str(item.project_id)))
|
||||
if item.agent_instance_id:
|
||||
entities.append(("agent_instance", str(item.agent_instance_id)))
|
||||
if item.agent_type_id:
|
||||
entities.append(("agent_type", str(item.agent_type_id)))
|
||||
|
||||
elif isinstance(item, Fact):
|
||||
# Subject and object are entities
|
||||
entities.append(("subject", item.subject))
|
||||
entities.append(("object", item.object))
|
||||
if item.project_id:
|
||||
entities.append(("project", str(item.project_id)))
|
||||
|
||||
elif isinstance(item, Procedure):
|
||||
entities.append(("procedure", item.name))
|
||||
if item.project_id:
|
||||
entities.append(("project", str(item.project_id)))
|
||||
if item.agent_type_id:
|
||||
entities.append(("agent_type", str(item.agent_type_id)))
|
||||
|
||||
return entities
|
||||
|
||||
def _get_memory_type(self, item: T) -> MemoryType:
|
||||
"""Get the memory type for an item."""
|
||||
if isinstance(item, Episode):
|
||||
return MemoryType.EPISODIC
|
||||
elif isinstance(item, Fact):
|
||||
return MemoryType.SEMANTIC
|
||||
elif isinstance(item, Procedure):
|
||||
return MemoryType.PROCEDURAL
|
||||
return MemoryType.WORKING
|
||||
|
||||
|
||||
class OutcomeIndex(MemoryIndex[T]):
|
||||
"""
|
||||
Outcome-based index for filtering by success/failure.
|
||||
|
||||
Primarily used for episodes and procedures.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the outcome index."""
|
||||
self._entries: dict[UUID, OutcomeIndexEntry] = {}
|
||||
# Inverted index by outcome
|
||||
self._outcome_to_memories: dict[Outcome, set[UUID]] = {
|
||||
Outcome.SUCCESS: set(),
|
||||
Outcome.FAILURE: set(),
|
||||
Outcome.PARTIAL: set(),
|
||||
}
|
||||
logger.info("Initialized OutcomeIndex")
|
||||
|
||||
async def add(self, item: T) -> OutcomeIndexEntry:
|
||||
"""
|
||||
Add an item to the outcome index.
|
||||
|
||||
Args:
|
||||
item: Memory item with outcome information
|
||||
|
||||
Returns:
|
||||
The created index entry
|
||||
"""
|
||||
outcome = self._get_outcome(item)
|
||||
|
||||
entry = OutcomeIndexEntry(
|
||||
memory_id=item.id,
|
||||
memory_type=self._get_memory_type(item),
|
||||
outcome=outcome,
|
||||
)
|
||||
|
||||
self._entries[item.id] = entry
|
||||
self._outcome_to_memories[outcome].add(item.id)
|
||||
|
||||
logger.debug(f"Added {item.id} to outcome index with {outcome.value}")
|
||||
return entry
|
||||
|
||||
async def remove(self, memory_id: UUID) -> bool:
|
||||
"""Remove an item from the outcome index."""
|
||||
if memory_id not in self._entries:
|
||||
return False
|
||||
|
||||
entry = self._entries.pop(memory_id)
|
||||
self._outcome_to_memories[entry.outcome].discard(memory_id)
|
||||
|
||||
logger.debug(f"Removed {memory_id} from outcome index")
|
||||
return True
|
||||
|
||||
async def search( # type: ignore[override]
|
||||
self,
|
||||
query: Any,
|
||||
limit: int = 10,
|
||||
outcome: Outcome | None = None,
|
||||
outcomes: list[Outcome] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> list[OutcomeIndexEntry]:
|
||||
"""
|
||||
Search for items by outcome.
|
||||
|
||||
Args:
|
||||
query: Ignored for outcome search
|
||||
limit: Maximum results to return
|
||||
outcome: Single outcome to filter
|
||||
outcomes: Multiple outcomes to filter (OR)
|
||||
**kwargs: Additional filter parameters
|
||||
|
||||
Returns:
|
||||
List of matching entries
|
||||
"""
|
||||
if outcome:
|
||||
outcomes = [outcome]
|
||||
|
||||
if outcomes:
|
||||
matching_ids: set[UUID] = set()
|
||||
for o in outcomes:
|
||||
matching_ids |= self._outcome_to_memories.get(o, set())
|
||||
else:
|
||||
matching_ids = set(self._entries.keys())
|
||||
|
||||
# Apply memory type filter if provided
|
||||
memory_type = kwargs.get("memory_type")
|
||||
results = []
|
||||
for mid in matching_ids:
|
||||
if mid in self._entries:
|
||||
entry = self._entries[mid]
|
||||
if memory_type and entry.memory_type != memory_type:
|
||||
continue
|
||||
results.append(entry)
|
||||
|
||||
logger.debug(f"Outcome search returned {min(len(results), limit)} results")
|
||||
return results[:limit]
|
||||
|
||||
async def clear(self) -> int:
|
||||
"""Clear all entries from the index."""
|
||||
count = len(self._entries)
|
||||
self._entries.clear()
|
||||
for outcome in self._outcome_to_memories:
|
||||
self._outcome_to_memories[outcome].clear()
|
||||
logger.info(f"Cleared {count} entries from outcome index")
|
||||
return count
|
||||
|
||||
async def count(self) -> int:
|
||||
"""Get the number of entries in the index."""
|
||||
return len(self._entries)
|
||||
|
||||
async def get_outcome_stats(self) -> dict[Outcome, int]:
|
||||
"""Get statistics on outcomes."""
|
||||
return {outcome: len(ids) for outcome, ids in self._outcome_to_memories.items()}
|
||||
|
||||
def _get_outcome(self, item: T) -> Outcome:
|
||||
"""Get the outcome for an item."""
|
||||
if isinstance(item, Episode):
|
||||
return item.outcome
|
||||
elif isinstance(item, Procedure):
|
||||
# Derive from success rate
|
||||
if item.success_rate >= 0.8:
|
||||
return Outcome.SUCCESS
|
||||
elif item.success_rate <= 0.2:
|
||||
return Outcome.FAILURE
|
||||
return Outcome.PARTIAL
|
||||
return Outcome.SUCCESS
|
||||
|
||||
def _get_memory_type(self, item: T) -> MemoryType:
|
||||
"""Get the memory type for an item."""
|
||||
if isinstance(item, Episode):
|
||||
return MemoryType.EPISODIC
|
||||
elif isinstance(item, Fact):
|
||||
return MemoryType.SEMANTIC
|
||||
elif isinstance(item, Procedure):
|
||||
return MemoryType.PROCEDURAL
|
||||
return MemoryType.WORKING
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryIndexer:
|
||||
"""
|
||||
Unified indexer that manages all index types.
|
||||
|
||||
Provides a single interface for indexing and searching across
|
||||
multiple index types.
|
||||
"""
|
||||
|
||||
vector_index: VectorIndex[Any] = field(default_factory=VectorIndex)
|
||||
temporal_index: TemporalIndex[Any] = field(default_factory=TemporalIndex)
|
||||
entity_index: EntityIndex[Any] = field(default_factory=EntityIndex)
|
||||
outcome_index: OutcomeIndex[Any] = field(default_factory=OutcomeIndex)
|
||||
|
||||
async def index(self, item: Episode | Fact | Procedure) -> dict[str, IndexEntry]:
|
||||
"""
|
||||
Index an item across all applicable indices.
|
||||
|
||||
Args:
|
||||
item: Memory item to index
|
||||
|
||||
Returns:
|
||||
Dictionary of index type to entry
|
||||
"""
|
||||
results: dict[str, IndexEntry] = {}
|
||||
|
||||
# Vector index (if embedding present)
|
||||
if getattr(item, "embedding", None):
|
||||
results["vector"] = await self.vector_index.add(item)
|
||||
|
||||
# Temporal index
|
||||
results["temporal"] = await self.temporal_index.add(item)
|
||||
|
||||
# Entity index
|
||||
results["entity"] = await self.entity_index.add(item)
|
||||
|
||||
# Outcome index (for episodes and procedures)
|
||||
if isinstance(item, (Episode, Procedure)):
|
||||
results["outcome"] = await self.outcome_index.add(item)
|
||||
|
||||
logger.info(
|
||||
f"Indexed {item.id} across {len(results)} indices: {list(results.keys())}"
|
||||
)
|
||||
return results
|
||||
|
||||
async def remove(self, memory_id: UUID) -> dict[str, bool]:
|
||||
"""
|
||||
Remove an item from all indices.
|
||||
|
||||
Args:
|
||||
memory_id: ID of the memory to remove
|
||||
|
||||
Returns:
|
||||
Dictionary of index type to removal success
|
||||
"""
|
||||
results = {
|
||||
"vector": await self.vector_index.remove(memory_id),
|
||||
"temporal": await self.temporal_index.remove(memory_id),
|
||||
"entity": await self.entity_index.remove(memory_id),
|
||||
"outcome": await self.outcome_index.remove(memory_id),
|
||||
}
|
||||
|
||||
removed_from = [k for k, v in results.items() if v]
|
||||
if removed_from:
|
||||
logger.info(f"Removed {memory_id} from indices: {removed_from}")
|
||||
|
||||
return results
|
||||
|
||||
async def clear_all(self) -> dict[str, int]:
|
||||
"""
|
||||
Clear all indices.
|
||||
|
||||
Returns:
|
||||
Dictionary of index type to count cleared
|
||||
"""
|
||||
return {
|
||||
"vector": await self.vector_index.clear(),
|
||||
"temporal": await self.temporal_index.clear(),
|
||||
"entity": await self.entity_index.clear(),
|
||||
"outcome": await self.outcome_index.clear(),
|
||||
}
|
||||
|
||||
async def get_stats(self) -> dict[str, int]:
|
||||
"""
|
||||
Get statistics for all indices.
|
||||
|
||||
Returns:
|
||||
Dictionary of index type to entry count
|
||||
"""
|
||||
return {
|
||||
"vector": await self.vector_index.count(),
|
||||
"temporal": await self.temporal_index.count(),
|
||||
"entity": await self.entity_index.count(),
|
||||
"outcome": await self.outcome_index.count(),
|
||||
}
|
||||
|
||||
|
||||
# Singleton indexer instance
|
||||
_indexer: MemoryIndexer | None = None
|
||||
|
||||
|
||||
def get_memory_indexer() -> MemoryIndexer:
|
||||
"""Get the singleton memory indexer instance."""
|
||||
global _indexer
|
||||
if _indexer is None:
|
||||
_indexer = MemoryIndexer()
|
||||
return _indexer
|
||||
742
backend/app/services/memory/indexing/retrieval.py
Normal file
742
backend/app/services/memory/indexing/retrieval.py
Normal file
@@ -0,0 +1,742 @@
|
||||
# app/services/memory/indexing/retrieval.py
|
||||
"""
|
||||
Memory Retrieval Engine.
|
||||
|
||||
Provides hybrid retrieval capabilities combining:
|
||||
- Vector similarity search
|
||||
- Temporal filtering
|
||||
- Entity filtering
|
||||
- Outcome filtering
|
||||
- Relevance scoring
|
||||
- Result caching
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, TypeVar
|
||||
from uuid import UUID
|
||||
|
||||
from app.services.memory.types import (
|
||||
Episode,
|
||||
Fact,
|
||||
MemoryType,
|
||||
Outcome,
|
||||
Procedure,
|
||||
RetrievalResult,
|
||||
)
|
||||
|
||||
from .index import (
|
||||
MemoryIndexer,
|
||||
get_memory_indexer,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T", Episode, Fact, Procedure)
|
||||
|
||||
|
||||
def _utcnow() -> datetime:
|
||||
"""Get current UTC time as timezone-aware datetime."""
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrievalQuery:
|
||||
"""Query parameters for memory retrieval."""
|
||||
|
||||
# Text/semantic query
|
||||
query_text: str | None = None
|
||||
query_embedding: list[float] | None = None
|
||||
|
||||
# Temporal filters
|
||||
start_time: datetime | None = None
|
||||
end_time: datetime | None = None
|
||||
recent_seconds: float | None = None
|
||||
|
||||
# Entity filters
|
||||
entities: list[tuple[str, str]] | None = None
|
||||
entity_match_all: bool = False
|
||||
|
||||
# Outcome filters
|
||||
outcomes: list[Outcome] | None = None
|
||||
|
||||
# Memory type filter
|
||||
memory_types: list[MemoryType] | None = None
|
||||
|
||||
# Result options
|
||||
limit: int = 10
|
||||
min_relevance: float = 0.0
|
||||
|
||||
# Retrieval mode
|
||||
use_vector: bool = True
|
||||
use_temporal: bool = True
|
||||
use_entity: bool = True
|
||||
use_outcome: bool = True
|
||||
|
||||
def to_cache_key(self) -> str:
|
||||
"""Generate a cache key for this query."""
|
||||
key_parts = [
|
||||
self.query_text or "",
|
||||
str(self.start_time),
|
||||
str(self.end_time),
|
||||
str(self.recent_seconds),
|
||||
str(self.entities),
|
||||
str(self.outcomes),
|
||||
str(self.memory_types),
|
||||
str(self.limit),
|
||||
str(self.min_relevance),
|
||||
]
|
||||
key_string = "|".join(key_parts)
|
||||
return hashlib.sha256(key_string.encode()).hexdigest()[:32]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScoredResult:
|
||||
"""A retrieval result with relevance score."""
|
||||
|
||||
memory_id: UUID
|
||||
memory_type: MemoryType
|
||||
relevance_score: float
|
||||
score_breakdown: dict[str, float] = field(default_factory=dict)
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheEntry:
|
||||
"""A cached retrieval result."""
|
||||
|
||||
results: list[ScoredResult]
|
||||
created_at: datetime
|
||||
ttl_seconds: float
|
||||
query_key: str
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if this cache entry has expired."""
|
||||
age = (_utcnow() - self.created_at).total_seconds()
|
||||
return age > self.ttl_seconds
|
||||
|
||||
|
||||
class RelevanceScorer:
|
||||
"""
|
||||
Calculates relevance scores for retrieved memories.
|
||||
|
||||
Combines multiple signals:
|
||||
- Vector similarity (if available)
|
||||
- Temporal recency
|
||||
- Entity match count
|
||||
- Outcome preference
|
||||
- Importance/confidence
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vector_weight: float = 0.4,
|
||||
recency_weight: float = 0.2,
|
||||
entity_weight: float = 0.2,
|
||||
outcome_weight: float = 0.1,
|
||||
importance_weight: float = 0.1,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the relevance scorer.
|
||||
|
||||
Args:
|
||||
vector_weight: Weight for vector similarity (0-1)
|
||||
recency_weight: Weight for temporal recency (0-1)
|
||||
entity_weight: Weight for entity matches (0-1)
|
||||
outcome_weight: Weight for outcome preference (0-1)
|
||||
importance_weight: Weight for importance score (0-1)
|
||||
"""
|
||||
total = (
|
||||
vector_weight
|
||||
+ recency_weight
|
||||
+ entity_weight
|
||||
+ outcome_weight
|
||||
+ importance_weight
|
||||
)
|
||||
# Normalize weights
|
||||
self.vector_weight = vector_weight / total
|
||||
self.recency_weight = recency_weight / total
|
||||
self.entity_weight = entity_weight / total
|
||||
self.outcome_weight = outcome_weight / total
|
||||
self.importance_weight = importance_weight / total
|
||||
|
||||
def score(
|
||||
self,
|
||||
memory_id: UUID,
|
||||
memory_type: MemoryType,
|
||||
vector_similarity: float | None = None,
|
||||
timestamp: datetime | None = None,
|
||||
entity_match_count: int = 0,
|
||||
entity_total: int = 1,
|
||||
outcome: Outcome | None = None,
|
||||
importance: float = 0.5,
|
||||
preferred_outcomes: list[Outcome] | None = None,
|
||||
) -> ScoredResult:
|
||||
"""
|
||||
Calculate a relevance score for a memory.
|
||||
|
||||
Args:
|
||||
memory_id: ID of the memory
|
||||
memory_type: Type of memory
|
||||
vector_similarity: Similarity score from vector search (0-1)
|
||||
timestamp: Timestamp of the memory
|
||||
entity_match_count: Number of matching entities
|
||||
entity_total: Total entities in query
|
||||
outcome: Outcome of the memory
|
||||
importance: Importance score of the memory (0-1)
|
||||
preferred_outcomes: Outcomes to prefer
|
||||
|
||||
Returns:
|
||||
Scored result with breakdown
|
||||
"""
|
||||
breakdown: dict[str, float] = {}
|
||||
|
||||
# Vector similarity score
|
||||
if vector_similarity is not None:
|
||||
breakdown["vector"] = vector_similarity
|
||||
else:
|
||||
breakdown["vector"] = 0.5 # Neutral if no vector
|
||||
|
||||
# Recency score (exponential decay)
|
||||
if timestamp:
|
||||
age_hours = (_utcnow() - timestamp).total_seconds() / 3600
|
||||
# Decay with half-life of 24 hours
|
||||
breakdown["recency"] = 2 ** (-age_hours / 24)
|
||||
else:
|
||||
breakdown["recency"] = 0.5
|
||||
|
||||
# Entity match score
|
||||
if entity_total > 0:
|
||||
breakdown["entity"] = entity_match_count / entity_total
|
||||
else:
|
||||
breakdown["entity"] = 1.0 # No entity filter = full score
|
||||
|
||||
# Outcome score
|
||||
if preferred_outcomes and outcome:
|
||||
breakdown["outcome"] = 1.0 if outcome in preferred_outcomes else 0.0
|
||||
else:
|
||||
breakdown["outcome"] = 0.5 # Neutral if no preference
|
||||
|
||||
# Importance score
|
||||
breakdown["importance"] = importance
|
||||
|
||||
# Calculate weighted sum
|
||||
total_score = (
|
||||
breakdown["vector"] * self.vector_weight
|
||||
+ breakdown["recency"] * self.recency_weight
|
||||
+ breakdown["entity"] * self.entity_weight
|
||||
+ breakdown["outcome"] * self.outcome_weight
|
||||
+ breakdown["importance"] * self.importance_weight
|
||||
)
|
||||
|
||||
return ScoredResult(
|
||||
memory_id=memory_id,
|
||||
memory_type=memory_type,
|
||||
relevance_score=total_score,
|
||||
score_breakdown=breakdown,
|
||||
)
|
||||
|
||||
|
||||
class RetrievalCache:
|
||||
"""
|
||||
In-memory cache for retrieval results.
|
||||
|
||||
Supports TTL-based expiration and LRU eviction with O(1) operations.
|
||||
Uses OrderedDict for efficient LRU tracking.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_entries: int = 1000,
|
||||
default_ttl_seconds: float = 300,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the cache.
|
||||
|
||||
Args:
|
||||
max_entries: Maximum cache entries
|
||||
default_ttl_seconds: Default TTL for entries
|
||||
"""
|
||||
# OrderedDict maintains insertion order; we use move_to_end for O(1) LRU
|
||||
self._cache: OrderedDict[str, CacheEntry] = OrderedDict()
|
||||
self._max_entries = max_entries
|
||||
self._default_ttl = default_ttl_seconds
|
||||
logger.info(
|
||||
f"Initialized RetrievalCache with max_entries={max_entries}, "
|
||||
f"ttl={default_ttl_seconds}s"
|
||||
)
|
||||
|
||||
def get(self, query_key: str) -> list[ScoredResult] | None:
|
||||
"""
|
||||
Get cached results for a query.
|
||||
|
||||
Args:
|
||||
query_key: Cache key for the query
|
||||
|
||||
Returns:
|
||||
Cached results or None if not found/expired
|
||||
"""
|
||||
if query_key not in self._cache:
|
||||
return None
|
||||
|
||||
entry = self._cache[query_key]
|
||||
if entry.is_expired():
|
||||
del self._cache[query_key]
|
||||
return None
|
||||
|
||||
# Update access order (LRU) - O(1) with OrderedDict
|
||||
self._cache.move_to_end(query_key)
|
||||
|
||||
logger.debug(f"Cache hit for {query_key}")
|
||||
return entry.results
|
||||
|
||||
def put(
|
||||
self,
|
||||
query_key: str,
|
||||
results: list[ScoredResult],
|
||||
ttl_seconds: float | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Cache results for a query.
|
||||
|
||||
Args:
|
||||
query_key: Cache key for the query
|
||||
results: Results to cache
|
||||
ttl_seconds: TTL for this entry (or default)
|
||||
"""
|
||||
# Evict oldest entries if at capacity - O(1) with popitem(last=False)
|
||||
while len(self._cache) >= self._max_entries:
|
||||
self._cache.popitem(last=False)
|
||||
|
||||
entry = CacheEntry(
|
||||
results=results,
|
||||
created_at=_utcnow(),
|
||||
ttl_seconds=ttl_seconds or self._default_ttl,
|
||||
query_key=query_key,
|
||||
)
|
||||
|
||||
self._cache[query_key] = entry
|
||||
logger.debug(f"Cached {len(results)} results for {query_key}")
|
||||
|
||||
def invalidate(self, query_key: str) -> bool:
|
||||
"""
|
||||
Invalidate a specific cache entry.
|
||||
|
||||
Args:
|
||||
query_key: Cache key to invalidate
|
||||
|
||||
Returns:
|
||||
True if entry was found and removed
|
||||
"""
|
||||
if query_key in self._cache:
|
||||
del self._cache[query_key]
|
||||
return True
|
||||
return False
|
||||
|
||||
def invalidate_by_memory(self, memory_id: UUID) -> int:
|
||||
"""
|
||||
Invalidate all cache entries containing a specific memory.
|
||||
|
||||
Args:
|
||||
memory_id: Memory ID to invalidate
|
||||
|
||||
Returns:
|
||||
Number of entries invalidated
|
||||
"""
|
||||
keys_to_remove = []
|
||||
for key, entry in self._cache.items():
|
||||
if any(r.memory_id == memory_id for r in entry.results):
|
||||
keys_to_remove.append(key)
|
||||
|
||||
for key in keys_to_remove:
|
||||
self.invalidate(key)
|
||||
|
||||
if keys_to_remove:
|
||||
logger.debug(
|
||||
f"Invalidated {len(keys_to_remove)} cache entries for {memory_id}"
|
||||
)
|
||||
return len(keys_to_remove)
|
||||
|
||||
def clear(self) -> int:
|
||||
"""
|
||||
Clear all cache entries.
|
||||
|
||||
Returns:
|
||||
Number of entries cleared
|
||||
"""
|
||||
count = len(self._cache)
|
||||
self._cache.clear()
|
||||
logger.info(f"Cleared {count} cache entries")
|
||||
return count
|
||||
|
||||
def get_stats(self) -> dict[str, Any]:
|
||||
"""Get cache statistics."""
|
||||
expired_count = sum(1 for e in self._cache.values() if e.is_expired())
|
||||
return {
|
||||
"total_entries": len(self._cache),
|
||||
"expired_entries": expired_count,
|
||||
"max_entries": self._max_entries,
|
||||
"default_ttl_seconds": self._default_ttl,
|
||||
}
|
||||
|
||||
|
||||
class RetrievalEngine:
|
||||
"""
|
||||
Hybrid retrieval engine for memory search.
|
||||
|
||||
Combines multiple index types for comprehensive retrieval:
|
||||
- Vector search for semantic similarity
|
||||
- Temporal index for time-based filtering
|
||||
- Entity index for entity-based lookups
|
||||
- Outcome index for success/failure filtering
|
||||
|
||||
Results are scored and ranked using relevance scoring.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
indexer: MemoryIndexer | None = None,
|
||||
scorer: RelevanceScorer | None = None,
|
||||
cache: RetrievalCache | None = None,
|
||||
enable_cache: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the retrieval engine.
|
||||
|
||||
Args:
|
||||
indexer: Memory indexer (defaults to singleton)
|
||||
scorer: Relevance scorer (defaults to new instance)
|
||||
cache: Retrieval cache (defaults to new instance)
|
||||
enable_cache: Whether to enable result caching
|
||||
"""
|
||||
self._indexer = indexer or get_memory_indexer()
|
||||
self._scorer = scorer or RelevanceScorer()
|
||||
self._cache = cache or RetrievalCache() if enable_cache else None
|
||||
self._enable_cache = enable_cache
|
||||
logger.info(f"Initialized RetrievalEngine with cache={enable_cache}")
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
query: RetrievalQuery,
|
||||
use_cache: bool = True,
|
||||
) -> RetrievalResult[ScoredResult]:
|
||||
"""
|
||||
Retrieve relevant memories using hybrid search.
|
||||
|
||||
Args:
|
||||
query: Retrieval query parameters
|
||||
use_cache: Whether to use cached results
|
||||
|
||||
Returns:
|
||||
Retrieval result with scored items
|
||||
"""
|
||||
start_time = _utcnow()
|
||||
|
||||
# Check cache
|
||||
cache_key = query.to_cache_key()
|
||||
if use_cache and self._cache:
|
||||
cached = self._cache.get(cache_key)
|
||||
if cached:
|
||||
latency = (_utcnow() - start_time).total_seconds() * 1000
|
||||
return RetrievalResult(
|
||||
items=cached,
|
||||
total_count=len(cached),
|
||||
query=query.query_text or "",
|
||||
retrieval_type="cached",
|
||||
latency_ms=latency,
|
||||
metadata={"cache_hit": True},
|
||||
)
|
||||
|
||||
# Collect candidates from each index
|
||||
candidates: dict[UUID, dict[str, Any]] = {}
|
||||
|
||||
# Vector search
|
||||
if query.use_vector and query.query_embedding:
|
||||
vector_results = await self._indexer.vector_index.search(
|
||||
query=query.query_embedding,
|
||||
limit=query.limit * 3, # Get more for filtering
|
||||
min_similarity=query.min_relevance,
|
||||
memory_type=query.memory_types[0] if query.memory_types else None,
|
||||
)
|
||||
for entry in vector_results:
|
||||
if entry.memory_id not in candidates:
|
||||
candidates[entry.memory_id] = {
|
||||
"memory_type": entry.memory_type,
|
||||
"sources": [],
|
||||
}
|
||||
candidates[entry.memory_id]["vector_similarity"] = entry.metadata.get(
|
||||
"similarity", 0.5
|
||||
)
|
||||
candidates[entry.memory_id]["sources"].append("vector")
|
||||
|
||||
# Temporal search
|
||||
if query.use_temporal and (
|
||||
query.start_time or query.end_time or query.recent_seconds
|
||||
):
|
||||
temporal_results = await self._indexer.temporal_index.search(
|
||||
query=None,
|
||||
limit=query.limit * 3,
|
||||
start_time=query.start_time,
|
||||
end_time=query.end_time,
|
||||
recent_seconds=query.recent_seconds,
|
||||
memory_type=query.memory_types[0] if query.memory_types else None,
|
||||
)
|
||||
for temporal_entry in temporal_results:
|
||||
if temporal_entry.memory_id not in candidates:
|
||||
candidates[temporal_entry.memory_id] = {
|
||||
"memory_type": temporal_entry.memory_type,
|
||||
"sources": [],
|
||||
}
|
||||
candidates[temporal_entry.memory_id]["timestamp"] = (
|
||||
temporal_entry.timestamp
|
||||
)
|
||||
candidates[temporal_entry.memory_id]["sources"].append("temporal")
|
||||
|
||||
# Entity search
|
||||
if query.use_entity and query.entities:
|
||||
entity_results = await self._indexer.entity_index.search(
|
||||
query=None,
|
||||
limit=query.limit * 3,
|
||||
entities=query.entities,
|
||||
match_all=query.entity_match_all,
|
||||
memory_type=query.memory_types[0] if query.memory_types else None,
|
||||
)
|
||||
for entity_entry in entity_results:
|
||||
if entity_entry.memory_id not in candidates:
|
||||
candidates[entity_entry.memory_id] = {
|
||||
"memory_type": entity_entry.memory_type,
|
||||
"sources": [],
|
||||
}
|
||||
# Count entity matches
|
||||
entity_count = candidates[entity_entry.memory_id].get(
|
||||
"entity_match_count", 0
|
||||
)
|
||||
candidates[entity_entry.memory_id]["entity_match_count"] = (
|
||||
entity_count + 1
|
||||
)
|
||||
candidates[entity_entry.memory_id]["sources"].append("entity")
|
||||
|
||||
# Outcome search
|
||||
if query.use_outcome and query.outcomes:
|
||||
outcome_results = await self._indexer.outcome_index.search(
|
||||
query=None,
|
||||
limit=query.limit * 3,
|
||||
outcomes=query.outcomes,
|
||||
memory_type=query.memory_types[0] if query.memory_types else None,
|
||||
)
|
||||
for outcome_entry in outcome_results:
|
||||
if outcome_entry.memory_id not in candidates:
|
||||
candidates[outcome_entry.memory_id] = {
|
||||
"memory_type": outcome_entry.memory_type,
|
||||
"sources": [],
|
||||
}
|
||||
candidates[outcome_entry.memory_id]["outcome"] = outcome_entry.outcome
|
||||
candidates[outcome_entry.memory_id]["sources"].append("outcome")
|
||||
|
||||
# Score and rank candidates
|
||||
scored_results: list[ScoredResult] = []
|
||||
entity_total = len(query.entities) if query.entities else 1
|
||||
|
||||
for memory_id, data in candidates.items():
|
||||
scored = self._scorer.score(
|
||||
memory_id=memory_id,
|
||||
memory_type=data["memory_type"],
|
||||
vector_similarity=data.get("vector_similarity"),
|
||||
timestamp=data.get("timestamp"),
|
||||
entity_match_count=data.get("entity_match_count", 0),
|
||||
entity_total=entity_total,
|
||||
outcome=data.get("outcome"),
|
||||
preferred_outcomes=query.outcomes,
|
||||
)
|
||||
scored.metadata["sources"] = data.get("sources", [])
|
||||
|
||||
# Filter by minimum relevance
|
||||
if scored.relevance_score >= query.min_relevance:
|
||||
scored_results.append(scored)
|
||||
|
||||
# Sort by relevance score
|
||||
scored_results.sort(key=lambda x: x.relevance_score, reverse=True)
|
||||
|
||||
# Apply limit
|
||||
final_results = scored_results[: query.limit]
|
||||
|
||||
# Cache results
|
||||
if use_cache and self._cache and final_results:
|
||||
self._cache.put(cache_key, final_results)
|
||||
|
||||
latency = (_utcnow() - start_time).total_seconds() * 1000
|
||||
|
||||
logger.info(
|
||||
f"Retrieved {len(final_results)} results from {len(candidates)} candidates "
|
||||
f"in {latency:.2f}ms"
|
||||
)
|
||||
|
||||
return RetrievalResult(
|
||||
items=final_results,
|
||||
total_count=len(candidates),
|
||||
query=query.query_text or "",
|
||||
retrieval_type="hybrid",
|
||||
latency_ms=latency,
|
||||
metadata={
|
||||
"cache_hit": False,
|
||||
"candidates_count": len(candidates),
|
||||
"filtered_count": len(scored_results),
|
||||
},
|
||||
)
|
||||
|
||||
async def retrieve_similar(
|
||||
self,
|
||||
embedding: list[float],
|
||||
limit: int = 10,
|
||||
min_similarity: float = 0.5,
|
||||
memory_types: list[MemoryType] | None = None,
|
||||
) -> RetrievalResult[ScoredResult]:
|
||||
"""
|
||||
Retrieve memories similar to a given embedding.
|
||||
|
||||
Args:
|
||||
embedding: Query embedding
|
||||
limit: Maximum results
|
||||
min_similarity: Minimum similarity threshold
|
||||
memory_types: Filter by memory types
|
||||
|
||||
Returns:
|
||||
Retrieval result with scored items
|
||||
"""
|
||||
query = RetrievalQuery(
|
||||
query_embedding=embedding,
|
||||
limit=limit,
|
||||
min_relevance=min_similarity,
|
||||
memory_types=memory_types,
|
||||
use_temporal=False,
|
||||
use_entity=False,
|
||||
use_outcome=False,
|
||||
)
|
||||
return await self.retrieve(query)
|
||||
|
||||
async def retrieve_recent(
|
||||
self,
|
||||
hours: float = 24,
|
||||
limit: int = 10,
|
||||
memory_types: list[MemoryType] | None = None,
|
||||
) -> RetrievalResult[ScoredResult]:
|
||||
"""
|
||||
Retrieve recent memories.
|
||||
|
||||
Args:
|
||||
hours: Number of hours to look back
|
||||
limit: Maximum results
|
||||
memory_types: Filter by memory types
|
||||
|
||||
Returns:
|
||||
Retrieval result with scored items
|
||||
"""
|
||||
query = RetrievalQuery(
|
||||
recent_seconds=hours * 3600,
|
||||
limit=limit,
|
||||
memory_types=memory_types,
|
||||
use_vector=False,
|
||||
use_entity=False,
|
||||
use_outcome=False,
|
||||
)
|
||||
return await self.retrieve(query)
|
||||
|
||||
async def retrieve_by_entity(
|
||||
self,
|
||||
entity_type: str,
|
||||
entity_value: str,
|
||||
limit: int = 10,
|
||||
memory_types: list[MemoryType] | None = None,
|
||||
) -> RetrievalResult[ScoredResult]:
|
||||
"""
|
||||
Retrieve memories by entity.
|
||||
|
||||
Args:
|
||||
entity_type: Type of entity
|
||||
entity_value: Entity value
|
||||
limit: Maximum results
|
||||
memory_types: Filter by memory types
|
||||
|
||||
Returns:
|
||||
Retrieval result with scored items
|
||||
"""
|
||||
query = RetrievalQuery(
|
||||
entities=[(entity_type, entity_value)],
|
||||
limit=limit,
|
||||
memory_types=memory_types,
|
||||
use_vector=False,
|
||||
use_temporal=False,
|
||||
use_outcome=False,
|
||||
)
|
||||
return await self.retrieve(query)
|
||||
|
||||
async def retrieve_successful(
|
||||
self,
|
||||
limit: int = 10,
|
||||
memory_types: list[MemoryType] | None = None,
|
||||
) -> RetrievalResult[ScoredResult]:
|
||||
"""
|
||||
Retrieve successful memories.
|
||||
|
||||
Args:
|
||||
limit: Maximum results
|
||||
memory_types: Filter by memory types
|
||||
|
||||
Returns:
|
||||
Retrieval result with scored items
|
||||
"""
|
||||
query = RetrievalQuery(
|
||||
outcomes=[Outcome.SUCCESS],
|
||||
limit=limit,
|
||||
memory_types=memory_types,
|
||||
use_vector=False,
|
||||
use_temporal=False,
|
||||
use_entity=False,
|
||||
)
|
||||
return await self.retrieve(query)
|
||||
|
||||
def invalidate_cache(self) -> int:
|
||||
"""
|
||||
Invalidate all cached results.
|
||||
|
||||
Returns:
|
||||
Number of entries invalidated
|
||||
"""
|
||||
if self._cache:
|
||||
return self._cache.clear()
|
||||
return 0
|
||||
|
||||
def invalidate_cache_for_memory(self, memory_id: UUID) -> int:
|
||||
"""
|
||||
Invalidate cache entries containing a specific memory.
|
||||
|
||||
Args:
|
||||
memory_id: Memory ID to invalidate
|
||||
|
||||
Returns:
|
||||
Number of entries invalidated
|
||||
"""
|
||||
if self._cache:
|
||||
return self._cache.invalidate_by_memory(memory_id)
|
||||
return 0
|
||||
|
||||
def get_cache_stats(self) -> dict[str, Any]:
|
||||
"""Get cache statistics."""
|
||||
if self._cache:
|
||||
return self._cache.get_stats()
|
||||
return {"enabled": False}
|
||||
|
||||
|
||||
# Singleton retrieval engine instance
|
||||
_engine: RetrievalEngine | None = None
|
||||
|
||||
|
||||
def get_retrieval_engine() -> RetrievalEngine:
|
||||
"""Get the singleton retrieval engine instance."""
|
||||
global _engine
|
||||
if _engine is None:
|
||||
_engine = RetrievalEngine()
|
||||
return _engine
|
||||
19
backend/app/services/memory/integration/__init__.py
Normal file
19
backend/app/services/memory/integration/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
# app/services/memory/integration/__init__.py
|
||||
"""
|
||||
Memory Integration Module.
|
||||
|
||||
Provides integration between the agent memory system and other Syndarix components:
|
||||
- Context Engine: Memory as context source
|
||||
- Agent Lifecycle: Spawn, pause, resume, terminate hooks
|
||||
"""
|
||||
|
||||
from .context_source import MemoryContextSource, get_memory_context_source
|
||||
from .lifecycle import AgentLifecycleManager, LifecycleHooks, get_lifecycle_manager
|
||||
|
||||
__all__ = [
|
||||
"AgentLifecycleManager",
|
||||
"LifecycleHooks",
|
||||
"MemoryContextSource",
|
||||
"get_lifecycle_manager",
|
||||
"get_memory_context_source",
|
||||
]
|
||||
399
backend/app/services/memory/integration/context_source.py
Normal file
399
backend/app/services/memory/integration/context_source.py
Normal file
@@ -0,0 +1,399 @@
|
||||
# app/services/memory/integration/context_source.py
|
||||
"""
|
||||
Memory Context Source.
|
||||
|
||||
Provides agent memory as a context source for the Context Engine.
|
||||
Retrieves relevant memories based on query and converts them to MemoryContext objects.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.context.types.memory import MemoryContext
|
||||
from app.services.memory.episodic import EpisodicMemory
|
||||
from app.services.memory.procedural import ProceduralMemory
|
||||
from app.services.memory.semantic import SemanticMemory
|
||||
from app.services.memory.working import WorkingMemory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryFetchConfig:
|
||||
"""Configuration for memory fetching."""
|
||||
|
||||
# Limits per memory type
|
||||
working_limit: int = 10
|
||||
episodic_limit: int = 10
|
||||
semantic_limit: int = 15
|
||||
procedural_limit: int = 5
|
||||
|
||||
# Time ranges
|
||||
episodic_days_back: int = 30
|
||||
min_relevance: float = 0.3
|
||||
|
||||
# Which memory types to include
|
||||
include_working: bool = True
|
||||
include_episodic: bool = True
|
||||
include_semantic: bool = True
|
||||
include_procedural: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryFetchResult:
|
||||
"""Result of memory fetch operation."""
|
||||
|
||||
contexts: list[MemoryContext]
|
||||
by_type: dict[str, int]
|
||||
fetch_time_ms: float
|
||||
query: str
|
||||
|
||||
|
||||
class MemoryContextSource:
|
||||
"""
|
||||
Source for memory context in the Context Engine.
|
||||
|
||||
This service retrieves relevant memories based on a query and
|
||||
converts them to MemoryContext objects for context assembly.
|
||||
It coordinates between all memory types (working, episodic,
|
||||
semantic, procedural) to provide a comprehensive memory context.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
embedding_generator: Any | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the memory context source.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
embedding_generator: Optional embedding generator for semantic search
|
||||
"""
|
||||
self._session = session
|
||||
self._embedding_generator = embedding_generator
|
||||
|
||||
# Lazy-initialized memory services
|
||||
self._episodic: EpisodicMemory | None = None
|
||||
self._semantic: SemanticMemory | None = None
|
||||
self._procedural: ProceduralMemory | None = None
|
||||
|
||||
async def _get_episodic(self) -> EpisodicMemory:
|
||||
"""Get or create episodic memory service."""
|
||||
if self._episodic is None:
|
||||
self._episodic = await EpisodicMemory.create(
|
||||
self._session,
|
||||
self._embedding_generator,
|
||||
)
|
||||
return self._episodic
|
||||
|
||||
async def _get_semantic(self) -> SemanticMemory:
|
||||
"""Get or create semantic memory service."""
|
||||
if self._semantic is None:
|
||||
self._semantic = await SemanticMemory.create(
|
||||
self._session,
|
||||
self._embedding_generator,
|
||||
)
|
||||
return self._semantic
|
||||
|
||||
async def _get_procedural(self) -> ProceduralMemory:
|
||||
"""Get or create procedural memory service."""
|
||||
if self._procedural is None:
|
||||
self._procedural = await ProceduralMemory.create(
|
||||
self._session,
|
||||
self._embedding_generator,
|
||||
)
|
||||
return self._procedural
|
||||
|
||||
async def fetch_context(
|
||||
self,
|
||||
query: str,
|
||||
project_id: UUID,
|
||||
agent_instance_id: UUID | None = None,
|
||||
agent_type_id: UUID | None = None,
|
||||
session_id: str | None = None,
|
||||
config: MemoryFetchConfig | None = None,
|
||||
) -> MemoryFetchResult:
|
||||
"""
|
||||
Fetch relevant memories as context.
|
||||
|
||||
This is the main entry point for the Context Engine integration.
|
||||
It searches across all memory types and returns relevant memories
|
||||
as MemoryContext objects.
|
||||
|
||||
Args:
|
||||
query: Search query for finding relevant memories
|
||||
project_id: Project scope
|
||||
agent_instance_id: Optional agent instance scope
|
||||
agent_type_id: Optional agent type scope (for procedural)
|
||||
session_id: Optional session ID (for working memory)
|
||||
config: Optional fetch configuration
|
||||
|
||||
Returns:
|
||||
MemoryFetchResult with contexts and metadata
|
||||
"""
|
||||
config = config or MemoryFetchConfig()
|
||||
start_time = datetime.now(UTC)
|
||||
|
||||
contexts: list[MemoryContext] = []
|
||||
by_type: dict[str, int] = {
|
||||
"working": 0,
|
||||
"episodic": 0,
|
||||
"semantic": 0,
|
||||
"procedural": 0,
|
||||
}
|
||||
|
||||
# Fetch from working memory (session-scoped)
|
||||
if config.include_working and session_id:
|
||||
try:
|
||||
working_contexts = await self._fetch_working(
|
||||
query=query,
|
||||
session_id=session_id,
|
||||
project_id=project_id,
|
||||
agent_instance_id=agent_instance_id,
|
||||
limit=config.working_limit,
|
||||
)
|
||||
contexts.extend(working_contexts)
|
||||
by_type["working"] = len(working_contexts)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch working memory: {e}")
|
||||
|
||||
# Fetch from episodic memory
|
||||
if config.include_episodic:
|
||||
try:
|
||||
episodic_contexts = await self._fetch_episodic(
|
||||
query=query,
|
||||
project_id=project_id,
|
||||
agent_instance_id=agent_instance_id,
|
||||
limit=config.episodic_limit,
|
||||
days_back=config.episodic_days_back,
|
||||
)
|
||||
contexts.extend(episodic_contexts)
|
||||
by_type["episodic"] = len(episodic_contexts)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch episodic memory: {e}")
|
||||
|
||||
# Fetch from semantic memory
|
||||
if config.include_semantic:
|
||||
try:
|
||||
semantic_contexts = await self._fetch_semantic(
|
||||
query=query,
|
||||
project_id=project_id,
|
||||
limit=config.semantic_limit,
|
||||
min_relevance=config.min_relevance,
|
||||
)
|
||||
contexts.extend(semantic_contexts)
|
||||
by_type["semantic"] = len(semantic_contexts)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch semantic memory: {e}")
|
||||
|
||||
# Fetch from procedural memory
|
||||
if config.include_procedural:
|
||||
try:
|
||||
procedural_contexts = await self._fetch_procedural(
|
||||
query=query,
|
||||
project_id=project_id,
|
||||
agent_type_id=agent_type_id,
|
||||
limit=config.procedural_limit,
|
||||
)
|
||||
contexts.extend(procedural_contexts)
|
||||
by_type["procedural"] = len(procedural_contexts)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch procedural memory: {e}")
|
||||
|
||||
# Sort by relevance
|
||||
contexts.sort(key=lambda c: c.relevance_score, reverse=True)
|
||||
|
||||
fetch_time = (datetime.now(UTC) - start_time).total_seconds() * 1000
|
||||
|
||||
logger.debug(
|
||||
f"Fetched {len(contexts)} memory contexts for query '{query[:50]}...' "
|
||||
f"in {fetch_time:.1f}ms"
|
||||
)
|
||||
|
||||
return MemoryFetchResult(
|
||||
contexts=contexts,
|
||||
by_type=by_type,
|
||||
fetch_time_ms=fetch_time,
|
||||
query=query,
|
||||
)
|
||||
|
||||
async def _fetch_working(
|
||||
self,
|
||||
query: str,
|
||||
session_id: str,
|
||||
project_id: UUID,
|
||||
agent_instance_id: UUID | None,
|
||||
limit: int,
|
||||
) -> list[MemoryContext]:
|
||||
"""Fetch from working memory."""
|
||||
working = await WorkingMemory.for_session(
|
||||
session_id=session_id,
|
||||
project_id=str(project_id),
|
||||
agent_instance_id=str(agent_instance_id) if agent_instance_id else None,
|
||||
)
|
||||
|
||||
contexts: list[MemoryContext] = []
|
||||
all_keys = await working.list_keys()
|
||||
|
||||
# Filter keys by query (simple substring match)
|
||||
query_lower = query.lower()
|
||||
matched_keys = [k for k in all_keys if query_lower in k.lower()]
|
||||
|
||||
# If no query match, include all keys (working memory is always relevant)
|
||||
if not matched_keys and query:
|
||||
matched_keys = all_keys
|
||||
|
||||
for key in matched_keys[:limit]:
|
||||
value = await working.get(key)
|
||||
if value is not None:
|
||||
contexts.append(
|
||||
MemoryContext.from_working_memory(
|
||||
key=key,
|
||||
value=value,
|
||||
source=f"working:{session_id}",
|
||||
query=query,
|
||||
)
|
||||
)
|
||||
|
||||
return contexts
|
||||
|
||||
async def _fetch_episodic(
|
||||
self,
|
||||
query: str,
|
||||
project_id: UUID,
|
||||
agent_instance_id: UUID | None,
|
||||
limit: int,
|
||||
days_back: int,
|
||||
) -> list[MemoryContext]:
|
||||
"""Fetch from episodic memory."""
|
||||
episodic = await self._get_episodic()
|
||||
|
||||
# Search for similar episodes
|
||||
episodes = await episodic.search_similar(
|
||||
project_id=project_id,
|
||||
query=query,
|
||||
limit=limit,
|
||||
agent_instance_id=agent_instance_id,
|
||||
)
|
||||
|
||||
# Also get recent episodes if we didn't find enough
|
||||
if len(episodes) < limit // 2:
|
||||
since = datetime.now(UTC) - timedelta(days=days_back)
|
||||
recent = await episodic.get_recent(
|
||||
project_id=project_id,
|
||||
limit=limit,
|
||||
since=since,
|
||||
)
|
||||
# Deduplicate by ID
|
||||
existing_ids = {e.id for e in episodes}
|
||||
for ep in recent:
|
||||
if ep.id not in existing_ids:
|
||||
episodes.append(ep)
|
||||
if len(episodes) >= limit:
|
||||
break
|
||||
|
||||
return [
|
||||
MemoryContext.from_episodic_memory(ep, query=query)
|
||||
for ep in episodes[:limit]
|
||||
]
|
||||
|
||||
async def _fetch_semantic(
|
||||
self,
|
||||
query: str,
|
||||
project_id: UUID,
|
||||
limit: int,
|
||||
min_relevance: float,
|
||||
) -> list[MemoryContext]:
|
||||
"""Fetch from semantic memory."""
|
||||
semantic = await self._get_semantic()
|
||||
|
||||
facts = await semantic.search_facts(
|
||||
query=query,
|
||||
project_id=project_id,
|
||||
limit=limit,
|
||||
min_confidence=min_relevance,
|
||||
)
|
||||
|
||||
return [MemoryContext.from_semantic_memory(fact, query=query) for fact in facts]
|
||||
|
||||
async def _fetch_procedural(
|
||||
self,
|
||||
query: str,
|
||||
project_id: UUID,
|
||||
agent_type_id: UUID | None,
|
||||
limit: int,
|
||||
) -> list[MemoryContext]:
|
||||
"""Fetch from procedural memory."""
|
||||
procedural = await self._get_procedural()
|
||||
|
||||
procedures = await procedural.find_matching(
|
||||
context=query,
|
||||
project_id=project_id,
|
||||
agent_type_id=agent_type_id,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
return [
|
||||
MemoryContext.from_procedural_memory(proc, query=query)
|
||||
for proc in procedures
|
||||
]
|
||||
|
||||
async def fetch_all_working(
|
||||
self,
|
||||
session_id: str,
|
||||
project_id: UUID,
|
||||
agent_instance_id: UUID | None = None,
|
||||
) -> list[MemoryContext]:
|
||||
"""
|
||||
Fetch all working memory for a session.
|
||||
|
||||
Useful for including entire session state in context.
|
||||
|
||||
Args:
|
||||
session_id: Session ID
|
||||
project_id: Project scope
|
||||
agent_instance_id: Optional agent instance scope
|
||||
|
||||
Returns:
|
||||
List of MemoryContext for all working memory items
|
||||
"""
|
||||
working = await WorkingMemory.for_session(
|
||||
session_id=session_id,
|
||||
project_id=str(project_id),
|
||||
agent_instance_id=str(agent_instance_id) if agent_instance_id else None,
|
||||
)
|
||||
|
||||
contexts: list[MemoryContext] = []
|
||||
all_keys = await working.list_keys()
|
||||
|
||||
for key in all_keys:
|
||||
value = await working.get(key)
|
||||
if value is not None:
|
||||
contexts.append(
|
||||
MemoryContext.from_working_memory(
|
||||
key=key,
|
||||
value=value,
|
||||
source=f"working:{session_id}",
|
||||
)
|
||||
)
|
||||
|
||||
return contexts
|
||||
|
||||
|
||||
# Factory function
|
||||
async def get_memory_context_source(
|
||||
session: AsyncSession,
|
||||
embedding_generator: Any | None = None,
|
||||
) -> MemoryContextSource:
|
||||
"""Create a memory context source instance."""
|
||||
return MemoryContextSource(
|
||||
session=session,
|
||||
embedding_generator=embedding_generator,
|
||||
)
|
||||
635
backend/app/services/memory/integration/lifecycle.py
Normal file
635
backend/app/services/memory/integration/lifecycle.py
Normal file
@@ -0,0 +1,635 @@
|
||||
# app/services/memory/integration/lifecycle.py
|
||||
"""
|
||||
Agent Lifecycle Hooks for Memory System.
|
||||
|
||||
Provides memory management hooks for agent lifecycle events:
|
||||
- spawn: Initialize working memory for new agent instance
|
||||
- pause: Checkpoint working memory state
|
||||
- resume: Restore working memory from checkpoint
|
||||
- terminate: Consolidate session to episodic memory
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable, Coroutine
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.memory.episodic import EpisodicMemory
|
||||
from app.services.memory.types import EpisodeCreate, Outcome
|
||||
from app.services.memory.working import WorkingMemory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LifecycleEvent:
|
||||
"""Event data for lifecycle hooks."""
|
||||
|
||||
event_type: str # spawn, pause, resume, terminate
|
||||
project_id: UUID
|
||||
agent_instance_id: UUID
|
||||
agent_type_id: UUID | None = None
|
||||
session_id: str | None = None
|
||||
timestamp: datetime = field(default_factory=lambda: datetime.now(UTC))
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LifecycleResult:
|
||||
"""Result of a lifecycle operation."""
|
||||
|
||||
success: bool
|
||||
event_type: str
|
||||
message: str | None = None
|
||||
data: dict[str, Any] = field(default_factory=dict)
|
||||
duration_ms: float = 0.0
|
||||
|
||||
|
||||
# Type alias for lifecycle hooks
|
||||
LifecycleHook = Callable[[LifecycleEvent], Coroutine[Any, Any, None]]
|
||||
|
||||
|
||||
class LifecycleHooks:
|
||||
"""
|
||||
Collection of lifecycle hooks.
|
||||
|
||||
Allows registration of custom hooks for lifecycle events.
|
||||
Hooks are called after the core memory operations.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize lifecycle hooks."""
|
||||
self._spawn_hooks: list[LifecycleHook] = []
|
||||
self._pause_hooks: list[LifecycleHook] = []
|
||||
self._resume_hooks: list[LifecycleHook] = []
|
||||
self._terminate_hooks: list[LifecycleHook] = []
|
||||
|
||||
def on_spawn(self, hook: LifecycleHook) -> LifecycleHook:
|
||||
"""Register a spawn hook."""
|
||||
self._spawn_hooks.append(hook)
|
||||
return hook
|
||||
|
||||
def on_pause(self, hook: LifecycleHook) -> LifecycleHook:
|
||||
"""Register a pause hook."""
|
||||
self._pause_hooks.append(hook)
|
||||
return hook
|
||||
|
||||
def on_resume(self, hook: LifecycleHook) -> LifecycleHook:
|
||||
"""Register a resume hook."""
|
||||
self._resume_hooks.append(hook)
|
||||
return hook
|
||||
|
||||
def on_terminate(self, hook: LifecycleHook) -> LifecycleHook:
|
||||
"""Register a terminate hook."""
|
||||
self._terminate_hooks.append(hook)
|
||||
return hook
|
||||
|
||||
async def run_spawn_hooks(self, event: LifecycleEvent) -> None:
|
||||
"""Run all spawn hooks."""
|
||||
for hook in self._spawn_hooks:
|
||||
try:
|
||||
await hook(event)
|
||||
except Exception as e:
|
||||
logger.warning(f"Spawn hook failed: {e}")
|
||||
|
||||
async def run_pause_hooks(self, event: LifecycleEvent) -> None:
|
||||
"""Run all pause hooks."""
|
||||
for hook in self._pause_hooks:
|
||||
try:
|
||||
await hook(event)
|
||||
except Exception as e:
|
||||
logger.warning(f"Pause hook failed: {e}")
|
||||
|
||||
async def run_resume_hooks(self, event: LifecycleEvent) -> None:
|
||||
"""Run all resume hooks."""
|
||||
for hook in self._resume_hooks:
|
||||
try:
|
||||
await hook(event)
|
||||
except Exception as e:
|
||||
logger.warning(f"Resume hook failed: {e}")
|
||||
|
||||
async def run_terminate_hooks(self, event: LifecycleEvent) -> None:
|
||||
"""Run all terminate hooks."""
|
||||
for hook in self._terminate_hooks:
|
||||
try:
|
||||
await hook(event)
|
||||
except Exception as e:
|
||||
logger.warning(f"Terminate hook failed: {e}")
|
||||
|
||||
|
||||
class AgentLifecycleManager:
|
||||
"""
|
||||
Manager for agent lifecycle and memory integration.
|
||||
|
||||
Handles memory operations during agent lifecycle events:
|
||||
- spawn: Creates new working memory for the session
|
||||
- pause: Saves working memory state to checkpoint
|
||||
- resume: Restores working memory from checkpoint
|
||||
- terminate: Consolidates working memory to episodic memory
|
||||
"""
|
||||
|
||||
# Key prefix for checkpoint storage
|
||||
CHECKPOINT_PREFIX = "__checkpoint__"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
embedding_generator: Any | None = None,
|
||||
hooks: LifecycleHooks | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the lifecycle manager.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
embedding_generator: Optional embedding generator
|
||||
hooks: Optional lifecycle hooks
|
||||
"""
|
||||
self._session = session
|
||||
self._embedding_generator = embedding_generator
|
||||
self._hooks = hooks or LifecycleHooks()
|
||||
|
||||
# Lazy-initialized services
|
||||
self._episodic: EpisodicMemory | None = None
|
||||
|
||||
async def _get_episodic(self) -> EpisodicMemory:
|
||||
"""Get or create episodic memory service."""
|
||||
if self._episodic is None:
|
||||
self._episodic = await EpisodicMemory.create(
|
||||
self._session,
|
||||
self._embedding_generator,
|
||||
)
|
||||
return self._episodic
|
||||
|
||||
@property
|
||||
def hooks(self) -> LifecycleHooks:
|
||||
"""Get the lifecycle hooks."""
|
||||
return self._hooks
|
||||
|
||||
async def spawn(
|
||||
self,
|
||||
project_id: UUID,
|
||||
agent_instance_id: UUID,
|
||||
session_id: str,
|
||||
agent_type_id: UUID | None = None,
|
||||
initial_state: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> LifecycleResult:
|
||||
"""
|
||||
Handle agent spawn - initialize working memory.
|
||||
|
||||
Creates a new working memory instance for the agent session
|
||||
and optionally populates it with initial state.
|
||||
|
||||
Args:
|
||||
project_id: Project scope
|
||||
agent_instance_id: Agent instance ID
|
||||
session_id: Session ID for working memory
|
||||
agent_type_id: Optional agent type ID
|
||||
initial_state: Optional initial state to populate
|
||||
metadata: Optional metadata for the event
|
||||
|
||||
Returns:
|
||||
LifecycleResult with spawn outcome
|
||||
"""
|
||||
start_time = datetime.now(UTC)
|
||||
|
||||
try:
|
||||
# Create working memory for the session
|
||||
working = await WorkingMemory.for_session(
|
||||
session_id=session_id,
|
||||
project_id=str(project_id),
|
||||
agent_instance_id=str(agent_instance_id),
|
||||
)
|
||||
|
||||
# Populate initial state if provided
|
||||
items_set = 0
|
||||
if initial_state:
|
||||
for key, value in initial_state.items():
|
||||
await working.set(key, value)
|
||||
items_set += 1
|
||||
|
||||
# Create and run event hooks
|
||||
event = LifecycleEvent(
|
||||
event_type="spawn",
|
||||
project_id=project_id,
|
||||
agent_instance_id=agent_instance_id,
|
||||
agent_type_id=agent_type_id,
|
||||
session_id=session_id,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
await self._hooks.run_spawn_hooks(event)
|
||||
|
||||
duration_ms = (datetime.now(UTC) - start_time).total_seconds() * 1000
|
||||
|
||||
logger.info(
|
||||
f"Agent {agent_instance_id} spawned with session {session_id}, "
|
||||
f"initial state: {items_set} items"
|
||||
)
|
||||
|
||||
return LifecycleResult(
|
||||
success=True,
|
||||
event_type="spawn",
|
||||
message="Agent spawned successfully",
|
||||
data={
|
||||
"session_id": session_id,
|
||||
"initial_items": items_set,
|
||||
},
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Spawn failed for agent {agent_instance_id}: {e}")
|
||||
return LifecycleResult(
|
||||
success=False,
|
||||
event_type="spawn",
|
||||
message=f"Spawn failed: {e}",
|
||||
)
|
||||
|
||||
async def pause(
|
||||
self,
|
||||
project_id: UUID,
|
||||
agent_instance_id: UUID,
|
||||
session_id: str,
|
||||
checkpoint_id: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> LifecycleResult:
|
||||
"""
|
||||
Handle agent pause - checkpoint working memory.
|
||||
|
||||
Saves the current working memory state to a checkpoint
|
||||
that can be restored later with resume().
|
||||
|
||||
Args:
|
||||
project_id: Project scope
|
||||
agent_instance_id: Agent instance ID
|
||||
session_id: Session ID
|
||||
checkpoint_id: Optional checkpoint identifier
|
||||
metadata: Optional metadata for the event
|
||||
|
||||
Returns:
|
||||
LifecycleResult with checkpoint data
|
||||
"""
|
||||
start_time = datetime.now(UTC)
|
||||
checkpoint_id = checkpoint_id or f"checkpoint_{int(start_time.timestamp())}"
|
||||
|
||||
try:
|
||||
working = await WorkingMemory.for_session(
|
||||
session_id=session_id,
|
||||
project_id=str(project_id),
|
||||
agent_instance_id=str(agent_instance_id),
|
||||
)
|
||||
|
||||
# Get all current state
|
||||
all_keys = await working.list_keys()
|
||||
# Filter out checkpoint keys
|
||||
state_keys = [
|
||||
k for k in all_keys if not k.startswith(self.CHECKPOINT_PREFIX)
|
||||
]
|
||||
|
||||
state: dict[str, Any] = {}
|
||||
for key in state_keys:
|
||||
value = await working.get(key)
|
||||
if value is not None:
|
||||
state[key] = value
|
||||
|
||||
# Store checkpoint
|
||||
checkpoint_key = f"{self.CHECKPOINT_PREFIX}{checkpoint_id}"
|
||||
await working.set(
|
||||
checkpoint_key,
|
||||
{
|
||||
"state": state,
|
||||
"timestamp": start_time.isoformat(),
|
||||
"keys_count": len(state),
|
||||
},
|
||||
ttl_seconds=86400 * 7, # Keep checkpoint for 7 days
|
||||
)
|
||||
|
||||
# Run hooks
|
||||
event = LifecycleEvent(
|
||||
event_type="pause",
|
||||
project_id=project_id,
|
||||
agent_instance_id=agent_instance_id,
|
||||
session_id=session_id,
|
||||
metadata={**(metadata or {}), "checkpoint_id": checkpoint_id},
|
||||
)
|
||||
await self._hooks.run_pause_hooks(event)
|
||||
|
||||
duration_ms = (datetime.now(UTC) - start_time).total_seconds() * 1000
|
||||
|
||||
logger.info(
|
||||
f"Agent {agent_instance_id} paused, checkpoint {checkpoint_id} "
|
||||
f"saved with {len(state)} items"
|
||||
)
|
||||
|
||||
return LifecycleResult(
|
||||
success=True,
|
||||
event_type="pause",
|
||||
message="Agent paused successfully",
|
||||
data={
|
||||
"checkpoint_id": checkpoint_id,
|
||||
"items_saved": len(state),
|
||||
"timestamp": start_time.isoformat(),
|
||||
},
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Pause failed for agent {agent_instance_id}: {e}")
|
||||
return LifecycleResult(
|
||||
success=False,
|
||||
event_type="pause",
|
||||
message=f"Pause failed: {e}",
|
||||
)
|
||||
|
||||
async def resume(
|
||||
self,
|
||||
project_id: UUID,
|
||||
agent_instance_id: UUID,
|
||||
session_id: str,
|
||||
checkpoint_id: str,
|
||||
clear_current: bool = True,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> LifecycleResult:
|
||||
"""
|
||||
Handle agent resume - restore from checkpoint.
|
||||
|
||||
Restores working memory state from a previously saved checkpoint.
|
||||
|
||||
Args:
|
||||
project_id: Project scope
|
||||
agent_instance_id: Agent instance ID
|
||||
session_id: Session ID
|
||||
checkpoint_id: Checkpoint to restore from
|
||||
clear_current: Whether to clear current state before restoring
|
||||
metadata: Optional metadata for the event
|
||||
|
||||
Returns:
|
||||
LifecycleResult with restore outcome
|
||||
"""
|
||||
start_time = datetime.now(UTC)
|
||||
|
||||
try:
|
||||
working = await WorkingMemory.for_session(
|
||||
session_id=session_id,
|
||||
project_id=str(project_id),
|
||||
agent_instance_id=str(agent_instance_id),
|
||||
)
|
||||
|
||||
# Get checkpoint
|
||||
checkpoint_key = f"{self.CHECKPOINT_PREFIX}{checkpoint_id}"
|
||||
checkpoint = await working.get(checkpoint_key)
|
||||
|
||||
if checkpoint is None:
|
||||
return LifecycleResult(
|
||||
success=False,
|
||||
event_type="resume",
|
||||
message=f"Checkpoint '{checkpoint_id}' not found",
|
||||
)
|
||||
|
||||
# Clear current state if requested
|
||||
if clear_current:
|
||||
all_keys = await working.list_keys()
|
||||
for key in all_keys:
|
||||
if not key.startswith(self.CHECKPOINT_PREFIX):
|
||||
await working.delete(key)
|
||||
|
||||
# Restore state from checkpoint
|
||||
state = checkpoint.get("state", {})
|
||||
items_restored = 0
|
||||
for key, value in state.items():
|
||||
await working.set(key, value)
|
||||
items_restored += 1
|
||||
|
||||
# Run hooks
|
||||
event = LifecycleEvent(
|
||||
event_type="resume",
|
||||
project_id=project_id,
|
||||
agent_instance_id=agent_instance_id,
|
||||
session_id=session_id,
|
||||
metadata={**(metadata or {}), "checkpoint_id": checkpoint_id},
|
||||
)
|
||||
await self._hooks.run_resume_hooks(event)
|
||||
|
||||
duration_ms = (datetime.now(UTC) - start_time).total_seconds() * 1000
|
||||
|
||||
logger.info(
|
||||
f"Agent {agent_instance_id} resumed from checkpoint {checkpoint_id}, "
|
||||
f"restored {items_restored} items"
|
||||
)
|
||||
|
||||
return LifecycleResult(
|
||||
success=True,
|
||||
event_type="resume",
|
||||
message="Agent resumed successfully",
|
||||
data={
|
||||
"checkpoint_id": checkpoint_id,
|
||||
"items_restored": items_restored,
|
||||
"checkpoint_timestamp": checkpoint.get("timestamp"),
|
||||
},
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Resume failed for agent {agent_instance_id}: {e}")
|
||||
return LifecycleResult(
|
||||
success=False,
|
||||
event_type="resume",
|
||||
message=f"Resume failed: {e}",
|
||||
)
|
||||
|
||||
async def terminate(
|
||||
self,
|
||||
project_id: UUID,
|
||||
agent_instance_id: UUID,
|
||||
session_id: str,
|
||||
task_description: str | None = None,
|
||||
outcome: Outcome = Outcome.SUCCESS,
|
||||
lessons_learned: list[str] | None = None,
|
||||
consolidate_to_episodic: bool = True,
|
||||
cleanup_working: bool = True,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> LifecycleResult:
|
||||
"""
|
||||
Handle agent termination - consolidate to episodic memory.
|
||||
|
||||
Consolidates the session's working memory into an episodic memory
|
||||
entry, then optionally cleans up the working memory.
|
||||
|
||||
Args:
|
||||
project_id: Project scope
|
||||
agent_instance_id: Agent instance ID
|
||||
session_id: Session ID
|
||||
task_description: Description of what was accomplished
|
||||
outcome: Task outcome (SUCCESS, FAILURE, PARTIAL)
|
||||
lessons_learned: Optional list of lessons learned
|
||||
consolidate_to_episodic: Whether to create episodic entry
|
||||
cleanup_working: Whether to clear working memory
|
||||
metadata: Optional metadata for the event
|
||||
|
||||
Returns:
|
||||
LifecycleResult with termination outcome
|
||||
"""
|
||||
start_time = datetime.now(UTC)
|
||||
|
||||
try:
|
||||
working = await WorkingMemory.for_session(
|
||||
session_id=session_id,
|
||||
project_id=str(project_id),
|
||||
agent_instance_id=str(agent_instance_id),
|
||||
)
|
||||
|
||||
# Gather session state for consolidation
|
||||
all_keys = await working.list_keys()
|
||||
state_keys = [
|
||||
k for k in all_keys if not k.startswith(self.CHECKPOINT_PREFIX)
|
||||
]
|
||||
|
||||
session_state: dict[str, Any] = {}
|
||||
for key in state_keys:
|
||||
value = await working.get(key)
|
||||
if value is not None:
|
||||
session_state[key] = value
|
||||
|
||||
episode_id: str | None = None
|
||||
|
||||
# Consolidate to episodic memory
|
||||
if consolidate_to_episodic:
|
||||
episodic = await self._get_episodic()
|
||||
|
||||
description = task_description or f"Session {session_id} completed"
|
||||
|
||||
episode_data = EpisodeCreate(
|
||||
project_id=project_id,
|
||||
agent_instance_id=agent_instance_id,
|
||||
session_id=session_id,
|
||||
task_type="session_completion",
|
||||
task_description=description[:500],
|
||||
outcome=outcome,
|
||||
outcome_details=f"Session terminated with {len(session_state)} state items",
|
||||
actions=[
|
||||
{
|
||||
"type": "session_terminate",
|
||||
"state_keys": list(session_state.keys()),
|
||||
"outcome": outcome.value,
|
||||
}
|
||||
],
|
||||
context_summary=str(session_state)[:1000] if session_state else "",
|
||||
lessons_learned=lessons_learned or [],
|
||||
duration_seconds=0.0, # Unknown at this point
|
||||
tokens_used=0,
|
||||
importance_score=0.6, # Moderate importance for session ends
|
||||
)
|
||||
|
||||
episode = await episodic.record_episode(episode_data)
|
||||
episode_id = str(episode.id)
|
||||
|
||||
# Clean up working memory
|
||||
items_cleared = 0
|
||||
if cleanup_working:
|
||||
for key in all_keys:
|
||||
await working.delete(key)
|
||||
items_cleared += 1
|
||||
|
||||
# Run hooks
|
||||
event = LifecycleEvent(
|
||||
event_type="terminate",
|
||||
project_id=project_id,
|
||||
agent_instance_id=agent_instance_id,
|
||||
session_id=session_id,
|
||||
metadata={**(metadata or {}), "episode_id": episode_id},
|
||||
)
|
||||
await self._hooks.run_terminate_hooks(event)
|
||||
|
||||
duration_ms = (datetime.now(UTC) - start_time).total_seconds() * 1000
|
||||
|
||||
logger.info(
|
||||
f"Agent {agent_instance_id} terminated, session {session_id} "
|
||||
f"consolidated to episode {episode_id}"
|
||||
)
|
||||
|
||||
return LifecycleResult(
|
||||
success=True,
|
||||
event_type="terminate",
|
||||
message="Agent terminated successfully",
|
||||
data={
|
||||
"episode_id": episode_id,
|
||||
"state_items_consolidated": len(session_state),
|
||||
"items_cleared": items_cleared,
|
||||
"outcome": outcome.value,
|
||||
},
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Terminate failed for agent {agent_instance_id}: {e}")
|
||||
return LifecycleResult(
|
||||
success=False,
|
||||
event_type="terminate",
|
||||
message=f"Terminate failed: {e}",
|
||||
)
|
||||
|
||||
async def list_checkpoints(
|
||||
self,
|
||||
project_id: UUID,
|
||||
agent_instance_id: UUID,
|
||||
session_id: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
List available checkpoints for a session.
|
||||
|
||||
Args:
|
||||
project_id: Project scope
|
||||
agent_instance_id: Agent instance ID
|
||||
session_id: Session ID
|
||||
|
||||
Returns:
|
||||
List of checkpoint metadata dicts
|
||||
"""
|
||||
working = await WorkingMemory.for_session(
|
||||
session_id=session_id,
|
||||
project_id=str(project_id),
|
||||
agent_instance_id=str(agent_instance_id),
|
||||
)
|
||||
|
||||
all_keys = await working.list_keys()
|
||||
checkpoints: list[dict[str, Any]] = []
|
||||
|
||||
for key in all_keys:
|
||||
if key.startswith(self.CHECKPOINT_PREFIX):
|
||||
checkpoint_id = key[len(self.CHECKPOINT_PREFIX) :]
|
||||
checkpoint = await working.get(key)
|
||||
if checkpoint:
|
||||
checkpoints.append(
|
||||
{
|
||||
"checkpoint_id": checkpoint_id,
|
||||
"timestamp": checkpoint.get("timestamp"),
|
||||
"keys_count": checkpoint.get("keys_count", 0),
|
||||
}
|
||||
)
|
||||
|
||||
# Sort by timestamp (newest first)
|
||||
checkpoints.sort(
|
||||
key=lambda c: c.get("timestamp", ""),
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
return checkpoints
|
||||
|
||||
|
||||
# Factory function
|
||||
async def get_lifecycle_manager(
|
||||
session: AsyncSession,
|
||||
embedding_generator: Any | None = None,
|
||||
hooks: LifecycleHooks | None = None,
|
||||
) -> AgentLifecycleManager:
|
||||
"""Create a lifecycle manager instance."""
|
||||
return AgentLifecycleManager(
|
||||
session=session,
|
||||
embedding_generator=embedding_generator,
|
||||
hooks=hooks,
|
||||
)
|
||||
606
backend/app/services/memory/manager.py
Normal file
606
backend/app/services/memory/manager.py
Normal file
@@ -0,0 +1,606 @@
|
||||
"""
|
||||
Memory Manager
|
||||
|
||||
Facade for the Agent Memory System providing unified access
|
||||
to all memory types and operations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from .config import MemorySettings, get_memory_settings
|
||||
from .types import (
|
||||
Episode,
|
||||
EpisodeCreate,
|
||||
Fact,
|
||||
FactCreate,
|
||||
MemoryStats,
|
||||
MemoryType,
|
||||
Outcome,
|
||||
Procedure,
|
||||
ProcedureCreate,
|
||||
RetrievalResult,
|
||||
ScopeContext,
|
||||
ScopeLevel,
|
||||
TaskState,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MemoryManager:
|
||||
"""
|
||||
Unified facade for the Agent Memory System.
|
||||
|
||||
Provides a single entry point for all memory operations across
|
||||
working, episodic, semantic, and procedural memory types.
|
||||
|
||||
Usage:
|
||||
manager = MemoryManager.create()
|
||||
|
||||
# Working memory
|
||||
await manager.set_working("key", {"data": "value"})
|
||||
value = await manager.get_working("key")
|
||||
|
||||
# Episodic memory
|
||||
episode = await manager.record_episode(episode_data)
|
||||
similar = await manager.search_episodes("query")
|
||||
|
||||
# Semantic memory
|
||||
fact = await manager.store_fact(fact_data)
|
||||
facts = await manager.search_facts("query")
|
||||
|
||||
# Procedural memory
|
||||
procedure = await manager.record_procedure(procedure_data)
|
||||
procedures = await manager.find_procedures("context")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings: MemorySettings,
|
||||
scope: ScopeContext,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the MemoryManager.
|
||||
|
||||
Args:
|
||||
settings: Memory configuration settings
|
||||
scope: The scope context for this manager instance
|
||||
"""
|
||||
self._settings = settings
|
||||
self._scope = scope
|
||||
self._initialized = False
|
||||
|
||||
# These will be initialized when the respective sub-modules are implemented
|
||||
self._working_memory: Any | None = None
|
||||
self._episodic_memory: Any | None = None
|
||||
self._semantic_memory: Any | None = None
|
||||
self._procedural_memory: Any | None = None
|
||||
|
||||
logger.debug(
|
||||
"MemoryManager created for scope %s:%s",
|
||||
scope.scope_type.value,
|
||||
scope.scope_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
scope_type: ScopeLevel = ScopeLevel.SESSION,
|
||||
scope_id: str = "default",
|
||||
parent_scope: ScopeContext | None = None,
|
||||
settings: MemorySettings | None = None,
|
||||
) -> "MemoryManager":
|
||||
"""
|
||||
Create a new MemoryManager instance.
|
||||
|
||||
Args:
|
||||
scope_type: The scope level for this manager
|
||||
scope_id: The scope identifier
|
||||
parent_scope: Optional parent scope for inheritance
|
||||
settings: Optional custom settings (uses global if not provided)
|
||||
|
||||
Returns:
|
||||
A new MemoryManager instance
|
||||
"""
|
||||
if settings is None:
|
||||
settings = get_memory_settings()
|
||||
|
||||
scope = ScopeContext(
|
||||
scope_type=scope_type,
|
||||
scope_id=scope_id,
|
||||
parent=parent_scope,
|
||||
)
|
||||
|
||||
return cls(settings=settings, scope=scope)
|
||||
|
||||
@classmethod
|
||||
def for_session(
|
||||
cls,
|
||||
session_id: str,
|
||||
agent_instance_id: UUID | None = None,
|
||||
project_id: UUID | None = None,
|
||||
) -> "MemoryManager":
|
||||
"""
|
||||
Create a MemoryManager for a specific session.
|
||||
|
||||
Builds the appropriate scope hierarchy based on provided IDs.
|
||||
|
||||
Args:
|
||||
session_id: The session identifier
|
||||
agent_instance_id: Optional agent instance ID
|
||||
project_id: Optional project ID
|
||||
|
||||
Returns:
|
||||
A MemoryManager configured for the session scope
|
||||
"""
|
||||
settings = get_memory_settings()
|
||||
|
||||
# Build scope hierarchy
|
||||
parent: ScopeContext | None = None
|
||||
|
||||
if project_id:
|
||||
parent = ScopeContext(
|
||||
scope_type=ScopeLevel.PROJECT,
|
||||
scope_id=str(project_id),
|
||||
parent=ScopeContext(
|
||||
scope_type=ScopeLevel.GLOBAL,
|
||||
scope_id="global",
|
||||
),
|
||||
)
|
||||
|
||||
if agent_instance_id:
|
||||
parent = ScopeContext(
|
||||
scope_type=ScopeLevel.AGENT_INSTANCE,
|
||||
scope_id=str(agent_instance_id),
|
||||
parent=parent,
|
||||
)
|
||||
|
||||
scope = ScopeContext(
|
||||
scope_type=ScopeLevel.SESSION,
|
||||
scope_id=session_id,
|
||||
parent=parent,
|
||||
)
|
||||
|
||||
return cls(settings=settings, scope=scope)
|
||||
|
||||
@property
|
||||
def scope(self) -> ScopeContext:
|
||||
"""Get the current scope context."""
|
||||
return self._scope
|
||||
|
||||
@property
|
||||
def settings(self) -> MemorySettings:
|
||||
"""Get the memory settings."""
|
||||
return self._settings
|
||||
|
||||
# =========================================================================
|
||||
# Working Memory Operations
|
||||
# =========================================================================
|
||||
|
||||
async def set_working(
|
||||
self,
|
||||
key: str,
|
||||
value: Any,
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Set a value in working memory.
|
||||
|
||||
Args:
|
||||
key: The key to store the value under
|
||||
value: The value to store (must be JSON serializable)
|
||||
ttl_seconds: Optional TTL (uses default if not provided)
|
||||
"""
|
||||
# Placeholder - will be implemented in #89
|
||||
logger.debug("set_working called for key=%s (not yet implemented)", key)
|
||||
raise NotImplementedError("Working memory not yet implemented")
|
||||
|
||||
async def get_working(
|
||||
self,
|
||||
key: str,
|
||||
default: Any = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Get a value from working memory.
|
||||
|
||||
Args:
|
||||
key: The key to retrieve
|
||||
default: Default value if key not found
|
||||
|
||||
Returns:
|
||||
The stored value or default
|
||||
"""
|
||||
# Placeholder - will be implemented in #89
|
||||
logger.debug("get_working called for key=%s (not yet implemented)", key)
|
||||
raise NotImplementedError("Working memory not yet implemented")
|
||||
|
||||
async def delete_working(self, key: str) -> bool:
|
||||
"""
|
||||
Delete a value from working memory.
|
||||
|
||||
Args:
|
||||
key: The key to delete
|
||||
|
||||
Returns:
|
||||
True if the key was deleted, False if not found
|
||||
"""
|
||||
# Placeholder - will be implemented in #89
|
||||
logger.debug("delete_working called for key=%s (not yet implemented)", key)
|
||||
raise NotImplementedError("Working memory not yet implemented")
|
||||
|
||||
async def set_task_state(self, state: TaskState) -> None:
|
||||
"""
|
||||
Set the current task state in working memory.
|
||||
|
||||
Args:
|
||||
state: The task state to store
|
||||
"""
|
||||
# Placeholder - will be implemented in #89
|
||||
logger.debug(
|
||||
"set_task_state called for task=%s (not yet implemented)",
|
||||
state.task_id,
|
||||
)
|
||||
raise NotImplementedError("Working memory not yet implemented")
|
||||
|
||||
async def get_task_state(self) -> TaskState | None:
|
||||
"""
|
||||
Get the current task state from working memory.
|
||||
|
||||
Returns:
|
||||
The current task state or None
|
||||
"""
|
||||
# Placeholder - will be implemented in #89
|
||||
logger.debug("get_task_state called (not yet implemented)")
|
||||
raise NotImplementedError("Working memory not yet implemented")
|
||||
|
||||
async def create_checkpoint(self) -> str:
|
||||
"""
|
||||
Create a checkpoint of the current working memory state.
|
||||
|
||||
Returns:
|
||||
The checkpoint ID
|
||||
"""
|
||||
# Placeholder - will be implemented in #89
|
||||
logger.debug("create_checkpoint called (not yet implemented)")
|
||||
raise NotImplementedError("Working memory not yet implemented")
|
||||
|
||||
async def restore_checkpoint(self, checkpoint_id: str) -> None:
|
||||
"""
|
||||
Restore working memory from a checkpoint.
|
||||
|
||||
Args:
|
||||
checkpoint_id: The checkpoint to restore from
|
||||
"""
|
||||
# Placeholder - will be implemented in #89
|
||||
logger.debug(
|
||||
"restore_checkpoint called for id=%s (not yet implemented)",
|
||||
checkpoint_id,
|
||||
)
|
||||
raise NotImplementedError("Working memory not yet implemented")
|
||||
|
||||
# =========================================================================
|
||||
# Episodic Memory Operations
|
||||
# =========================================================================
|
||||
|
||||
async def record_episode(self, episode: EpisodeCreate) -> Episode:
|
||||
"""
|
||||
Record a new episode in episodic memory.
|
||||
|
||||
Args:
|
||||
episode: The episode data to record
|
||||
|
||||
Returns:
|
||||
The created episode with ID
|
||||
"""
|
||||
# Placeholder - will be implemented in #90
|
||||
logger.debug(
|
||||
"record_episode called for task=%s (not yet implemented)",
|
||||
episode.task_type,
|
||||
)
|
||||
raise NotImplementedError("Episodic memory not yet implemented")
|
||||
|
||||
async def search_episodes(
|
||||
self,
|
||||
query: str,
|
||||
limit: int | None = None,
|
||||
) -> RetrievalResult[Episode]:
|
||||
"""
|
||||
Search for similar episodes.
|
||||
|
||||
Args:
|
||||
query: The search query
|
||||
limit: Maximum results to return
|
||||
|
||||
Returns:
|
||||
Retrieval result with matching episodes
|
||||
"""
|
||||
# Placeholder - will be implemented in #90
|
||||
logger.debug(
|
||||
"search_episodes called for query=%s (not yet implemented)",
|
||||
query[:50],
|
||||
)
|
||||
raise NotImplementedError("Episodic memory not yet implemented")
|
||||
|
||||
async def get_recent_episodes(
|
||||
self,
|
||||
limit: int = 10,
|
||||
) -> list[Episode]:
|
||||
"""
|
||||
Get the most recent episodes.
|
||||
|
||||
Args:
|
||||
limit: Maximum episodes to return
|
||||
|
||||
Returns:
|
||||
List of recent episodes
|
||||
"""
|
||||
# Placeholder - will be implemented in #90
|
||||
logger.debug("get_recent_episodes called (not yet implemented)")
|
||||
raise NotImplementedError("Episodic memory not yet implemented")
|
||||
|
||||
async def get_episodes_by_outcome(
|
||||
self,
|
||||
outcome: Outcome,
|
||||
limit: int = 10,
|
||||
) -> list[Episode]:
|
||||
"""
|
||||
Get episodes by outcome.
|
||||
|
||||
Args:
|
||||
outcome: The outcome to filter by
|
||||
limit: Maximum episodes to return
|
||||
|
||||
Returns:
|
||||
List of episodes with the specified outcome
|
||||
"""
|
||||
# Placeholder - will be implemented in #90
|
||||
logger.debug(
|
||||
"get_episodes_by_outcome called for outcome=%s (not yet implemented)",
|
||||
outcome.value,
|
||||
)
|
||||
raise NotImplementedError("Episodic memory not yet implemented")
|
||||
|
||||
# =========================================================================
|
||||
# Semantic Memory Operations
|
||||
# =========================================================================
|
||||
|
||||
async def store_fact(self, fact: FactCreate) -> Fact:
|
||||
"""
|
||||
Store a new fact in semantic memory.
|
||||
|
||||
Args:
|
||||
fact: The fact data to store
|
||||
|
||||
Returns:
|
||||
The created fact with ID
|
||||
"""
|
||||
# Placeholder - will be implemented in #91
|
||||
logger.debug(
|
||||
"store_fact called for %s %s %s (not yet implemented)",
|
||||
fact.subject,
|
||||
fact.predicate,
|
||||
fact.object,
|
||||
)
|
||||
raise NotImplementedError("Semantic memory not yet implemented")
|
||||
|
||||
async def search_facts(
|
||||
self,
|
||||
query: str,
|
||||
limit: int | None = None,
|
||||
) -> RetrievalResult[Fact]:
|
||||
"""
|
||||
Search for facts matching a query.
|
||||
|
||||
Args:
|
||||
query: The search query
|
||||
limit: Maximum results to return
|
||||
|
||||
Returns:
|
||||
Retrieval result with matching facts
|
||||
"""
|
||||
# Placeholder - will be implemented in #91
|
||||
logger.debug(
|
||||
"search_facts called for query=%s (not yet implemented)",
|
||||
query[:50],
|
||||
)
|
||||
raise NotImplementedError("Semantic memory not yet implemented")
|
||||
|
||||
async def get_facts_by_entity(
|
||||
self,
|
||||
entity: str,
|
||||
limit: int = 20,
|
||||
) -> list[Fact]:
|
||||
"""
|
||||
Get facts related to an entity.
|
||||
|
||||
Args:
|
||||
entity: The entity to search for
|
||||
limit: Maximum facts to return
|
||||
|
||||
Returns:
|
||||
List of facts mentioning the entity
|
||||
"""
|
||||
# Placeholder - will be implemented in #91
|
||||
logger.debug(
|
||||
"get_facts_by_entity called for entity=%s (not yet implemented)",
|
||||
entity,
|
||||
)
|
||||
raise NotImplementedError("Semantic memory not yet implemented")
|
||||
|
||||
async def reinforce_fact(self, fact_id: UUID) -> Fact:
|
||||
"""
|
||||
Reinforce a fact (increase confidence from repeated learning).
|
||||
|
||||
Args:
|
||||
fact_id: The fact to reinforce
|
||||
|
||||
Returns:
|
||||
The updated fact
|
||||
"""
|
||||
# Placeholder - will be implemented in #91
|
||||
logger.debug(
|
||||
"reinforce_fact called for id=%s (not yet implemented)",
|
||||
fact_id,
|
||||
)
|
||||
raise NotImplementedError("Semantic memory not yet implemented")
|
||||
|
||||
# =========================================================================
|
||||
# Procedural Memory Operations
|
||||
# =========================================================================
|
||||
|
||||
async def record_procedure(self, procedure: ProcedureCreate) -> Procedure:
|
||||
"""
|
||||
Record a new procedure.
|
||||
|
||||
Args:
|
||||
procedure: The procedure data to record
|
||||
|
||||
Returns:
|
||||
The created procedure with ID
|
||||
"""
|
||||
# Placeholder - will be implemented in #92
|
||||
logger.debug(
|
||||
"record_procedure called for name=%s (not yet implemented)",
|
||||
procedure.name,
|
||||
)
|
||||
raise NotImplementedError("Procedural memory not yet implemented")
|
||||
|
||||
async def find_procedures(
|
||||
self,
|
||||
context: str,
|
||||
limit: int = 5,
|
||||
) -> list[Procedure]:
|
||||
"""
|
||||
Find procedures matching the current context.
|
||||
|
||||
Args:
|
||||
context: The context to match against
|
||||
limit: Maximum procedures to return
|
||||
|
||||
Returns:
|
||||
List of matching procedures sorted by success rate
|
||||
"""
|
||||
# Placeholder - will be implemented in #92
|
||||
logger.debug(
|
||||
"find_procedures called for context=%s (not yet implemented)",
|
||||
context[:50],
|
||||
)
|
||||
raise NotImplementedError("Procedural memory not yet implemented")
|
||||
|
||||
async def record_procedure_outcome(
|
||||
self,
|
||||
procedure_id: UUID,
|
||||
success: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Record the outcome of using a procedure.
|
||||
|
||||
Args:
|
||||
procedure_id: The procedure that was used
|
||||
success: Whether the procedure succeeded
|
||||
"""
|
||||
# Placeholder - will be implemented in #92
|
||||
logger.debug(
|
||||
"record_procedure_outcome called for id=%s success=%s (not yet implemented)",
|
||||
procedure_id,
|
||||
success,
|
||||
)
|
||||
raise NotImplementedError("Procedural memory not yet implemented")
|
||||
|
||||
# =========================================================================
|
||||
# Cross-Memory Operations
|
||||
# =========================================================================
|
||||
|
||||
async def recall(
|
||||
self,
|
||||
query: str,
|
||||
memory_types: list[MemoryType] | None = None,
|
||||
limit: int = 10,
|
||||
) -> dict[MemoryType, list[Any]]:
|
||||
"""
|
||||
Recall memories across multiple memory types.
|
||||
|
||||
Args:
|
||||
query: The search query
|
||||
memory_types: Memory types to search (all if not specified)
|
||||
limit: Maximum results per type
|
||||
|
||||
Returns:
|
||||
Dictionary mapping memory types to results
|
||||
"""
|
||||
# Placeholder - will be implemented in #97 (Component Integration)
|
||||
logger.debug("recall called for query=%s (not yet implemented)", query[:50])
|
||||
raise NotImplementedError("Cross-memory recall not yet implemented")
|
||||
|
||||
async def get_stats(
|
||||
self,
|
||||
memory_type: MemoryType | None = None,
|
||||
) -> list[MemoryStats]:
|
||||
"""
|
||||
Get memory statistics.
|
||||
|
||||
Args:
|
||||
memory_type: Specific type or all if not specified
|
||||
|
||||
Returns:
|
||||
List of statistics for requested memory types
|
||||
"""
|
||||
# Placeholder - will be implemented in #100 (Metrics & Observability)
|
||||
logger.debug("get_stats called (not yet implemented)")
|
||||
raise NotImplementedError("Memory stats not yet implemented")
|
||||
|
||||
# =========================================================================
|
||||
# Lifecycle Operations
|
||||
# =========================================================================
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""
|
||||
Initialize the memory manager and its backends.
|
||||
|
||||
Should be called before using the manager.
|
||||
"""
|
||||
if self._initialized:
|
||||
logger.debug("MemoryManager already initialized")
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"Initializing MemoryManager for scope %s:%s",
|
||||
self._scope.scope_type.value,
|
||||
self._scope.scope_id,
|
||||
)
|
||||
|
||||
# TODO: Initialize backends when implemented
|
||||
|
||||
self._initialized = True
|
||||
logger.info("MemoryManager initialized successfully")
|
||||
|
||||
async def close(self) -> None:
|
||||
"""
|
||||
Close the memory manager and release resources.
|
||||
|
||||
Should be called when done using the manager.
|
||||
"""
|
||||
if not self._initialized:
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"Closing MemoryManager for scope %s:%s",
|
||||
self._scope.scope_type.value,
|
||||
self._scope.scope_id,
|
||||
)
|
||||
|
||||
# TODO: Close backends when implemented
|
||||
|
||||
self._initialized = False
|
||||
logger.info("MemoryManager closed successfully")
|
||||
|
||||
async def __aenter__(self) -> "MemoryManager":
|
||||
"""Async context manager entry."""
|
||||
await self.initialize()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
"""Async context manager exit."""
|
||||
await self.close()
|
||||
40
backend/app/services/memory/mcp/__init__.py
Normal file
40
backend/app/services/memory/mcp/__init__.py
Normal file
@@ -0,0 +1,40 @@
|
||||
# app/services/memory/mcp/__init__.py
|
||||
"""
|
||||
MCP Tools for Agent Memory System.
|
||||
|
||||
Exposes memory operations as MCP-compatible tools that agents can invoke:
|
||||
- remember: Store data in memory
|
||||
- recall: Retrieve from memory
|
||||
- forget: Remove from memory
|
||||
- reflect: Analyze patterns
|
||||
- get_memory_stats: Usage statistics
|
||||
- search_procedures: Find relevant procedures
|
||||
- record_outcome: Record task success/failure
|
||||
"""
|
||||
|
||||
from .service import MemoryToolService, get_memory_tool_service
|
||||
from .tools import (
|
||||
MEMORY_TOOL_DEFINITIONS,
|
||||
ForgetArgs,
|
||||
GetMemoryStatsArgs,
|
||||
MemoryToolDefinition,
|
||||
RecallArgs,
|
||||
RecordOutcomeArgs,
|
||||
ReflectArgs,
|
||||
RememberArgs,
|
||||
SearchProceduresArgs,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"MEMORY_TOOL_DEFINITIONS",
|
||||
"ForgetArgs",
|
||||
"GetMemoryStatsArgs",
|
||||
"MemoryToolDefinition",
|
||||
"MemoryToolService",
|
||||
"RecallArgs",
|
||||
"RecordOutcomeArgs",
|
||||
"ReflectArgs",
|
||||
"RememberArgs",
|
||||
"SearchProceduresArgs",
|
||||
"get_memory_tool_service",
|
||||
]
|
||||
1086
backend/app/services/memory/mcp/service.py
Normal file
1086
backend/app/services/memory/mcp/service.py
Normal file
File diff suppressed because it is too large
Load Diff
485
backend/app/services/memory/mcp/tools.py
Normal file
485
backend/app/services/memory/mcp/tools.py
Normal file
@@ -0,0 +1,485 @@
|
||||
# app/services/memory/mcp/tools.py
|
||||
"""
|
||||
MCP Tool Definitions for Agent Memory System.
|
||||
|
||||
Defines the schema and metadata for memory-related MCP tools.
|
||||
These tools are invoked by AI agents to interact with the memory system.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# OutcomeType alias - uses core Outcome enum from types module for consistency
|
||||
from app.services.memory.types import Outcome as OutcomeType
|
||||
|
||||
|
||||
class MemoryType(str, Enum):
|
||||
"""Types of memory for storage operations."""
|
||||
|
||||
WORKING = "working"
|
||||
EPISODIC = "episodic"
|
||||
SEMANTIC = "semantic"
|
||||
PROCEDURAL = "procedural"
|
||||
|
||||
|
||||
class AnalysisType(str, Enum):
|
||||
"""Types of pattern analysis for the reflect tool."""
|
||||
|
||||
RECENT_PATTERNS = "recent_patterns"
|
||||
SUCCESS_FACTORS = "success_factors"
|
||||
FAILURE_PATTERNS = "failure_patterns"
|
||||
COMMON_PROCEDURES = "common_procedures"
|
||||
LEARNING_PROGRESS = "learning_progress"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tool Argument Schemas (Pydantic models for validation)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class RememberArgs(BaseModel):
|
||||
"""Arguments for the 'remember' tool."""
|
||||
|
||||
memory_type: MemoryType = Field(
|
||||
...,
|
||||
description="Type of memory to store in: working, episodic, semantic, or procedural",
|
||||
)
|
||||
content: str = Field(
|
||||
...,
|
||||
description="The content to remember. Can be text, facts, or procedure steps.",
|
||||
min_length=1,
|
||||
max_length=10000,
|
||||
)
|
||||
key: str | None = Field(
|
||||
None,
|
||||
description="Optional key for working memory entries. Required for working memory type.",
|
||||
max_length=256,
|
||||
)
|
||||
importance: float = Field(
|
||||
0.5,
|
||||
description="Importance score from 0.0 (low) to 1.0 (critical)",
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
)
|
||||
ttl_seconds: int | None = Field(
|
||||
None,
|
||||
description="Time-to-live in seconds for working memory. None for permanent storage.",
|
||||
ge=1,
|
||||
le=86400 * 30, # Max 30 days
|
||||
)
|
||||
metadata: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Additional metadata to store with the memory",
|
||||
)
|
||||
# For semantic memory (facts)
|
||||
subject: str | None = Field(
|
||||
None,
|
||||
description="Subject of the fact (for semantic memory)",
|
||||
max_length=256,
|
||||
)
|
||||
predicate: str | None = Field(
|
||||
None,
|
||||
description="Predicate/relationship (for semantic memory)",
|
||||
max_length=256,
|
||||
)
|
||||
object_value: str | None = Field(
|
||||
None,
|
||||
description="Object of the fact (for semantic memory)",
|
||||
max_length=1000,
|
||||
)
|
||||
# For procedural memory
|
||||
trigger: str | None = Field(
|
||||
None,
|
||||
description="Trigger condition for the procedure (for procedural memory)",
|
||||
max_length=500,
|
||||
)
|
||||
steps: list[dict[str, Any]] | None = Field(
|
||||
None,
|
||||
description="Procedure steps as a list of action dictionaries",
|
||||
)
|
||||
|
||||
|
||||
class RecallArgs(BaseModel):
|
||||
"""Arguments for the 'recall' tool."""
|
||||
|
||||
query: str = Field(
|
||||
...,
|
||||
description="Search query to find relevant memories",
|
||||
min_length=1,
|
||||
max_length=1000,
|
||||
)
|
||||
memory_types: list[MemoryType] = Field(
|
||||
default_factory=lambda: [MemoryType.EPISODIC, MemoryType.SEMANTIC],
|
||||
description="Types of memory to search in",
|
||||
)
|
||||
limit: int = Field(
|
||||
10,
|
||||
description="Maximum number of results to return",
|
||||
ge=1,
|
||||
le=100,
|
||||
)
|
||||
min_relevance: float = Field(
|
||||
0.0,
|
||||
description="Minimum relevance score (0.0-1.0) for results",
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
)
|
||||
filters: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Additional filters (e.g., outcome, task_type, date range)",
|
||||
)
|
||||
include_context: bool = Field(
|
||||
True,
|
||||
description="Whether to include surrounding context in results",
|
||||
)
|
||||
|
||||
|
||||
class ForgetArgs(BaseModel):
|
||||
"""Arguments for the 'forget' tool."""
|
||||
|
||||
memory_type: MemoryType = Field(
|
||||
...,
|
||||
description="Type of memory to remove from",
|
||||
)
|
||||
key: str | None = Field(
|
||||
None,
|
||||
description="Key to remove (for working memory)",
|
||||
max_length=256,
|
||||
)
|
||||
memory_id: str | None = Field(
|
||||
None,
|
||||
description="Specific memory ID to remove (for episodic/semantic/procedural)",
|
||||
)
|
||||
pattern: str | None = Field(
|
||||
None,
|
||||
description="Pattern to match for bulk removal (use with caution)",
|
||||
max_length=500,
|
||||
)
|
||||
confirm_bulk: bool = Field(
|
||||
False,
|
||||
description="Must be True to confirm bulk deletion when using pattern",
|
||||
)
|
||||
|
||||
|
||||
class ReflectArgs(BaseModel):
|
||||
"""Arguments for the 'reflect' tool."""
|
||||
|
||||
analysis_type: AnalysisType = Field(
|
||||
...,
|
||||
description="Type of pattern analysis to perform",
|
||||
)
|
||||
scope: str | None = Field(
|
||||
None,
|
||||
description="Optional scope to limit analysis (e.g., task_type, time range)",
|
||||
max_length=500,
|
||||
)
|
||||
depth: int = Field(
|
||||
3,
|
||||
description="Depth of analysis (1=surface, 5=deep)",
|
||||
ge=1,
|
||||
le=5,
|
||||
)
|
||||
include_examples: bool = Field(
|
||||
True,
|
||||
description="Whether to include example memories in the analysis",
|
||||
)
|
||||
max_items: int = Field(
|
||||
10,
|
||||
description="Maximum number of patterns/examples to analyze",
|
||||
ge=1,
|
||||
le=50,
|
||||
)
|
||||
|
||||
|
||||
class GetMemoryStatsArgs(BaseModel):
|
||||
"""Arguments for the 'get_memory_stats' tool."""
|
||||
|
||||
include_breakdown: bool = Field(
|
||||
True,
|
||||
description="Include breakdown by memory type",
|
||||
)
|
||||
include_recent_activity: bool = Field(
|
||||
True,
|
||||
description="Include recent memory activity summary",
|
||||
)
|
||||
time_range_days: int = Field(
|
||||
7,
|
||||
description="Time range for activity analysis in days",
|
||||
ge=1,
|
||||
le=90,
|
||||
)
|
||||
|
||||
|
||||
class SearchProceduresArgs(BaseModel):
|
||||
"""Arguments for the 'search_procedures' tool."""
|
||||
|
||||
trigger: str = Field(
|
||||
...,
|
||||
description="Trigger or situation to find procedures for",
|
||||
min_length=1,
|
||||
max_length=500,
|
||||
)
|
||||
task_type: str | None = Field(
|
||||
None,
|
||||
description="Optional task type to filter procedures",
|
||||
max_length=100,
|
||||
)
|
||||
min_success_rate: float = Field(
|
||||
0.5,
|
||||
description="Minimum success rate (0.0-1.0) for returned procedures",
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
)
|
||||
limit: int = Field(
|
||||
5,
|
||||
description="Maximum number of procedures to return",
|
||||
ge=1,
|
||||
le=20,
|
||||
)
|
||||
include_steps: bool = Field(
|
||||
True,
|
||||
description="Whether to include detailed steps in the response",
|
||||
)
|
||||
|
||||
|
||||
class RecordOutcomeArgs(BaseModel):
|
||||
"""Arguments for the 'record_outcome' tool."""
|
||||
|
||||
task_type: str = Field(
|
||||
...,
|
||||
description="Type of task that was executed",
|
||||
min_length=1,
|
||||
max_length=100,
|
||||
)
|
||||
outcome: OutcomeType = Field(
|
||||
...,
|
||||
description="Outcome of the task execution",
|
||||
)
|
||||
procedure_id: str | None = Field(
|
||||
None,
|
||||
description="ID of the procedure that was followed (if any)",
|
||||
)
|
||||
context: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Context in which the task was executed",
|
||||
)
|
||||
lessons_learned: str | None = Field(
|
||||
None,
|
||||
description="What was learned from this execution",
|
||||
max_length=2000,
|
||||
)
|
||||
duration_seconds: float | None = Field(
|
||||
None,
|
||||
description="How long the task took to execute",
|
||||
ge=0.0,
|
||||
)
|
||||
error_details: str | None = Field(
|
||||
None,
|
||||
description="Details about any errors encountered (for failures)",
|
||||
max_length=2000,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tool Definition Structure
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryToolDefinition:
|
||||
"""Definition of an MCP tool for the memory system."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
args_schema: type[BaseModel]
|
||||
input_schema: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Generate input schema from Pydantic model."""
|
||||
if not self.input_schema:
|
||||
self.input_schema = self.args_schema.model_json_schema()
|
||||
|
||||
def to_mcp_format(self) -> dict[str, Any]:
|
||||
"""Convert to MCP tool format."""
|
||||
return {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"inputSchema": self.input_schema,
|
||||
}
|
||||
|
||||
def validate_args(self, args: dict[str, Any]) -> BaseModel:
|
||||
"""Validate and parse arguments."""
|
||||
return self.args_schema.model_validate(args)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tool Definitions
|
||||
# ============================================================================
|
||||
|
||||
|
||||
REMEMBER_TOOL = MemoryToolDefinition(
|
||||
name="remember",
|
||||
description="""Store information in the agent's memory system.
|
||||
|
||||
Use this tool to:
|
||||
- Store temporary data in working memory (key-value with optional TTL)
|
||||
- Record important events in episodic memory (automatically done on session end)
|
||||
- Store facts/knowledge in semantic memory (subject-predicate-object triples)
|
||||
- Save procedures in procedural memory (trigger conditions and steps)
|
||||
|
||||
Examples:
|
||||
- Working memory: {"memory_type": "working", "key": "current_task", "content": "Implementing auth", "ttl_seconds": 3600}
|
||||
- Semantic fact: {"memory_type": "semantic", "subject": "User", "predicate": "prefers", "object_value": "dark mode", "content": "User preference noted"}
|
||||
- Procedure: {"memory_type": "procedural", "trigger": "When creating a new file", "steps": [{"action": "check_exists"}, {"action": "create"}], "content": "File creation procedure"}
|
||||
""",
|
||||
args_schema=RememberArgs,
|
||||
)
|
||||
|
||||
|
||||
RECALL_TOOL = MemoryToolDefinition(
|
||||
name="recall",
|
||||
description="""Retrieve information from the agent's memory system.
|
||||
|
||||
Use this tool to:
|
||||
- Search for relevant past experiences (episodic)
|
||||
- Look up known facts and knowledge (semantic)
|
||||
- Find applicable procedures for current task (procedural)
|
||||
- Get current session state (working)
|
||||
|
||||
The query supports semantic search - describe what you're looking for in natural language.
|
||||
|
||||
Examples:
|
||||
- {"query": "How did I handle authentication errors before?", "memory_types": ["episodic"]}
|
||||
- {"query": "What are the user's preferences?", "memory_types": ["semantic"], "limit": 5}
|
||||
- {"query": "database connection", "memory_types": ["episodic", "semantic", "procedural"], "filters": {"outcome": "success"}}
|
||||
""",
|
||||
args_schema=RecallArgs,
|
||||
)
|
||||
|
||||
|
||||
FORGET_TOOL = MemoryToolDefinition(
|
||||
name="forget",
|
||||
description="""Remove information from the agent's memory system.
|
||||
|
||||
Use this tool to:
|
||||
- Clear temporary working memory entries
|
||||
- Remove specific memories by ID
|
||||
- Bulk remove memories matching a pattern (requires confirmation)
|
||||
|
||||
WARNING: Deletion is permanent. Use with caution.
|
||||
|
||||
Examples:
|
||||
- Working memory: {"memory_type": "working", "key": "temp_calculation"}
|
||||
- Specific memory: {"memory_type": "episodic", "memory_id": "ep-123"}
|
||||
- Bulk (requires confirm): {"memory_type": "working", "pattern": "cache_*", "confirm_bulk": true}
|
||||
""",
|
||||
args_schema=ForgetArgs,
|
||||
)
|
||||
|
||||
|
||||
REFLECT_TOOL = MemoryToolDefinition(
|
||||
name="reflect",
|
||||
description="""Analyze patterns in the agent's memory to gain insights.
|
||||
|
||||
Use this tool to:
|
||||
- Identify patterns in recent work
|
||||
- Understand what leads to success/failure
|
||||
- Learn from past experiences
|
||||
- Track learning progress over time
|
||||
|
||||
Analysis types:
|
||||
- recent_patterns: What patterns appear in recent work
|
||||
- success_factors: What conditions lead to success
|
||||
- failure_patterns: What causes failures and how to avoid them
|
||||
- common_procedures: Most frequently used procedures
|
||||
- learning_progress: How knowledge has grown over time
|
||||
|
||||
Examples:
|
||||
- {"analysis_type": "success_factors", "scope": "code_review", "depth": 3}
|
||||
- {"analysis_type": "failure_patterns", "include_examples": true, "max_items": 5}
|
||||
""",
|
||||
args_schema=ReflectArgs,
|
||||
)
|
||||
|
||||
|
||||
GET_MEMORY_STATS_TOOL = MemoryToolDefinition(
|
||||
name="get_memory_stats",
|
||||
description="""Get statistics about the agent's memory usage.
|
||||
|
||||
Returns information about:
|
||||
- Total memories stored by type
|
||||
- Storage utilization
|
||||
- Recent activity summary
|
||||
- Memory health indicators
|
||||
|
||||
Use this to understand memory capacity and usage patterns.
|
||||
|
||||
Examples:
|
||||
- {"include_breakdown": true, "include_recent_activity": true}
|
||||
- {"time_range_days": 30, "include_breakdown": true}
|
||||
""",
|
||||
args_schema=GetMemoryStatsArgs,
|
||||
)
|
||||
|
||||
|
||||
SEARCH_PROCEDURES_TOOL = MemoryToolDefinition(
|
||||
name="search_procedures",
|
||||
description="""Find relevant procedures for a given situation.
|
||||
|
||||
Use this tool when you need to:
|
||||
- Find the best way to handle a situation
|
||||
- Look up proven approaches to problems
|
||||
- Get step-by-step guidance for tasks
|
||||
|
||||
Returns procedures ranked by relevance and success rate.
|
||||
|
||||
Examples:
|
||||
- {"trigger": "Deploying to production", "min_success_rate": 0.8}
|
||||
- {"trigger": "Handling merge conflicts", "task_type": "git_operations", "limit": 3}
|
||||
""",
|
||||
args_schema=SearchProceduresArgs,
|
||||
)
|
||||
|
||||
|
||||
RECORD_OUTCOME_TOOL = MemoryToolDefinition(
|
||||
name="record_outcome",
|
||||
description="""Record the outcome of a task execution.
|
||||
|
||||
Use this tool after completing a task to:
|
||||
- Update procedure success/failure rates
|
||||
- Store lessons learned for future reference
|
||||
- Improve procedure recommendations
|
||||
|
||||
This helps the memory system learn from experience.
|
||||
|
||||
Examples:
|
||||
- {"task_type": "code_review", "outcome": "success", "lessons_learned": "Breaking changes caught early"}
|
||||
- {"task_type": "deployment", "outcome": "failure", "error_details": "Database migration timeout", "lessons_learned": "Need to test migrations locally first"}
|
||||
""",
|
||||
args_schema=RecordOutcomeArgs,
|
||||
)
|
||||
|
||||
|
||||
# All tool definitions in a dictionary for easy lookup
|
||||
MEMORY_TOOL_DEFINITIONS: dict[str, MemoryToolDefinition] = {
|
||||
"remember": REMEMBER_TOOL,
|
||||
"recall": RECALL_TOOL,
|
||||
"forget": FORGET_TOOL,
|
||||
"reflect": REFLECT_TOOL,
|
||||
"get_memory_stats": GET_MEMORY_STATS_TOOL,
|
||||
"search_procedures": SEARCH_PROCEDURES_TOOL,
|
||||
"record_outcome": RECORD_OUTCOME_TOOL,
|
||||
}
|
||||
|
||||
|
||||
def get_all_tool_schemas() -> list[dict[str, Any]]:
|
||||
"""Get MCP-formatted schemas for all memory tools."""
|
||||
return [tool.to_mcp_format() for tool in MEMORY_TOOL_DEFINITIONS.values()]
|
||||
|
||||
|
||||
def get_tool_definition(name: str) -> MemoryToolDefinition | None:
|
||||
"""Get a specific tool definition by name."""
|
||||
return MEMORY_TOOL_DEFINITIONS.get(name)
|
||||
18
backend/app/services/memory/metrics/__init__.py
Normal file
18
backend/app/services/memory/metrics/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# app/services/memory/metrics/__init__.py
|
||||
"""Memory Metrics module."""
|
||||
|
||||
from .collector import (
|
||||
MemoryMetrics,
|
||||
get_memory_metrics,
|
||||
record_memory_operation,
|
||||
record_retrieval,
|
||||
reset_memory_metrics,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"MemoryMetrics",
|
||||
"get_memory_metrics",
|
||||
"record_memory_operation",
|
||||
"record_retrieval",
|
||||
"reset_memory_metrics",
|
||||
]
|
||||
542
backend/app/services/memory/metrics/collector.py
Normal file
542
backend/app/services/memory/metrics/collector.py
Normal file
@@ -0,0 +1,542 @@
|
||||
# app/services/memory/metrics/collector.py
|
||||
"""
|
||||
Memory Metrics Collector
|
||||
|
||||
Collects and exposes metrics for the memory system.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections import Counter, defaultdict, deque
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MetricType(str, Enum):
|
||||
"""Types of metrics."""
|
||||
|
||||
COUNTER = "counter"
|
||||
GAUGE = "gauge"
|
||||
HISTOGRAM = "histogram"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MetricValue:
|
||||
"""A single metric value."""
|
||||
|
||||
name: str
|
||||
metric_type: MetricType
|
||||
value: float
|
||||
labels: dict[str, str] = field(default_factory=dict)
|
||||
timestamp: datetime = field(default_factory=lambda: datetime.now(UTC))
|
||||
|
||||
|
||||
@dataclass
|
||||
class HistogramBucket:
|
||||
"""Histogram bucket for distribution metrics."""
|
||||
|
||||
le: float # Less than or equal
|
||||
count: int = 0
|
||||
|
||||
|
||||
class MemoryMetrics:
|
||||
"""
|
||||
Collects memory system metrics.
|
||||
|
||||
Metrics tracked:
|
||||
- Memory operations (get/set/delete by type and scope)
|
||||
- Retrieval operations and latencies
|
||||
- Memory item counts by type
|
||||
- Consolidation operations and durations
|
||||
- Cache hit/miss rates
|
||||
- Procedure success rates
|
||||
- Embedding operations
|
||||
"""
|
||||
|
||||
# Maximum samples to keep in histogram (circular buffer)
|
||||
MAX_HISTOGRAM_SAMPLES = 10000
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize MemoryMetrics."""
|
||||
self._counters: dict[str, Counter[str]] = defaultdict(Counter)
|
||||
self._gauges: dict[str, dict[str, float]] = defaultdict(dict)
|
||||
# Use deque with maxlen for bounded memory (circular buffer)
|
||||
self._histograms: dict[str, deque[float]] = defaultdict(
|
||||
lambda: deque(maxlen=self.MAX_HISTOGRAM_SAMPLES)
|
||||
)
|
||||
self._histogram_buckets: dict[str, list[HistogramBucket]] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# Initialize histogram buckets
|
||||
self._init_histogram_buckets()
|
||||
|
||||
def _init_histogram_buckets(self) -> None:
|
||||
"""Initialize histogram buckets for latency metrics."""
|
||||
# Fast operations (working memory)
|
||||
fast_buckets = [0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, float("inf")]
|
||||
|
||||
# Normal operations (retrieval)
|
||||
normal_buckets = [0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, float("inf")]
|
||||
|
||||
# Slow operations (consolidation)
|
||||
slow_buckets = [0.1, 0.5, 1.0, 2.5, 5.0, 10.0, 30.0, 60.0, float("inf")]
|
||||
|
||||
self._histogram_buckets["memory_working_latency_seconds"] = [
|
||||
HistogramBucket(le=b) for b in fast_buckets
|
||||
]
|
||||
self._histogram_buckets["memory_retrieval_latency_seconds"] = [
|
||||
HistogramBucket(le=b) for b in normal_buckets
|
||||
]
|
||||
self._histogram_buckets["memory_consolidation_duration_seconds"] = [
|
||||
HistogramBucket(le=b) for b in slow_buckets
|
||||
]
|
||||
self._histogram_buckets["memory_embedding_latency_seconds"] = [
|
||||
HistogramBucket(le=b) for b in normal_buckets
|
||||
]
|
||||
|
||||
# Counter methods - Operations
|
||||
|
||||
async def inc_operations(
|
||||
self,
|
||||
operation: str,
|
||||
memory_type: str,
|
||||
scope: str | None = None,
|
||||
success: bool = True,
|
||||
) -> None:
|
||||
"""Increment memory operation counter."""
|
||||
async with self._lock:
|
||||
labels = f"operation={operation},memory_type={memory_type}"
|
||||
if scope:
|
||||
labels += f",scope={scope}"
|
||||
labels += f",success={str(success).lower()}"
|
||||
self._counters["memory_operations_total"][labels] += 1
|
||||
|
||||
async def inc_retrieval(
|
||||
self,
|
||||
memory_type: str,
|
||||
strategy: str,
|
||||
results_count: int,
|
||||
) -> None:
|
||||
"""Increment retrieval counter."""
|
||||
async with self._lock:
|
||||
labels = f"memory_type={memory_type},strategy={strategy}"
|
||||
self._counters["memory_retrievals_total"][labels] += 1
|
||||
|
||||
# Track result counts as a separate metric
|
||||
self._counters["memory_retrieval_results_total"][labels] += results_count
|
||||
|
||||
async def inc_cache_hit(self, cache_type: str) -> None:
|
||||
"""Increment cache hit counter."""
|
||||
async with self._lock:
|
||||
labels = f"cache_type={cache_type}"
|
||||
self._counters["memory_cache_hits_total"][labels] += 1
|
||||
|
||||
async def inc_cache_miss(self, cache_type: str) -> None:
|
||||
"""Increment cache miss counter."""
|
||||
async with self._lock:
|
||||
labels = f"cache_type={cache_type}"
|
||||
self._counters["memory_cache_misses_total"][labels] += 1
|
||||
|
||||
async def inc_consolidation(
|
||||
self,
|
||||
consolidation_type: str,
|
||||
success: bool = True,
|
||||
) -> None:
|
||||
"""Increment consolidation counter."""
|
||||
async with self._lock:
|
||||
labels = f"type={consolidation_type},success={str(success).lower()}"
|
||||
self._counters["memory_consolidations_total"][labels] += 1
|
||||
|
||||
async def inc_procedure_execution(
|
||||
self,
|
||||
procedure_id: str | None = None,
|
||||
success: bool = True,
|
||||
) -> None:
|
||||
"""Increment procedure execution counter."""
|
||||
async with self._lock:
|
||||
labels = f"success={str(success).lower()}"
|
||||
self._counters["memory_procedure_executions_total"][labels] += 1
|
||||
|
||||
async def inc_embeddings_generated(self, memory_type: str) -> None:
|
||||
"""Increment embeddings generated counter."""
|
||||
async with self._lock:
|
||||
labels = f"memory_type={memory_type}"
|
||||
self._counters["memory_embeddings_generated_total"][labels] += 1
|
||||
|
||||
async def inc_fact_reinforcements(self) -> None:
|
||||
"""Increment fact reinforcement counter."""
|
||||
async with self._lock:
|
||||
self._counters["memory_fact_reinforcements_total"][""] += 1
|
||||
|
||||
async def inc_episodes_recorded(self, outcome: str) -> None:
|
||||
"""Increment episodes recorded counter."""
|
||||
async with self._lock:
|
||||
labels = f"outcome={outcome}"
|
||||
self._counters["memory_episodes_recorded_total"][labels] += 1
|
||||
|
||||
async def inc_anomalies_detected(self, anomaly_type: str) -> None:
|
||||
"""Increment anomaly detection counter."""
|
||||
async with self._lock:
|
||||
labels = f"anomaly_type={anomaly_type}"
|
||||
self._counters["memory_anomalies_detected_total"][labels] += 1
|
||||
|
||||
async def inc_patterns_detected(self, pattern_type: str) -> None:
|
||||
"""Increment pattern detection counter."""
|
||||
async with self._lock:
|
||||
labels = f"pattern_type={pattern_type}"
|
||||
self._counters["memory_patterns_detected_total"][labels] += 1
|
||||
|
||||
async def inc_insights_generated(self, insight_type: str) -> None:
|
||||
"""Increment insight generation counter."""
|
||||
async with self._lock:
|
||||
labels = f"insight_type={insight_type}"
|
||||
self._counters["memory_insights_generated_total"][labels] += 1
|
||||
|
||||
# Gauge methods
|
||||
|
||||
async def set_memory_items_count(
|
||||
self,
|
||||
memory_type: str,
|
||||
scope: str,
|
||||
count: int,
|
||||
) -> None:
|
||||
"""Set memory item count gauge."""
|
||||
async with self._lock:
|
||||
labels = f"memory_type={memory_type},scope={scope}"
|
||||
self._gauges["memory_items_count"][labels] = float(count)
|
||||
|
||||
async def set_memory_size_bytes(
|
||||
self,
|
||||
memory_type: str,
|
||||
scope: str,
|
||||
size_bytes: int,
|
||||
) -> None:
|
||||
"""Set memory size gauge in bytes."""
|
||||
async with self._lock:
|
||||
labels = f"memory_type={memory_type},scope={scope}"
|
||||
self._gauges["memory_size_bytes"][labels] = float(size_bytes)
|
||||
|
||||
async def set_cache_size(self, cache_type: str, size: int) -> None:
|
||||
"""Set cache size gauge."""
|
||||
async with self._lock:
|
||||
labels = f"cache_type={cache_type}"
|
||||
self._gauges["memory_cache_size"][labels] = float(size)
|
||||
|
||||
async def set_procedure_success_rate(
|
||||
self,
|
||||
procedure_name: str,
|
||||
rate: float,
|
||||
) -> None:
|
||||
"""Set procedure success rate gauge (0-1)."""
|
||||
async with self._lock:
|
||||
labels = f"procedure_name={procedure_name}"
|
||||
self._gauges["memory_procedure_success_rate"][labels] = rate
|
||||
|
||||
async def set_active_sessions(self, count: int) -> None:
|
||||
"""Set active working memory sessions gauge."""
|
||||
async with self._lock:
|
||||
self._gauges["memory_active_sessions"][""] = float(count)
|
||||
|
||||
async def set_pending_consolidations(self, count: int) -> None:
|
||||
"""Set pending consolidations gauge."""
|
||||
async with self._lock:
|
||||
self._gauges["memory_pending_consolidations"][""] = float(count)
|
||||
|
||||
# Histogram methods
|
||||
|
||||
async def observe_working_latency(self, latency_seconds: float) -> None:
|
||||
"""Observe working memory operation latency."""
|
||||
async with self._lock:
|
||||
self._observe_histogram("memory_working_latency_seconds", latency_seconds)
|
||||
|
||||
async def observe_retrieval_latency(self, latency_seconds: float) -> None:
|
||||
"""Observe retrieval latency."""
|
||||
async with self._lock:
|
||||
self._observe_histogram("memory_retrieval_latency_seconds", latency_seconds)
|
||||
|
||||
async def observe_consolidation_duration(self, duration_seconds: float) -> None:
|
||||
"""Observe consolidation duration."""
|
||||
async with self._lock:
|
||||
self._observe_histogram(
|
||||
"memory_consolidation_duration_seconds", duration_seconds
|
||||
)
|
||||
|
||||
async def observe_embedding_latency(self, latency_seconds: float) -> None:
|
||||
"""Observe embedding generation latency."""
|
||||
async with self._lock:
|
||||
self._observe_histogram("memory_embedding_latency_seconds", latency_seconds)
|
||||
|
||||
def _observe_histogram(self, name: str, value: float) -> None:
|
||||
"""Record a value in a histogram."""
|
||||
self._histograms[name].append(value)
|
||||
|
||||
# Update buckets
|
||||
if name in self._histogram_buckets:
|
||||
for bucket in self._histogram_buckets[name]:
|
||||
if value <= bucket.le:
|
||||
bucket.count += 1
|
||||
|
||||
# Export methods
|
||||
|
||||
async def get_all_metrics(self) -> list[MetricValue]:
|
||||
"""Get all metrics as MetricValue objects."""
|
||||
metrics: list[MetricValue] = []
|
||||
|
||||
async with self._lock:
|
||||
# Export counters
|
||||
for name, counter in self._counters.items():
|
||||
for labels_str, value in counter.items():
|
||||
labels = self._parse_labels(labels_str)
|
||||
metrics.append(
|
||||
MetricValue(
|
||||
name=name,
|
||||
metric_type=MetricType.COUNTER,
|
||||
value=float(value),
|
||||
labels=labels,
|
||||
)
|
||||
)
|
||||
|
||||
# Export gauges
|
||||
for name, gauge_dict in self._gauges.items():
|
||||
for labels_str, gauge_value in gauge_dict.items():
|
||||
gauge_labels = self._parse_labels(labels_str)
|
||||
metrics.append(
|
||||
MetricValue(
|
||||
name=name,
|
||||
metric_type=MetricType.GAUGE,
|
||||
value=gauge_value,
|
||||
labels=gauge_labels,
|
||||
)
|
||||
)
|
||||
|
||||
# Export histogram summaries
|
||||
for name, values in self._histograms.items():
|
||||
if values:
|
||||
metrics.append(
|
||||
MetricValue(
|
||||
name=f"{name}_count",
|
||||
metric_type=MetricType.COUNTER,
|
||||
value=float(len(values)),
|
||||
)
|
||||
)
|
||||
metrics.append(
|
||||
MetricValue(
|
||||
name=f"{name}_sum",
|
||||
metric_type=MetricType.COUNTER,
|
||||
value=sum(values),
|
||||
)
|
||||
)
|
||||
|
||||
return metrics
|
||||
|
||||
async def get_prometheus_format(self) -> str:
|
||||
"""Export metrics in Prometheus text format."""
|
||||
lines: list[str] = []
|
||||
|
||||
async with self._lock:
|
||||
# Export counters
|
||||
for name, counter in self._counters.items():
|
||||
lines.append(f"# TYPE {name} counter")
|
||||
for labels_str, value in counter.items():
|
||||
if labels_str:
|
||||
lines.append(f"{name}{{{labels_str}}} {value}")
|
||||
else:
|
||||
lines.append(f"{name} {value}")
|
||||
|
||||
# Export gauges
|
||||
for name, gauge_dict in self._gauges.items():
|
||||
lines.append(f"# TYPE {name} gauge")
|
||||
for labels_str, gauge_value in gauge_dict.items():
|
||||
if labels_str:
|
||||
lines.append(f"{name}{{{labels_str}}} {gauge_value}")
|
||||
else:
|
||||
lines.append(f"{name} {gauge_value}")
|
||||
|
||||
# Export histograms
|
||||
for name, buckets in self._histogram_buckets.items():
|
||||
lines.append(f"# TYPE {name} histogram")
|
||||
for bucket in buckets:
|
||||
le_str = "+Inf" if bucket.le == float("inf") else str(bucket.le)
|
||||
lines.append(f'{name}_bucket{{le="{le_str}"}} {bucket.count}')
|
||||
|
||||
if name in self._histograms:
|
||||
values = self._histograms[name]
|
||||
lines.append(f"{name}_count {len(values)}")
|
||||
lines.append(f"{name}_sum {sum(values)}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
async def get_summary(self) -> dict[str, Any]:
|
||||
"""Get a summary of key metrics."""
|
||||
async with self._lock:
|
||||
total_operations = sum(self._counters["memory_operations_total"].values())
|
||||
successful_operations = sum(
|
||||
v
|
||||
for k, v in self._counters["memory_operations_total"].items()
|
||||
if "success=true" in k
|
||||
)
|
||||
|
||||
total_retrievals = sum(self._counters["memory_retrievals_total"].values())
|
||||
|
||||
total_cache_hits = sum(self._counters["memory_cache_hits_total"].values())
|
||||
total_cache_misses = sum(
|
||||
self._counters["memory_cache_misses_total"].values()
|
||||
)
|
||||
cache_hit_rate = (
|
||||
total_cache_hits / (total_cache_hits + total_cache_misses)
|
||||
if (total_cache_hits + total_cache_misses) > 0
|
||||
else 0.0
|
||||
)
|
||||
|
||||
total_consolidations = sum(
|
||||
self._counters["memory_consolidations_total"].values()
|
||||
)
|
||||
|
||||
total_episodes = sum(
|
||||
self._counters["memory_episodes_recorded_total"].values()
|
||||
)
|
||||
|
||||
# Calculate average latencies
|
||||
retrieval_latencies = list(
|
||||
self._histograms.get("memory_retrieval_latency_seconds", deque())
|
||||
)
|
||||
avg_retrieval_latency = (
|
||||
sum(retrieval_latencies) / len(retrieval_latencies)
|
||||
if retrieval_latencies
|
||||
else 0.0
|
||||
)
|
||||
|
||||
return {
|
||||
"total_operations": total_operations,
|
||||
"successful_operations": successful_operations,
|
||||
"operation_success_rate": (
|
||||
successful_operations / total_operations
|
||||
if total_operations > 0
|
||||
else 1.0
|
||||
),
|
||||
"total_retrievals": total_retrievals,
|
||||
"cache_hit_rate": cache_hit_rate,
|
||||
"total_consolidations": total_consolidations,
|
||||
"total_episodes_recorded": total_episodes,
|
||||
"avg_retrieval_latency_ms": avg_retrieval_latency * 1000,
|
||||
"patterns_detected": sum(
|
||||
self._counters["memory_patterns_detected_total"].values()
|
||||
),
|
||||
"insights_generated": sum(
|
||||
self._counters["memory_insights_generated_total"].values()
|
||||
),
|
||||
"anomalies_detected": sum(
|
||||
self._counters["memory_anomalies_detected_total"].values()
|
||||
),
|
||||
"active_sessions": self._gauges.get("memory_active_sessions", {}).get(
|
||||
"", 0
|
||||
),
|
||||
"pending_consolidations": self._gauges.get(
|
||||
"memory_pending_consolidations", {}
|
||||
).get("", 0),
|
||||
}
|
||||
|
||||
async def get_cache_stats(self) -> dict[str, Any]:
|
||||
"""Get detailed cache statistics."""
|
||||
async with self._lock:
|
||||
stats: dict[str, Any] = {}
|
||||
|
||||
# Get hits/misses by cache type
|
||||
for labels_str, hits in self._counters["memory_cache_hits_total"].items():
|
||||
cache_type = self._parse_labels(labels_str).get("cache_type", "unknown")
|
||||
if cache_type not in stats:
|
||||
stats[cache_type] = {"hits": 0, "misses": 0}
|
||||
stats[cache_type]["hits"] = hits
|
||||
|
||||
for labels_str, misses in self._counters[
|
||||
"memory_cache_misses_total"
|
||||
].items():
|
||||
cache_type = self._parse_labels(labels_str).get("cache_type", "unknown")
|
||||
if cache_type not in stats:
|
||||
stats[cache_type] = {"hits": 0, "misses": 0}
|
||||
stats[cache_type]["misses"] = misses
|
||||
|
||||
# Calculate hit rates
|
||||
for data in stats.values():
|
||||
total = data["hits"] + data["misses"]
|
||||
data["hit_rate"] = data["hits"] / total if total > 0 else 0.0
|
||||
data["total"] = total
|
||||
|
||||
return stats
|
||||
|
||||
async def reset(self) -> None:
|
||||
"""Reset all metrics."""
|
||||
async with self._lock:
|
||||
self._counters.clear()
|
||||
self._gauges.clear()
|
||||
self._histograms.clear()
|
||||
self._init_histogram_buckets()
|
||||
|
||||
def _parse_labels(self, labels_str: str) -> dict[str, str]:
|
||||
"""Parse labels string into dictionary."""
|
||||
if not labels_str:
|
||||
return {}
|
||||
|
||||
labels = {}
|
||||
for pair in labels_str.split(","):
|
||||
if "=" in pair:
|
||||
key, value = pair.split("=", 1)
|
||||
labels[key.strip()] = value.strip()
|
||||
|
||||
return labels
|
||||
|
||||
|
||||
# Singleton instance
|
||||
_metrics: MemoryMetrics | None = None
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def get_memory_metrics() -> MemoryMetrics:
|
||||
"""Get the singleton MemoryMetrics instance."""
|
||||
global _metrics
|
||||
|
||||
async with _lock:
|
||||
if _metrics is None:
|
||||
_metrics = MemoryMetrics()
|
||||
return _metrics
|
||||
|
||||
|
||||
async def reset_memory_metrics() -> None:
|
||||
"""Reset the singleton instance (for testing)."""
|
||||
global _metrics
|
||||
async with _lock:
|
||||
_metrics = None
|
||||
|
||||
|
||||
# Convenience functions
|
||||
|
||||
|
||||
async def record_memory_operation(
|
||||
operation: str,
|
||||
memory_type: str,
|
||||
scope: str | None = None,
|
||||
success: bool = True,
|
||||
latency_ms: float | None = None,
|
||||
) -> None:
|
||||
"""Record a memory operation."""
|
||||
metrics = await get_memory_metrics()
|
||||
await metrics.inc_operations(operation, memory_type, scope, success)
|
||||
|
||||
if latency_ms is not None and memory_type == "working":
|
||||
await metrics.observe_working_latency(latency_ms / 1000)
|
||||
|
||||
|
||||
async def record_retrieval(
|
||||
memory_type: str,
|
||||
strategy: str,
|
||||
results_count: int,
|
||||
latency_ms: float,
|
||||
) -> None:
|
||||
"""Record a retrieval operation."""
|
||||
metrics = await get_memory_metrics()
|
||||
await metrics.inc_retrieval(memory_type, strategy, results_count)
|
||||
await metrics.observe_retrieval_latency(latency_ms / 1000)
|
||||
22
backend/app/services/memory/procedural/__init__.py
Normal file
22
backend/app/services/memory/procedural/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
# app/services/memory/procedural/__init__.py
|
||||
"""
|
||||
Procedural Memory
|
||||
|
||||
Learned skills and procedures from successful task patterns.
|
||||
"""
|
||||
|
||||
from .matching import (
|
||||
MatchContext,
|
||||
MatchResult,
|
||||
ProcedureMatcher,
|
||||
get_procedure_matcher,
|
||||
)
|
||||
from .memory import ProceduralMemory
|
||||
|
||||
__all__ = [
|
||||
"MatchContext",
|
||||
"MatchResult",
|
||||
"ProceduralMemory",
|
||||
"ProcedureMatcher",
|
||||
"get_procedure_matcher",
|
||||
]
|
||||
291
backend/app/services/memory/procedural/matching.py
Normal file
291
backend/app/services/memory/procedural/matching.py
Normal file
@@ -0,0 +1,291 @@
|
||||
# app/services/memory/procedural/matching.py
|
||||
"""
|
||||
Procedure Matching.
|
||||
|
||||
Provides utilities for matching procedures to contexts,
|
||||
ranking procedures by relevance, and suggesting procedures.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from app.services.memory.types import Procedure
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MatchResult:
|
||||
"""Result of a procedure match."""
|
||||
|
||||
procedure: Procedure
|
||||
score: float
|
||||
matched_terms: list[str] = field(default_factory=list)
|
||||
match_type: str = "keyword" # keyword, semantic, pattern
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"procedure_id": str(self.procedure.id),
|
||||
"procedure_name": self.procedure.name,
|
||||
"score": self.score,
|
||||
"matched_terms": self.matched_terms,
|
||||
"match_type": self.match_type,
|
||||
"success_rate": self.procedure.success_rate,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class MatchContext:
|
||||
"""Context for procedure matching."""
|
||||
|
||||
query: str
|
||||
task_type: str | None = None
|
||||
project_id: Any | None = None
|
||||
agent_type_id: Any | None = None
|
||||
max_results: int = 5
|
||||
min_score: float = 0.3
|
||||
require_success_rate: float | None = None
|
||||
|
||||
|
||||
class ProcedureMatcher:
|
||||
"""
|
||||
Matches procedures to contexts using multiple strategies.
|
||||
|
||||
Matching strategies:
|
||||
- Keyword matching on trigger pattern and name
|
||||
- Pattern-based matching using regex
|
||||
- Success rate weighting
|
||||
|
||||
In production, this would be augmented with vector similarity search.
|
||||
"""
|
||||
|
||||
# Common task-related keywords for boosting
|
||||
TASK_KEYWORDS: ClassVar[set[str]] = {
|
||||
"create",
|
||||
"update",
|
||||
"delete",
|
||||
"fix",
|
||||
"implement",
|
||||
"add",
|
||||
"remove",
|
||||
"refactor",
|
||||
"test",
|
||||
"deploy",
|
||||
"configure",
|
||||
"setup",
|
||||
"build",
|
||||
"debug",
|
||||
"optimize",
|
||||
}
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the matcher."""
|
||||
self._compiled_patterns: dict[str, re.Pattern[str]] = {}
|
||||
|
||||
def match(
|
||||
self,
|
||||
procedures: list[Procedure],
|
||||
context: MatchContext,
|
||||
) -> list[MatchResult]:
|
||||
"""
|
||||
Match procedures against a context.
|
||||
|
||||
Args:
|
||||
procedures: List of procedures to match
|
||||
context: Matching context
|
||||
|
||||
Returns:
|
||||
List of match results, sorted by score (highest first)
|
||||
"""
|
||||
results: list[MatchResult] = []
|
||||
|
||||
query_terms = self._extract_terms(context.query)
|
||||
query_lower = context.query.lower()
|
||||
|
||||
for procedure in procedures:
|
||||
score, matched = self._calculate_match_score(
|
||||
procedure=procedure,
|
||||
query_terms=query_terms,
|
||||
query_lower=query_lower,
|
||||
context=context,
|
||||
)
|
||||
|
||||
if score >= context.min_score:
|
||||
# Apply success rate boost
|
||||
if context.require_success_rate is not None:
|
||||
if procedure.success_rate < context.require_success_rate:
|
||||
continue
|
||||
|
||||
# Boost score based on success rate
|
||||
success_boost = procedure.success_rate * 0.2
|
||||
final_score = min(1.0, score + success_boost)
|
||||
|
||||
results.append(
|
||||
MatchResult(
|
||||
procedure=procedure,
|
||||
score=final_score,
|
||||
matched_terms=matched,
|
||||
match_type="keyword",
|
||||
)
|
||||
)
|
||||
|
||||
# Sort by score descending
|
||||
results.sort(key=lambda r: r.score, reverse=True)
|
||||
|
||||
return results[: context.max_results]
|
||||
|
||||
def _extract_terms(self, text: str) -> list[str]:
|
||||
"""Extract searchable terms from text."""
|
||||
# Remove special characters and split
|
||||
clean = re.sub(r"[^\w\s-]", " ", text.lower())
|
||||
terms = clean.split()
|
||||
|
||||
# Filter out very short terms
|
||||
return [t for t in terms if len(t) >= 2]
|
||||
|
||||
def _calculate_match_score(
|
||||
self,
|
||||
procedure: Procedure,
|
||||
query_terms: list[str],
|
||||
query_lower: str,
|
||||
context: MatchContext,
|
||||
) -> tuple[float, list[str]]:
|
||||
"""
|
||||
Calculate match score between procedure and query.
|
||||
|
||||
Returns:
|
||||
Tuple of (score, matched_terms)
|
||||
"""
|
||||
score = 0.0
|
||||
matched: list[str] = []
|
||||
|
||||
trigger_lower = procedure.trigger_pattern.lower()
|
||||
name_lower = procedure.name.lower()
|
||||
|
||||
# Exact name match - high score
|
||||
if name_lower in query_lower or query_lower in name_lower:
|
||||
score += 0.5
|
||||
matched.append(f"name:{procedure.name}")
|
||||
|
||||
# Trigger pattern match
|
||||
if trigger_lower in query_lower or query_lower in trigger_lower:
|
||||
score += 0.4
|
||||
matched.append(f"trigger:{procedure.trigger_pattern[:30]}")
|
||||
|
||||
# Term-by-term matching
|
||||
for term in query_terms:
|
||||
if term in trigger_lower:
|
||||
score += 0.1
|
||||
matched.append(term)
|
||||
elif term in name_lower:
|
||||
score += 0.08
|
||||
matched.append(term)
|
||||
|
||||
# Boost for task keywords
|
||||
if term in self.TASK_KEYWORDS:
|
||||
if term in trigger_lower or term in name_lower:
|
||||
score += 0.05
|
||||
|
||||
# Task type match if provided
|
||||
if context.task_type:
|
||||
task_type_lower = context.task_type.lower()
|
||||
if task_type_lower in trigger_lower or task_type_lower in name_lower:
|
||||
score += 0.3
|
||||
matched.append(f"task_type:{context.task_type}")
|
||||
|
||||
# Regex pattern matching on trigger
|
||||
try:
|
||||
pattern = self._get_or_compile_pattern(trigger_lower)
|
||||
if pattern and pattern.search(query_lower):
|
||||
score += 0.25
|
||||
matched.append("pattern_match")
|
||||
except re.error:
|
||||
pass # Invalid regex, skip pattern matching
|
||||
|
||||
return min(1.0, score), matched
|
||||
|
||||
def _get_or_compile_pattern(self, pattern: str) -> re.Pattern[str] | None:
|
||||
"""Get or compile a regex pattern with caching."""
|
||||
if pattern in self._compiled_patterns:
|
||||
return self._compiled_patterns[pattern]
|
||||
|
||||
# Only compile if it looks like a regex pattern
|
||||
if not any(c in pattern for c in r"\.*+?[]{}|()^$"):
|
||||
return None
|
||||
|
||||
try:
|
||||
compiled = re.compile(pattern, re.IGNORECASE)
|
||||
self._compiled_patterns[pattern] = compiled
|
||||
return compiled
|
||||
except re.error:
|
||||
return None
|
||||
|
||||
def rank_by_relevance(
|
||||
self,
|
||||
procedures: list[Procedure],
|
||||
task_type: str,
|
||||
) -> list[Procedure]:
|
||||
"""
|
||||
Rank procedures by relevance to a task type.
|
||||
|
||||
Args:
|
||||
procedures: Procedures to rank
|
||||
task_type: Task type for relevance
|
||||
|
||||
Returns:
|
||||
Procedures sorted by relevance
|
||||
"""
|
||||
context = MatchContext(
|
||||
query=task_type,
|
||||
task_type=task_type,
|
||||
min_score=0.0,
|
||||
max_results=len(procedures),
|
||||
)
|
||||
|
||||
results = self.match(procedures, context)
|
||||
return [r.procedure for r in results]
|
||||
|
||||
def suggest_procedures(
|
||||
self,
|
||||
procedures: list[Procedure],
|
||||
query: str,
|
||||
min_success_rate: float = 0.5,
|
||||
max_suggestions: int = 3,
|
||||
) -> list[MatchResult]:
|
||||
"""
|
||||
Suggest the best procedures for a query.
|
||||
|
||||
Only suggests procedures with sufficient success rate.
|
||||
|
||||
Args:
|
||||
procedures: Available procedures
|
||||
query: Query/context
|
||||
min_success_rate: Minimum success rate to suggest
|
||||
max_suggestions: Maximum suggestions
|
||||
|
||||
Returns:
|
||||
List of procedure suggestions
|
||||
"""
|
||||
context = MatchContext(
|
||||
query=query,
|
||||
max_results=max_suggestions,
|
||||
min_score=0.2,
|
||||
require_success_rate=min_success_rate,
|
||||
)
|
||||
|
||||
return self.match(procedures, context)
|
||||
|
||||
|
||||
# Singleton matcher instance
|
||||
_matcher: ProcedureMatcher | None = None
|
||||
|
||||
|
||||
def get_procedure_matcher() -> ProcedureMatcher:
|
||||
"""Get the singleton procedure matcher instance."""
|
||||
global _matcher
|
||||
if _matcher is None:
|
||||
_matcher = ProcedureMatcher()
|
||||
return _matcher
|
||||
749
backend/app/services/memory/procedural/memory.py
Normal file
749
backend/app/services/memory/procedural/memory.py
Normal file
@@ -0,0 +1,749 @@
|
||||
# app/services/memory/procedural/memory.py
|
||||
"""
|
||||
Procedural Memory Implementation.
|
||||
|
||||
Provides storage and retrieval for learned procedures (skills)
|
||||
derived from successful task execution patterns.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_, desc, or_, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.memory.procedure import Procedure as ProcedureModel
|
||||
from app.services.memory.config import get_memory_settings
|
||||
from app.services.memory.types import Procedure, ProcedureCreate, RetrievalResult, Step
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _escape_like_pattern(pattern: str) -> str:
|
||||
"""
|
||||
Escape SQL LIKE/ILIKE special characters to prevent pattern injection.
|
||||
|
||||
Characters escaped:
|
||||
- % (matches zero or more characters)
|
||||
- _ (matches exactly one character)
|
||||
- \\ (escape character itself)
|
||||
|
||||
Args:
|
||||
pattern: Raw search pattern from user input
|
||||
|
||||
Returns:
|
||||
Escaped pattern safe for use in LIKE/ILIKE queries
|
||||
"""
|
||||
# Escape backslash first, then the wildcards
|
||||
return pattern.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
|
||||
|
||||
|
||||
def _model_to_procedure(model: ProcedureModel) -> Procedure:
|
||||
"""Convert SQLAlchemy model to Procedure dataclass."""
|
||||
return Procedure(
|
||||
id=model.id, # type: ignore[arg-type]
|
||||
project_id=model.project_id, # type: ignore[arg-type]
|
||||
agent_type_id=model.agent_type_id, # type: ignore[arg-type]
|
||||
name=model.name, # type: ignore[arg-type]
|
||||
trigger_pattern=model.trigger_pattern, # type: ignore[arg-type]
|
||||
steps=model.steps or [], # type: ignore[arg-type]
|
||||
success_count=model.success_count, # type: ignore[arg-type]
|
||||
failure_count=model.failure_count, # type: ignore[arg-type]
|
||||
last_used=model.last_used, # type: ignore[arg-type]
|
||||
embedding=None, # Don't expose raw embedding
|
||||
created_at=model.created_at, # type: ignore[arg-type]
|
||||
updated_at=model.updated_at, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
|
||||
class ProceduralMemory:
|
||||
"""
|
||||
Procedural Memory Service.
|
||||
|
||||
Provides procedure storage and retrieval:
|
||||
- Record procedures from successful task patterns
|
||||
- Find matching procedures by trigger pattern
|
||||
- Track success/failure rates
|
||||
- Get best procedure for a task type
|
||||
- Update procedure steps
|
||||
|
||||
Performance target: <50ms P95 for matching
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
embedding_generator: Any | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize procedural memory.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
embedding_generator: Optional embedding generator for semantic matching
|
||||
"""
|
||||
self._session = session
|
||||
self._embedding_generator = embedding_generator
|
||||
self._settings = get_memory_settings()
|
||||
|
||||
@classmethod
|
||||
async def create(
|
||||
cls,
|
||||
session: AsyncSession,
|
||||
embedding_generator: Any | None = None,
|
||||
) -> "ProceduralMemory":
|
||||
"""
|
||||
Factory method to create ProceduralMemory.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
embedding_generator: Optional embedding generator
|
||||
|
||||
Returns:
|
||||
Configured ProceduralMemory instance
|
||||
"""
|
||||
return cls(session=session, embedding_generator=embedding_generator)
|
||||
|
||||
# =========================================================================
|
||||
# Procedure Recording
|
||||
# =========================================================================
|
||||
|
||||
async def record_procedure(self, procedure: ProcedureCreate) -> Procedure:
|
||||
"""
|
||||
Record a new procedure or update an existing one.
|
||||
|
||||
If a procedure with the same name exists in the same scope,
|
||||
its steps will be updated and success count incremented.
|
||||
|
||||
Args:
|
||||
procedure: Procedure data to record
|
||||
|
||||
Returns:
|
||||
The created or updated procedure
|
||||
"""
|
||||
# Check for existing procedure with same name
|
||||
existing = await self._find_existing_procedure(
|
||||
project_id=procedure.project_id,
|
||||
agent_type_id=procedure.agent_type_id,
|
||||
name=procedure.name,
|
||||
)
|
||||
|
||||
if existing is not None:
|
||||
# Update existing procedure
|
||||
return await self._update_existing_procedure(
|
||||
existing=existing,
|
||||
new_steps=procedure.steps,
|
||||
new_trigger=procedure.trigger_pattern,
|
||||
)
|
||||
|
||||
# Create new procedure
|
||||
now = datetime.now(UTC)
|
||||
|
||||
# Generate embedding if possible
|
||||
embedding = None
|
||||
if self._embedding_generator is not None:
|
||||
embedding_text = self._create_embedding_text(procedure)
|
||||
embedding = await self._embedding_generator.generate(embedding_text)
|
||||
|
||||
model = ProcedureModel(
|
||||
project_id=procedure.project_id,
|
||||
agent_type_id=procedure.agent_type_id,
|
||||
name=procedure.name,
|
||||
trigger_pattern=procedure.trigger_pattern,
|
||||
steps=procedure.steps,
|
||||
success_count=1, # New procedures start with 1 success (they worked)
|
||||
failure_count=0,
|
||||
last_used=now,
|
||||
embedding=embedding,
|
||||
)
|
||||
|
||||
self._session.add(model)
|
||||
await self._session.flush()
|
||||
await self._session.refresh(model)
|
||||
|
||||
logger.info(
|
||||
f"Recorded new procedure: {procedure.name} with {len(procedure.steps)} steps"
|
||||
)
|
||||
|
||||
return _model_to_procedure(model)
|
||||
|
||||
async def _find_existing_procedure(
|
||||
self,
|
||||
project_id: UUID | None,
|
||||
agent_type_id: UUID | None,
|
||||
name: str,
|
||||
) -> ProcedureModel | None:
|
||||
"""Find an existing procedure with the same name in the same scope."""
|
||||
query = select(ProcedureModel).where(ProcedureModel.name == name)
|
||||
|
||||
if project_id is not None:
|
||||
query = query.where(ProcedureModel.project_id == project_id)
|
||||
else:
|
||||
query = query.where(ProcedureModel.project_id.is_(None))
|
||||
|
||||
if agent_type_id is not None:
|
||||
query = query.where(ProcedureModel.agent_type_id == agent_type_id)
|
||||
else:
|
||||
query = query.where(ProcedureModel.agent_type_id.is_(None))
|
||||
|
||||
result = await self._session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def _update_existing_procedure(
|
||||
self,
|
||||
existing: ProcedureModel,
|
||||
new_steps: list[dict[str, Any]],
|
||||
new_trigger: str,
|
||||
) -> Procedure:
|
||||
"""Update an existing procedure with new steps."""
|
||||
now = datetime.now(UTC)
|
||||
|
||||
# Merge steps intelligently - keep existing order, add new steps
|
||||
merged_steps = self._merge_steps(
|
||||
existing.steps or [], # type: ignore[arg-type]
|
||||
new_steps,
|
||||
)
|
||||
|
||||
stmt = (
|
||||
update(ProcedureModel)
|
||||
.where(ProcedureModel.id == existing.id)
|
||||
.values(
|
||||
steps=merged_steps,
|
||||
trigger_pattern=new_trigger,
|
||||
success_count=ProcedureModel.success_count + 1,
|
||||
last_used=now,
|
||||
updated_at=now,
|
||||
)
|
||||
.returning(ProcedureModel)
|
||||
)
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
updated_model = result.scalar_one()
|
||||
await self._session.flush()
|
||||
|
||||
logger.info(f"Updated existing procedure: {existing.name}")
|
||||
|
||||
return _model_to_procedure(updated_model)
|
||||
|
||||
def _merge_steps(
|
||||
self,
|
||||
existing_steps: list[dict[str, Any]],
|
||||
new_steps: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Merge steps from a new execution with existing steps."""
|
||||
if not existing_steps:
|
||||
return new_steps
|
||||
if not new_steps:
|
||||
return existing_steps
|
||||
|
||||
# For now, use the new steps if they differ significantly
|
||||
# In production, this could use more sophisticated merging
|
||||
if len(new_steps) != len(existing_steps):
|
||||
# If structure changed, prefer newer steps
|
||||
return new_steps
|
||||
|
||||
# Merge step-by-step, preferring new data where available
|
||||
merged = []
|
||||
for i, new_step in enumerate(new_steps):
|
||||
if i < len(existing_steps):
|
||||
# Merge with existing step
|
||||
step = {**existing_steps[i], **new_step}
|
||||
else:
|
||||
step = new_step
|
||||
merged.append(step)
|
||||
|
||||
return merged
|
||||
|
||||
def _create_embedding_text(self, procedure: ProcedureCreate) -> str:
|
||||
"""Create text for embedding from procedure data."""
|
||||
steps_text = " ".join(step.get("action", "") for step in procedure.steps)
|
||||
return f"{procedure.name} {procedure.trigger_pattern} {steps_text}"
|
||||
|
||||
# =========================================================================
|
||||
# Procedure Retrieval
|
||||
# =========================================================================
|
||||
|
||||
async def find_matching(
|
||||
self,
|
||||
context: str,
|
||||
project_id: UUID | None = None,
|
||||
agent_type_id: UUID | None = None,
|
||||
limit: int = 5,
|
||||
) -> list[Procedure]:
|
||||
"""
|
||||
Find procedures matching the given context.
|
||||
|
||||
Args:
|
||||
context: Context/trigger to match against
|
||||
project_id: Optional project to search within
|
||||
agent_type_id: Optional agent type filter
|
||||
limit: Maximum results
|
||||
|
||||
Returns:
|
||||
List of matching procedures
|
||||
"""
|
||||
result = await self._find_matching_with_metadata(
|
||||
context=context,
|
||||
project_id=project_id,
|
||||
agent_type_id=agent_type_id,
|
||||
limit=limit,
|
||||
)
|
||||
return result.items
|
||||
|
||||
async def _find_matching_with_metadata(
|
||||
self,
|
||||
context: str,
|
||||
project_id: UUID | None = None,
|
||||
agent_type_id: UUID | None = None,
|
||||
limit: int = 5,
|
||||
) -> RetrievalResult[Procedure]:
|
||||
"""Find matching procedures with full result metadata."""
|
||||
start_time = time.perf_counter()
|
||||
|
||||
# Build base query - prioritize by success rate
|
||||
stmt = (
|
||||
select(ProcedureModel)
|
||||
.order_by(
|
||||
desc(
|
||||
ProcedureModel.success_count
|
||||
/ (ProcedureModel.success_count + ProcedureModel.failure_count + 1)
|
||||
),
|
||||
desc(ProcedureModel.last_used),
|
||||
)
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
# Apply scope filters
|
||||
if project_id is not None:
|
||||
stmt = stmt.where(
|
||||
or_(
|
||||
ProcedureModel.project_id == project_id,
|
||||
ProcedureModel.project_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
if agent_type_id is not None:
|
||||
stmt = stmt.where(
|
||||
or_(
|
||||
ProcedureModel.agent_type_id == agent_type_id,
|
||||
ProcedureModel.agent_type_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
# Text-based matching on trigger pattern and name
|
||||
# TODO: Implement proper vector similarity search when pgvector is integrated
|
||||
search_terms = context.lower().split()[:5] # Limit to 5 terms
|
||||
if search_terms:
|
||||
conditions = []
|
||||
for term in search_terms:
|
||||
# Escape SQL wildcards to prevent pattern injection
|
||||
escaped_term = _escape_like_pattern(term)
|
||||
term_pattern = f"%{escaped_term}%"
|
||||
conditions.append(
|
||||
or_(
|
||||
ProcedureModel.trigger_pattern.ilike(term_pattern),
|
||||
ProcedureModel.name.ilike(term_pattern),
|
||||
)
|
||||
)
|
||||
if conditions:
|
||||
stmt = stmt.where(or_(*conditions))
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
models = list(result.scalars().all())
|
||||
|
||||
latency_ms = (time.perf_counter() - start_time) * 1000
|
||||
|
||||
return RetrievalResult(
|
||||
items=[_model_to_procedure(m) for m in models],
|
||||
total_count=len(models),
|
||||
query=context,
|
||||
retrieval_type="procedural",
|
||||
latency_ms=latency_ms,
|
||||
metadata={"project_id": str(project_id) if project_id else None},
|
||||
)
|
||||
|
||||
async def get_best_procedure(
|
||||
self,
|
||||
task_type: str,
|
||||
project_id: UUID | None = None,
|
||||
agent_type_id: UUID | None = None,
|
||||
min_success_rate: float = 0.5,
|
||||
min_uses: int = 1,
|
||||
) -> Procedure | None:
|
||||
"""
|
||||
Get the best procedure for a given task type.
|
||||
|
||||
Returns the procedure with the highest success rate that
|
||||
meets the minimum thresholds.
|
||||
|
||||
Args:
|
||||
task_type: Task type to find procedure for
|
||||
project_id: Optional project scope
|
||||
agent_type_id: Optional agent type scope
|
||||
min_success_rate: Minimum required success rate
|
||||
min_uses: Minimum number of uses required
|
||||
|
||||
Returns:
|
||||
Best matching procedure or None
|
||||
"""
|
||||
# Escape SQL wildcards to prevent pattern injection
|
||||
escaped_task_type = _escape_like_pattern(task_type)
|
||||
task_type_pattern = f"%{escaped_task_type}%"
|
||||
|
||||
# Build query for procedures matching task type
|
||||
stmt = (
|
||||
select(ProcedureModel)
|
||||
.where(
|
||||
and_(
|
||||
(ProcedureModel.success_count + ProcedureModel.failure_count)
|
||||
>= min_uses,
|
||||
or_(
|
||||
ProcedureModel.trigger_pattern.ilike(task_type_pattern),
|
||||
ProcedureModel.name.ilike(task_type_pattern),
|
||||
),
|
||||
)
|
||||
)
|
||||
.order_by(
|
||||
desc(
|
||||
ProcedureModel.success_count
|
||||
/ (ProcedureModel.success_count + ProcedureModel.failure_count + 1)
|
||||
),
|
||||
desc(ProcedureModel.last_used),
|
||||
)
|
||||
.limit(10)
|
||||
)
|
||||
|
||||
# Apply scope filters
|
||||
if project_id is not None:
|
||||
stmt = stmt.where(
|
||||
or_(
|
||||
ProcedureModel.project_id == project_id,
|
||||
ProcedureModel.project_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
if agent_type_id is not None:
|
||||
stmt = stmt.where(
|
||||
or_(
|
||||
ProcedureModel.agent_type_id == agent_type_id,
|
||||
ProcedureModel.agent_type_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
models = list(result.scalars().all())
|
||||
|
||||
# Filter by success rate in Python (SQLAlchemy division in WHERE is complex)
|
||||
for model in models:
|
||||
success = float(model.success_count)
|
||||
failure = float(model.failure_count)
|
||||
total = success + failure
|
||||
if total > 0 and (success / total) >= min_success_rate:
|
||||
logger.debug(
|
||||
f"Found best procedure for '{task_type}': {model.name} "
|
||||
f"(success_rate={success / total:.2%})"
|
||||
)
|
||||
return _model_to_procedure(model)
|
||||
|
||||
return None
|
||||
|
||||
async def get_by_id(self, procedure_id: UUID) -> Procedure | None:
|
||||
"""Get a procedure by ID."""
|
||||
query = select(ProcedureModel).where(ProcedureModel.id == procedure_id)
|
||||
result = await self._session.execute(query)
|
||||
model = result.scalar_one_or_none()
|
||||
return _model_to_procedure(model) if model else None
|
||||
|
||||
# =========================================================================
|
||||
# Outcome Recording
|
||||
# =========================================================================
|
||||
|
||||
async def record_outcome(
|
||||
self,
|
||||
procedure_id: UUID,
|
||||
success: bool,
|
||||
) -> Procedure:
|
||||
"""
|
||||
Record the outcome of using a procedure.
|
||||
|
||||
Updates the success or failure count and last_used timestamp.
|
||||
|
||||
Args:
|
||||
procedure_id: Procedure that was used
|
||||
success: Whether the procedure succeeded
|
||||
|
||||
Returns:
|
||||
Updated procedure
|
||||
|
||||
Raises:
|
||||
ValueError: If procedure not found
|
||||
"""
|
||||
query = select(ProcedureModel).where(ProcedureModel.id == procedure_id)
|
||||
result = await self._session.execute(query)
|
||||
model = result.scalar_one_or_none()
|
||||
|
||||
if model is None:
|
||||
raise ValueError(f"Procedure not found: {procedure_id}")
|
||||
|
||||
now = datetime.now(UTC)
|
||||
|
||||
if success:
|
||||
stmt = (
|
||||
update(ProcedureModel)
|
||||
.where(ProcedureModel.id == procedure_id)
|
||||
.values(
|
||||
success_count=ProcedureModel.success_count + 1,
|
||||
last_used=now,
|
||||
updated_at=now,
|
||||
)
|
||||
.returning(ProcedureModel)
|
||||
)
|
||||
else:
|
||||
stmt = (
|
||||
update(ProcedureModel)
|
||||
.where(ProcedureModel.id == procedure_id)
|
||||
.values(
|
||||
failure_count=ProcedureModel.failure_count + 1,
|
||||
last_used=now,
|
||||
updated_at=now,
|
||||
)
|
||||
.returning(ProcedureModel)
|
||||
)
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
updated_model = result.scalar_one()
|
||||
await self._session.flush()
|
||||
|
||||
outcome = "success" if success else "failure"
|
||||
logger.info(
|
||||
f"Recorded {outcome} for procedure {procedure_id}: "
|
||||
f"success_rate={updated_model.success_rate:.2%}"
|
||||
)
|
||||
|
||||
return _model_to_procedure(updated_model)
|
||||
|
||||
# =========================================================================
|
||||
# Step Management
|
||||
# =========================================================================
|
||||
|
||||
async def update_steps(
|
||||
self,
|
||||
procedure_id: UUID,
|
||||
steps: list[Step],
|
||||
) -> Procedure:
|
||||
"""
|
||||
Update the steps of a procedure.
|
||||
|
||||
Args:
|
||||
procedure_id: Procedure to update
|
||||
steps: New steps
|
||||
|
||||
Returns:
|
||||
Updated procedure
|
||||
|
||||
Raises:
|
||||
ValueError: If procedure not found
|
||||
"""
|
||||
query = select(ProcedureModel).where(ProcedureModel.id == procedure_id)
|
||||
result = await self._session.execute(query)
|
||||
model = result.scalar_one_or_none()
|
||||
|
||||
if model is None:
|
||||
raise ValueError(f"Procedure not found: {procedure_id}")
|
||||
|
||||
# Convert Step objects to dictionaries
|
||||
steps_dict = [
|
||||
{
|
||||
"order": step.order,
|
||||
"action": step.action,
|
||||
"parameters": step.parameters,
|
||||
"expected_outcome": step.expected_outcome,
|
||||
"fallback_action": step.fallback_action,
|
||||
}
|
||||
for step in steps
|
||||
]
|
||||
|
||||
now = datetime.now(UTC)
|
||||
stmt = (
|
||||
update(ProcedureModel)
|
||||
.where(ProcedureModel.id == procedure_id)
|
||||
.values(
|
||||
steps=steps_dict,
|
||||
updated_at=now,
|
||||
)
|
||||
.returning(ProcedureModel)
|
||||
)
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
updated_model = result.scalar_one()
|
||||
await self._session.flush()
|
||||
|
||||
logger.info(f"Updated steps for procedure {procedure_id}: {len(steps)} steps")
|
||||
|
||||
return _model_to_procedure(updated_model)
|
||||
|
||||
# =========================================================================
|
||||
# Statistics & Management
|
||||
# =========================================================================
|
||||
|
||||
async def get_stats(
|
||||
self,
|
||||
project_id: UUID | None = None,
|
||||
agent_type_id: UUID | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Get statistics about procedural memory.
|
||||
|
||||
Args:
|
||||
project_id: Optional project to get stats for
|
||||
agent_type_id: Optional agent type filter
|
||||
|
||||
Returns:
|
||||
Dictionary with statistics
|
||||
"""
|
||||
query = select(ProcedureModel)
|
||||
|
||||
if project_id is not None:
|
||||
query = query.where(
|
||||
or_(
|
||||
ProcedureModel.project_id == project_id,
|
||||
ProcedureModel.project_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
if agent_type_id is not None:
|
||||
query = query.where(
|
||||
or_(
|
||||
ProcedureModel.agent_type_id == agent_type_id,
|
||||
ProcedureModel.agent_type_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
result = await self._session.execute(query)
|
||||
models = list(result.scalars().all())
|
||||
|
||||
if not models:
|
||||
return {
|
||||
"total_procedures": 0,
|
||||
"avg_success_rate": 0.0,
|
||||
"avg_steps_count": 0.0,
|
||||
"total_uses": 0,
|
||||
"high_success_count": 0,
|
||||
"low_success_count": 0,
|
||||
}
|
||||
|
||||
success_rates = [m.success_rate for m in models]
|
||||
step_counts = [len(m.steps or []) for m in models]
|
||||
total_uses = sum(m.total_uses for m in models)
|
||||
|
||||
return {
|
||||
"total_procedures": len(models),
|
||||
"avg_success_rate": sum(success_rates) / len(success_rates),
|
||||
"avg_steps_count": sum(step_counts) / len(step_counts),
|
||||
"total_uses": total_uses,
|
||||
"high_success_count": sum(1 for r in success_rates if r >= 0.8),
|
||||
"low_success_count": sum(1 for r in success_rates if r < 0.5),
|
||||
}
|
||||
|
||||
async def count(
|
||||
self,
|
||||
project_id: UUID | None = None,
|
||||
agent_type_id: UUID | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Count procedures in scope.
|
||||
|
||||
Args:
|
||||
project_id: Optional project to count for
|
||||
agent_type_id: Optional agent type filter
|
||||
|
||||
Returns:
|
||||
Number of procedures
|
||||
"""
|
||||
query = select(ProcedureModel)
|
||||
|
||||
if project_id is not None:
|
||||
query = query.where(
|
||||
or_(
|
||||
ProcedureModel.project_id == project_id,
|
||||
ProcedureModel.project_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
if agent_type_id is not None:
|
||||
query = query.where(
|
||||
or_(
|
||||
ProcedureModel.agent_type_id == agent_type_id,
|
||||
ProcedureModel.agent_type_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
result = await self._session.execute(query)
|
||||
return len(list(result.scalars().all()))
|
||||
|
||||
async def delete(self, procedure_id: UUID) -> bool:
|
||||
"""
|
||||
Delete a procedure.
|
||||
|
||||
Args:
|
||||
procedure_id: Procedure to delete
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
query = select(ProcedureModel).where(ProcedureModel.id == procedure_id)
|
||||
result = await self._session.execute(query)
|
||||
model = result.scalar_one_or_none()
|
||||
|
||||
if model is None:
|
||||
return False
|
||||
|
||||
await self._session.delete(model)
|
||||
await self._session.flush()
|
||||
|
||||
logger.info(f"Deleted procedure {procedure_id}")
|
||||
return True
|
||||
|
||||
async def get_procedures_by_success_rate(
|
||||
self,
|
||||
min_rate: float = 0.0,
|
||||
max_rate: float = 1.0,
|
||||
project_id: UUID | None = None,
|
||||
limit: int = 20,
|
||||
) -> list[Procedure]:
|
||||
"""
|
||||
Get procedures within a success rate range.
|
||||
|
||||
Args:
|
||||
min_rate: Minimum success rate
|
||||
max_rate: Maximum success rate
|
||||
project_id: Optional project scope
|
||||
limit: Maximum results
|
||||
|
||||
Returns:
|
||||
List of procedures
|
||||
"""
|
||||
query = (
|
||||
select(ProcedureModel)
|
||||
.order_by(desc(ProcedureModel.last_used))
|
||||
.limit(limit * 2) # Fetch more since we filter in Python
|
||||
)
|
||||
|
||||
if project_id is not None:
|
||||
query = query.where(
|
||||
or_(
|
||||
ProcedureModel.project_id == project_id,
|
||||
ProcedureModel.project_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
result = await self._session.execute(query)
|
||||
models = list(result.scalars().all())
|
||||
|
||||
# Filter by success rate in Python
|
||||
filtered = [m for m in models if min_rate <= m.success_rate <= max_rate][:limit]
|
||||
|
||||
return [_model_to_procedure(m) for m in filtered]
|
||||
38
backend/app/services/memory/reflection/__init__.py
Normal file
38
backend/app/services/memory/reflection/__init__.py
Normal file
@@ -0,0 +1,38 @@
|
||||
# app/services/memory/reflection/__init__.py
|
||||
"""
|
||||
Memory Reflection Layer.
|
||||
|
||||
Analyzes patterns in agent experiences to generate actionable insights.
|
||||
"""
|
||||
|
||||
from .service import (
|
||||
MemoryReflection,
|
||||
ReflectionConfig,
|
||||
get_memory_reflection,
|
||||
)
|
||||
from .types import (
|
||||
Anomaly,
|
||||
AnomalyType,
|
||||
Factor,
|
||||
FactorType,
|
||||
Insight,
|
||||
InsightType,
|
||||
Pattern,
|
||||
PatternType,
|
||||
TimeRange,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Anomaly",
|
||||
"AnomalyType",
|
||||
"Factor",
|
||||
"FactorType",
|
||||
"Insight",
|
||||
"InsightType",
|
||||
"MemoryReflection",
|
||||
"Pattern",
|
||||
"PatternType",
|
||||
"ReflectionConfig",
|
||||
"TimeRange",
|
||||
"get_memory_reflection",
|
||||
]
|
||||
1451
backend/app/services/memory/reflection/service.py
Normal file
1451
backend/app/services/memory/reflection/service.py
Normal file
File diff suppressed because it is too large
Load Diff
304
backend/app/services/memory/reflection/types.py
Normal file
304
backend/app/services/memory/reflection/types.py
Normal file
@@ -0,0 +1,304 @@
|
||||
# app/services/memory/reflection/types.py
|
||||
"""
|
||||
Memory Reflection Types.
|
||||
|
||||
Type definitions for pattern detection, anomaly detection, and insights.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
|
||||
def _utcnow() -> datetime:
|
||||
"""Get current UTC time as timezone-aware datetime."""
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
class PatternType(str, Enum):
|
||||
"""Types of patterns detected in episodic memory."""
|
||||
|
||||
RECURRING_SUCCESS = "recurring_success"
|
||||
RECURRING_FAILURE = "recurring_failure"
|
||||
ACTION_SEQUENCE = "action_sequence"
|
||||
CONTEXT_CORRELATION = "context_correlation"
|
||||
TEMPORAL = "temporal"
|
||||
EFFICIENCY = "efficiency"
|
||||
|
||||
|
||||
class FactorType(str, Enum):
|
||||
"""Types of factors contributing to outcomes."""
|
||||
|
||||
ACTION = "action"
|
||||
CONTEXT = "context"
|
||||
TIMING = "timing"
|
||||
RESOURCE = "resource"
|
||||
PRECEDING_STATE = "preceding_state"
|
||||
|
||||
|
||||
class AnomalyType(str, Enum):
|
||||
"""Types of anomalies detected."""
|
||||
|
||||
UNUSUAL_DURATION = "unusual_duration"
|
||||
UNEXPECTED_OUTCOME = "unexpected_outcome"
|
||||
UNUSUAL_TOKEN_USAGE = "unusual_token_usage"
|
||||
UNUSUAL_FAILURE_RATE = "unusual_failure_rate"
|
||||
UNUSUAL_ACTION_PATTERN = "unusual_action_pattern"
|
||||
|
||||
|
||||
class InsightType(str, Enum):
|
||||
"""Types of insights generated."""
|
||||
|
||||
OPTIMIZATION = "optimization"
|
||||
WARNING = "warning"
|
||||
LEARNING = "learning"
|
||||
RECOMMENDATION = "recommendation"
|
||||
TREND = "trend"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimeRange:
|
||||
"""Time range for reflection analysis."""
|
||||
|
||||
start: datetime
|
||||
end: datetime
|
||||
|
||||
@classmethod
|
||||
def last_hours(cls, hours: int = 24) -> "TimeRange":
|
||||
"""Create time range for last N hours."""
|
||||
end = _utcnow()
|
||||
start = datetime(
|
||||
end.year, end.month, end.day, end.hour, end.minute, end.second, tzinfo=UTC
|
||||
) - __import__("datetime").timedelta(hours=hours)
|
||||
return cls(start=start, end=end)
|
||||
|
||||
@classmethod
|
||||
def last_days(cls, days: int = 7) -> "TimeRange":
|
||||
"""Create time range for last N days."""
|
||||
from datetime import timedelta
|
||||
|
||||
end = _utcnow()
|
||||
start = end - timedelta(days=days)
|
||||
return cls(start=start, end=end)
|
||||
|
||||
@property
|
||||
def duration_hours(self) -> float:
|
||||
"""Get duration in hours."""
|
||||
return (self.end - self.start).total_seconds() / 3600
|
||||
|
||||
@property
|
||||
def duration_days(self) -> float:
|
||||
"""Get duration in days."""
|
||||
return (self.end - self.start).total_seconds() / 86400
|
||||
|
||||
|
||||
@dataclass
|
||||
class Pattern:
|
||||
"""A detected pattern in episodic memory."""
|
||||
|
||||
id: UUID
|
||||
pattern_type: PatternType
|
||||
name: str
|
||||
description: str
|
||||
confidence: float
|
||||
occurrence_count: int
|
||||
episode_ids: list[UUID]
|
||||
first_seen: datetime
|
||||
last_seen: datetime
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def frequency(self) -> float:
|
||||
"""Calculate pattern frequency per day."""
|
||||
duration_days = (self.last_seen - self.first_seen).total_seconds() / 86400
|
||||
if duration_days < 1:
|
||||
duration_days = 1
|
||||
return self.occurrence_count / duration_days
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"id": str(self.id),
|
||||
"pattern_type": self.pattern_type.value,
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"confidence": self.confidence,
|
||||
"occurrence_count": self.occurrence_count,
|
||||
"episode_ids": [str(eid) for eid in self.episode_ids],
|
||||
"first_seen": self.first_seen.isoformat(),
|
||||
"last_seen": self.last_seen.isoformat(),
|
||||
"frequency": self.frequency,
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Factor:
|
||||
"""A factor contributing to success or failure."""
|
||||
|
||||
id: UUID
|
||||
factor_type: FactorType
|
||||
name: str
|
||||
description: str
|
||||
impact_score: float
|
||||
correlation: float
|
||||
sample_size: int
|
||||
positive_examples: list[UUID]
|
||||
negative_examples: list[UUID]
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def net_impact(self) -> float:
|
||||
"""Calculate net impact considering sample size."""
|
||||
# Weight impact by sample confidence
|
||||
confidence_weight = min(1.0, self.sample_size / 20)
|
||||
return self.impact_score * self.correlation * confidence_weight
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"id": str(self.id),
|
||||
"factor_type": self.factor_type.value,
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"impact_score": self.impact_score,
|
||||
"correlation": self.correlation,
|
||||
"sample_size": self.sample_size,
|
||||
"positive_examples": [str(eid) for eid in self.positive_examples],
|
||||
"negative_examples": [str(eid) for eid in self.negative_examples],
|
||||
"net_impact": self.net_impact,
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Anomaly:
|
||||
"""An anomaly detected in memory patterns."""
|
||||
|
||||
id: UUID
|
||||
anomaly_type: AnomalyType
|
||||
description: str
|
||||
severity: float
|
||||
episode_ids: list[UUID]
|
||||
detected_at: datetime
|
||||
baseline_value: float
|
||||
observed_value: float
|
||||
deviation_factor: float
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def is_critical(self) -> bool:
|
||||
"""Check if anomaly is critical (severity > 0.8)."""
|
||||
return self.severity > 0.8
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"id": str(self.id),
|
||||
"anomaly_type": self.anomaly_type.value,
|
||||
"description": self.description,
|
||||
"severity": self.severity,
|
||||
"episode_ids": [str(eid) for eid in self.episode_ids],
|
||||
"detected_at": self.detected_at.isoformat(),
|
||||
"baseline_value": self.baseline_value,
|
||||
"observed_value": self.observed_value,
|
||||
"deviation_factor": self.deviation_factor,
|
||||
"is_critical": self.is_critical,
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Insight:
|
||||
"""An actionable insight generated from reflection."""
|
||||
|
||||
id: UUID
|
||||
insight_type: InsightType
|
||||
title: str
|
||||
description: str
|
||||
priority: float
|
||||
confidence: float
|
||||
source_patterns: list[UUID]
|
||||
source_factors: list[UUID]
|
||||
source_anomalies: list[UUID]
|
||||
recommended_actions: list[str]
|
||||
generated_at: datetime
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def actionable_score(self) -> float:
|
||||
"""Calculate how actionable this insight is."""
|
||||
action_weight = min(1.0, len(self.recommended_actions) / 3)
|
||||
return self.priority * self.confidence * action_weight
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"id": str(self.id),
|
||||
"insight_type": self.insight_type.value,
|
||||
"title": self.title,
|
||||
"description": self.description,
|
||||
"priority": self.priority,
|
||||
"confidence": self.confidence,
|
||||
"source_patterns": [str(pid) for pid in self.source_patterns],
|
||||
"source_factors": [str(fid) for fid in self.source_factors],
|
||||
"source_anomalies": [str(aid) for aid in self.source_anomalies],
|
||||
"recommended_actions": self.recommended_actions,
|
||||
"generated_at": self.generated_at.isoformat(),
|
||||
"actionable_score": self.actionable_score,
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReflectionResult:
|
||||
"""Result of a reflection operation."""
|
||||
|
||||
patterns: list[Pattern]
|
||||
factors: list[Factor]
|
||||
anomalies: list[Anomaly]
|
||||
insights: list[Insight]
|
||||
time_range: TimeRange
|
||||
episodes_analyzed: int
|
||||
analysis_duration_seconds: float
|
||||
generated_at: datetime = field(default_factory=_utcnow)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"patterns": [p.to_dict() for p in self.patterns],
|
||||
"factors": [f.to_dict() for f in self.factors],
|
||||
"anomalies": [a.to_dict() for a in self.anomalies],
|
||||
"insights": [i.to_dict() for i in self.insights],
|
||||
"time_range": {
|
||||
"start": self.time_range.start.isoformat(),
|
||||
"end": self.time_range.end.isoformat(),
|
||||
"duration_hours": self.time_range.duration_hours,
|
||||
},
|
||||
"episodes_analyzed": self.episodes_analyzed,
|
||||
"analysis_duration_seconds": self.analysis_duration_seconds,
|
||||
"generated_at": self.generated_at.isoformat(),
|
||||
}
|
||||
|
||||
@property
|
||||
def summary(self) -> str:
|
||||
"""Generate a summary of the reflection results."""
|
||||
lines = [
|
||||
f"Reflection Analysis ({self.time_range.duration_days:.1f} days)",
|
||||
f"Episodes analyzed: {self.episodes_analyzed}",
|
||||
"",
|
||||
f"Patterns detected: {len(self.patterns)}",
|
||||
f"Success/failure factors: {len(self.factors)}",
|
||||
f"Anomalies found: {len(self.anomalies)}",
|
||||
f"Insights generated: {len(self.insights)}",
|
||||
]
|
||||
|
||||
if self.insights:
|
||||
lines.append("")
|
||||
lines.append("Top insights:")
|
||||
for insight in sorted(self.insights, key=lambda i: -i.priority)[:3]:
|
||||
lines.append(f" - [{insight.insight_type.value}] {insight.title}")
|
||||
|
||||
return "\n".join(lines)
|
||||
33
backend/app/services/memory/scoping/__init__.py
Normal file
33
backend/app/services/memory/scoping/__init__.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# app/services/memory/scoping/__init__.py
|
||||
"""
|
||||
Memory Scoping
|
||||
|
||||
Hierarchical scoping for memory with inheritance:
|
||||
Global -> Project -> Agent Type -> Agent Instance -> Session
|
||||
"""
|
||||
|
||||
from .resolver import (
|
||||
ResolutionOptions,
|
||||
ResolutionResult,
|
||||
ScopeFilter,
|
||||
ScopeResolver,
|
||||
get_scope_resolver,
|
||||
)
|
||||
from .scope import (
|
||||
ScopeInfo,
|
||||
ScopeManager,
|
||||
ScopePolicy,
|
||||
get_scope_manager,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ResolutionOptions",
|
||||
"ResolutionResult",
|
||||
"ScopeFilter",
|
||||
"ScopeInfo",
|
||||
"ScopeManager",
|
||||
"ScopePolicy",
|
||||
"ScopeResolver",
|
||||
"get_scope_manager",
|
||||
"get_scope_resolver",
|
||||
]
|
||||
390
backend/app/services/memory/scoping/resolver.py
Normal file
390
backend/app/services/memory/scoping/resolver.py
Normal file
@@ -0,0 +1,390 @@
|
||||
# app/services/memory/scoping/resolver.py
|
||||
"""
|
||||
Scope Resolution.
|
||||
|
||||
Provides utilities for resolving memory queries across scope hierarchies,
|
||||
implementing inheritance and aggregation of memories from parent scopes.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from app.services.memory.types import ScopeContext, ScopeLevel
|
||||
|
||||
from .scope import ScopeManager, get_scope_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResolutionResult[T]:
|
||||
"""Result of a scope resolution."""
|
||||
|
||||
items: list[T]
|
||||
sources: list[ScopeContext]
|
||||
total_from_each: dict[str, int] = field(default_factory=dict)
|
||||
inherited_count: int = 0
|
||||
own_count: int = 0
|
||||
|
||||
@property
|
||||
def total_count(self) -> int:
|
||||
"""Get total items from all sources."""
|
||||
return len(self.items)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResolutionOptions:
|
||||
"""Options for scope resolution."""
|
||||
|
||||
include_inherited: bool = True
|
||||
max_inheritance_depth: int = 5
|
||||
limit_per_scope: int = 100
|
||||
total_limit: int = 500
|
||||
deduplicate: bool = True
|
||||
deduplicate_key: str | None = None # Field to use for deduplication
|
||||
|
||||
|
||||
class ScopeResolver:
|
||||
"""
|
||||
Resolves memory queries across scope hierarchies.
|
||||
|
||||
Features:
|
||||
- Traverse scope hierarchy for inherited memories
|
||||
- Aggregate results from multiple scope levels
|
||||
- Apply access control policies
|
||||
- Support deduplication across scopes
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
manager: ScopeManager | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the resolver.
|
||||
|
||||
Args:
|
||||
manager: Scope manager to use (defaults to singleton)
|
||||
"""
|
||||
self._manager = manager or get_scope_manager()
|
||||
|
||||
def resolve(
|
||||
self,
|
||||
scope: ScopeContext,
|
||||
fetcher: Callable[[ScopeContext, int], list[T]],
|
||||
options: ResolutionOptions | None = None,
|
||||
) -> ResolutionResult[T]:
|
||||
"""
|
||||
Resolve memories for a scope, including inherited memories.
|
||||
|
||||
Args:
|
||||
scope: Starting scope
|
||||
fetcher: Function to fetch items for a scope (scope, limit) -> items
|
||||
options: Resolution options
|
||||
|
||||
Returns:
|
||||
Resolution result with items from all scopes
|
||||
"""
|
||||
opts = options or ResolutionOptions()
|
||||
|
||||
all_items: list[T] = []
|
||||
sources: list[ScopeContext] = []
|
||||
counts: dict[str, int] = {}
|
||||
seen_keys: set[Any] = set()
|
||||
|
||||
# Collect scopes to query (starting from current, going up to ancestors)
|
||||
scopes_to_query = self._collect_queryable_scopes(
|
||||
scope=scope,
|
||||
max_depth=opts.max_inheritance_depth if opts.include_inherited else 0,
|
||||
)
|
||||
|
||||
own_count = 0
|
||||
inherited_count = 0
|
||||
remaining_limit = opts.total_limit
|
||||
|
||||
for i, query_scope in enumerate(scopes_to_query):
|
||||
if remaining_limit <= 0:
|
||||
break
|
||||
|
||||
# Check access policy
|
||||
policy = self._manager.get_policy(query_scope)
|
||||
if not policy.allows_read():
|
||||
continue
|
||||
if i > 0 and not policy.allows_inherit():
|
||||
continue
|
||||
|
||||
# Fetch items for this scope
|
||||
scope_limit = min(opts.limit_per_scope, remaining_limit)
|
||||
items = fetcher(query_scope, scope_limit)
|
||||
|
||||
# Apply deduplication
|
||||
if opts.deduplicate and opts.deduplicate_key:
|
||||
items = self._deduplicate(items, opts.deduplicate_key, seen_keys)
|
||||
|
||||
if items:
|
||||
all_items.extend(items)
|
||||
sources.append(query_scope)
|
||||
key = f"{query_scope.scope_type.value}:{query_scope.scope_id}"
|
||||
counts[key] = len(items)
|
||||
|
||||
if i == 0:
|
||||
own_count = len(items)
|
||||
else:
|
||||
inherited_count += len(items)
|
||||
|
||||
remaining_limit -= len(items)
|
||||
|
||||
logger.debug(
|
||||
f"Resolved {len(all_items)} items from {len(sources)} scopes "
|
||||
f"(own={own_count}, inherited={inherited_count})"
|
||||
)
|
||||
|
||||
return ResolutionResult(
|
||||
items=all_items[: opts.total_limit],
|
||||
sources=sources,
|
||||
total_from_each=counts,
|
||||
own_count=own_count,
|
||||
inherited_count=inherited_count,
|
||||
)
|
||||
|
||||
def _collect_queryable_scopes(
|
||||
self,
|
||||
scope: ScopeContext,
|
||||
max_depth: int,
|
||||
) -> list[ScopeContext]:
|
||||
"""Collect scopes to query, from current to ancestors."""
|
||||
scopes: list[ScopeContext] = [scope]
|
||||
|
||||
if max_depth <= 0:
|
||||
return scopes
|
||||
|
||||
current = scope.parent
|
||||
depth = 0
|
||||
|
||||
while current is not None and depth < max_depth:
|
||||
scopes.append(current)
|
||||
current = current.parent
|
||||
depth += 1
|
||||
|
||||
return scopes
|
||||
|
||||
def _deduplicate(
|
||||
self,
|
||||
items: list[T],
|
||||
key_field: str,
|
||||
seen_keys: set[Any],
|
||||
) -> list[T]:
|
||||
"""Remove duplicate items based on a key field."""
|
||||
unique: list[T] = []
|
||||
|
||||
for item in items:
|
||||
key = getattr(item, key_field, None)
|
||||
if key is None:
|
||||
# If no key, include the item
|
||||
unique.append(item)
|
||||
elif key not in seen_keys:
|
||||
seen_keys.add(key)
|
||||
unique.append(item)
|
||||
|
||||
return unique
|
||||
|
||||
def get_visible_scopes(
|
||||
self,
|
||||
scope: ScopeContext,
|
||||
) -> list[ScopeContext]:
|
||||
"""
|
||||
Get all scopes visible from a given scope.
|
||||
|
||||
A scope can see itself and all its ancestors (if inheritance allowed).
|
||||
|
||||
Args:
|
||||
scope: Starting scope
|
||||
|
||||
Returns:
|
||||
List of visible scopes (from most specific to most general)
|
||||
"""
|
||||
visible = [scope]
|
||||
|
||||
current = scope.parent
|
||||
while current is not None:
|
||||
policy = self._manager.get_policy(current)
|
||||
if policy.allows_inherit():
|
||||
visible.append(current)
|
||||
else:
|
||||
break # Stop at first non-inheritable scope
|
||||
current = current.parent
|
||||
|
||||
return visible
|
||||
|
||||
def find_write_scope(
|
||||
self,
|
||||
target_level: ScopeLevel,
|
||||
scope: ScopeContext,
|
||||
) -> ScopeContext | None:
|
||||
"""
|
||||
Find the appropriate scope for writing at a target level.
|
||||
|
||||
Walks up the hierarchy to find a scope at the target level
|
||||
that allows writing.
|
||||
|
||||
Args:
|
||||
target_level: Desired scope level
|
||||
scope: Starting scope
|
||||
|
||||
Returns:
|
||||
Scope to write to, or None if not found/not allowed
|
||||
"""
|
||||
# First check if current scope is at target level
|
||||
if scope.scope_type == target_level:
|
||||
policy = self._manager.get_policy(scope)
|
||||
return scope if policy.allows_write() else None
|
||||
|
||||
# Check ancestors
|
||||
current = scope.parent
|
||||
while current is not None:
|
||||
if current.scope_type == target_level:
|
||||
policy = self._manager.get_policy(current)
|
||||
return current if policy.allows_write() else None
|
||||
current = current.parent
|
||||
|
||||
return None
|
||||
|
||||
def resolve_scope_from_memory(
|
||||
self,
|
||||
memory_type: str,
|
||||
project_id: str | None = None,
|
||||
agent_type_id: str | None = None,
|
||||
agent_instance_id: str | None = None,
|
||||
session_id: str | None = None,
|
||||
) -> tuple[ScopeContext, ScopeLevel]:
|
||||
"""
|
||||
Resolve the appropriate scope for a memory operation.
|
||||
|
||||
Different memory types have different scope requirements:
|
||||
- working: Session or Agent Instance
|
||||
- episodic: Agent Instance or Project
|
||||
- semantic: Project or Global
|
||||
- procedural: Agent Type or Project
|
||||
|
||||
Args:
|
||||
memory_type: Type of memory
|
||||
project_id: Project ID
|
||||
agent_type_id: Agent type ID
|
||||
agent_instance_id: Agent instance ID
|
||||
session_id: Session ID
|
||||
|
||||
Returns:
|
||||
Tuple of (scope context, recommended level)
|
||||
"""
|
||||
# Build full scope chain
|
||||
scope = self._manager.create_scope_from_ids(
|
||||
project_id=project_id if project_id else None, # type: ignore[arg-type]
|
||||
agent_type_id=agent_type_id if agent_type_id else None, # type: ignore[arg-type]
|
||||
agent_instance_id=agent_instance_id if agent_instance_id else None, # type: ignore[arg-type]
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Determine recommended level based on memory type
|
||||
recommended = self._get_recommended_level(memory_type)
|
||||
|
||||
return scope, recommended
|
||||
|
||||
def _get_recommended_level(self, memory_type: str) -> ScopeLevel:
|
||||
"""Get recommended scope level for a memory type."""
|
||||
recommendations = {
|
||||
"working": ScopeLevel.SESSION,
|
||||
"episodic": ScopeLevel.AGENT_INSTANCE,
|
||||
"semantic": ScopeLevel.PROJECT,
|
||||
"procedural": ScopeLevel.AGENT_TYPE,
|
||||
}
|
||||
return recommendations.get(memory_type, ScopeLevel.PROJECT)
|
||||
|
||||
def validate_write_access(
|
||||
self,
|
||||
scope: ScopeContext,
|
||||
memory_type: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Validate that writing is allowed for the given scope and memory type.
|
||||
|
||||
Args:
|
||||
scope: Scope to validate
|
||||
memory_type: Type of memory to write
|
||||
|
||||
Returns:
|
||||
True if write is allowed
|
||||
"""
|
||||
policy = self._manager.get_policy(scope)
|
||||
|
||||
if not policy.allows_write():
|
||||
return False
|
||||
|
||||
if not policy.allows_memory_type(memory_type):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_scope_chain(
|
||||
self,
|
||||
scope: ScopeContext,
|
||||
) -> list[tuple[ScopeLevel, str]]:
|
||||
"""
|
||||
Get the scope chain as a list of (level, id) tuples.
|
||||
|
||||
Args:
|
||||
scope: Scope to get chain for
|
||||
|
||||
Returns:
|
||||
List of (level, id) tuples from root to leaf
|
||||
"""
|
||||
chain: list[tuple[ScopeLevel, str]] = []
|
||||
|
||||
# Get full hierarchy
|
||||
hierarchy = scope.get_hierarchy()
|
||||
for ctx in hierarchy:
|
||||
chain.append((ctx.scope_type, ctx.scope_id))
|
||||
|
||||
return chain
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScopeFilter:
|
||||
"""Filter for querying across scopes."""
|
||||
|
||||
scope_types: list[ScopeLevel] | None = None
|
||||
project_ids: list[str] | None = None
|
||||
agent_type_ids: list[str] | None = None
|
||||
include_global: bool = True
|
||||
|
||||
def matches(self, scope: ScopeContext) -> bool:
|
||||
"""Check if a scope matches this filter."""
|
||||
if self.scope_types and scope.scope_type not in self.scope_types:
|
||||
return False
|
||||
|
||||
if scope.scope_type == ScopeLevel.GLOBAL:
|
||||
return self.include_global
|
||||
|
||||
if scope.scope_type == ScopeLevel.PROJECT:
|
||||
if self.project_ids and scope.scope_id not in self.project_ids:
|
||||
return False
|
||||
|
||||
if scope.scope_type == ScopeLevel.AGENT_TYPE:
|
||||
if self.agent_type_ids and scope.scope_id not in self.agent_type_ids:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
# Singleton resolver instance
|
||||
_resolver: ScopeResolver | None = None
|
||||
|
||||
|
||||
def get_scope_resolver() -> ScopeResolver:
|
||||
"""Get the singleton scope resolver instance."""
|
||||
global _resolver
|
||||
if _resolver is None:
|
||||
_resolver = ScopeResolver()
|
||||
return _resolver
|
||||
472
backend/app/services/memory/scoping/scope.py
Normal file
472
backend/app/services/memory/scoping/scope.py
Normal file
@@ -0,0 +1,472 @@
|
||||
# app/services/memory/scoping/scope.py
|
||||
"""
|
||||
Scope Management.
|
||||
|
||||
Provides utilities for managing memory scopes with hierarchical inheritance:
|
||||
Global -> Project -> Agent Type -> Agent Instance -> Session
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, ClassVar
|
||||
from uuid import UUID
|
||||
|
||||
from app.services.memory.types import ScopeContext, ScopeLevel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScopePolicy:
|
||||
"""Access control policy for a scope."""
|
||||
|
||||
scope_type: ScopeLevel
|
||||
scope_id: str
|
||||
can_read: bool = True
|
||||
can_write: bool = True
|
||||
can_inherit: bool = True
|
||||
allowed_memory_types: list[str] = field(default_factory=lambda: ["all"])
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def allows_read(self) -> bool:
|
||||
"""Check if reading is allowed."""
|
||||
return self.can_read
|
||||
|
||||
def allows_write(self) -> bool:
|
||||
"""Check if writing is allowed."""
|
||||
return self.can_write
|
||||
|
||||
def allows_inherit(self) -> bool:
|
||||
"""Check if inheritance from parent is allowed."""
|
||||
return self.can_inherit
|
||||
|
||||
def allows_memory_type(self, memory_type: str) -> bool:
|
||||
"""Check if a specific memory type is allowed."""
|
||||
return (
|
||||
"all" in self.allowed_memory_types
|
||||
or memory_type in self.allowed_memory_types
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScopeInfo:
|
||||
"""Information about a scope including its hierarchy."""
|
||||
|
||||
context: ScopeContext
|
||||
policy: ScopePolicy
|
||||
parent_info: "ScopeInfo | None" = None
|
||||
child_count: int = 0
|
||||
memory_count: int = 0
|
||||
|
||||
@property
|
||||
def depth(self) -> int:
|
||||
"""Get the depth of this scope in the hierarchy."""
|
||||
count = 0
|
||||
current = self.parent_info
|
||||
while current is not None:
|
||||
count += 1
|
||||
current = current.parent_info
|
||||
return count
|
||||
|
||||
|
||||
class ScopeManager:
|
||||
"""
|
||||
Manages memory scopes and their hierarchies.
|
||||
|
||||
Provides:
|
||||
- Scope creation and validation
|
||||
- Hierarchy management
|
||||
- Access control policy management
|
||||
- Scope inheritance rules
|
||||
"""
|
||||
|
||||
# Order of scope levels from root to leaf
|
||||
SCOPE_ORDER: ClassVar[list[ScopeLevel]] = [
|
||||
ScopeLevel.GLOBAL,
|
||||
ScopeLevel.PROJECT,
|
||||
ScopeLevel.AGENT_TYPE,
|
||||
ScopeLevel.AGENT_INSTANCE,
|
||||
ScopeLevel.SESSION,
|
||||
]
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the scope manager."""
|
||||
# In-memory policy cache (would be backed by database in production)
|
||||
self._policies: dict[str, ScopePolicy] = {}
|
||||
self._default_policies = self._create_default_policies()
|
||||
|
||||
def _create_default_policies(self) -> dict[ScopeLevel, ScopePolicy]:
|
||||
"""Create default policies for each scope level."""
|
||||
return {
|
||||
ScopeLevel.GLOBAL: ScopePolicy(
|
||||
scope_type=ScopeLevel.GLOBAL,
|
||||
scope_id="global",
|
||||
can_read=True,
|
||||
can_write=False, # Global writes require special permission
|
||||
can_inherit=True,
|
||||
),
|
||||
ScopeLevel.PROJECT: ScopePolicy(
|
||||
scope_type=ScopeLevel.PROJECT,
|
||||
scope_id="default",
|
||||
can_read=True,
|
||||
can_write=True,
|
||||
can_inherit=True,
|
||||
),
|
||||
ScopeLevel.AGENT_TYPE: ScopePolicy(
|
||||
scope_type=ScopeLevel.AGENT_TYPE,
|
||||
scope_id="default",
|
||||
can_read=True,
|
||||
can_write=True,
|
||||
can_inherit=True,
|
||||
),
|
||||
ScopeLevel.AGENT_INSTANCE: ScopePolicy(
|
||||
scope_type=ScopeLevel.AGENT_INSTANCE,
|
||||
scope_id="default",
|
||||
can_read=True,
|
||||
can_write=True,
|
||||
can_inherit=True,
|
||||
),
|
||||
ScopeLevel.SESSION: ScopePolicy(
|
||||
scope_type=ScopeLevel.SESSION,
|
||||
scope_id="default",
|
||||
can_read=True,
|
||||
can_write=True,
|
||||
can_inherit=True,
|
||||
allowed_memory_types=["working"], # Sessions only allow working memory
|
||||
),
|
||||
}
|
||||
|
||||
def create_scope(
|
||||
self,
|
||||
scope_type: ScopeLevel,
|
||||
scope_id: str,
|
||||
parent: ScopeContext | None = None,
|
||||
) -> ScopeContext:
|
||||
"""
|
||||
Create a new scope context.
|
||||
|
||||
Args:
|
||||
scope_type: Level of the scope
|
||||
scope_id: Unique identifier within the level
|
||||
parent: Optional parent scope
|
||||
|
||||
Returns:
|
||||
Created scope context
|
||||
|
||||
Raises:
|
||||
ValueError: If scope hierarchy is invalid
|
||||
"""
|
||||
# Validate hierarchy
|
||||
if parent is not None:
|
||||
self._validate_parent_child(parent.scope_type, scope_type)
|
||||
|
||||
# For non-global scopes without parent, auto-create parent chain
|
||||
if parent is None and scope_type != ScopeLevel.GLOBAL:
|
||||
parent = self._create_parent_chain(scope_type, scope_id)
|
||||
|
||||
context = ScopeContext(
|
||||
scope_type=scope_type,
|
||||
scope_id=scope_id,
|
||||
parent=parent,
|
||||
)
|
||||
|
||||
logger.debug(f"Created scope: {scope_type.value}:{scope_id}")
|
||||
return context
|
||||
|
||||
def _validate_parent_child(
|
||||
self,
|
||||
parent_type: ScopeLevel,
|
||||
child_type: ScopeLevel,
|
||||
) -> None:
|
||||
"""Validate that parent-child relationship is valid."""
|
||||
parent_idx = self.SCOPE_ORDER.index(parent_type)
|
||||
child_idx = self.SCOPE_ORDER.index(child_type)
|
||||
|
||||
if child_idx <= parent_idx:
|
||||
raise ValueError(
|
||||
f"Invalid scope hierarchy: {child_type.value} cannot be child of {parent_type.value}"
|
||||
)
|
||||
|
||||
# Allow skipping levels (e.g., PROJECT -> SESSION is valid)
|
||||
# This enables flexible scope structures
|
||||
|
||||
def _create_parent_chain(
|
||||
self,
|
||||
target_type: ScopeLevel,
|
||||
scope_id: str,
|
||||
) -> ScopeContext:
|
||||
"""Create parent scope chain up to target type."""
|
||||
target_idx = self.SCOPE_ORDER.index(target_type)
|
||||
|
||||
# Start from global and build chain
|
||||
current: ScopeContext | None = None
|
||||
|
||||
for i in range(target_idx):
|
||||
level = self.SCOPE_ORDER[i]
|
||||
if level == ScopeLevel.GLOBAL:
|
||||
level_id = "global"
|
||||
else:
|
||||
# Use a default ID for intermediate levels
|
||||
level_id = f"default_{level.value}"
|
||||
|
||||
current = ScopeContext(
|
||||
scope_type=level,
|
||||
scope_id=level_id,
|
||||
parent=current,
|
||||
)
|
||||
|
||||
return current # type: ignore[return-value]
|
||||
|
||||
def create_scope_from_ids(
|
||||
self,
|
||||
project_id: UUID | None = None,
|
||||
agent_type_id: UUID | None = None,
|
||||
agent_instance_id: UUID | None = None,
|
||||
session_id: str | None = None,
|
||||
) -> ScopeContext:
|
||||
"""
|
||||
Create a scope context from individual IDs.
|
||||
|
||||
Automatically determines the most specific scope level
|
||||
based on provided IDs.
|
||||
|
||||
Args:
|
||||
project_id: Project UUID
|
||||
agent_type_id: Agent type UUID
|
||||
agent_instance_id: Agent instance UUID
|
||||
session_id: Session identifier
|
||||
|
||||
Returns:
|
||||
Scope context for the most specific level
|
||||
"""
|
||||
# Build scope chain from most general to most specific
|
||||
current: ScopeContext = ScopeContext(
|
||||
scope_type=ScopeLevel.GLOBAL,
|
||||
scope_id="global",
|
||||
parent=None,
|
||||
)
|
||||
|
||||
if project_id is not None:
|
||||
current = ScopeContext(
|
||||
scope_type=ScopeLevel.PROJECT,
|
||||
scope_id=str(project_id),
|
||||
parent=current,
|
||||
)
|
||||
|
||||
if agent_type_id is not None:
|
||||
current = ScopeContext(
|
||||
scope_type=ScopeLevel.AGENT_TYPE,
|
||||
scope_id=str(agent_type_id),
|
||||
parent=current,
|
||||
)
|
||||
|
||||
if agent_instance_id is not None:
|
||||
current = ScopeContext(
|
||||
scope_type=ScopeLevel.AGENT_INSTANCE,
|
||||
scope_id=str(agent_instance_id),
|
||||
parent=current,
|
||||
)
|
||||
|
||||
if session_id is not None:
|
||||
current = ScopeContext(
|
||||
scope_type=ScopeLevel.SESSION,
|
||||
scope_id=session_id,
|
||||
parent=current,
|
||||
)
|
||||
|
||||
return current
|
||||
|
||||
def get_policy(
|
||||
self,
|
||||
scope: ScopeContext,
|
||||
) -> ScopePolicy:
|
||||
"""
|
||||
Get the access policy for a scope.
|
||||
|
||||
Args:
|
||||
scope: Scope to get policy for
|
||||
|
||||
Returns:
|
||||
Policy for the scope
|
||||
"""
|
||||
key = self._scope_key(scope)
|
||||
|
||||
if key in self._policies:
|
||||
return self._policies[key]
|
||||
|
||||
# Return default policy for the scope level
|
||||
return self._default_policies.get(
|
||||
scope.scope_type,
|
||||
ScopePolicy(
|
||||
scope_type=scope.scope_type,
|
||||
scope_id=scope.scope_id,
|
||||
),
|
||||
)
|
||||
|
||||
def set_policy(
|
||||
self,
|
||||
scope: ScopeContext,
|
||||
policy: ScopePolicy,
|
||||
) -> None:
|
||||
"""
|
||||
Set the access policy for a scope.
|
||||
|
||||
Args:
|
||||
scope: Scope to set policy for
|
||||
policy: Policy to apply
|
||||
"""
|
||||
key = self._scope_key(scope)
|
||||
self._policies[key] = policy
|
||||
logger.info(f"Set policy for scope {key}")
|
||||
|
||||
def _scope_key(self, scope: ScopeContext) -> str:
|
||||
"""Generate a unique key for a scope."""
|
||||
return f"{scope.scope_type.value}:{scope.scope_id}"
|
||||
|
||||
def get_scope_depth(self, scope_type: ScopeLevel) -> int:
|
||||
"""Get the depth of a scope level in the hierarchy."""
|
||||
return self.SCOPE_ORDER.index(scope_type)
|
||||
|
||||
def get_parent_level(self, scope_type: ScopeLevel) -> ScopeLevel | None:
|
||||
"""Get the parent scope level for a given level."""
|
||||
idx = self.SCOPE_ORDER.index(scope_type)
|
||||
if idx == 0:
|
||||
return None
|
||||
return self.SCOPE_ORDER[idx - 1]
|
||||
|
||||
def get_child_level(self, scope_type: ScopeLevel) -> ScopeLevel | None:
|
||||
"""Get the child scope level for a given level."""
|
||||
idx = self.SCOPE_ORDER.index(scope_type)
|
||||
if idx >= len(self.SCOPE_ORDER) - 1:
|
||||
return None
|
||||
return self.SCOPE_ORDER[idx + 1]
|
||||
|
||||
def is_ancestor(
|
||||
self,
|
||||
potential_ancestor: ScopeContext,
|
||||
descendant: ScopeContext,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if one scope is an ancestor of another.
|
||||
|
||||
Args:
|
||||
potential_ancestor: Scope to check as ancestor
|
||||
descendant: Scope to check as descendant
|
||||
|
||||
Returns:
|
||||
True if ancestor relationship exists
|
||||
"""
|
||||
current = descendant.parent
|
||||
while current is not None:
|
||||
if (
|
||||
current.scope_type == potential_ancestor.scope_type
|
||||
and current.scope_id == potential_ancestor.scope_id
|
||||
):
|
||||
return True
|
||||
current = current.parent
|
||||
return False
|
||||
|
||||
def get_common_ancestor(
|
||||
self,
|
||||
scope_a: ScopeContext,
|
||||
scope_b: ScopeContext,
|
||||
) -> ScopeContext | None:
|
||||
"""
|
||||
Find the nearest common ancestor of two scopes.
|
||||
|
||||
Args:
|
||||
scope_a: First scope
|
||||
scope_b: Second scope
|
||||
|
||||
Returns:
|
||||
Common ancestor or None if none exists
|
||||
"""
|
||||
# Get ancestors of scope_a
|
||||
ancestors_a: set[str] = set()
|
||||
current: ScopeContext | None = scope_a
|
||||
while current is not None:
|
||||
ancestors_a.add(self._scope_key(current))
|
||||
current = current.parent
|
||||
|
||||
# Find first ancestor of scope_b that's in ancestors_a
|
||||
current = scope_b
|
||||
while current is not None:
|
||||
if self._scope_key(current) in ancestors_a:
|
||||
return current
|
||||
current = current.parent
|
||||
|
||||
return None
|
||||
|
||||
def can_access(
|
||||
self,
|
||||
accessor_scope: ScopeContext,
|
||||
target_scope: ScopeContext,
|
||||
operation: str = "read",
|
||||
) -> bool:
|
||||
"""
|
||||
Check if accessor scope can access target scope.
|
||||
|
||||
Access rules:
|
||||
- A scope can always access itself
|
||||
- A scope can access ancestors (if inheritance allowed)
|
||||
- A scope CANNOT access descendants (privacy)
|
||||
- Sibling scopes cannot access each other
|
||||
|
||||
Args:
|
||||
accessor_scope: Scope attempting access
|
||||
target_scope: Scope being accessed
|
||||
operation: Type of operation (read/write)
|
||||
|
||||
Returns:
|
||||
True if access is allowed
|
||||
"""
|
||||
# Same scope - always allowed
|
||||
if (
|
||||
accessor_scope.scope_type == target_scope.scope_type
|
||||
and accessor_scope.scope_id == target_scope.scope_id
|
||||
):
|
||||
policy = self.get_policy(target_scope)
|
||||
if operation == "write":
|
||||
return policy.allows_write()
|
||||
return policy.allows_read()
|
||||
|
||||
# Check if target is ancestor (inheritance)
|
||||
if self.is_ancestor(target_scope, accessor_scope):
|
||||
policy = self.get_policy(target_scope)
|
||||
if not policy.allows_inherit():
|
||||
return False
|
||||
if operation == "write":
|
||||
return policy.allows_write()
|
||||
return policy.allows_read()
|
||||
|
||||
# Check if accessor is ancestor of target (downward access)
|
||||
# This is NOT allowed - parents cannot access children's memories
|
||||
if self.is_ancestor(accessor_scope, target_scope):
|
||||
return False
|
||||
|
||||
# Sibling scopes cannot access each other
|
||||
return False
|
||||
|
||||
|
||||
# Singleton manager instance with thread-safe initialization
|
||||
_manager: ScopeManager | None = None
|
||||
_manager_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_scope_manager() -> ScopeManager:
|
||||
"""Get the singleton scope manager instance (thread-safe)."""
|
||||
global _manager
|
||||
if _manager is None:
|
||||
with _manager_lock:
|
||||
# Double-check locking pattern
|
||||
if _manager is None:
|
||||
_manager = ScopeManager()
|
||||
return _manager
|
||||
|
||||
|
||||
def reset_scope_manager() -> None:
|
||||
"""Reset the scope manager singleton (for testing)."""
|
||||
global _manager
|
||||
with _manager_lock:
|
||||
_manager = None
|
||||
27
backend/app/services/memory/semantic/__init__.py
Normal file
27
backend/app/services/memory/semantic/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
||||
# app/services/memory/semantic/__init__.py
|
||||
"""
|
||||
Semantic Memory
|
||||
|
||||
Fact storage with triple format (subject, predicate, object)
|
||||
and semantic search capabilities.
|
||||
"""
|
||||
|
||||
from .extraction import (
|
||||
ExtractedFact,
|
||||
ExtractionContext,
|
||||
FactExtractor,
|
||||
get_fact_extractor,
|
||||
)
|
||||
from .memory import SemanticMemory
|
||||
from .verification import FactConflict, FactVerifier, VerificationResult
|
||||
|
||||
__all__ = [
|
||||
"ExtractedFact",
|
||||
"ExtractionContext",
|
||||
"FactConflict",
|
||||
"FactExtractor",
|
||||
"FactVerifier",
|
||||
"SemanticMemory",
|
||||
"VerificationResult",
|
||||
"get_fact_extractor",
|
||||
]
|
||||
313
backend/app/services/memory/semantic/extraction.py
Normal file
313
backend/app/services/memory/semantic/extraction.py
Normal file
@@ -0,0 +1,313 @@
|
||||
# app/services/memory/semantic/extraction.py
|
||||
"""
|
||||
Fact Extraction from Episodes.
|
||||
|
||||
Provides utilities for extracting semantic facts (subject-predicate-object triples)
|
||||
from episodic memories and other text sources.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from app.services.memory.types import Episode, FactCreate, Outcome
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractionContext:
|
||||
"""Context for fact extraction."""
|
||||
|
||||
project_id: Any | None = None
|
||||
source_episode_id: Any | None = None
|
||||
min_confidence: float = 0.5
|
||||
max_facts_per_source: int = 10
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractedFact:
|
||||
"""A fact extracted from text before storage."""
|
||||
|
||||
subject: str
|
||||
predicate: str
|
||||
object: str
|
||||
confidence: float
|
||||
source_text: str = ""
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_fact_create(
|
||||
self,
|
||||
project_id: Any | None = None,
|
||||
source_episode_ids: list[Any] | None = None,
|
||||
) -> FactCreate:
|
||||
"""Convert to FactCreate for storage."""
|
||||
return FactCreate(
|
||||
subject=self.subject,
|
||||
predicate=self.predicate,
|
||||
object=self.object,
|
||||
confidence=self.confidence,
|
||||
project_id=project_id,
|
||||
source_episode_ids=source_episode_ids or [],
|
||||
)
|
||||
|
||||
|
||||
class FactExtractor:
|
||||
"""
|
||||
Extracts facts from episodes and text.
|
||||
|
||||
This is a rule-based extractor. In production, this would be
|
||||
replaced or augmented with LLM-based extraction for better accuracy.
|
||||
"""
|
||||
|
||||
# Common predicates we can detect
|
||||
PREDICATE_PATTERNS: ClassVar[dict[str, str]] = {
|
||||
"uses": r"(?:uses?|using|utilizes?)",
|
||||
"requires": r"(?:requires?|needs?|depends?\s+on)",
|
||||
"is_a": r"(?:is\s+a|is\s+an|are\s+a|are)",
|
||||
"has": r"(?:has|have|contains?)",
|
||||
"part_of": r"(?:part\s+of|belongs?\s+to|member\s+of)",
|
||||
"causes": r"(?:causes?|leads?\s+to|results?\s+in)",
|
||||
"prevents": r"(?:prevents?|avoids?|stops?)",
|
||||
"solves": r"(?:solves?|fixes?|resolves?)",
|
||||
}
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize extractor."""
|
||||
self._compiled_patterns = {
|
||||
pred: re.compile(pattern, re.IGNORECASE)
|
||||
for pred, pattern in self.PREDICATE_PATTERNS.items()
|
||||
}
|
||||
|
||||
def extract_from_episode(
|
||||
self,
|
||||
episode: Episode,
|
||||
context: ExtractionContext | None = None,
|
||||
) -> list[ExtractedFact]:
|
||||
"""
|
||||
Extract facts from an episode.
|
||||
|
||||
Args:
|
||||
episode: Episode to extract from
|
||||
context: Optional extraction context
|
||||
|
||||
Returns:
|
||||
List of extracted facts
|
||||
"""
|
||||
ctx = context or ExtractionContext()
|
||||
facts: list[ExtractedFact] = []
|
||||
|
||||
# Extract from task description
|
||||
task_facts = self._extract_from_text(
|
||||
episode.task_description,
|
||||
source_prefix=episode.task_type,
|
||||
)
|
||||
facts.extend(task_facts)
|
||||
|
||||
# Extract from lessons learned
|
||||
for lesson in episode.lessons_learned:
|
||||
lesson_facts = self._extract_from_lesson(lesson, episode)
|
||||
facts.extend(lesson_facts)
|
||||
|
||||
# Extract outcome-based facts
|
||||
outcome_facts = self._extract_outcome_facts(episode)
|
||||
facts.extend(outcome_facts)
|
||||
|
||||
# Limit and filter
|
||||
facts = [f for f in facts if f.confidence >= ctx.min_confidence]
|
||||
facts = facts[: ctx.max_facts_per_source]
|
||||
|
||||
logger.debug(f"Extracted {len(facts)} facts from episode {episode.id}")
|
||||
|
||||
return facts
|
||||
|
||||
def _extract_from_text(
|
||||
self,
|
||||
text: str,
|
||||
source_prefix: str = "",
|
||||
) -> list[ExtractedFact]:
|
||||
"""Extract facts from free-form text using pattern matching."""
|
||||
facts: list[ExtractedFact] = []
|
||||
|
||||
if not text or len(text) < 10:
|
||||
return facts
|
||||
|
||||
# Split into sentences
|
||||
sentences = re.split(r"[.!?]+", text)
|
||||
|
||||
for sentence in sentences:
|
||||
sentence = sentence.strip()
|
||||
if len(sentence) < 10:
|
||||
continue
|
||||
|
||||
# Try to match predicate patterns
|
||||
for predicate, pattern in self._compiled_patterns.items():
|
||||
match = pattern.search(sentence)
|
||||
if match:
|
||||
# Extract subject (text before predicate)
|
||||
subject = sentence[: match.start()].strip()
|
||||
# Extract object (text after predicate)
|
||||
obj = sentence[match.end() :].strip()
|
||||
|
||||
if len(subject) > 2 and len(obj) > 2:
|
||||
facts.append(
|
||||
ExtractedFact(
|
||||
subject=subject[:200], # Limit length
|
||||
predicate=predicate,
|
||||
object=obj[:500],
|
||||
confidence=0.6, # Medium confidence for pattern matching
|
||||
source_text=sentence,
|
||||
)
|
||||
)
|
||||
break # One fact per sentence
|
||||
|
||||
return facts
|
||||
|
||||
def _extract_from_lesson(
|
||||
self,
|
||||
lesson: str,
|
||||
episode: Episode,
|
||||
) -> list[ExtractedFact]:
|
||||
"""Extract facts from a lesson learned."""
|
||||
facts: list[ExtractedFact] = []
|
||||
|
||||
if not lesson or len(lesson) < 10:
|
||||
return facts
|
||||
|
||||
# Lessons are typically in the form "Always do X" or "Never do Y"
|
||||
# or "When X, do Y"
|
||||
|
||||
# Direct lesson fact
|
||||
facts.append(
|
||||
ExtractedFact(
|
||||
subject=episode.task_type,
|
||||
predicate="lesson_learned",
|
||||
object=lesson,
|
||||
confidence=0.8, # High confidence for explicit lessons
|
||||
source_text=lesson,
|
||||
metadata={"outcome": episode.outcome.value},
|
||||
)
|
||||
)
|
||||
|
||||
# Extract conditional patterns
|
||||
conditional_match = re.match(
|
||||
r"(?:when|if)\s+(.+?),\s*(.+)",
|
||||
lesson,
|
||||
re.IGNORECASE,
|
||||
)
|
||||
if conditional_match:
|
||||
condition, action = conditional_match.groups()
|
||||
facts.append(
|
||||
ExtractedFact(
|
||||
subject=condition.strip(),
|
||||
predicate="requires_action",
|
||||
object=action.strip(),
|
||||
confidence=0.7,
|
||||
source_text=lesson,
|
||||
)
|
||||
)
|
||||
|
||||
# Extract "always/never" patterns
|
||||
always_match = re.match(
|
||||
r"(?:always)\s+(.+)",
|
||||
lesson,
|
||||
re.IGNORECASE,
|
||||
)
|
||||
if always_match:
|
||||
facts.append(
|
||||
ExtractedFact(
|
||||
subject=episode.task_type,
|
||||
predicate="best_practice",
|
||||
object=always_match.group(1).strip(),
|
||||
confidence=0.85,
|
||||
source_text=lesson,
|
||||
)
|
||||
)
|
||||
|
||||
never_match = re.match(
|
||||
r"(?:never|avoid)\s+(.+)",
|
||||
lesson,
|
||||
re.IGNORECASE,
|
||||
)
|
||||
if never_match:
|
||||
facts.append(
|
||||
ExtractedFact(
|
||||
subject=episode.task_type,
|
||||
predicate="anti_pattern",
|
||||
object=never_match.group(1).strip(),
|
||||
confidence=0.85,
|
||||
source_text=lesson,
|
||||
)
|
||||
)
|
||||
|
||||
return facts
|
||||
|
||||
def _extract_outcome_facts(
|
||||
self,
|
||||
episode: Episode,
|
||||
) -> list[ExtractedFact]:
|
||||
"""Extract facts based on episode outcome."""
|
||||
facts: list[ExtractedFact] = []
|
||||
|
||||
# Create fact based on outcome
|
||||
if episode.outcome == Outcome.SUCCESS:
|
||||
if episode.outcome_details:
|
||||
facts.append(
|
||||
ExtractedFact(
|
||||
subject=episode.task_type,
|
||||
predicate="successful_approach",
|
||||
object=episode.outcome_details[:500],
|
||||
confidence=0.75,
|
||||
source_text=episode.outcome_details,
|
||||
)
|
||||
)
|
||||
elif episode.outcome == Outcome.FAILURE:
|
||||
if episode.outcome_details:
|
||||
facts.append(
|
||||
ExtractedFact(
|
||||
subject=episode.task_type,
|
||||
predicate="known_failure_mode",
|
||||
object=episode.outcome_details[:500],
|
||||
confidence=0.8, # High confidence for failures
|
||||
source_text=episode.outcome_details,
|
||||
)
|
||||
)
|
||||
|
||||
return facts
|
||||
|
||||
def extract_from_text(
|
||||
self,
|
||||
text: str,
|
||||
context: ExtractionContext | None = None,
|
||||
) -> list[ExtractedFact]:
|
||||
"""
|
||||
Extract facts from arbitrary text.
|
||||
|
||||
Args:
|
||||
text: Text to extract from
|
||||
context: Optional extraction context
|
||||
|
||||
Returns:
|
||||
List of extracted facts
|
||||
"""
|
||||
ctx = context or ExtractionContext()
|
||||
|
||||
facts = self._extract_from_text(text)
|
||||
|
||||
# Filter by confidence
|
||||
facts = [f for f in facts if f.confidence >= ctx.min_confidence]
|
||||
|
||||
return facts[: ctx.max_facts_per_source]
|
||||
|
||||
|
||||
# Singleton extractor instance
|
||||
_extractor: FactExtractor | None = None
|
||||
|
||||
|
||||
def get_fact_extractor() -> FactExtractor:
|
||||
"""Get the singleton fact extractor instance."""
|
||||
global _extractor
|
||||
if _extractor is None:
|
||||
_extractor = FactExtractor()
|
||||
return _extractor
|
||||
767
backend/app/services/memory/semantic/memory.py
Normal file
767
backend/app/services/memory/semantic/memory.py
Normal file
@@ -0,0 +1,767 @@
|
||||
# app/services/memory/semantic/memory.py
|
||||
"""
|
||||
Semantic Memory Implementation.
|
||||
|
||||
Provides fact storage and retrieval using subject-predicate-object triples.
|
||||
Supports semantic search, confidence scoring, and fact reinforcement.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_, desc, or_, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.memory.fact import Fact as FactModel
|
||||
from app.services.memory.config import get_memory_settings
|
||||
from app.services.memory.types import Episode, Fact, FactCreate, RetrievalResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _escape_like_pattern(pattern: str) -> str:
|
||||
"""
|
||||
Escape SQL LIKE/ILIKE special characters to prevent pattern injection.
|
||||
|
||||
Characters escaped:
|
||||
- % (matches zero or more characters)
|
||||
- _ (matches exactly one character)
|
||||
- \\ (escape character itself)
|
||||
|
||||
Args:
|
||||
pattern: Raw search pattern from user input
|
||||
|
||||
Returns:
|
||||
Escaped pattern safe for use in LIKE/ILIKE queries
|
||||
"""
|
||||
# Escape backslash first, then the wildcards
|
||||
return pattern.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
|
||||
|
||||
|
||||
def _model_to_fact(model: FactModel) -> Fact:
|
||||
"""Convert SQLAlchemy model to Fact dataclass."""
|
||||
# SQLAlchemy Column types are inferred as Column[T] by mypy, but at runtime
|
||||
# they return actual values. We use type: ignore to handle this mismatch.
|
||||
return Fact(
|
||||
id=model.id, # type: ignore[arg-type]
|
||||
project_id=model.project_id, # type: ignore[arg-type]
|
||||
subject=model.subject, # type: ignore[arg-type]
|
||||
predicate=model.predicate, # type: ignore[arg-type]
|
||||
object=model.object, # type: ignore[arg-type]
|
||||
confidence=model.confidence, # type: ignore[arg-type]
|
||||
source_episode_ids=model.source_episode_ids or [], # type: ignore[arg-type]
|
||||
first_learned=model.first_learned, # type: ignore[arg-type]
|
||||
last_reinforced=model.last_reinforced, # type: ignore[arg-type]
|
||||
reinforcement_count=model.reinforcement_count, # type: ignore[arg-type]
|
||||
embedding=None, # Don't expose raw embedding
|
||||
created_at=model.created_at, # type: ignore[arg-type]
|
||||
updated_at=model.updated_at, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
|
||||
class SemanticMemory:
|
||||
"""
|
||||
Semantic Memory Service.
|
||||
|
||||
Provides fact storage and retrieval:
|
||||
- Store facts as subject-predicate-object triples
|
||||
- Semantic search over facts
|
||||
- Entity-based retrieval
|
||||
- Confidence scoring and decay
|
||||
- Fact reinforcement on repeated learning
|
||||
- Conflict resolution
|
||||
|
||||
Performance target: <100ms P95 for retrieval
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
embedding_generator: Any | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize semantic memory.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
embedding_generator: Optional embedding generator for semantic search
|
||||
"""
|
||||
self._session = session
|
||||
self._embedding_generator = embedding_generator
|
||||
self._settings = get_memory_settings()
|
||||
|
||||
@classmethod
|
||||
async def create(
|
||||
cls,
|
||||
session: AsyncSession,
|
||||
embedding_generator: Any | None = None,
|
||||
) -> "SemanticMemory":
|
||||
"""
|
||||
Factory method to create SemanticMemory.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
embedding_generator: Optional embedding generator
|
||||
|
||||
Returns:
|
||||
Configured SemanticMemory instance
|
||||
"""
|
||||
return cls(session=session, embedding_generator=embedding_generator)
|
||||
|
||||
# =========================================================================
|
||||
# Fact Storage
|
||||
# =========================================================================
|
||||
|
||||
async def store_fact(self, fact: FactCreate) -> Fact:
|
||||
"""
|
||||
Store a new fact or reinforce an existing one.
|
||||
|
||||
If a fact with the same triple (subject, predicate, object) exists
|
||||
in the same scope, it will be reinforced instead of duplicated.
|
||||
|
||||
Args:
|
||||
fact: Fact data to store
|
||||
|
||||
Returns:
|
||||
The created or reinforced fact
|
||||
"""
|
||||
# Check for existing fact with same triple
|
||||
existing = await self._find_existing_fact(
|
||||
project_id=fact.project_id,
|
||||
subject=fact.subject,
|
||||
predicate=fact.predicate,
|
||||
object=fact.object,
|
||||
)
|
||||
|
||||
if existing is not None:
|
||||
# Reinforce existing fact
|
||||
return await self.reinforce_fact(
|
||||
existing.id, # type: ignore[arg-type]
|
||||
source_episode_ids=fact.source_episode_ids,
|
||||
)
|
||||
|
||||
# Create new fact
|
||||
now = datetime.now(UTC)
|
||||
|
||||
# Generate embedding if possible
|
||||
embedding = None
|
||||
if self._embedding_generator is not None:
|
||||
embedding_text = self._create_embedding_text(fact)
|
||||
embedding = await self._embedding_generator.generate(embedding_text)
|
||||
|
||||
model = FactModel(
|
||||
project_id=fact.project_id,
|
||||
subject=fact.subject,
|
||||
predicate=fact.predicate,
|
||||
object=fact.object,
|
||||
confidence=fact.confidence,
|
||||
source_episode_ids=fact.source_episode_ids,
|
||||
first_learned=now,
|
||||
last_reinforced=now,
|
||||
reinforcement_count=1,
|
||||
embedding=embedding,
|
||||
)
|
||||
|
||||
self._session.add(model)
|
||||
await self._session.flush()
|
||||
await self._session.refresh(model)
|
||||
|
||||
logger.info(
|
||||
f"Stored new fact: {fact.subject} - {fact.predicate} - {fact.object[:50]}..."
|
||||
)
|
||||
|
||||
return _model_to_fact(model)
|
||||
|
||||
async def _find_existing_fact(
|
||||
self,
|
||||
project_id: UUID | None,
|
||||
subject: str,
|
||||
predicate: str,
|
||||
object: str,
|
||||
) -> FactModel | None:
|
||||
"""Find an existing fact with the same triple in the same scope."""
|
||||
query = select(FactModel).where(
|
||||
and_(
|
||||
FactModel.subject == subject,
|
||||
FactModel.predicate == predicate,
|
||||
FactModel.object == object,
|
||||
)
|
||||
)
|
||||
|
||||
if project_id is not None:
|
||||
query = query.where(FactModel.project_id == project_id)
|
||||
else:
|
||||
query = query.where(FactModel.project_id.is_(None))
|
||||
|
||||
result = await self._session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
def _create_embedding_text(self, fact: FactCreate) -> str:
|
||||
"""Create text for embedding from fact data."""
|
||||
return f"{fact.subject} {fact.predicate} {fact.object}"
|
||||
|
||||
# =========================================================================
|
||||
# Fact Retrieval
|
||||
# =========================================================================
|
||||
|
||||
async def search_facts(
|
||||
self,
|
||||
query: str,
|
||||
project_id: UUID | None = None,
|
||||
limit: int = 10,
|
||||
min_confidence: float | None = None,
|
||||
) -> list[Fact]:
|
||||
"""
|
||||
Search for facts semantically similar to the query.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
project_id: Optional project to search within
|
||||
limit: Maximum results
|
||||
min_confidence: Optional minimum confidence filter
|
||||
|
||||
Returns:
|
||||
List of matching facts
|
||||
"""
|
||||
result = await self._search_facts_with_metadata(
|
||||
query=query,
|
||||
project_id=project_id,
|
||||
limit=limit,
|
||||
min_confidence=min_confidence,
|
||||
)
|
||||
return result.items
|
||||
|
||||
async def _search_facts_with_metadata(
|
||||
self,
|
||||
query: str,
|
||||
project_id: UUID | None = None,
|
||||
limit: int = 10,
|
||||
min_confidence: float | None = None,
|
||||
) -> RetrievalResult[Fact]:
|
||||
"""Search facts with full result metadata."""
|
||||
start_time = time.perf_counter()
|
||||
|
||||
min_conf = min_confidence or self._settings.semantic_min_confidence
|
||||
|
||||
# Build base query
|
||||
stmt = (
|
||||
select(FactModel)
|
||||
.where(FactModel.confidence >= min_conf)
|
||||
.order_by(desc(FactModel.confidence), desc(FactModel.last_reinforced))
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
# Apply project filter
|
||||
if project_id is not None:
|
||||
# Include both project-specific and global facts
|
||||
stmt = stmt.where(
|
||||
or_(
|
||||
FactModel.project_id == project_id,
|
||||
FactModel.project_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
# TODO: Implement proper vector similarity search when pgvector is integrated
|
||||
# For now, do text-based search on subject/predicate/object
|
||||
search_terms = query.lower().split()
|
||||
if search_terms:
|
||||
conditions = []
|
||||
for term in search_terms[:5]: # Limit to 5 terms
|
||||
# Escape SQL wildcards to prevent pattern injection
|
||||
escaped_term = _escape_like_pattern(term)
|
||||
term_pattern = f"%{escaped_term}%"
|
||||
conditions.append(
|
||||
or_(
|
||||
FactModel.subject.ilike(term_pattern),
|
||||
FactModel.predicate.ilike(term_pattern),
|
||||
FactModel.object.ilike(term_pattern),
|
||||
)
|
||||
)
|
||||
if conditions:
|
||||
stmt = stmt.where(or_(*conditions))
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
models = list(result.scalars().all())
|
||||
|
||||
latency_ms = (time.perf_counter() - start_time) * 1000
|
||||
|
||||
return RetrievalResult(
|
||||
items=[_model_to_fact(m) for m in models],
|
||||
total_count=len(models),
|
||||
query=query,
|
||||
retrieval_type="semantic",
|
||||
latency_ms=latency_ms,
|
||||
metadata={"min_confidence": min_conf},
|
||||
)
|
||||
|
||||
async def get_by_entity(
|
||||
self,
|
||||
entity: str,
|
||||
project_id: UUID | None = None,
|
||||
limit: int = 20,
|
||||
) -> list[Fact]:
|
||||
"""
|
||||
Get facts related to an entity (as subject or object).
|
||||
|
||||
Args:
|
||||
entity: Entity to search for
|
||||
project_id: Optional project to search within
|
||||
limit: Maximum results
|
||||
|
||||
Returns:
|
||||
List of facts mentioning the entity
|
||||
"""
|
||||
start_time = time.perf_counter()
|
||||
|
||||
# Escape SQL wildcards to prevent pattern injection
|
||||
escaped_entity = _escape_like_pattern(entity)
|
||||
entity_pattern = f"%{escaped_entity}%"
|
||||
|
||||
stmt = (
|
||||
select(FactModel)
|
||||
.where(
|
||||
or_(
|
||||
FactModel.subject.ilike(entity_pattern),
|
||||
FactModel.object.ilike(entity_pattern),
|
||||
)
|
||||
)
|
||||
.order_by(desc(FactModel.confidence), desc(FactModel.last_reinforced))
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
if project_id is not None:
|
||||
stmt = stmt.where(
|
||||
or_(
|
||||
FactModel.project_id == project_id,
|
||||
FactModel.project_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
models = list(result.scalars().all())
|
||||
|
||||
latency_ms = (time.perf_counter() - start_time) * 1000
|
||||
logger.debug(
|
||||
f"get_by_entity({entity}) returned {len(models)} facts in {latency_ms:.1f}ms"
|
||||
)
|
||||
|
||||
return [_model_to_fact(m) for m in models]
|
||||
|
||||
async def get_by_subject(
|
||||
self,
|
||||
subject: str,
|
||||
project_id: UUID | None = None,
|
||||
predicate: str | None = None,
|
||||
limit: int = 20,
|
||||
) -> list[Fact]:
|
||||
"""
|
||||
Get facts with a specific subject.
|
||||
|
||||
Args:
|
||||
subject: Subject to search for
|
||||
project_id: Optional project to search within
|
||||
predicate: Optional predicate filter
|
||||
limit: Maximum results
|
||||
|
||||
Returns:
|
||||
List of facts with matching subject
|
||||
"""
|
||||
stmt = (
|
||||
select(FactModel)
|
||||
.where(FactModel.subject == subject)
|
||||
.order_by(desc(FactModel.confidence))
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
if predicate is not None:
|
||||
stmt = stmt.where(FactModel.predicate == predicate)
|
||||
|
||||
if project_id is not None:
|
||||
stmt = stmt.where(
|
||||
or_(
|
||||
FactModel.project_id == project_id,
|
||||
FactModel.project_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
models = list(result.scalars().all())
|
||||
|
||||
return [_model_to_fact(m) for m in models]
|
||||
|
||||
async def get_by_id(self, fact_id: UUID) -> Fact | None:
|
||||
"""Get a fact by ID."""
|
||||
query = select(FactModel).where(FactModel.id == fact_id)
|
||||
result = await self._session.execute(query)
|
||||
model = result.scalar_one_or_none()
|
||||
return _model_to_fact(model) if model else None
|
||||
|
||||
# =========================================================================
|
||||
# Fact Reinforcement
|
||||
# =========================================================================
|
||||
|
||||
async def reinforce_fact(
|
||||
self,
|
||||
fact_id: UUID,
|
||||
confidence_boost: float = 0.1,
|
||||
source_episode_ids: list[UUID] | None = None,
|
||||
) -> Fact:
|
||||
"""
|
||||
Reinforce a fact, increasing its confidence.
|
||||
|
||||
Args:
|
||||
fact_id: Fact to reinforce
|
||||
confidence_boost: Amount to increase confidence (default 0.1)
|
||||
source_episode_ids: Additional source episodes
|
||||
|
||||
Returns:
|
||||
Updated fact
|
||||
|
||||
Raises:
|
||||
ValueError: If fact not found
|
||||
"""
|
||||
query = select(FactModel).where(FactModel.id == fact_id)
|
||||
result = await self._session.execute(query)
|
||||
model = result.scalar_one_or_none()
|
||||
|
||||
if model is None:
|
||||
raise ValueError(f"Fact not found: {fact_id}")
|
||||
|
||||
# Calculate new confidence (max 1.0)
|
||||
current_confidence: float = model.confidence # type: ignore[assignment]
|
||||
new_confidence = min(1.0, current_confidence + confidence_boost)
|
||||
|
||||
# Merge source episode IDs
|
||||
current_sources: list[UUID] = model.source_episode_ids or [] # type: ignore[assignment]
|
||||
if source_episode_ids:
|
||||
# Add new sources, avoiding duplicates
|
||||
new_sources = list(set(current_sources + source_episode_ids))
|
||||
else:
|
||||
new_sources = current_sources
|
||||
|
||||
now = datetime.now(UTC)
|
||||
stmt = (
|
||||
update(FactModel)
|
||||
.where(FactModel.id == fact_id)
|
||||
.values(
|
||||
confidence=new_confidence,
|
||||
source_episode_ids=new_sources,
|
||||
last_reinforced=now,
|
||||
reinforcement_count=FactModel.reinforcement_count + 1,
|
||||
updated_at=now,
|
||||
)
|
||||
.returning(FactModel)
|
||||
)
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
updated_model = result.scalar_one()
|
||||
await self._session.flush()
|
||||
|
||||
logger.info(
|
||||
f"Reinforced fact {fact_id}: confidence {current_confidence:.2f} -> {new_confidence:.2f}"
|
||||
)
|
||||
|
||||
return _model_to_fact(updated_model)
|
||||
|
||||
async def deprecate_fact(
|
||||
self,
|
||||
fact_id: UUID,
|
||||
reason: str,
|
||||
new_confidence: float = 0.0,
|
||||
) -> Fact | None:
|
||||
"""
|
||||
Deprecate a fact by lowering its confidence.
|
||||
|
||||
Args:
|
||||
fact_id: Fact to deprecate
|
||||
reason: Reason for deprecation
|
||||
new_confidence: New confidence level (default 0.0)
|
||||
|
||||
Returns:
|
||||
Updated fact or None if not found
|
||||
"""
|
||||
query = select(FactModel).where(FactModel.id == fact_id)
|
||||
result = await self._session.execute(query)
|
||||
model = result.scalar_one_or_none()
|
||||
|
||||
if model is None:
|
||||
return None
|
||||
|
||||
now = datetime.now(UTC)
|
||||
stmt = (
|
||||
update(FactModel)
|
||||
.where(FactModel.id == fact_id)
|
||||
.values(
|
||||
confidence=max(0.0, new_confidence),
|
||||
updated_at=now,
|
||||
)
|
||||
.returning(FactModel)
|
||||
)
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
updated_model = result.scalar_one_or_none()
|
||||
await self._session.flush()
|
||||
|
||||
logger.info(f"Deprecated fact {fact_id}: {reason}")
|
||||
|
||||
return _model_to_fact(updated_model) if updated_model else None
|
||||
|
||||
# =========================================================================
|
||||
# Fact Extraction from Episodes
|
||||
# =========================================================================
|
||||
|
||||
async def extract_facts_from_episode(
|
||||
self,
|
||||
episode: Episode,
|
||||
) -> list[Fact]:
|
||||
"""
|
||||
Extract facts from an episode.
|
||||
|
||||
This is a placeholder for LLM-based fact extraction.
|
||||
In production, this would call an LLM to analyze the episode
|
||||
and extract subject-predicate-object triples.
|
||||
|
||||
Args:
|
||||
episode: Episode to extract facts from
|
||||
|
||||
Returns:
|
||||
List of extracted facts
|
||||
"""
|
||||
# For now, extract basic facts from lessons learned
|
||||
extracted_facts: list[Fact] = []
|
||||
|
||||
for lesson in episode.lessons_learned:
|
||||
if len(lesson) > 10: # Skip very short lessons
|
||||
fact_create = FactCreate(
|
||||
subject=episode.task_type,
|
||||
predicate="lesson_learned",
|
||||
object=lesson,
|
||||
confidence=0.7, # Lessons start with moderate confidence
|
||||
project_id=episode.project_id,
|
||||
source_episode_ids=[episode.id],
|
||||
)
|
||||
fact = await self.store_fact(fact_create)
|
||||
extracted_facts.append(fact)
|
||||
|
||||
logger.debug(
|
||||
f"Extracted {len(extracted_facts)} facts from episode {episode.id}"
|
||||
)
|
||||
|
||||
return extracted_facts
|
||||
|
||||
# =========================================================================
|
||||
# Conflict Resolution
|
||||
# =========================================================================
|
||||
|
||||
async def resolve_conflict(
|
||||
self,
|
||||
fact_ids: list[UUID],
|
||||
keep_fact_id: UUID | None = None,
|
||||
) -> Fact | None:
|
||||
"""
|
||||
Resolve a conflict between multiple facts.
|
||||
|
||||
If keep_fact_id is specified, that fact is kept and others are deprecated.
|
||||
Otherwise, the fact with highest confidence is kept.
|
||||
|
||||
Args:
|
||||
fact_ids: IDs of conflicting facts
|
||||
keep_fact_id: Optional ID of fact to keep
|
||||
|
||||
Returns:
|
||||
The winning fact, or None if no facts found
|
||||
"""
|
||||
if not fact_ids:
|
||||
return None
|
||||
|
||||
# Load all facts
|
||||
query = select(FactModel).where(FactModel.id.in_(fact_ids))
|
||||
result = await self._session.execute(query)
|
||||
models = list(result.scalars().all())
|
||||
|
||||
if not models:
|
||||
return None
|
||||
|
||||
# Determine winner
|
||||
if keep_fact_id is not None:
|
||||
winner = next((m for m in models if m.id == keep_fact_id), None)
|
||||
if winner is None:
|
||||
# Fallback to highest confidence
|
||||
winner = max(models, key=lambda m: m.confidence)
|
||||
else:
|
||||
# Keep the fact with highest confidence
|
||||
winner = max(models, key=lambda m: m.confidence)
|
||||
|
||||
# Deprecate losers
|
||||
for model in models:
|
||||
if model.id != winner.id:
|
||||
await self.deprecate_fact(
|
||||
model.id, # type: ignore[arg-type]
|
||||
reason=f"Conflict resolution: superseded by {winner.id}",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Resolved conflict between {len(fact_ids)} facts, keeping {winner.id}"
|
||||
)
|
||||
|
||||
return _model_to_fact(winner)
|
||||
|
||||
# =========================================================================
|
||||
# Confidence Decay
|
||||
# =========================================================================
|
||||
|
||||
async def apply_confidence_decay(
|
||||
self,
|
||||
project_id: UUID | None = None,
|
||||
decay_factor: float = 0.01,
|
||||
) -> int:
|
||||
"""
|
||||
Apply confidence decay to facts that haven't been reinforced recently.
|
||||
|
||||
Args:
|
||||
project_id: Optional project to apply decay to
|
||||
decay_factor: Decay factor per day (default 0.01)
|
||||
|
||||
Returns:
|
||||
Number of facts affected
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
decay_days = self._settings.semantic_confidence_decay_days
|
||||
min_conf = self._settings.semantic_min_confidence
|
||||
|
||||
# Calculate cutoff date
|
||||
from datetime import timedelta
|
||||
|
||||
cutoff = now - timedelta(days=decay_days)
|
||||
|
||||
# Find facts needing decay
|
||||
query = select(FactModel).where(
|
||||
and_(
|
||||
FactModel.last_reinforced < cutoff,
|
||||
FactModel.confidence > min_conf,
|
||||
)
|
||||
)
|
||||
|
||||
if project_id is not None:
|
||||
query = query.where(FactModel.project_id == project_id)
|
||||
|
||||
result = await self._session.execute(query)
|
||||
models = list(result.scalars().all())
|
||||
|
||||
# Apply decay
|
||||
updated_count = 0
|
||||
for model in models:
|
||||
# Calculate days since last reinforcement
|
||||
days_since: float = (now - model.last_reinforced).days
|
||||
|
||||
# Calculate decay: exponential decay based on days
|
||||
decay = decay_factor * (days_since - decay_days)
|
||||
new_confidence = max(min_conf, model.confidence - decay)
|
||||
|
||||
if new_confidence != model.confidence:
|
||||
await self._session.execute(
|
||||
update(FactModel)
|
||||
.where(FactModel.id == model.id)
|
||||
.values(confidence=new_confidence, updated_at=now)
|
||||
)
|
||||
updated_count += 1
|
||||
|
||||
await self._session.flush()
|
||||
logger.info(f"Applied confidence decay to {updated_count} facts")
|
||||
|
||||
return updated_count
|
||||
|
||||
# =========================================================================
|
||||
# Statistics
|
||||
# =========================================================================
|
||||
|
||||
async def get_stats(self, project_id: UUID | None = None) -> dict[str, Any]:
|
||||
"""
|
||||
Get statistics about semantic memory.
|
||||
|
||||
Args:
|
||||
project_id: Optional project to get stats for
|
||||
|
||||
Returns:
|
||||
Dictionary with statistics
|
||||
"""
|
||||
# Get all facts for this scope
|
||||
query = select(FactModel)
|
||||
if project_id is not None:
|
||||
query = query.where(
|
||||
or_(
|
||||
FactModel.project_id == project_id,
|
||||
FactModel.project_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
result = await self._session.execute(query)
|
||||
models = list(result.scalars().all())
|
||||
|
||||
if not models:
|
||||
return {
|
||||
"total_facts": 0,
|
||||
"avg_confidence": 0.0,
|
||||
"avg_reinforcement_count": 0.0,
|
||||
"high_confidence_count": 0,
|
||||
"low_confidence_count": 0,
|
||||
}
|
||||
|
||||
confidences = [m.confidence for m in models]
|
||||
reinforcements = [m.reinforcement_count for m in models]
|
||||
|
||||
return {
|
||||
"total_facts": len(models),
|
||||
"avg_confidence": sum(confidences) / len(confidences),
|
||||
"avg_reinforcement_count": sum(reinforcements) / len(reinforcements),
|
||||
"high_confidence_count": sum(1 for c in confidences if c >= 0.8),
|
||||
"low_confidence_count": sum(1 for c in confidences if c < 0.5),
|
||||
}
|
||||
|
||||
async def count(self, project_id: UUID | None = None) -> int:
|
||||
"""
|
||||
Count facts in scope.
|
||||
|
||||
Args:
|
||||
project_id: Optional project to count for
|
||||
|
||||
Returns:
|
||||
Number of facts
|
||||
"""
|
||||
query = select(FactModel)
|
||||
if project_id is not None:
|
||||
query = query.where(
|
||||
or_(
|
||||
FactModel.project_id == project_id,
|
||||
FactModel.project_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
result = await self._session.execute(query)
|
||||
return len(list(result.scalars().all()))
|
||||
|
||||
async def delete(self, fact_id: UUID) -> bool:
|
||||
"""
|
||||
Delete a fact.
|
||||
|
||||
Args:
|
||||
fact_id: Fact to delete
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
query = select(FactModel).where(FactModel.id == fact_id)
|
||||
result = await self._session.execute(query)
|
||||
model = result.scalar_one_or_none()
|
||||
|
||||
if model is None:
|
||||
return False
|
||||
|
||||
await self._session.delete(model)
|
||||
await self._session.flush()
|
||||
|
||||
logger.info(f"Deleted fact {fact_id}")
|
||||
return True
|
||||
363
backend/app/services/memory/semantic/verification.py
Normal file
363
backend/app/services/memory/semantic/verification.py
Normal file
@@ -0,0 +1,363 @@
|
||||
# app/services/memory/semantic/verification.py
|
||||
"""
|
||||
Fact Verification.
|
||||
|
||||
Provides utilities for verifying facts, detecting conflicts,
|
||||
and managing fact consistency.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, ClassVar
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_, or_, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.memory.fact import Fact as FactModel
|
||||
from app.services.memory.types import Fact
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VerificationResult:
|
||||
"""Result of fact verification."""
|
||||
|
||||
is_valid: bool
|
||||
confidence_adjustment: float = 0.0
|
||||
conflicts: list["FactConflict"] = field(default_factory=list)
|
||||
supporting_facts: list[Fact] = field(default_factory=list)
|
||||
messages: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FactConflict:
|
||||
"""Represents a conflict between two facts."""
|
||||
|
||||
fact_a_id: UUID
|
||||
fact_b_id: UUID
|
||||
conflict_type: str # "contradiction", "superseded", "partial_overlap"
|
||||
description: str
|
||||
suggested_resolution: str | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"fact_a_id": str(self.fact_a_id),
|
||||
"fact_b_id": str(self.fact_b_id),
|
||||
"conflict_type": self.conflict_type,
|
||||
"description": self.description,
|
||||
"suggested_resolution": self.suggested_resolution,
|
||||
}
|
||||
|
||||
|
||||
class FactVerifier:
|
||||
"""
|
||||
Verifies facts and detects conflicts.
|
||||
|
||||
Provides methods to:
|
||||
- Check if a fact conflicts with existing facts
|
||||
- Find supporting evidence for a fact
|
||||
- Detect contradictions in the fact base
|
||||
"""
|
||||
|
||||
# Predicates that are opposites/contradictions
|
||||
CONTRADICTORY_PREDICATES: ClassVar[set[tuple[str, str]]] = {
|
||||
("uses", "does_not_use"),
|
||||
("requires", "does_not_require"),
|
||||
("is_a", "is_not_a"),
|
||||
("causes", "prevents"),
|
||||
("allows", "prevents"),
|
||||
("supports", "does_not_support"),
|
||||
("best_practice", "anti_pattern"),
|
||||
}
|
||||
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
"""Initialize verifier with database session."""
|
||||
self._session = session
|
||||
|
||||
async def verify_fact(
|
||||
self,
|
||||
subject: str,
|
||||
predicate: str,
|
||||
obj: str,
|
||||
project_id: UUID | None = None,
|
||||
) -> VerificationResult:
|
||||
"""
|
||||
Verify a fact against existing facts.
|
||||
|
||||
Args:
|
||||
subject: Fact subject
|
||||
predicate: Fact predicate
|
||||
obj: Fact object
|
||||
project_id: Optional project scope
|
||||
|
||||
Returns:
|
||||
VerificationResult with verification details
|
||||
"""
|
||||
result = VerificationResult(is_valid=True)
|
||||
|
||||
# Check for direct contradictions
|
||||
conflicts = await self._find_contradictions(
|
||||
subject=subject,
|
||||
predicate=predicate,
|
||||
obj=obj,
|
||||
project_id=project_id,
|
||||
)
|
||||
result.conflicts = conflicts
|
||||
|
||||
if conflicts:
|
||||
result.is_valid = False
|
||||
result.messages.append(f"Found {len(conflicts)} conflicting fact(s)")
|
||||
# Reduce confidence based on conflicts
|
||||
result.confidence_adjustment = -0.1 * len(conflicts)
|
||||
|
||||
# Find supporting facts
|
||||
supporting = await self._find_supporting_facts(
|
||||
subject=subject,
|
||||
predicate=predicate,
|
||||
project_id=project_id,
|
||||
)
|
||||
result.supporting_facts = supporting
|
||||
|
||||
if supporting:
|
||||
result.messages.append(f"Found {len(supporting)} supporting fact(s)")
|
||||
# Boost confidence based on support
|
||||
result.confidence_adjustment += 0.05 * min(len(supporting), 3)
|
||||
|
||||
return result
|
||||
|
||||
async def _find_contradictions(
|
||||
self,
|
||||
subject: str,
|
||||
predicate: str,
|
||||
obj: str,
|
||||
project_id: UUID | None = None,
|
||||
) -> list[FactConflict]:
|
||||
"""Find facts that contradict the given fact."""
|
||||
conflicts: list[FactConflict] = []
|
||||
|
||||
# Find opposite predicates
|
||||
opposite_predicates = self._get_opposite_predicates(predicate)
|
||||
|
||||
if not opposite_predicates:
|
||||
return conflicts
|
||||
|
||||
# Search for contradicting facts
|
||||
query = select(FactModel).where(
|
||||
and_(
|
||||
FactModel.subject == subject,
|
||||
FactModel.predicate.in_(opposite_predicates),
|
||||
)
|
||||
)
|
||||
|
||||
if project_id is not None:
|
||||
query = query.where(
|
||||
or_(
|
||||
FactModel.project_id == project_id,
|
||||
FactModel.project_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
result = await self._session.execute(query)
|
||||
models = list(result.scalars().all())
|
||||
|
||||
for model in models:
|
||||
conflicts.append(
|
||||
FactConflict(
|
||||
fact_a_id=model.id, # type: ignore[arg-type]
|
||||
fact_b_id=UUID(
|
||||
"00000000-0000-0000-0000-000000000000"
|
||||
), # Placeholder for new fact
|
||||
conflict_type="contradiction",
|
||||
description=(
|
||||
f"'{subject} {predicate} {obj}' contradicts "
|
||||
f"'{model.subject} {model.predicate} {model.object}'"
|
||||
),
|
||||
suggested_resolution="Keep fact with higher confidence",
|
||||
)
|
||||
)
|
||||
|
||||
return conflicts
|
||||
|
||||
def _get_opposite_predicates(self, predicate: str) -> list[str]:
|
||||
"""Get predicates that are opposite to the given predicate."""
|
||||
opposites: list[str] = []
|
||||
|
||||
for pair in self.CONTRADICTORY_PREDICATES:
|
||||
if predicate in pair:
|
||||
opposites.extend(p for p in pair if p != predicate)
|
||||
|
||||
return opposites
|
||||
|
||||
async def _find_supporting_facts(
|
||||
self,
|
||||
subject: str,
|
||||
predicate: str,
|
||||
project_id: UUID | None = None,
|
||||
) -> list[Fact]:
|
||||
"""Find facts that support the given fact."""
|
||||
# Find facts with same subject and predicate
|
||||
query = (
|
||||
select(FactModel)
|
||||
.where(
|
||||
and_(
|
||||
FactModel.subject == subject,
|
||||
FactModel.predicate == predicate,
|
||||
FactModel.confidence >= 0.5,
|
||||
)
|
||||
)
|
||||
.limit(10)
|
||||
)
|
||||
|
||||
if project_id is not None:
|
||||
query = query.where(
|
||||
or_(
|
||||
FactModel.project_id == project_id,
|
||||
FactModel.project_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
result = await self._session.execute(query)
|
||||
models = list(result.scalars().all())
|
||||
|
||||
return [self._model_to_fact(m) for m in models]
|
||||
|
||||
async def find_all_conflicts(
|
||||
self,
|
||||
project_id: UUID | None = None,
|
||||
) -> list[FactConflict]:
|
||||
"""
|
||||
Find all conflicts in the fact base.
|
||||
|
||||
Args:
|
||||
project_id: Optional project scope
|
||||
|
||||
Returns:
|
||||
List of all detected conflicts
|
||||
"""
|
||||
conflicts: list[FactConflict] = []
|
||||
|
||||
# Get all facts
|
||||
query = select(FactModel)
|
||||
if project_id is not None:
|
||||
query = query.where(
|
||||
or_(
|
||||
FactModel.project_id == project_id,
|
||||
FactModel.project_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
result = await self._session.execute(query)
|
||||
models = list(result.scalars().all())
|
||||
|
||||
# Check each pair for conflicts
|
||||
for i, fact_a in enumerate(models):
|
||||
for fact_b in models[i + 1 :]:
|
||||
conflict = self._check_pair_conflict(fact_a, fact_b)
|
||||
if conflict:
|
||||
conflicts.append(conflict)
|
||||
|
||||
logger.info(f"Found {len(conflicts)} conflicts in fact base")
|
||||
|
||||
return conflicts
|
||||
|
||||
def _check_pair_conflict(
|
||||
self,
|
||||
fact_a: FactModel,
|
||||
fact_b: FactModel,
|
||||
) -> FactConflict | None:
|
||||
"""Check if two facts conflict."""
|
||||
# Same subject?
|
||||
if fact_a.subject != fact_b.subject:
|
||||
return None
|
||||
|
||||
# Contradictory predicates?
|
||||
opposite = self._get_opposite_predicates(fact_a.predicate) # type: ignore[arg-type]
|
||||
if fact_b.predicate not in opposite:
|
||||
return None
|
||||
|
||||
return FactConflict(
|
||||
fact_a_id=fact_a.id, # type: ignore[arg-type]
|
||||
fact_b_id=fact_b.id, # type: ignore[arg-type]
|
||||
conflict_type="contradiction",
|
||||
description=(
|
||||
f"'{fact_a.subject} {fact_a.predicate} {fact_a.object}' "
|
||||
f"contradicts '{fact_b.subject} {fact_b.predicate} {fact_b.object}'"
|
||||
),
|
||||
suggested_resolution="Deprecate fact with lower confidence",
|
||||
)
|
||||
|
||||
async def get_fact_reliability_score(
|
||||
self,
|
||||
fact_id: UUID,
|
||||
) -> float:
|
||||
"""
|
||||
Calculate a reliability score for a fact.
|
||||
|
||||
Based on:
|
||||
- Confidence score
|
||||
- Number of reinforcements
|
||||
- Number of supporting facts
|
||||
- Absence of conflicts
|
||||
|
||||
Args:
|
||||
fact_id: Fact to score
|
||||
|
||||
Returns:
|
||||
Reliability score (0.0 to 1.0)
|
||||
"""
|
||||
query = select(FactModel).where(FactModel.id == fact_id)
|
||||
result = await self._session.execute(query)
|
||||
model = result.scalar_one_or_none()
|
||||
|
||||
if model is None:
|
||||
return 0.0
|
||||
|
||||
# Base score from confidence - explicitly typed to avoid Column type issues
|
||||
score: float = float(model.confidence)
|
||||
|
||||
# Boost for reinforcements (diminishing returns)
|
||||
reinforcement_boost = min(0.2, float(model.reinforcement_count) * 0.02)
|
||||
score += reinforcement_boost
|
||||
|
||||
# Find supporting facts
|
||||
supporting = await self._find_supporting_facts(
|
||||
subject=model.subject, # type: ignore[arg-type]
|
||||
predicate=model.predicate, # type: ignore[arg-type]
|
||||
project_id=model.project_id, # type: ignore[arg-type]
|
||||
)
|
||||
support_boost = min(0.1, len(supporting) * 0.02)
|
||||
score += support_boost
|
||||
|
||||
# Check for conflicts
|
||||
conflicts = await self._find_contradictions(
|
||||
subject=model.subject, # type: ignore[arg-type]
|
||||
predicate=model.predicate, # type: ignore[arg-type]
|
||||
obj=model.object, # type: ignore[arg-type]
|
||||
project_id=model.project_id, # type: ignore[arg-type]
|
||||
)
|
||||
conflict_penalty = min(0.3, len(conflicts) * 0.1)
|
||||
score -= conflict_penalty
|
||||
|
||||
# Clamp to valid range
|
||||
return max(0.0, min(1.0, score))
|
||||
|
||||
def _model_to_fact(self, model: FactModel) -> Fact:
|
||||
"""Convert SQLAlchemy model to Fact dataclass."""
|
||||
return Fact(
|
||||
id=model.id, # type: ignore[arg-type]
|
||||
project_id=model.project_id, # type: ignore[arg-type]
|
||||
subject=model.subject, # type: ignore[arg-type]
|
||||
predicate=model.predicate, # type: ignore[arg-type]
|
||||
object=model.object, # type: ignore[arg-type]
|
||||
confidence=model.confidence, # type: ignore[arg-type]
|
||||
source_episode_ids=model.source_episode_ids or [], # type: ignore[arg-type]
|
||||
first_learned=model.first_learned, # type: ignore[arg-type]
|
||||
last_reinforced=model.last_reinforced, # type: ignore[arg-type]
|
||||
reinforcement_count=model.reinforcement_count, # type: ignore[arg-type]
|
||||
embedding=None,
|
||||
created_at=model.created_at, # type: ignore[arg-type]
|
||||
updated_at=model.updated_at, # type: ignore[arg-type]
|
||||
)
|
||||
328
backend/app/services/memory/types.py
Normal file
328
backend/app/services/memory/types.py
Normal file
@@ -0,0 +1,328 @@
|
||||
"""
|
||||
Memory System Types
|
||||
|
||||
Core type definitions and interfaces for the Agent Memory System.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
|
||||
def _utcnow() -> datetime:
|
||||
"""Get current UTC time as timezone-aware datetime."""
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
class MemoryType(str, Enum):
|
||||
"""Types of memory in the agent memory system."""
|
||||
|
||||
WORKING = "working"
|
||||
EPISODIC = "episodic"
|
||||
SEMANTIC = "semantic"
|
||||
PROCEDURAL = "procedural"
|
||||
|
||||
|
||||
class ScopeLevel(str, Enum):
|
||||
"""Hierarchical scoping levels for memory."""
|
||||
|
||||
GLOBAL = "global"
|
||||
PROJECT = "project"
|
||||
AGENT_TYPE = "agent_type"
|
||||
AGENT_INSTANCE = "agent_instance"
|
||||
SESSION = "session"
|
||||
|
||||
|
||||
class Outcome(str, Enum):
|
||||
"""Outcome of a task or episode."""
|
||||
|
||||
SUCCESS = "success"
|
||||
FAILURE = "failure"
|
||||
PARTIAL = "partial"
|
||||
ABANDONED = "abandoned"
|
||||
|
||||
|
||||
class ConsolidationStatus(str, Enum):
|
||||
"""Status of a memory consolidation job."""
|
||||
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class ConsolidationType(str, Enum):
|
||||
"""Types of memory consolidation."""
|
||||
|
||||
WORKING_TO_EPISODIC = "working_to_episodic"
|
||||
EPISODIC_TO_SEMANTIC = "episodic_to_semantic"
|
||||
EPISODIC_TO_PROCEDURAL = "episodic_to_procedural"
|
||||
PRUNING = "pruning"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScopeContext:
|
||||
"""Represents a memory scope with its hierarchy."""
|
||||
|
||||
scope_type: ScopeLevel
|
||||
scope_id: str
|
||||
parent: "ScopeContext | None" = None
|
||||
|
||||
def get_hierarchy(self) -> list["ScopeContext"]:
|
||||
"""Get the full scope hierarchy from root to this scope."""
|
||||
hierarchy: list[ScopeContext] = []
|
||||
current: ScopeContext | None = self
|
||||
while current is not None:
|
||||
hierarchy.insert(0, current)
|
||||
current = current.parent
|
||||
return hierarchy
|
||||
|
||||
def to_key_prefix(self) -> str:
|
||||
"""Convert scope to a key prefix for storage."""
|
||||
return f"{self.scope_type.value}:{self.scope_id}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryItem:
|
||||
"""Base class for all memory items."""
|
||||
|
||||
id: UUID
|
||||
memory_type: MemoryType
|
||||
scope_type: ScopeLevel
|
||||
scope_id: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def get_age_seconds(self) -> float:
|
||||
"""Get the age of this memory item in seconds."""
|
||||
return (_utcnow() - self.created_at).total_seconds()
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkingMemoryItem:
|
||||
"""A key-value item in working memory."""
|
||||
|
||||
id: UUID
|
||||
scope_type: ScopeLevel
|
||||
scope_id: str
|
||||
key: str
|
||||
value: Any
|
||||
expires_at: datetime | None = None
|
||||
created_at: datetime = field(default_factory=_utcnow)
|
||||
updated_at: datetime = field(default_factory=_utcnow)
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if this item has expired."""
|
||||
if self.expires_at is None:
|
||||
return False
|
||||
return _utcnow() > self.expires_at
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskState:
|
||||
"""Current state of a task in working memory."""
|
||||
|
||||
task_id: str
|
||||
task_type: str
|
||||
description: str
|
||||
status: str = "in_progress"
|
||||
current_step: int = 0
|
||||
total_steps: int = 0
|
||||
progress_percent: float = 0.0
|
||||
context: dict[str, Any] = field(default_factory=dict)
|
||||
started_at: datetime = field(default_factory=_utcnow)
|
||||
updated_at: datetime = field(default_factory=_utcnow)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Episode:
|
||||
"""An episodic memory - a recorded experience."""
|
||||
|
||||
id: UUID
|
||||
project_id: UUID
|
||||
agent_instance_id: UUID | None
|
||||
agent_type_id: UUID | None
|
||||
session_id: str
|
||||
task_type: str
|
||||
task_description: str
|
||||
actions: list[dict[str, Any]]
|
||||
context_summary: str
|
||||
outcome: Outcome
|
||||
outcome_details: str
|
||||
duration_seconds: float
|
||||
tokens_used: int
|
||||
lessons_learned: list[str]
|
||||
importance_score: float
|
||||
embedding: list[float] | None
|
||||
occurred_at: datetime
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
@dataclass
|
||||
class EpisodeCreate:
|
||||
"""Data required to create a new episode."""
|
||||
|
||||
project_id: UUID
|
||||
session_id: str
|
||||
task_type: str
|
||||
task_description: str
|
||||
actions: list[dict[str, Any]]
|
||||
context_summary: str
|
||||
outcome: Outcome
|
||||
outcome_details: str
|
||||
duration_seconds: float
|
||||
tokens_used: int
|
||||
lessons_learned: list[str] = field(default_factory=list)
|
||||
importance_score: float = 0.5
|
||||
agent_instance_id: UUID | None = None
|
||||
agent_type_id: UUID | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Fact:
|
||||
"""A semantic memory fact - a piece of knowledge."""
|
||||
|
||||
id: UUID
|
||||
project_id: UUID | None # None for global facts
|
||||
subject: str
|
||||
predicate: str
|
||||
object: str
|
||||
confidence: float
|
||||
source_episode_ids: list[UUID]
|
||||
first_learned: datetime
|
||||
last_reinforced: datetime
|
||||
reinforcement_count: int
|
||||
embedding: list[float] | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
@dataclass
|
||||
class FactCreate:
|
||||
"""Data required to create a new fact."""
|
||||
|
||||
subject: str
|
||||
predicate: str
|
||||
object: str
|
||||
confidence: float = 0.8
|
||||
project_id: UUID | None = None
|
||||
source_episode_ids: list[UUID] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Procedure:
|
||||
"""A procedural memory - a learned skill or procedure."""
|
||||
|
||||
id: UUID
|
||||
project_id: UUID | None
|
||||
agent_type_id: UUID | None
|
||||
name: str
|
||||
trigger_pattern: str
|
||||
steps: list[dict[str, Any]]
|
||||
success_count: int
|
||||
failure_count: int
|
||||
last_used: datetime | None
|
||||
embedding: list[float] | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
@property
|
||||
def success_rate(self) -> float:
|
||||
"""Calculate the success rate of this procedure."""
|
||||
total = self.success_count + self.failure_count
|
||||
if total == 0:
|
||||
return 0.0
|
||||
return self.success_count / total
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcedureCreate:
|
||||
"""Data required to create a new procedure."""
|
||||
|
||||
name: str
|
||||
trigger_pattern: str
|
||||
steps: list[dict[str, Any]]
|
||||
project_id: UUID | None = None
|
||||
agent_type_id: UUID | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Step:
|
||||
"""A single step in a procedure."""
|
||||
|
||||
order: int
|
||||
action: str
|
||||
parameters: dict[str, Any] = field(default_factory=dict)
|
||||
expected_outcome: str = ""
|
||||
fallback_action: str | None = None
|
||||
|
||||
|
||||
class MemoryStore[T: MemoryItem](ABC):
|
||||
"""Abstract base class for memory storage backends."""
|
||||
|
||||
@abstractmethod
|
||||
async def store(self, item: T) -> T:
|
||||
"""Store a memory item."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get(self, item_id: UUID) -> T | None:
|
||||
"""Get a memory item by ID."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def delete(self, item_id: UUID) -> bool:
|
||||
"""Delete a memory item."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def list(
|
||||
self,
|
||||
scope_type: ScopeLevel | None = None,
|
||||
scope_id: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> list[T]:
|
||||
"""List memory items with optional scope filtering."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def count(
|
||||
self,
|
||||
scope_type: ScopeLevel | None = None,
|
||||
scope_id: str | None = None,
|
||||
) -> int:
|
||||
"""Count memory items with optional scope filtering."""
|
||||
...
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrievalResult[T]:
|
||||
"""Result of a memory retrieval operation."""
|
||||
|
||||
items: list[T]
|
||||
total_count: int
|
||||
query: str
|
||||
retrieval_type: str
|
||||
latency_ms: float
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryStats:
|
||||
"""Statistics about memory usage."""
|
||||
|
||||
memory_type: MemoryType
|
||||
scope_type: ScopeLevel | None
|
||||
scope_id: str | None
|
||||
item_count: int
|
||||
total_size_bytes: int
|
||||
oldest_item_age_seconds: float
|
||||
newest_item_age_seconds: float
|
||||
avg_item_size_bytes: float
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
16
backend/app/services/memory/working/__init__.py
Normal file
16
backend/app/services/memory/working/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# app/services/memory/working/__init__.py
|
||||
"""
|
||||
Working Memory Implementation.
|
||||
|
||||
Provides short-term memory storage with Redis primary and in-memory fallback.
|
||||
"""
|
||||
|
||||
from .memory import WorkingMemory
|
||||
from .storage import InMemoryStorage, RedisStorage, WorkingMemoryStorage
|
||||
|
||||
__all__ = [
|
||||
"InMemoryStorage",
|
||||
"RedisStorage",
|
||||
"WorkingMemory",
|
||||
"WorkingMemoryStorage",
|
||||
]
|
||||
544
backend/app/services/memory/working/memory.py
Normal file
544
backend/app/services/memory/working/memory.py
Normal file
@@ -0,0 +1,544 @@
|
||||
# app/services/memory/working/memory.py
|
||||
"""
|
||||
Working Memory Implementation.
|
||||
|
||||
Provides session-scoped ephemeral memory with:
|
||||
- Key-value storage with TTL
|
||||
- Task state tracking
|
||||
- Scratchpad for reasoning steps
|
||||
- Checkpoint/snapshot support
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from dataclasses import asdict
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from app.services.memory.config import get_memory_settings
|
||||
from app.services.memory.exceptions import (
|
||||
MemoryConnectionError,
|
||||
MemoryNotFoundError,
|
||||
)
|
||||
from app.services.memory.types import ScopeContext, ScopeLevel, TaskState
|
||||
|
||||
from .storage import InMemoryStorage, RedisStorage, WorkingMemoryStorage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Reserved key prefixes for internal use
|
||||
_TASK_STATE_KEY = "_task_state"
|
||||
_SCRATCHPAD_KEY = "_scratchpad"
|
||||
_CHECKPOINT_PREFIX = "_checkpoint:"
|
||||
_METADATA_KEY = "_metadata"
|
||||
|
||||
|
||||
class WorkingMemory:
|
||||
"""
|
||||
Session-scoped working memory.
|
||||
|
||||
Provides ephemeral storage for agent's current task context:
|
||||
- Variables and intermediate data
|
||||
- Task state (current step, status, progress)
|
||||
- Scratchpad for reasoning steps
|
||||
- Checkpoints for recovery
|
||||
|
||||
Uses Redis as primary storage with in-memory fallback.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scope: ScopeContext,
|
||||
storage: WorkingMemoryStorage,
|
||||
default_ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize working memory for a scope.
|
||||
|
||||
Args:
|
||||
scope: The scope context (session, agent instance, etc.)
|
||||
storage: Storage backend (use create() factory for auto-configuration)
|
||||
default_ttl_seconds: Default TTL for keys (None = no expiration)
|
||||
"""
|
||||
self._scope = scope
|
||||
self._storage: WorkingMemoryStorage = storage
|
||||
self._default_ttl = default_ttl_seconds
|
||||
self._using_fallback = False
|
||||
self._initialized = False
|
||||
|
||||
@classmethod
|
||||
async def create(
|
||||
cls,
|
||||
scope: ScopeContext,
|
||||
default_ttl_seconds: int | None = None,
|
||||
) -> "WorkingMemory":
|
||||
"""
|
||||
Factory method to create WorkingMemory with auto-configured storage.
|
||||
|
||||
Attempts Redis first, falls back to in-memory if unavailable.
|
||||
"""
|
||||
settings = get_memory_settings()
|
||||
key_prefix = f"wm:{scope.to_key_prefix()}:"
|
||||
storage: WorkingMemoryStorage
|
||||
|
||||
# Try Redis first
|
||||
if settings.working_memory_backend == "redis":
|
||||
redis_storage = RedisStorage(key_prefix=key_prefix)
|
||||
try:
|
||||
if await redis_storage.is_healthy():
|
||||
logger.debug(f"Using Redis storage for scope {scope.scope_id}")
|
||||
instance = cls(
|
||||
scope=scope,
|
||||
storage=redis_storage,
|
||||
default_ttl_seconds=default_ttl_seconds
|
||||
or settings.working_memory_default_ttl_seconds,
|
||||
)
|
||||
await instance._initialize()
|
||||
return instance
|
||||
except MemoryConnectionError:
|
||||
logger.warning("Redis unavailable, falling back to in-memory storage")
|
||||
await redis_storage.close()
|
||||
|
||||
# Fall back to in-memory
|
||||
storage = InMemoryStorage(
|
||||
max_keys=settings.working_memory_max_items_per_session
|
||||
)
|
||||
instance = cls(
|
||||
scope=scope,
|
||||
storage=storage,
|
||||
default_ttl_seconds=default_ttl_seconds
|
||||
or settings.working_memory_default_ttl_seconds,
|
||||
)
|
||||
instance._using_fallback = True
|
||||
await instance._initialize()
|
||||
return instance
|
||||
|
||||
@classmethod
|
||||
async def for_session(
|
||||
cls,
|
||||
session_id: str,
|
||||
project_id: str | None = None,
|
||||
agent_instance_id: str | None = None,
|
||||
) -> "WorkingMemory":
|
||||
"""
|
||||
Convenience factory for session-scoped working memory.
|
||||
|
||||
Args:
|
||||
session_id: Unique session identifier
|
||||
project_id: Optional project context
|
||||
agent_instance_id: Optional agent instance context
|
||||
"""
|
||||
# Build scope hierarchy
|
||||
parent = None
|
||||
if project_id:
|
||||
parent = ScopeContext(
|
||||
scope_type=ScopeLevel.PROJECT,
|
||||
scope_id=project_id,
|
||||
)
|
||||
if agent_instance_id:
|
||||
parent = ScopeContext(
|
||||
scope_type=ScopeLevel.AGENT_INSTANCE,
|
||||
scope_id=agent_instance_id,
|
||||
parent=parent,
|
||||
)
|
||||
|
||||
scope = ScopeContext(
|
||||
scope_type=ScopeLevel.SESSION,
|
||||
scope_id=session_id,
|
||||
parent=parent,
|
||||
)
|
||||
|
||||
return await cls.create(scope=scope)
|
||||
|
||||
async def _initialize(self) -> None:
|
||||
"""Initialize working memory metadata."""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
metadata = {
|
||||
"scope_type": self._scope.scope_type.value,
|
||||
"scope_id": self._scope.scope_id,
|
||||
"created_at": datetime.now(UTC).isoformat(),
|
||||
"using_fallback": self._using_fallback,
|
||||
}
|
||||
await self._storage.set(_METADATA_KEY, metadata)
|
||||
self._initialized = True
|
||||
|
||||
@property
|
||||
def scope(self) -> ScopeContext:
|
||||
"""Get the scope context."""
|
||||
return self._scope
|
||||
|
||||
@property
|
||||
def is_using_fallback(self) -> bool:
|
||||
"""Check if using fallback in-memory storage."""
|
||||
return self._using_fallback
|
||||
|
||||
# =========================================================================
|
||||
# Basic Key-Value Operations
|
||||
# =========================================================================
|
||||
|
||||
async def set(
|
||||
self,
|
||||
key: str,
|
||||
value: Any,
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Store a value.
|
||||
|
||||
Args:
|
||||
key: The key to store under
|
||||
value: The value to store (must be JSON-serializable)
|
||||
ttl_seconds: Optional TTL (uses default if not specified)
|
||||
"""
|
||||
if key.startswith("_"):
|
||||
raise ValueError("Keys starting with '_' are reserved for internal use")
|
||||
|
||||
ttl = ttl_seconds if ttl_seconds is not None else self._default_ttl
|
||||
await self._storage.set(key, value, ttl)
|
||||
|
||||
async def get(self, key: str, default: Any = None) -> Any:
|
||||
"""
|
||||
Get a value.
|
||||
|
||||
Args:
|
||||
key: The key to retrieve
|
||||
default: Default value if key not found
|
||||
|
||||
Returns:
|
||||
The stored value or default
|
||||
"""
|
||||
result = await self._storage.get(key)
|
||||
return result if result is not None else default
|
||||
|
||||
async def delete(self, key: str) -> bool:
|
||||
"""
|
||||
Delete a key.
|
||||
|
||||
Args:
|
||||
key: The key to delete
|
||||
|
||||
Returns:
|
||||
True if the key existed
|
||||
"""
|
||||
if key.startswith("_"):
|
||||
raise ValueError("Cannot delete internal keys directly")
|
||||
return await self._storage.delete(key)
|
||||
|
||||
async def exists(self, key: str) -> bool:
|
||||
"""
|
||||
Check if a key exists.
|
||||
|
||||
Args:
|
||||
key: The key to check
|
||||
|
||||
Returns:
|
||||
True if the key exists
|
||||
"""
|
||||
return await self._storage.exists(key)
|
||||
|
||||
async def list_keys(self, pattern: str = "*") -> list[str]:
|
||||
"""
|
||||
List keys matching a pattern.
|
||||
|
||||
Args:
|
||||
pattern: Glob-style pattern (default "*" for all)
|
||||
|
||||
Returns:
|
||||
List of matching keys (excludes internal keys)
|
||||
"""
|
||||
all_keys = await self._storage.list_keys(pattern)
|
||||
return [k for k in all_keys if not k.startswith("_")]
|
||||
|
||||
async def get_all(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get all user key-value pairs.
|
||||
|
||||
Returns:
|
||||
Dictionary of all key-value pairs (excludes internal keys)
|
||||
"""
|
||||
all_data = await self._storage.get_all()
|
||||
return {k: v for k, v in all_data.items() if not k.startswith("_")}
|
||||
|
||||
async def clear(self) -> int:
|
||||
"""
|
||||
Clear all user keys (preserves internal state).
|
||||
|
||||
Returns:
|
||||
Number of keys deleted
|
||||
"""
|
||||
# Save internal state
|
||||
task_state = await self._storage.get(_TASK_STATE_KEY)
|
||||
scratchpad = await self._storage.get(_SCRATCHPAD_KEY)
|
||||
metadata = await self._storage.get(_METADATA_KEY)
|
||||
|
||||
count = await self._storage.clear()
|
||||
|
||||
# Restore internal state
|
||||
if metadata is not None:
|
||||
await self._storage.set(_METADATA_KEY, metadata)
|
||||
if task_state is not None:
|
||||
await self._storage.set(_TASK_STATE_KEY, task_state)
|
||||
if scratchpad is not None:
|
||||
await self._storage.set(_SCRATCHPAD_KEY, scratchpad)
|
||||
|
||||
# Adjust count for preserved keys
|
||||
preserved = sum(1 for x in [task_state, scratchpad, metadata] if x is not None)
|
||||
return max(0, count - preserved)
|
||||
|
||||
# =========================================================================
|
||||
# Task State Operations
|
||||
# =========================================================================
|
||||
|
||||
async def set_task_state(self, state: TaskState) -> None:
|
||||
"""
|
||||
Set the current task state.
|
||||
|
||||
Args:
|
||||
state: The task state to store
|
||||
"""
|
||||
state.updated_at = datetime.now(UTC)
|
||||
await self._storage.set(_TASK_STATE_KEY, asdict(state))
|
||||
|
||||
async def get_task_state(self) -> TaskState | None:
|
||||
"""
|
||||
Get the current task state.
|
||||
|
||||
Returns:
|
||||
The current TaskState or None if not set
|
||||
"""
|
||||
data = await self._storage.get(_TASK_STATE_KEY)
|
||||
if data is None:
|
||||
return None
|
||||
|
||||
# Convert datetime strings back to datetime objects
|
||||
if isinstance(data.get("started_at"), str):
|
||||
data["started_at"] = datetime.fromisoformat(data["started_at"])
|
||||
if isinstance(data.get("updated_at"), str):
|
||||
data["updated_at"] = datetime.fromisoformat(data["updated_at"])
|
||||
|
||||
return TaskState(**data)
|
||||
|
||||
async def update_task_progress(
|
||||
self,
|
||||
current_step: int | None = None,
|
||||
progress_percent: float | None = None,
|
||||
status: str | None = None,
|
||||
) -> TaskState | None:
|
||||
"""
|
||||
Update task progress fields.
|
||||
|
||||
Args:
|
||||
current_step: New current step number
|
||||
progress_percent: New progress percentage (0.0 to 100.0)
|
||||
status: New status string
|
||||
|
||||
Returns:
|
||||
Updated TaskState or None if no task state exists
|
||||
"""
|
||||
state = await self.get_task_state()
|
||||
if state is None:
|
||||
return None
|
||||
|
||||
if current_step is not None:
|
||||
state.current_step = current_step
|
||||
if progress_percent is not None:
|
||||
state.progress_percent = min(100.0, max(0.0, progress_percent))
|
||||
if status is not None:
|
||||
state.status = status
|
||||
|
||||
await self.set_task_state(state)
|
||||
return state
|
||||
|
||||
# =========================================================================
|
||||
# Scratchpad Operations
|
||||
# =========================================================================
|
||||
|
||||
async def append_scratchpad(self, content: str) -> None:
|
||||
"""
|
||||
Append content to the scratchpad.
|
||||
|
||||
Args:
|
||||
content: Text to append
|
||||
"""
|
||||
settings = get_memory_settings()
|
||||
entries = await self._storage.get(_SCRATCHPAD_KEY) or []
|
||||
|
||||
# Check capacity
|
||||
if len(entries) >= settings.working_memory_max_items_per_session:
|
||||
# Remove oldest entries
|
||||
entries = entries[-(settings.working_memory_max_items_per_session - 1) :]
|
||||
|
||||
entry = {
|
||||
"content": content,
|
||||
"timestamp": datetime.now(UTC).isoformat(),
|
||||
}
|
||||
entries.append(entry)
|
||||
await self._storage.set(_SCRATCHPAD_KEY, entries)
|
||||
|
||||
async def get_scratchpad(self) -> list[str]:
|
||||
"""
|
||||
Get all scratchpad entries.
|
||||
|
||||
Returns:
|
||||
List of scratchpad content strings (ordered by time)
|
||||
"""
|
||||
entries = await self._storage.get(_SCRATCHPAD_KEY) or []
|
||||
return [e["content"] for e in entries]
|
||||
|
||||
async def get_scratchpad_with_timestamps(self) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Get all scratchpad entries with timestamps.
|
||||
|
||||
Returns:
|
||||
List of dicts with 'content' and 'timestamp' keys
|
||||
"""
|
||||
return await self._storage.get(_SCRATCHPAD_KEY) or []
|
||||
|
||||
async def clear_scratchpad(self) -> int:
|
||||
"""
|
||||
Clear the scratchpad.
|
||||
|
||||
Returns:
|
||||
Number of entries cleared
|
||||
"""
|
||||
entries = await self._storage.get(_SCRATCHPAD_KEY) or []
|
||||
count = len(entries)
|
||||
await self._storage.set(_SCRATCHPAD_KEY, [])
|
||||
return count
|
||||
|
||||
# =========================================================================
|
||||
# Checkpoint Operations
|
||||
# =========================================================================
|
||||
|
||||
async def create_checkpoint(self, description: str = "") -> str:
|
||||
"""
|
||||
Create a checkpoint of current state.
|
||||
|
||||
Args:
|
||||
description: Optional description of the checkpoint
|
||||
|
||||
Returns:
|
||||
Checkpoint ID for later restoration
|
||||
"""
|
||||
# Use full UUID to avoid collision risk (8 chars has ~50k collision at birthday paradox)
|
||||
checkpoint_id = str(uuid.uuid4())
|
||||
checkpoint_key = f"{_CHECKPOINT_PREFIX}{checkpoint_id}"
|
||||
|
||||
# Capture all current state
|
||||
all_data = await self._storage.get_all()
|
||||
|
||||
checkpoint = {
|
||||
"id": checkpoint_id,
|
||||
"description": description,
|
||||
"created_at": datetime.now(UTC).isoformat(),
|
||||
"data": all_data,
|
||||
}
|
||||
|
||||
await self._storage.set(checkpoint_key, checkpoint)
|
||||
logger.debug(f"Created checkpoint {checkpoint_id}")
|
||||
return checkpoint_id
|
||||
|
||||
async def restore_checkpoint(self, checkpoint_id: str) -> None:
|
||||
"""
|
||||
Restore state from a checkpoint.
|
||||
|
||||
Args:
|
||||
checkpoint_id: ID of the checkpoint to restore
|
||||
|
||||
Raises:
|
||||
MemoryNotFoundError: If checkpoint not found
|
||||
"""
|
||||
checkpoint_key = f"{_CHECKPOINT_PREFIX}{checkpoint_id}"
|
||||
checkpoint = await self._storage.get(checkpoint_key)
|
||||
|
||||
if checkpoint is None:
|
||||
raise MemoryNotFoundError(f"Checkpoint {checkpoint_id} not found")
|
||||
|
||||
# Clear current state
|
||||
await self._storage.clear()
|
||||
|
||||
# Restore all data from checkpoint
|
||||
for key, value in checkpoint["data"].items():
|
||||
await self._storage.set(key, value)
|
||||
|
||||
# Keep the checkpoint itself
|
||||
await self._storage.set(checkpoint_key, checkpoint)
|
||||
|
||||
logger.debug(f"Restored checkpoint {checkpoint_id}")
|
||||
|
||||
async def list_checkpoints(self) -> list[dict[str, Any]]:
|
||||
"""
|
||||
List all available checkpoints.
|
||||
|
||||
Returns:
|
||||
List of checkpoint metadata (id, description, created_at)
|
||||
"""
|
||||
checkpoint_keys = await self._storage.list_keys(f"{_CHECKPOINT_PREFIX}*")
|
||||
checkpoints = []
|
||||
|
||||
for key in checkpoint_keys:
|
||||
data = await self._storage.get(key)
|
||||
if data:
|
||||
checkpoints.append(
|
||||
{
|
||||
"id": data["id"],
|
||||
"description": data["description"],
|
||||
"created_at": data["created_at"],
|
||||
}
|
||||
)
|
||||
|
||||
# Sort by creation time
|
||||
checkpoints.sort(key=lambda x: x["created_at"])
|
||||
return checkpoints
|
||||
|
||||
async def delete_checkpoint(self, checkpoint_id: str) -> bool:
|
||||
"""
|
||||
Delete a checkpoint.
|
||||
|
||||
Args:
|
||||
checkpoint_id: ID of the checkpoint to delete
|
||||
|
||||
Returns:
|
||||
True if checkpoint existed
|
||||
"""
|
||||
checkpoint_key = f"{_CHECKPOINT_PREFIX}{checkpoint_id}"
|
||||
return await self._storage.delete(checkpoint_key)
|
||||
|
||||
# =========================================================================
|
||||
# Health and Lifecycle
|
||||
# =========================================================================
|
||||
|
||||
async def is_healthy(self) -> bool:
|
||||
"""Check if the working memory storage is healthy."""
|
||||
return await self._storage.is_healthy()
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the working memory storage."""
|
||||
if self._storage:
|
||||
await self._storage.close()
|
||||
|
||||
async def get_stats(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get working memory statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with stats about current state
|
||||
"""
|
||||
all_keys = await self._storage.list_keys("*")
|
||||
user_keys = [k for k in all_keys if not k.startswith("_")]
|
||||
checkpoint_keys = [k for k in all_keys if k.startswith(_CHECKPOINT_PREFIX)]
|
||||
scratchpad = await self._storage.get(_SCRATCHPAD_KEY) or []
|
||||
|
||||
return {
|
||||
"scope_type": self._scope.scope_type.value,
|
||||
"scope_id": self._scope.scope_id,
|
||||
"using_fallback": self._using_fallback,
|
||||
"total_keys": len(all_keys),
|
||||
"user_keys": len(user_keys),
|
||||
"checkpoint_count": len(checkpoint_keys),
|
||||
"scratchpad_entries": len(scratchpad),
|
||||
"has_task_state": await self._storage.exists(_TASK_STATE_KEY),
|
||||
}
|
||||
406
backend/app/services/memory/working/storage.py
Normal file
406
backend/app/services/memory/working/storage.py
Normal file
@@ -0,0 +1,406 @@
|
||||
# app/services/memory/working/storage.py
|
||||
"""
|
||||
Working Memory Storage Backends.
|
||||
|
||||
Provides abstract storage interface and implementations:
|
||||
- RedisStorage: Primary storage using Redis with connection pooling
|
||||
- InMemoryStorage: Fallback storage when Redis is unavailable
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import fnmatch
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from app.services.memory.config import get_memory_settings
|
||||
from app.services.memory.exceptions import (
|
||||
MemoryConnectionError,
|
||||
MemoryStorageError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkingMemoryStorage(ABC):
|
||||
"""Abstract base class for working memory storage backends."""
|
||||
|
||||
@abstractmethod
|
||||
async def set(
|
||||
self,
|
||||
key: str,
|
||||
value: Any,
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
"""Store a value with optional TTL."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get(self, key: str) -> Any | None:
|
||||
"""Get a value by key, returns None if not found or expired."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def delete(self, key: str) -> bool:
|
||||
"""Delete a key, returns True if existed."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def exists(self, key: str) -> bool:
|
||||
"""Check if a key exists and is not expired."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def list_keys(self, pattern: str = "*") -> list[str]:
|
||||
"""List all keys matching a pattern."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_all(self) -> dict[str, Any]:
|
||||
"""Get all key-value pairs."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def clear(self) -> int:
|
||||
"""Clear all keys, returns count of deleted keys."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def is_healthy(self) -> bool:
|
||||
"""Check if the storage backend is healthy."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def close(self) -> None:
|
||||
"""Close the storage connection."""
|
||||
...
|
||||
|
||||
|
||||
class InMemoryStorage(WorkingMemoryStorage):
|
||||
"""
|
||||
In-memory storage backend for working memory.
|
||||
|
||||
Used as fallback when Redis is unavailable. Data is not persisted
|
||||
across restarts and is not shared between processes.
|
||||
"""
|
||||
|
||||
def __init__(self, max_keys: int = 10000) -> None:
|
||||
"""Initialize in-memory storage."""
|
||||
self._data: dict[str, Any] = {}
|
||||
self._expirations: dict[str, datetime] = {}
|
||||
self._max_keys = max_keys
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
def _is_expired(self, key: str) -> bool:
|
||||
"""Check if a key has expired."""
|
||||
if key not in self._expirations:
|
||||
return False
|
||||
return datetime.now(UTC) > self._expirations[key]
|
||||
|
||||
def _cleanup_expired(self) -> None:
|
||||
"""Remove all expired keys."""
|
||||
now = datetime.now(UTC)
|
||||
expired_keys = [
|
||||
key for key, exp_time in self._expirations.items() if now > exp_time
|
||||
]
|
||||
for key in expired_keys:
|
||||
self._data.pop(key, None)
|
||||
self._expirations.pop(key, None)
|
||||
|
||||
async def set(
|
||||
self,
|
||||
key: str,
|
||||
value: Any,
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
"""Store a value with optional TTL."""
|
||||
async with self._lock:
|
||||
# Cleanup expired keys periodically
|
||||
if len(self._data) % 100 == 0:
|
||||
self._cleanup_expired()
|
||||
|
||||
# Check capacity
|
||||
if key not in self._data and len(self._data) >= self._max_keys:
|
||||
# Evict expired keys first
|
||||
self._cleanup_expired()
|
||||
if len(self._data) >= self._max_keys:
|
||||
raise MemoryStorageError(
|
||||
f"Working memory capacity exceeded: {self._max_keys} keys"
|
||||
)
|
||||
|
||||
self._data[key] = value
|
||||
if ttl_seconds is not None:
|
||||
self._expirations[key] = datetime.now(UTC) + timedelta(
|
||||
seconds=ttl_seconds
|
||||
)
|
||||
elif key in self._expirations:
|
||||
# Remove existing expiration if no TTL specified
|
||||
del self._expirations[key]
|
||||
|
||||
async def get(self, key: str) -> Any | None:
|
||||
"""Get a value by key."""
|
||||
async with self._lock:
|
||||
if key not in self._data:
|
||||
return None
|
||||
if self._is_expired(key):
|
||||
del self._data[key]
|
||||
del self._expirations[key]
|
||||
return None
|
||||
return self._data[key]
|
||||
|
||||
async def delete(self, key: str) -> bool:
|
||||
"""Delete a key."""
|
||||
async with self._lock:
|
||||
existed = key in self._data
|
||||
self._data.pop(key, None)
|
||||
self._expirations.pop(key, None)
|
||||
return existed
|
||||
|
||||
async def exists(self, key: str) -> bool:
|
||||
"""Check if a key exists and is not expired."""
|
||||
async with self._lock:
|
||||
if key not in self._data:
|
||||
return False
|
||||
if self._is_expired(key):
|
||||
del self._data[key]
|
||||
del self._expirations[key]
|
||||
return False
|
||||
return True
|
||||
|
||||
async def list_keys(self, pattern: str = "*") -> list[str]:
|
||||
"""List all keys matching a pattern."""
|
||||
async with self._lock:
|
||||
self._cleanup_expired()
|
||||
if pattern == "*":
|
||||
return list(self._data.keys())
|
||||
return [key for key in self._data.keys() if fnmatch.fnmatch(key, pattern)]
|
||||
|
||||
async def get_all(self) -> dict[str, Any]:
|
||||
"""Get all key-value pairs."""
|
||||
async with self._lock:
|
||||
self._cleanup_expired()
|
||||
return dict(self._data)
|
||||
|
||||
async def clear(self) -> int:
|
||||
"""Clear all keys."""
|
||||
async with self._lock:
|
||||
count = len(self._data)
|
||||
self._data.clear()
|
||||
self._expirations.clear()
|
||||
return count
|
||||
|
||||
async def is_healthy(self) -> bool:
|
||||
"""In-memory storage is always healthy."""
|
||||
return True
|
||||
|
||||
async def close(self) -> None:
|
||||
"""No cleanup needed for in-memory storage."""
|
||||
|
||||
|
||||
class RedisStorage(WorkingMemoryStorage):
|
||||
"""
|
||||
Redis storage backend for working memory.
|
||||
|
||||
Primary storage with connection pooling, automatic reconnection,
|
||||
and proper serialization of Python objects.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
key_prefix: str = "",
|
||||
connection_timeout: float = 5.0,
|
||||
socket_timeout: float = 5.0,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize Redis storage.
|
||||
|
||||
Args:
|
||||
key_prefix: Prefix for all keys (e.g., "session:abc123:")
|
||||
connection_timeout: Timeout for establishing connections
|
||||
socket_timeout: Timeout for socket operations
|
||||
"""
|
||||
self._key_prefix = key_prefix
|
||||
self._connection_timeout = connection_timeout
|
||||
self._socket_timeout = socket_timeout
|
||||
self._redis: Any = None
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
def _make_key(self, key: str) -> str:
|
||||
"""Add prefix to key."""
|
||||
return f"{self._key_prefix}{key}"
|
||||
|
||||
def _strip_prefix(self, key: str) -> str:
|
||||
"""Remove prefix from key."""
|
||||
if key.startswith(self._key_prefix):
|
||||
return key[len(self._key_prefix) :]
|
||||
return key
|
||||
|
||||
def _serialize(self, value: Any) -> str:
|
||||
"""Serialize a Python value to JSON string."""
|
||||
return json.dumps(value, default=str)
|
||||
|
||||
def _deserialize(self, data: str | bytes | None) -> Any | None:
|
||||
"""Deserialize a JSON string to Python value."""
|
||||
if data is None:
|
||||
return None
|
||||
if isinstance(data, bytes):
|
||||
data = data.decode("utf-8")
|
||||
return json.loads(data)
|
||||
|
||||
async def _get_client(self) -> Any:
|
||||
"""Get or create Redis client."""
|
||||
if self._redis is not None:
|
||||
return self._redis
|
||||
|
||||
async with self._lock:
|
||||
if self._redis is not None:
|
||||
return self._redis
|
||||
|
||||
try:
|
||||
import redis.asyncio as aioredis
|
||||
except ImportError as e:
|
||||
raise MemoryConnectionError(
|
||||
"redis package not installed. Install with: pip install redis"
|
||||
) from e
|
||||
|
||||
settings = get_memory_settings()
|
||||
redis_url = settings.redis_url
|
||||
|
||||
try:
|
||||
self._redis = await aioredis.from_url(
|
||||
redis_url,
|
||||
encoding="utf-8",
|
||||
decode_responses=True,
|
||||
socket_connect_timeout=self._connection_timeout,
|
||||
socket_timeout=self._socket_timeout,
|
||||
)
|
||||
# Test connection
|
||||
await self._redis.ping()
|
||||
logger.info("Connected to Redis for working memory")
|
||||
return self._redis
|
||||
except Exception as e:
|
||||
self._redis = None
|
||||
raise MemoryConnectionError(f"Failed to connect to Redis: {e}") from e
|
||||
|
||||
async def set(
|
||||
self,
|
||||
key: str,
|
||||
value: Any,
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
"""Store a value with optional TTL."""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
full_key = self._make_key(key)
|
||||
serialized = self._serialize(value)
|
||||
|
||||
if ttl_seconds is not None:
|
||||
await client.setex(full_key, ttl_seconds, serialized)
|
||||
else:
|
||||
await client.set(full_key, serialized)
|
||||
except MemoryConnectionError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise MemoryStorageError(f"Failed to set key {key}: {e}") from e
|
||||
|
||||
async def get(self, key: str) -> Any | None:
|
||||
"""Get a value by key."""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
full_key = self._make_key(key)
|
||||
data = await client.get(full_key)
|
||||
return self._deserialize(data)
|
||||
except MemoryConnectionError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise MemoryStorageError(f"Failed to get key {key}: {e}") from e
|
||||
|
||||
async def delete(self, key: str) -> bool:
|
||||
"""Delete a key."""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
full_key = self._make_key(key)
|
||||
result = await client.delete(full_key)
|
||||
return bool(result)
|
||||
except MemoryConnectionError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise MemoryStorageError(f"Failed to delete key {key}: {e}") from e
|
||||
|
||||
async def exists(self, key: str) -> bool:
|
||||
"""Check if a key exists."""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
full_key = self._make_key(key)
|
||||
result = await client.exists(full_key)
|
||||
return bool(result)
|
||||
except MemoryConnectionError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise MemoryStorageError(f"Failed to check key {key}: {e}") from e
|
||||
|
||||
async def list_keys(self, pattern: str = "*") -> list[str]:
|
||||
"""List all keys matching a pattern."""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
full_pattern = self._make_key(pattern)
|
||||
keys = await client.keys(full_pattern)
|
||||
return [self._strip_prefix(key) for key in keys]
|
||||
except MemoryConnectionError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise MemoryStorageError(f"Failed to list keys: {e}") from e
|
||||
|
||||
async def get_all(self) -> dict[str, Any]:
|
||||
"""Get all key-value pairs."""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
full_pattern = self._make_key("*")
|
||||
keys = await client.keys(full_pattern)
|
||||
|
||||
if not keys:
|
||||
return {}
|
||||
|
||||
values = await client.mget(*keys)
|
||||
result = {}
|
||||
for key, value in zip(keys, values, strict=False):
|
||||
stripped_key = self._strip_prefix(key)
|
||||
result[stripped_key] = self._deserialize(value)
|
||||
return result
|
||||
except MemoryConnectionError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise MemoryStorageError(f"Failed to get all keys: {e}") from e
|
||||
|
||||
async def clear(self) -> int:
|
||||
"""Clear all keys with this prefix."""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
full_pattern = self._make_key("*")
|
||||
keys = await client.keys(full_pattern)
|
||||
|
||||
if not keys:
|
||||
return 0
|
||||
|
||||
return await client.delete(*keys)
|
||||
except MemoryConnectionError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise MemoryStorageError(f"Failed to clear keys: {e}") from e
|
||||
|
||||
async def is_healthy(self) -> bool:
|
||||
"""Check if Redis connection is healthy."""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
await client.ping()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the Redis connection."""
|
||||
if self._redis is not None:
|
||||
await self._redis.close()
|
||||
self._redis = None
|
||||
@@ -10,14 +10,16 @@ Modules:
|
||||
sync: Issue synchronization tasks (incremental/full sync, webhooks)
|
||||
workflow: Workflow state management tasks
|
||||
cost: Cost tracking and budget monitoring tasks
|
||||
memory_consolidation: Memory consolidation tasks
|
||||
"""
|
||||
|
||||
from app.tasks import agent, cost, git, sync, workflow
|
||||
from app.tasks import agent, cost, git, memory_consolidation, sync, workflow
|
||||
|
||||
__all__ = [
|
||||
"agent",
|
||||
"cost",
|
||||
"git",
|
||||
"memory_consolidation",
|
||||
"sync",
|
||||
"workflow",
|
||||
]
|
||||
|
||||
234
backend/app/tasks/memory_consolidation.py
Normal file
234
backend/app/tasks/memory_consolidation.py
Normal file
@@ -0,0 +1,234 @@
|
||||
# app/tasks/memory_consolidation.py
|
||||
"""
|
||||
Memory consolidation Celery tasks.
|
||||
|
||||
Handles scheduled and on-demand memory consolidation:
|
||||
- Session consolidation (on session end)
|
||||
- Nightly consolidation (scheduled)
|
||||
- On-demand project consolidation
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.celery_app import celery_app
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
bind=True,
|
||||
name="app.tasks.memory_consolidation.consolidate_session",
|
||||
autoretry_for=(Exception,),
|
||||
retry_backoff=True,
|
||||
retry_kwargs={"max_retries": 3},
|
||||
)
|
||||
def consolidate_session(
|
||||
self,
|
||||
project_id: str,
|
||||
session_id: str,
|
||||
task_type: str = "session_task",
|
||||
agent_instance_id: str | None = None,
|
||||
agent_type_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Consolidate a session's working memory to episodic memory.
|
||||
|
||||
This task is triggered when an agent session ends to transfer
|
||||
relevant session data into persistent episodic memory.
|
||||
|
||||
Args:
|
||||
project_id: UUID of the project
|
||||
session_id: Session identifier
|
||||
task_type: Type of task performed
|
||||
agent_instance_id: Optional agent instance UUID
|
||||
agent_type_id: Optional agent type UUID
|
||||
|
||||
Returns:
|
||||
dict with consolidation results
|
||||
"""
|
||||
logger.info(f"Consolidating session {session_id} for project {project_id}")
|
||||
|
||||
# TODO: Implement actual consolidation
|
||||
# This will involve:
|
||||
# 1. Getting database session from async context
|
||||
# 2. Loading working memory for session
|
||||
# 3. Calling consolidation service
|
||||
# 4. Returning results
|
||||
|
||||
# Placeholder implementation
|
||||
return {
|
||||
"status": "pending",
|
||||
"project_id": project_id,
|
||||
"session_id": session_id,
|
||||
"episode_created": False,
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
bind=True,
|
||||
name="app.tasks.memory_consolidation.run_nightly_consolidation",
|
||||
autoretry_for=(Exception,),
|
||||
retry_backoff=True,
|
||||
retry_kwargs={"max_retries": 3},
|
||||
)
|
||||
def run_nightly_consolidation(
|
||||
self,
|
||||
project_id: str,
|
||||
agent_type_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Run nightly memory consolidation for a project.
|
||||
|
||||
This task performs the full consolidation workflow:
|
||||
1. Extract facts from recent episodes to semantic memory
|
||||
2. Learn procedures from successful episode patterns
|
||||
3. Prune old, low-value memories
|
||||
|
||||
Args:
|
||||
project_id: UUID of the project to consolidate
|
||||
agent_type_id: Optional agent type to filter by
|
||||
|
||||
Returns:
|
||||
dict with consolidation results
|
||||
"""
|
||||
logger.info(f"Running nightly consolidation for project {project_id}")
|
||||
|
||||
# TODO: Implement actual consolidation
|
||||
# This will involve:
|
||||
# 1. Getting database session from async context
|
||||
# 2. Creating consolidation service instance
|
||||
# 3. Running run_nightly_consolidation
|
||||
# 4. Returning results
|
||||
|
||||
# Placeholder implementation
|
||||
return {
|
||||
"status": "pending",
|
||||
"project_id": project_id,
|
||||
"total_facts_created": 0,
|
||||
"total_procedures_created": 0,
|
||||
"total_pruned": 0,
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
bind=True,
|
||||
name="app.tasks.memory_consolidation.consolidate_episodes_to_facts",
|
||||
autoretry_for=(Exception,),
|
||||
retry_backoff=True,
|
||||
retry_kwargs={"max_retries": 3},
|
||||
)
|
||||
def consolidate_episodes_to_facts(
|
||||
self,
|
||||
project_id: str,
|
||||
since_hours: int = 24,
|
||||
limit: int | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Extract facts from episodic memories.
|
||||
|
||||
Args:
|
||||
project_id: UUID of the project
|
||||
since_hours: Process episodes from last N hours
|
||||
limit: Maximum episodes to process
|
||||
|
||||
Returns:
|
||||
dict with extraction results
|
||||
"""
|
||||
logger.info(f"Consolidating episodes to facts for project {project_id}")
|
||||
|
||||
# TODO: Implement actual consolidation
|
||||
# Placeholder implementation
|
||||
return {
|
||||
"status": "pending",
|
||||
"project_id": project_id,
|
||||
"items_processed": 0,
|
||||
"items_created": 0,
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
bind=True,
|
||||
name="app.tasks.memory_consolidation.consolidate_episodes_to_procedures",
|
||||
autoretry_for=(Exception,),
|
||||
retry_backoff=True,
|
||||
retry_kwargs={"max_retries": 3},
|
||||
)
|
||||
def consolidate_episodes_to_procedures(
|
||||
self,
|
||||
project_id: str,
|
||||
agent_type_id: str | None = None,
|
||||
since_days: int = 7,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Learn procedures from episodic patterns.
|
||||
|
||||
Args:
|
||||
project_id: UUID of the project
|
||||
agent_type_id: Optional agent type filter
|
||||
since_days: Process episodes from last N days
|
||||
|
||||
Returns:
|
||||
dict with procedure learning results
|
||||
"""
|
||||
logger.info(f"Consolidating episodes to procedures for project {project_id}")
|
||||
|
||||
# TODO: Implement actual consolidation
|
||||
# Placeholder implementation
|
||||
return {
|
||||
"status": "pending",
|
||||
"project_id": project_id,
|
||||
"items_processed": 0,
|
||||
"items_created": 0,
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
bind=True,
|
||||
name="app.tasks.memory_consolidation.prune_old_memories",
|
||||
autoretry_for=(Exception,),
|
||||
retry_backoff=True,
|
||||
retry_kwargs={"max_retries": 3},
|
||||
)
|
||||
def prune_old_memories(
|
||||
self,
|
||||
project_id: str,
|
||||
max_age_days: int = 90,
|
||||
min_importance: float = 0.2,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Prune old, low-value memories.
|
||||
|
||||
Args:
|
||||
project_id: UUID of the project
|
||||
max_age_days: Maximum age in days
|
||||
min_importance: Minimum importance to keep
|
||||
|
||||
Returns:
|
||||
dict with pruning results
|
||||
"""
|
||||
logger.info(f"Pruning old memories for project {project_id}")
|
||||
|
||||
# TODO: Implement actual pruning
|
||||
# Placeholder implementation
|
||||
return {
|
||||
"status": "pending",
|
||||
"project_id": project_id,
|
||||
"items_pruned": 0,
|
||||
}
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Celery Beat Schedule Configuration
|
||||
# =========================================================================
|
||||
|
||||
# This would typically be configured in celery_app.py or a separate config file
|
||||
# Example schedule for nightly consolidation:
|
||||
#
|
||||
# app.conf.beat_schedule = {
|
||||
# 'nightly-memory-consolidation': {
|
||||
# 'task': 'app.tasks.memory_consolidation.run_nightly_consolidation',
|
||||
# 'schedule': crontab(hour=2, minute=0), # 2 AM daily
|
||||
# 'args': (None,), # Will process all projects
|
||||
# },
|
||||
# }
|
||||
1118
backend/data/default_agent_types.json
Normal file
1118
backend/data/default_agent_types.json
Normal file
File diff suppressed because it is too large
Load Diff
879
backend/data/demo_data.json
Normal file
879
backend/data/demo_data.json
Normal file
@@ -0,0 +1,879 @@
|
||||
{
|
||||
"organizations": [
|
||||
{
|
||||
"name": "Acme Corp",
|
||||
"slug": "acme-corp",
|
||||
"description": "A leading provider of coyote-catching equipment."
|
||||
},
|
||||
{
|
||||
"name": "Globex Corporation",
|
||||
"slug": "globex",
|
||||
"description": "We own the East Coast."
|
||||
},
|
||||
{
|
||||
"name": "Soylent Corp",
|
||||
"slug": "soylent",
|
||||
"description": "Making food for the future."
|
||||
},
|
||||
{
|
||||
"name": "Initech",
|
||||
"slug": "initech",
|
||||
"description": "Software for the soul."
|
||||
},
|
||||
{
|
||||
"name": "Umbrella Corporation",
|
||||
"slug": "umbrella",
|
||||
"description": "Our business is life itself."
|
||||
},
|
||||
{
|
||||
"name": "Massive Dynamic",
|
||||
"slug": "massive-dynamic",
|
||||
"description": "What don't we do?"
|
||||
}
|
||||
],
|
||||
"users": [
|
||||
{
|
||||
"email": "demo@example.com",
|
||||
"password": "DemoPass1234!",
|
||||
"first_name": "Demo",
|
||||
"last_name": "User",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "acme-corp",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "alice@acme.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Alice",
|
||||
"last_name": "Smith",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "acme-corp",
|
||||
"role": "admin",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "bob@acme.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Bob",
|
||||
"last_name": "Jones",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "acme-corp",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "charlie@acme.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Charlie",
|
||||
"last_name": "Brown",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "acme-corp",
|
||||
"role": "member",
|
||||
"is_active": false
|
||||
},
|
||||
{
|
||||
"email": "diana@acme.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Diana",
|
||||
"last_name": "Prince",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "acme-corp",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "carol@globex.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Carol",
|
||||
"last_name": "Williams",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "globex",
|
||||
"role": "owner",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "dan@globex.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Dan",
|
||||
"last_name": "Miller",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "globex",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "ellen@globex.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Ellen",
|
||||
"last_name": "Ripley",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "globex",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "fred@globex.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Fred",
|
||||
"last_name": "Flintstone",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "globex",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "dave@soylent.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Dave",
|
||||
"last_name": "Brown",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "soylent",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "gina@soylent.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Gina",
|
||||
"last_name": "Torres",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "soylent",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "harry@soylent.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Harry",
|
||||
"last_name": "Potter",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "soylent",
|
||||
"role": "admin",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "eve@initech.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Eve",
|
||||
"last_name": "Davis",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "initech",
|
||||
"role": "admin",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "iris@initech.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Iris",
|
||||
"last_name": "West",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "initech",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "jack@initech.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Jack",
|
||||
"last_name": "Sparrow",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "initech",
|
||||
"role": "member",
|
||||
"is_active": false
|
||||
},
|
||||
{
|
||||
"email": "frank@umbrella.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Frank",
|
||||
"last_name": "Miller",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "umbrella",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "george@umbrella.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "George",
|
||||
"last_name": "Costanza",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "umbrella",
|
||||
"role": "member",
|
||||
"is_active": false
|
||||
},
|
||||
{
|
||||
"email": "kate@umbrella.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Kate",
|
||||
"last_name": "Bishop",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "umbrella",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "leo@massive.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Leo",
|
||||
"last_name": "Messi",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "massive-dynamic",
|
||||
"role": "owner",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "mary@massive.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Mary",
|
||||
"last_name": "Jane",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "massive-dynamic",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "nathan@massive.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Nathan",
|
||||
"last_name": "Drake",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "massive-dynamic",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "olivia@massive.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Olivia",
|
||||
"last_name": "Dunham",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "massive-dynamic",
|
||||
"role": "admin",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "peter@massive.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Peter",
|
||||
"last_name": "Parker",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "massive-dynamic",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "quinn@massive.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Quinn",
|
||||
"last_name": "Mallory",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "massive-dynamic",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "grace@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Grace",
|
||||
"last_name": "Hopper",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "heidi@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Heidi",
|
||||
"last_name": "Klum",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "ivan@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Ivan",
|
||||
"last_name": "Drago",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": false
|
||||
},
|
||||
{
|
||||
"email": "rachel@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Rachel",
|
||||
"last_name": "Green",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "sam@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Sam",
|
||||
"last_name": "Wilson",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "tony@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Tony",
|
||||
"last_name": "Stark",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "una@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Una",
|
||||
"last_name": "Chin-Riley",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": false
|
||||
},
|
||||
{
|
||||
"email": "victor@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Victor",
|
||||
"last_name": "Von Doom",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "wanda@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Wanda",
|
||||
"last_name": "Maximoff",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": true
|
||||
}
|
||||
],
|
||||
"projects": [
|
||||
{
|
||||
"name": "E-Commerce Platform Redesign",
|
||||
"slug": "ecommerce-redesign",
|
||||
"description": "Complete redesign of the e-commerce platform with modern UX, improved checkout flow, and mobile-first approach.",
|
||||
"owner_email": "__admin__",
|
||||
"autonomy_level": "milestone",
|
||||
"status": "active",
|
||||
"complexity": "complex",
|
||||
"client_mode": "technical",
|
||||
"settings": {
|
||||
"mcp_servers": ["gitea", "knowledge-base"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Mobile Banking App",
|
||||
"slug": "mobile-banking",
|
||||
"description": "Secure mobile banking application with biometric authentication, transaction history, and real-time notifications.",
|
||||
"owner_email": "__admin__",
|
||||
"autonomy_level": "full_control",
|
||||
"status": "active",
|
||||
"complexity": "complex",
|
||||
"client_mode": "technical",
|
||||
"settings": {
|
||||
"mcp_servers": ["gitea", "knowledge-base"],
|
||||
"security_level": "high"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Internal HR Portal",
|
||||
"slug": "hr-portal",
|
||||
"description": "Employee self-service portal for leave requests, performance reviews, and document management.",
|
||||
"owner_email": "__admin__",
|
||||
"autonomy_level": "autonomous",
|
||||
"status": "active",
|
||||
"complexity": "medium",
|
||||
"client_mode": "auto",
|
||||
"settings": {
|
||||
"mcp_servers": ["gitea", "knowledge-base"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "API Gateway Modernization",
|
||||
"slug": "api-gateway",
|
||||
"description": "Migrate legacy REST API gateway to modern GraphQL-based architecture with improved caching and rate limiting.",
|
||||
"owner_email": "__admin__",
|
||||
"autonomy_level": "milestone",
|
||||
"status": "active",
|
||||
"complexity": "complex",
|
||||
"client_mode": "technical",
|
||||
"settings": {
|
||||
"mcp_servers": ["gitea", "knowledge-base"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Customer Analytics Dashboard",
|
||||
"slug": "analytics-dashboard",
|
||||
"description": "Real-time analytics dashboard for customer behavior insights, cohort analysis, and predictive modeling.",
|
||||
"owner_email": "__admin__",
|
||||
"autonomy_level": "autonomous",
|
||||
"status": "completed",
|
||||
"complexity": "medium",
|
||||
"client_mode": "auto",
|
||||
"settings": {
|
||||
"mcp_servers": ["gitea", "knowledge-base"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "DevOps Pipeline Automation",
|
||||
"slug": "devops-automation",
|
||||
"description": "Automate CI/CD pipelines with AI-assisted deployments, rollback detection, and infrastructure as code.",
|
||||
"owner_email": "__admin__",
|
||||
"autonomy_level": "full_control",
|
||||
"status": "active",
|
||||
"complexity": "complex",
|
||||
"client_mode": "technical",
|
||||
"settings": {
|
||||
"mcp_servers": ["gitea", "knowledge-base"]
|
||||
}
|
||||
}
|
||||
],
|
||||
"sprints": [
|
||||
{
|
||||
"project_slug": "ecommerce-redesign",
|
||||
"name": "Sprint 1: Foundation",
|
||||
"number": 1,
|
||||
"goal": "Set up project infrastructure, design system, and core navigation components.",
|
||||
"start_date": "2026-01-06",
|
||||
"end_date": "2026-01-20",
|
||||
"status": "active",
|
||||
"planned_points": 21
|
||||
},
|
||||
{
|
||||
"project_slug": "ecommerce-redesign",
|
||||
"name": "Sprint 2: Product Catalog",
|
||||
"number": 2,
|
||||
"goal": "Implement product listing, filtering, search, and detail pages.",
|
||||
"start_date": "2026-01-20",
|
||||
"end_date": "2026-02-03",
|
||||
"status": "planned",
|
||||
"planned_points": 34
|
||||
},
|
||||
{
|
||||
"project_slug": "mobile-banking",
|
||||
"name": "Sprint 1: Authentication",
|
||||
"number": 1,
|
||||
"goal": "Implement secure login, biometric authentication, and session management.",
|
||||
"start_date": "2026-01-06",
|
||||
"end_date": "2026-01-20",
|
||||
"status": "active",
|
||||
"planned_points": 26
|
||||
},
|
||||
{
|
||||
"project_slug": "hr-portal",
|
||||
"name": "Sprint 1: Core Features",
|
||||
"number": 1,
|
||||
"goal": "Build employee dashboard, leave request system, and basic document management.",
|
||||
"start_date": "2026-01-06",
|
||||
"end_date": "2026-01-20",
|
||||
"status": "active",
|
||||
"planned_points": 18
|
||||
},
|
||||
{
|
||||
"project_slug": "api-gateway",
|
||||
"name": "Sprint 1: GraphQL Schema",
|
||||
"number": 1,
|
||||
"goal": "Define GraphQL schema and implement core resolvers for existing REST endpoints.",
|
||||
"start_date": "2025-12-23",
|
||||
"end_date": "2026-01-06",
|
||||
"status": "completed",
|
||||
"planned_points": 21
|
||||
},
|
||||
{
|
||||
"project_slug": "api-gateway",
|
||||
"name": "Sprint 2: Caching Layer",
|
||||
"number": 2,
|
||||
"goal": "Implement Redis-based caching layer and query batching.",
|
||||
"start_date": "2026-01-06",
|
||||
"end_date": "2026-01-20",
|
||||
"status": "active",
|
||||
"planned_points": 26
|
||||
},
|
||||
{
|
||||
"project_slug": "analytics-dashboard",
|
||||
"name": "Sprint 1: Data Pipeline",
|
||||
"number": 1,
|
||||
"goal": "Set up data ingestion pipeline and real-time event processing.",
|
||||
"start_date": "2025-11-15",
|
||||
"end_date": "2025-11-29",
|
||||
"status": "completed",
|
||||
"planned_points": 18
|
||||
},
|
||||
{
|
||||
"project_slug": "analytics-dashboard",
|
||||
"name": "Sprint 2: Dashboard UI",
|
||||
"number": 2,
|
||||
"goal": "Build interactive dashboard with charts and filtering capabilities.",
|
||||
"start_date": "2025-11-29",
|
||||
"end_date": "2025-12-13",
|
||||
"status": "completed",
|
||||
"planned_points": 21
|
||||
},
|
||||
{
|
||||
"project_slug": "devops-automation",
|
||||
"name": "Sprint 1: Pipeline Templates",
|
||||
"number": 1,
|
||||
"goal": "Create reusable CI/CD pipeline templates for common deployment patterns.",
|
||||
"start_date": "2026-01-06",
|
||||
"end_date": "2026-01-20",
|
||||
"status": "active",
|
||||
"planned_points": 24
|
||||
}
|
||||
],
|
||||
"agent_instances": [
|
||||
{
|
||||
"project_slug": "ecommerce-redesign",
|
||||
"agent_type_slug": "product-owner",
|
||||
"name": "Aria",
|
||||
"status": "idle"
|
||||
},
|
||||
{
|
||||
"project_slug": "ecommerce-redesign",
|
||||
"agent_type_slug": "solutions-architect",
|
||||
"name": "Marcus",
|
||||
"status": "idle"
|
||||
},
|
||||
{
|
||||
"project_slug": "ecommerce-redesign",
|
||||
"agent_type_slug": "senior-engineer",
|
||||
"name": "Zara",
|
||||
"status": "working",
|
||||
"current_task": "Implementing responsive navigation component"
|
||||
},
|
||||
{
|
||||
"project_slug": "mobile-banking",
|
||||
"agent_type_slug": "product-owner",
|
||||
"name": "Felix",
|
||||
"status": "waiting",
|
||||
"current_task": "Awaiting security requirements clarification"
|
||||
},
|
||||
{
|
||||
"project_slug": "mobile-banking",
|
||||
"agent_type_slug": "senior-engineer",
|
||||
"name": "Luna",
|
||||
"status": "working",
|
||||
"current_task": "Implementing biometric authentication flow"
|
||||
},
|
||||
{
|
||||
"project_slug": "mobile-banking",
|
||||
"agent_type_slug": "qa-engineer",
|
||||
"name": "Rex",
|
||||
"status": "idle"
|
||||
},
|
||||
{
|
||||
"project_slug": "hr-portal",
|
||||
"agent_type_slug": "business-analyst",
|
||||
"name": "Nova",
|
||||
"status": "working",
|
||||
"current_task": "Documenting leave request workflow"
|
||||
},
|
||||
{
|
||||
"project_slug": "hr-portal",
|
||||
"agent_type_slug": "senior-engineer",
|
||||
"name": "Atlas",
|
||||
"status": "working",
|
||||
"current_task": "Building employee dashboard API"
|
||||
},
|
||||
{
|
||||
"project_slug": "api-gateway",
|
||||
"agent_type_slug": "solutions-architect",
|
||||
"name": "Orion",
|
||||
"status": "working",
|
||||
"current_task": "Designing caching strategy for GraphQL queries"
|
||||
},
|
||||
{
|
||||
"project_slug": "api-gateway",
|
||||
"agent_type_slug": "senior-engineer",
|
||||
"name": "Cleo",
|
||||
"status": "working",
|
||||
"current_task": "Implementing Redis cache invalidation"
|
||||
},
|
||||
{
|
||||
"project_slug": "devops-automation",
|
||||
"agent_type_slug": "devops-engineer",
|
||||
"name": "Volt",
|
||||
"status": "working",
|
||||
"current_task": "Creating Terraform modules for AWS ECS"
|
||||
},
|
||||
{
|
||||
"project_slug": "devops-automation",
|
||||
"agent_type_slug": "senior-engineer",
|
||||
"name": "Sage",
|
||||
"status": "idle"
|
||||
},
|
||||
{
|
||||
"project_slug": "devops-automation",
|
||||
"agent_type_slug": "qa-engineer",
|
||||
"name": "Echo",
|
||||
"status": "waiting",
|
||||
"current_task": "Waiting for pipeline templates to test"
|
||||
}
|
||||
],
|
||||
"issues": [
|
||||
{
|
||||
"project_slug": "ecommerce-redesign",
|
||||
"sprint_number": 1,
|
||||
"type": "story",
|
||||
"title": "Design responsive navigation component",
|
||||
"body": "As a user, I want a navigation menu that works seamlessly on both desktop and mobile devices.\n\n## Acceptance Criteria\n- Hamburger menu on mobile viewports\n- Sticky header on scroll\n- Keyboard accessible\n- Screen reader compatible",
|
||||
"status": "in_progress",
|
||||
"priority": "high",
|
||||
"labels": ["frontend", "design-system"],
|
||||
"story_points": 5,
|
||||
"assigned_agent_name": "Zara"
|
||||
},
|
||||
{
|
||||
"project_slug": "ecommerce-redesign",
|
||||
"sprint_number": 1,
|
||||
"type": "task",
|
||||
"title": "Set up Tailwind CSS configuration",
|
||||
"body": "Configure Tailwind CSS with custom design tokens for the e-commerce platform.\n\n- Define color palette\n- Set up typography scale\n- Configure breakpoints\n- Add custom utilities",
|
||||
"status": "closed",
|
||||
"priority": "high",
|
||||
"labels": ["frontend", "infrastructure"],
|
||||
"story_points": 3
|
||||
},
|
||||
{
|
||||
"project_slug": "ecommerce-redesign",
|
||||
"sprint_number": 1,
|
||||
"type": "task",
|
||||
"title": "Create base component library structure",
|
||||
"body": "Set up the foundational component library with:\n- Button variants\n- Form inputs\n- Card component\n- Modal system",
|
||||
"status": "open",
|
||||
"priority": "medium",
|
||||
"labels": ["frontend", "design-system"],
|
||||
"story_points": 8
|
||||
},
|
||||
{
|
||||
"project_slug": "ecommerce-redesign",
|
||||
"sprint_number": 1,
|
||||
"type": "story",
|
||||
"title": "Implement user authentication flow",
|
||||
"body": "As a user, I want to sign up, log in, and manage my account.\n\n## Features\n- Email/password registration\n- Social login (Google, GitHub)\n- Password reset flow\n- Email verification",
|
||||
"status": "open",
|
||||
"priority": "critical",
|
||||
"labels": ["auth", "backend", "frontend"],
|
||||
"story_points": 13
|
||||
},
|
||||
{
|
||||
"project_slug": "ecommerce-redesign",
|
||||
"sprint_number": 2,
|
||||
"type": "epic",
|
||||
"title": "Product Catalog System",
|
||||
"body": "Complete product catalog implementation including:\n- Product listing with pagination\n- Advanced filtering and search\n- Product detail pages\n- Category navigation",
|
||||
"status": "open",
|
||||
"priority": "high",
|
||||
"labels": ["catalog", "backend", "frontend"],
|
||||
"story_points": null
|
||||
},
|
||||
{
|
||||
"project_slug": "mobile-banking",
|
||||
"sprint_number": 1,
|
||||
"type": "story",
|
||||
"title": "Implement biometric authentication",
|
||||
"body": "As a user, I want to log in using Face ID or Touch ID for quick and secure access.\n\n## Requirements\n- Support Face ID on iOS\n- Support fingerprint on Android\n- Fallback to PIN/password\n- Secure keychain storage",
|
||||
"status": "in_progress",
|
||||
"priority": "critical",
|
||||
"labels": ["auth", "security", "mobile"],
|
||||
"story_points": 8,
|
||||
"assigned_agent_name": "Luna"
|
||||
},
|
||||
{
|
||||
"project_slug": "mobile-banking",
|
||||
"sprint_number": 1,
|
||||
"type": "task",
|
||||
"title": "Set up secure session management",
|
||||
"body": "Implement secure session handling with:\n- JWT tokens with short expiry\n- Refresh token rotation\n- Session timeout handling\n- Multi-device session management",
|
||||
"status": "open",
|
||||
"priority": "critical",
|
||||
"labels": ["auth", "security", "backend"],
|
||||
"story_points": 5
|
||||
},
|
||||
{
|
||||
"project_slug": "mobile-banking",
|
||||
"sprint_number": 1,
|
||||
"type": "bug",
|
||||
"title": "Fix token refresh race condition",
|
||||
"body": "When multiple API calls happen simultaneously after token expiry, multiple refresh requests are made causing 401 errors.\n\n## Steps to Reproduce\n1. Wait for token to expire\n2. Trigger multiple API calls at once\n3. Observe multiple 401 errors",
|
||||
"status": "open",
|
||||
"priority": "high",
|
||||
"labels": ["bug", "auth", "backend"],
|
||||
"story_points": 3
|
||||
},
|
||||
{
|
||||
"project_slug": "mobile-banking",
|
||||
"sprint_number": 1,
|
||||
"type": "task",
|
||||
"title": "Implement PIN entry screen",
|
||||
"body": "Create secure PIN entry component with:\n- Masked input display\n- Haptic feedback\n- Brute force protection (lockout after 5 attempts)\n- Secure PIN storage",
|
||||
"status": "open",
|
||||
"priority": "high",
|
||||
"labels": ["auth", "mobile", "frontend"],
|
||||
"story_points": 5
|
||||
},
|
||||
{
|
||||
"project_slug": "hr-portal",
|
||||
"sprint_number": 1,
|
||||
"type": "story",
|
||||
"title": "Build employee dashboard",
|
||||
"body": "As an employee, I want a dashboard showing my key information at a glance.\n\n## Dashboard Widgets\n- Leave balance\n- Pending approvals\n- Upcoming holidays\n- Recent announcements",
|
||||
"status": "in_progress",
|
||||
"priority": "high",
|
||||
"labels": ["frontend", "dashboard"],
|
||||
"story_points": 5,
|
||||
"assigned_agent_name": "Atlas"
|
||||
},
|
||||
{
|
||||
"project_slug": "hr-portal",
|
||||
"sprint_number": 1,
|
||||
"type": "story",
|
||||
"title": "Implement leave request system",
|
||||
"body": "As an employee, I want to submit and track leave requests.\n\n## Features\n- Submit leave request with date range\n- View leave balance by type\n- Track request status\n- Manager approval workflow",
|
||||
"status": "in_progress",
|
||||
"priority": "high",
|
||||
"labels": ["backend", "frontend", "workflow"],
|
||||
"story_points": 8,
|
||||
"assigned_agent_name": "Nova"
|
||||
},
|
||||
{
|
||||
"project_slug": "hr-portal",
|
||||
"sprint_number": 1,
|
||||
"type": "task",
|
||||
"title": "Set up document storage integration",
|
||||
"body": "Integrate with S3-compatible storage for employee documents:\n- Secure upload/download\n- File type validation\n- Size limits\n- Virus scanning",
|
||||
"status": "open",
|
||||
"priority": "medium",
|
||||
"labels": ["backend", "infrastructure", "storage"],
|
||||
"story_points": 5
|
||||
},
|
||||
{
|
||||
"project_slug": "api-gateway",
|
||||
"sprint_number": 2,
|
||||
"type": "story",
|
||||
"title": "Implement Redis caching layer",
|
||||
"body": "As an API consumer, I want responses to be cached for improved performance.\n\n## Requirements\n- Cache GraphQL query results\n- Configurable TTL per query type\n- Cache invalidation on mutations\n- Cache hit/miss metrics",
|
||||
"status": "in_progress",
|
||||
"priority": "critical",
|
||||
"labels": ["backend", "performance", "redis"],
|
||||
"story_points": 8,
|
||||
"assigned_agent_name": "Cleo"
|
||||
},
|
||||
{
|
||||
"project_slug": "api-gateway",
|
||||
"sprint_number": 2,
|
||||
"type": "task",
|
||||
"title": "Set up query batching and deduplication",
|
||||
"body": "Implement DataLoader pattern for:\n- Batching multiple queries into single database calls\n- Deduplicating identical queries within request scope\n- N+1 query prevention",
|
||||
"status": "open",
|
||||
"priority": "high",
|
||||
"labels": ["backend", "performance", "graphql"],
|
||||
"story_points": 5
|
||||
},
|
||||
{
|
||||
"project_slug": "api-gateway",
|
||||
"sprint_number": 2,
|
||||
"type": "task",
|
||||
"title": "Implement rate limiting middleware",
|
||||
"body": "Add rate limiting to prevent API abuse:\n- Per-user rate limits\n- Per-IP fallback for anonymous requests\n- Sliding window algorithm\n- Custom limits per operation type",
|
||||
"status": "open",
|
||||
"priority": "high",
|
||||
"labels": ["backend", "security", "middleware"],
|
||||
"story_points": 5,
|
||||
"assigned_agent_name": "Orion"
|
||||
},
|
||||
{
|
||||
"project_slug": "api-gateway",
|
||||
"sprint_number": 2,
|
||||
"type": "bug",
|
||||
"title": "Fix N+1 query in user resolver",
|
||||
"body": "The user resolver is making separate database calls for each user's organization.\n\n## Steps to Reproduce\n1. Query users with organization field\n2. Check database logs\n3. Observe N+1 queries",
|
||||
"status": "open",
|
||||
"priority": "high",
|
||||
"labels": ["bug", "performance", "graphql"],
|
||||
"story_points": 3
|
||||
},
|
||||
{
|
||||
"project_slug": "analytics-dashboard",
|
||||
"sprint_number": 2,
|
||||
"type": "story",
|
||||
"title": "Build cohort analysis charts",
|
||||
"body": "As a product manager, I want to analyze user cohorts over time.\n\n## Features\n- Weekly/monthly cohort grouping\n- Retention curve visualization\n- Cohort comparison view",
|
||||
"status": "closed",
|
||||
"priority": "high",
|
||||
"labels": ["frontend", "charts", "analytics"],
|
||||
"story_points": 8
|
||||
},
|
||||
{
|
||||
"project_slug": "analytics-dashboard",
|
||||
"sprint_number": 2,
|
||||
"type": "task",
|
||||
"title": "Implement real-time event streaming",
|
||||
"body": "Set up WebSocket connection for live event updates:\n- Event type filtering\n- Buffering for high-volume periods\n- Reconnection handling",
|
||||
"status": "closed",
|
||||
"priority": "high",
|
||||
"labels": ["backend", "websocket", "realtime"],
|
||||
"story_points": 5
|
||||
},
|
||||
{
|
||||
"project_slug": "devops-automation",
|
||||
"sprint_number": 1,
|
||||
"type": "epic",
|
||||
"title": "CI/CD Pipeline Templates",
|
||||
"body": "Create reusable pipeline templates for common deployment patterns.\n\n## Templates Needed\n- Node.js applications\n- Python applications\n- Docker-based deployments\n- Kubernetes deployments",
|
||||
"status": "in_progress",
|
||||
"priority": "critical",
|
||||
"labels": ["infrastructure", "cicd", "templates"],
|
||||
"story_points": null
|
||||
},
|
||||
{
|
||||
"project_slug": "devops-automation",
|
||||
"sprint_number": 1,
|
||||
"type": "story",
|
||||
"title": "Create Terraform modules for AWS ECS",
|
||||
"body": "As a DevOps engineer, I want Terraform modules for ECS deployments.\n\n## Modules\n- ECS cluster configuration\n- Service and task definitions\n- Load balancer integration\n- Auto-scaling policies",
|
||||
"status": "in_progress",
|
||||
"priority": "high",
|
||||
"labels": ["terraform", "aws", "ecs"],
|
||||
"story_points": 8,
|
||||
"assigned_agent_name": "Volt"
|
||||
},
|
||||
{
|
||||
"project_slug": "devops-automation",
|
||||
"sprint_number": 1,
|
||||
"type": "task",
|
||||
"title": "Set up Gitea Actions runners",
|
||||
"body": "Configure self-hosted Gitea Actions runners:\n- Docker-in-Docker support\n- Caching for npm/pip\n- Secrets management\n- Resource limits",
|
||||
"status": "open",
|
||||
"priority": "high",
|
||||
"labels": ["infrastructure", "gitea", "cicd"],
|
||||
"story_points": 5
|
||||
},
|
||||
{
|
||||
"project_slug": "devops-automation",
|
||||
"sprint_number": 1,
|
||||
"type": "task",
|
||||
"title": "Implement rollback detection system",
|
||||
"body": "AI-assisted rollback detection:\n- Monitor deployment health metrics\n- Automatic rollback triggers\n- Notification system\n- Post-rollback analysis",
|
||||
"status": "open",
|
||||
"priority": "medium",
|
||||
"labels": ["ai", "monitoring", "automation"],
|
||||
"story_points": 8
|
||||
}
|
||||
]
|
||||
}
|
||||
507
backend/docs/MEMORY_SYSTEM.md
Normal file
507
backend/docs/MEMORY_SYSTEM.md
Normal file
@@ -0,0 +1,507 @@
|
||||
# Agent Memory System
|
||||
|
||||
Comprehensive multi-tier cognitive memory for AI agents, enabling state persistence, experiential learning, and context continuity across sessions.
|
||||
|
||||
## Overview
|
||||
|
||||
The Agent Memory System implements a cognitive architecture inspired by human memory:
|
||||
|
||||
```
|
||||
+------------------------------------------------------------------+
|
||||
| Agent Memory System |
|
||||
+------------------------------------------------------------------+
|
||||
| |
|
||||
| +------------------+ +------------------+ |
|
||||
| | Working Memory |----consolidate---->| Episodic Memory | |
|
||||
| | (Redis/In-Mem) | | (PostgreSQL) | |
|
||||
| | | | | |
|
||||
| | - Current task | | - Past sessions | |
|
||||
| | - Variables | | - Experiences | |
|
||||
| | - Scratchpad | | - Outcomes | |
|
||||
| +------------------+ +--------+---------+ |
|
||||
| | |
|
||||
| extract | |
|
||||
| v |
|
||||
| +------------------+ +------------------+ |
|
||||
| |Procedural Memory |<-----learn from----| Semantic Memory | |
|
||||
| | (PostgreSQL) | | (PostgreSQL + | |
|
||||
| | | | pgvector) | |
|
||||
| | - Procedures | | | |
|
||||
| | - Skills | | - Facts | |
|
||||
| | - Patterns | | - Entities | |
|
||||
| +------------------+ | - Relationships | |
|
||||
| +------------------+ |
|
||||
+------------------------------------------------------------------+
|
||||
```
|
||||
|
||||
## Memory Types
|
||||
|
||||
### Working Memory
|
||||
Short-term, session-scoped memory for current task state.
|
||||
|
||||
**Features:**
|
||||
- Key-value storage with TTL
|
||||
- Task state tracking
|
||||
- Scratchpad for reasoning
|
||||
- Checkpoint/restore support
|
||||
- Redis primary with in-memory fallback
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from app.services.memory.working import WorkingMemory
|
||||
|
||||
memory = WorkingMemory(scope_context)
|
||||
await memory.set("key", {"data": "value"}, ttl_seconds=3600)
|
||||
value = await memory.get("key")
|
||||
|
||||
# Task state
|
||||
await memory.set_task_state(TaskState(task_id="t1", status="running"))
|
||||
state = await memory.get_task_state()
|
||||
|
||||
# Checkpoints
|
||||
checkpoint_id = await memory.create_checkpoint()
|
||||
await memory.restore_checkpoint(checkpoint_id)
|
||||
```
|
||||
|
||||
### Episodic Memory
|
||||
Experiential records of past agent actions and outcomes.
|
||||
|
||||
**Features:**
|
||||
- Records task completions and failures
|
||||
- Semantic similarity search (pgvector)
|
||||
- Temporal and outcome-based retrieval
|
||||
- Importance scoring
|
||||
- Episode summarization
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from app.services.memory.episodic import EpisodicMemory
|
||||
|
||||
memory = EpisodicMemory(session, embedder)
|
||||
|
||||
# Record an episode
|
||||
episode = await memory.record_episode(
|
||||
project_id=project_id,
|
||||
episode=EpisodeCreate(
|
||||
task_type="code_review",
|
||||
task_description="Review PR #42",
|
||||
outcome=Outcome.SUCCESS,
|
||||
actions=[{"type": "analyze", "target": "src/"}],
|
||||
)
|
||||
)
|
||||
|
||||
# Search similar experiences
|
||||
similar = await memory.search_similar(
|
||||
project_id=project_id,
|
||||
query="debugging memory leak",
|
||||
limit=5
|
||||
)
|
||||
|
||||
# Get recent episodes
|
||||
recent = await memory.get_recent(project_id, limit=10)
|
||||
```
|
||||
|
||||
### Semantic Memory
|
||||
Learned facts and knowledge with confidence scoring.
|
||||
|
||||
**Features:**
|
||||
- Triple format (subject, predicate, object)
|
||||
- Confidence scoring with decay
|
||||
- Fact extraction from episodes
|
||||
- Conflict resolution
|
||||
- Entity-based retrieval
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from app.services.memory.semantic import SemanticMemory
|
||||
|
||||
memory = SemanticMemory(session, embedder)
|
||||
|
||||
# Store a fact
|
||||
fact = await memory.store_fact(
|
||||
project_id=project_id,
|
||||
fact=FactCreate(
|
||||
subject="UserService",
|
||||
predicate="handles",
|
||||
object="authentication",
|
||||
confidence=0.9,
|
||||
)
|
||||
)
|
||||
|
||||
# Search facts
|
||||
facts = await memory.search_facts(project_id, "authentication flow")
|
||||
|
||||
# Reinforce on repeated learning
|
||||
await memory.reinforce_fact(fact.id)
|
||||
```
|
||||
|
||||
### Procedural Memory
|
||||
Learned skills and procedures from successful patterns.
|
||||
|
||||
**Features:**
|
||||
- Procedure recording from task patterns
|
||||
- Trigger-based matching
|
||||
- Success rate tracking
|
||||
- Procedure suggestions
|
||||
- Step-by-step storage
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from app.services.memory.procedural import ProceduralMemory
|
||||
|
||||
memory = ProceduralMemory(session, embedder)
|
||||
|
||||
# Record a procedure
|
||||
procedure = await memory.record_procedure(
|
||||
project_id=project_id,
|
||||
procedure=ProcedureCreate(
|
||||
name="PR Review Process",
|
||||
trigger_pattern="code review requested",
|
||||
steps=[
|
||||
Step(action="fetch_diff"),
|
||||
Step(action="analyze_changes"),
|
||||
Step(action="check_tests"),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
# Find matching procedures
|
||||
matches = await memory.find_matching(project_id, "need to review code")
|
||||
|
||||
# Record outcomes
|
||||
await memory.record_outcome(procedure.id, success=True)
|
||||
```
|
||||
|
||||
## Memory Scoping
|
||||
|
||||
Memory is organized in a hierarchical scope structure:
|
||||
|
||||
```
|
||||
Global Memory (shared by all)
|
||||
└── Project Memory (per project)
|
||||
└── Agent Type Memory (per agent type)
|
||||
└── Agent Instance Memory (per instance)
|
||||
└── Session Memory (ephemeral)
|
||||
```
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from app.services.memory.scoping import ScopeManager, ScopeLevel
|
||||
|
||||
manager = ScopeManager(session)
|
||||
|
||||
# Get scoped memories with inheritance
|
||||
memories = await manager.get_scoped_memories(
|
||||
context=ScopeContext(
|
||||
project_id=project_id,
|
||||
agent_type_id=agent_type_id,
|
||||
agent_instance_id=agent_instance_id,
|
||||
session_id=session_id,
|
||||
),
|
||||
include_inherited=True, # Include parent scopes
|
||||
)
|
||||
```
|
||||
|
||||
## Memory Consolidation
|
||||
|
||||
Automatic background processes transfer and extract knowledge:
|
||||
|
||||
```
|
||||
Working Memory ──> Episodic Memory ──> Semantic Memory
|
||||
└──> Procedural Memory
|
||||
```
|
||||
|
||||
**Consolidation Types:**
|
||||
- `working_to_episodic`: Transfer session state to episodes (on session end)
|
||||
- `episodic_to_semantic`: Extract facts from experiences
|
||||
- `episodic_to_procedural`: Learn procedures from patterns
|
||||
- `prune`: Remove low-value memories
|
||||
|
||||
**Celery Tasks:**
|
||||
```python
|
||||
from app.tasks.memory_consolidation import (
|
||||
consolidate_session,
|
||||
run_nightly_consolidation,
|
||||
prune_old_memories,
|
||||
)
|
||||
|
||||
# Manual consolidation
|
||||
consolidate_session.delay(session_id)
|
||||
|
||||
# Scheduled nightly (3 AM by default)
|
||||
run_nightly_consolidation.delay()
|
||||
```
|
||||
|
||||
## Memory Retrieval
|
||||
|
||||
### Hybrid Retrieval
|
||||
Combine multiple retrieval strategies:
|
||||
|
||||
```python
|
||||
from app.services.memory.indexing import RetrievalEngine
|
||||
|
||||
engine = RetrievalEngine(session, embedder)
|
||||
|
||||
# Hybrid search across memory types
|
||||
results = await engine.retrieve_hybrid(
|
||||
project_id=project_id,
|
||||
query="authentication error handling",
|
||||
memory_types=["episodic", "semantic", "procedural"],
|
||||
filters={"outcome": "success"},
|
||||
limit=10,
|
||||
)
|
||||
```
|
||||
|
||||
### Index Types
|
||||
- **Vector Index**: Semantic similarity (HNSW/pgvector)
|
||||
- **Temporal Index**: Time-based retrieval
|
||||
- **Entity Index**: Entity mention lookup
|
||||
- **Outcome Index**: Success/failure filtering
|
||||
|
||||
## MCP Tools
|
||||
|
||||
The memory system exposes MCP tools for agent use:
|
||||
|
||||
### `remember`
|
||||
Store information in memory.
|
||||
```json
|
||||
{
|
||||
"memory_type": "working",
|
||||
"content": {"key": "value"},
|
||||
"importance": 0.8,
|
||||
"ttl_seconds": 3600
|
||||
}
|
||||
```
|
||||
|
||||
### `recall`
|
||||
Retrieve from memory.
|
||||
```json
|
||||
{
|
||||
"query": "authentication patterns",
|
||||
"memory_types": ["episodic", "semantic"],
|
||||
"limit": 10,
|
||||
"filters": {"outcome": "success"}
|
||||
}
|
||||
```
|
||||
|
||||
### `forget`
|
||||
Remove from memory.
|
||||
```json
|
||||
{
|
||||
"memory_type": "working",
|
||||
"key": "temp_data"
|
||||
}
|
||||
```
|
||||
|
||||
### `reflect`
|
||||
Analyze memory patterns.
|
||||
```json
|
||||
{
|
||||
"analysis_type": "success_factors",
|
||||
"task_type": "code_review",
|
||||
"time_range_days": 30
|
||||
}
|
||||
```
|
||||
|
||||
### `get_memory_stats`
|
||||
Get memory usage statistics.
|
||||
|
||||
### `record_outcome`
|
||||
Record task success/failure for learning.
|
||||
|
||||
## Memory Reflection
|
||||
|
||||
Analyze patterns and generate insights from memory:
|
||||
|
||||
```python
|
||||
from app.services.memory.reflection import MemoryReflection, TimeRange
|
||||
|
||||
reflection = MemoryReflection(session)
|
||||
|
||||
# Detect patterns
|
||||
patterns = await reflection.analyze_patterns(
|
||||
project_id=project_id,
|
||||
time_range=TimeRange.last_days(30),
|
||||
)
|
||||
|
||||
# Identify success factors
|
||||
factors = await reflection.identify_success_factors(
|
||||
project_id=project_id,
|
||||
task_type="code_review",
|
||||
)
|
||||
|
||||
# Detect anomalies
|
||||
anomalies = await reflection.detect_anomalies(
|
||||
project_id=project_id,
|
||||
baseline_days=30,
|
||||
)
|
||||
|
||||
# Generate insights
|
||||
insights = await reflection.generate_insights(project_id)
|
||||
|
||||
# Comprehensive reflection
|
||||
result = await reflection.reflect(project_id)
|
||||
print(result.summary)
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
All settings use the `MEM_` environment variable prefix:
|
||||
|
||||
| Variable | Default | Description |
|
||||
|----------|---------|-------------|
|
||||
| `MEM_WORKING_MEMORY_BACKEND` | `redis` | Backend: `redis` or `memory` |
|
||||
| `MEM_WORKING_MEMORY_DEFAULT_TTL_SECONDS` | `3600` | Default TTL (1 hour) |
|
||||
| `MEM_REDIS_URL` | `redis://localhost:6379/0` | Redis connection URL |
|
||||
| `MEM_EPISODIC_MAX_EPISODES_PER_PROJECT` | `10000` | Max episodes per project |
|
||||
| `MEM_EPISODIC_RETENTION_DAYS` | `365` | Episode retention period |
|
||||
| `MEM_SEMANTIC_MAX_FACTS_PER_PROJECT` | `50000` | Max facts per project |
|
||||
| `MEM_SEMANTIC_CONFIDENCE_DECAY_DAYS` | `90` | Confidence half-life |
|
||||
| `MEM_EMBEDDING_MODEL` | `text-embedding-3-small` | Embedding model |
|
||||
| `MEM_EMBEDDING_DIMENSIONS` | `1536` | Vector dimensions |
|
||||
| `MEM_RETRIEVAL_MIN_SIMILARITY` | `0.5` | Minimum similarity score |
|
||||
| `MEM_CONSOLIDATION_ENABLED` | `true` | Enable auto-consolidation |
|
||||
| `MEM_CONSOLIDATION_SCHEDULE_CRON` | `0 3 * * *` | Nightly schedule |
|
||||
| `MEM_CACHE_ENABLED` | `true` | Enable retrieval caching |
|
||||
| `MEM_CACHE_TTL_SECONDS` | `300` | Cache TTL (5 minutes) |
|
||||
|
||||
See `app/services/memory/config.py` for complete configuration options.
|
||||
|
||||
## Integration with Context Engine
|
||||
|
||||
Memory integrates with the Context Engine as a context source:
|
||||
|
||||
```python
|
||||
from app.services.memory.integration import MemoryContextSource
|
||||
|
||||
# Register as context source
|
||||
source = MemoryContextSource(memory_manager)
|
||||
context_engine.register_source(source)
|
||||
|
||||
# Memory is automatically included in context assembly
|
||||
context = await context_engine.assemble_context(
|
||||
project_id=project_id,
|
||||
session_id=session_id,
|
||||
current_task="Review authentication code",
|
||||
)
|
||||
```
|
||||
|
||||
## Caching
|
||||
|
||||
Multi-layer caching for performance:
|
||||
|
||||
- **Hot Cache**: Frequently accessed memories (LRU)
|
||||
- **Retrieval Cache**: Query result caching
|
||||
- **Embedding Cache**: Pre-computed embeddings
|
||||
|
||||
```python
|
||||
from app.services.memory.cache import CacheManager
|
||||
|
||||
cache = CacheManager(settings)
|
||||
await cache.warm_hot_cache(project_id) # Pre-warm common memories
|
||||
```
|
||||
|
||||
## Metrics
|
||||
|
||||
Prometheus-compatible metrics:
|
||||
|
||||
| Metric | Type | Labels |
|
||||
|--------|------|--------|
|
||||
| `memory_operations_total` | Counter | operation, memory_type, scope, success |
|
||||
| `memory_retrievals_total` | Counter | memory_type, strategy |
|
||||
| `memory_cache_hits_total` | Counter | cache_type |
|
||||
| `memory_retrieval_latency_seconds` | Histogram | - |
|
||||
| `memory_consolidation_duration_seconds` | Histogram | - |
|
||||
| `memory_items_count` | Gauge | memory_type, scope |
|
||||
|
||||
```python
|
||||
from app.services.memory.metrics import get_memory_metrics
|
||||
|
||||
metrics = await get_memory_metrics()
|
||||
summary = await metrics.get_summary()
|
||||
prometheus_output = await metrics.get_prometheus_format()
|
||||
```
|
||||
|
||||
## Performance Targets
|
||||
|
||||
| Operation | Target P95 |
|
||||
|-----------|------------|
|
||||
| Working memory get/set | < 5ms |
|
||||
| Episodic memory retrieval | < 100ms |
|
||||
| Semantic memory search | < 100ms |
|
||||
| Procedural memory matching | < 50ms |
|
||||
| Consolidation batch (1000 items) | < 30s |
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Redis Connection Issues
|
||||
```bash
|
||||
# Check Redis connectivity
|
||||
redis-cli ping
|
||||
|
||||
# Verify memory settings
|
||||
MEM_REDIS_URL=redis://localhost:6379/0
|
||||
```
|
||||
|
||||
### Slow Retrieval
|
||||
1. Check if caching is enabled: `MEM_CACHE_ENABLED=true`
|
||||
2. Verify HNSW indexes exist on vector columns
|
||||
3. Monitor `memory_retrieval_latency_seconds` metric
|
||||
|
||||
### High Memory Usage
|
||||
1. Review `MEM_EPISODIC_MAX_EPISODES_PER_PROJECT` limit
|
||||
2. Ensure pruning is enabled: `MEM_PRUNING_ENABLED=true`
|
||||
3. Check consolidation is running (cron schedule)
|
||||
|
||||
### Embedding Errors
|
||||
1. Verify LLM Gateway is accessible
|
||||
2. Check embedding model is valid
|
||||
3. Review batch size if hitting rate limits
|
||||
|
||||
## Directory Structure
|
||||
|
||||
```
|
||||
app/services/memory/
|
||||
├── __init__.py # Public exports
|
||||
├── config.py # MemorySettings
|
||||
├── exceptions.py # Memory-specific errors
|
||||
├── manager.py # MemoryManager facade
|
||||
├── types.py # Core types
|
||||
├── working/ # Working memory
|
||||
│ ├── memory.py
|
||||
│ └── storage.py
|
||||
├── episodic/ # Episodic memory
|
||||
│ ├── memory.py
|
||||
│ ├── recorder.py
|
||||
│ └── retrieval.py
|
||||
├── semantic/ # Semantic memory
|
||||
│ ├── memory.py
|
||||
│ ├── extraction.py
|
||||
│ └── verification.py
|
||||
├── procedural/ # Procedural memory
|
||||
│ ├── memory.py
|
||||
│ └── matching.py
|
||||
├── scoping/ # Memory scoping
|
||||
│ ├── scope.py
|
||||
│ └── resolver.py
|
||||
├── indexing/ # Indexing & retrieval
|
||||
│ ├── index.py
|
||||
│ └── retrieval.py
|
||||
├── consolidation/ # Memory consolidation
|
||||
│ └── service.py
|
||||
├── reflection/ # Memory reflection
|
||||
│ ├── service.py
|
||||
│ └── types.py
|
||||
├── integration/ # External integrations
|
||||
│ ├── context_source.py
|
||||
│ └── lifecycle.py
|
||||
├── cache/ # Caching layer
|
||||
│ ├── cache_manager.py
|
||||
│ ├── hot_cache.py
|
||||
│ └── embedding_cache.py
|
||||
├── mcp/ # MCP tools
|
||||
│ ├── service.py
|
||||
│ └── tools.py
|
||||
└── metrics/ # Observability
|
||||
└── collector.py
|
||||
```
|
||||
@@ -26,6 +26,7 @@ Usage:
|
||||
# Inside Docker (without --local flag):
|
||||
python migrate.py auto "Add new field"
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import subprocess
|
||||
@@ -44,13 +45,14 @@ def setup_database_url(use_local: bool) -> str:
|
||||
# Override DATABASE_URL to use localhost instead of Docker hostname
|
||||
local_url = os.environ.get(
|
||||
"LOCAL_DATABASE_URL",
|
||||
"postgresql://postgres:postgres@localhost:5432/app"
|
||||
"postgresql://postgres:postgres@localhost:5432/syndarix",
|
||||
)
|
||||
os.environ["DATABASE_URL"] = local_url
|
||||
return local_url
|
||||
|
||||
# Use the configured DATABASE_URL from environment/.env
|
||||
from app.core.config import settings
|
||||
|
||||
return settings.database_url
|
||||
|
||||
|
||||
@@ -61,6 +63,7 @@ def check_models():
|
||||
try:
|
||||
# Import all models through the models package
|
||||
from app.models import __all__ as all_models
|
||||
|
||||
print(f"Found {len(all_models)} model(s):")
|
||||
for model in all_models:
|
||||
print(f" - {model}")
|
||||
@@ -110,7 +113,9 @@ def generate_migration(message, rev_id=None, auto_rev_id=True, offline=False):
|
||||
# Look for the revision ID, which is typically 12 hex characters
|
||||
parts = line.split()
|
||||
for part in parts:
|
||||
if len(part) >= 12 and all(c in "0123456789abcdef" for c in part[:12]):
|
||||
if len(part) >= 12 and all(
|
||||
c in "0123456789abcdef" for c in part[:12]
|
||||
):
|
||||
revision = part[:12]
|
||||
break
|
||||
except Exception as e:
|
||||
@@ -185,6 +190,7 @@ def check_database_connection():
|
||||
db_url = os.environ.get("DATABASE_URL")
|
||||
if not db_url:
|
||||
from app.core.config import settings
|
||||
|
||||
db_url = settings.database_url
|
||||
|
||||
engine = create_engine(db_url)
|
||||
@@ -270,8 +276,8 @@ def generate_offline_migration(message, rev_id):
|
||||
content = f'''"""{message}
|
||||
|
||||
Revision ID: {rev_id}
|
||||
Revises: {down_revision or ''}
|
||||
Create Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')}
|
||||
Revises: {down_revision or ""}
|
||||
Create Date: {datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")}
|
||||
|
||||
"""
|
||||
|
||||
@@ -320,6 +326,7 @@ def reset_alembic_version():
|
||||
db_url = os.environ.get("DATABASE_URL")
|
||||
if not db_url:
|
||||
from app.core.config import settings
|
||||
|
||||
db_url = settings.database_url
|
||||
|
||||
try:
|
||||
@@ -338,82 +345,80 @@ def reset_alembic_version():
|
||||
def main():
|
||||
"""Main function"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Database migration helper for Generative Models Arena'
|
||||
description="Database migration helper for Generative Models Arena"
|
||||
)
|
||||
|
||||
# Global options
|
||||
parser.add_argument(
|
||||
'--local', '-l',
|
||||
action='store_true',
|
||||
help='Use localhost instead of Docker hostname (for local development)'
|
||||
"--local",
|
||||
"-l",
|
||||
action="store_true",
|
||||
help="Use localhost instead of Docker hostname (for local development)",
|
||||
)
|
||||
|
||||
subparsers = parser.add_subparsers(dest='command', help='Command to run')
|
||||
subparsers = parser.add_subparsers(dest="command", help="Command to run")
|
||||
|
||||
# Generate command
|
||||
generate_parser = subparsers.add_parser('generate', help='Generate a migration')
|
||||
generate_parser.add_argument('message', help='Migration message')
|
||||
generate_parser = subparsers.add_parser("generate", help="Generate a migration")
|
||||
generate_parser.add_argument("message", help="Migration message")
|
||||
generate_parser.add_argument(
|
||||
'--rev-id',
|
||||
help='Custom revision ID (e.g., 0001, 0002 for sequential naming)'
|
||||
"--rev-id", help="Custom revision ID (e.g., 0001, 0002 for sequential naming)"
|
||||
)
|
||||
generate_parser.add_argument(
|
||||
'--offline',
|
||||
action='store_true',
|
||||
help='Generate empty migration template without database connection'
|
||||
"--offline",
|
||||
action="store_true",
|
||||
help="Generate empty migration template without database connection",
|
||||
)
|
||||
|
||||
# Apply command
|
||||
apply_parser = subparsers.add_parser('apply', help='Apply migrations')
|
||||
apply_parser.add_argument('--revision', help='Specific revision to apply to')
|
||||
apply_parser = subparsers.add_parser("apply", help="Apply migrations")
|
||||
apply_parser.add_argument("--revision", help="Specific revision to apply to")
|
||||
|
||||
# List command
|
||||
subparsers.add_parser('list', help='List migrations')
|
||||
subparsers.add_parser("list", help="List migrations")
|
||||
|
||||
# Current command
|
||||
subparsers.add_parser('current', help='Show current revision')
|
||||
subparsers.add_parser("current", help="Show current revision")
|
||||
|
||||
# Check command
|
||||
subparsers.add_parser('check', help='Check database connection and models')
|
||||
subparsers.add_parser("check", help="Check database connection and models")
|
||||
|
||||
# Next command (show next revision ID)
|
||||
subparsers.add_parser('next', help='Show the next sequential revision ID')
|
||||
subparsers.add_parser("next", help="Show the next sequential revision ID")
|
||||
|
||||
# Reset command (clear alembic_version table)
|
||||
subparsers.add_parser(
|
||||
'reset',
|
||||
help='Reset alembic_version table (use after deleting all migrations)'
|
||||
"reset", help="Reset alembic_version table (use after deleting all migrations)"
|
||||
)
|
||||
|
||||
# Auto command (generate and apply)
|
||||
auto_parser = subparsers.add_parser('auto', help='Generate and apply migration')
|
||||
auto_parser.add_argument('message', help='Migration message')
|
||||
auto_parser = subparsers.add_parser("auto", help="Generate and apply migration")
|
||||
auto_parser.add_argument("message", help="Migration message")
|
||||
auto_parser.add_argument(
|
||||
'--rev-id',
|
||||
help='Custom revision ID (e.g., 0001, 0002 for sequential naming)'
|
||||
"--rev-id", help="Custom revision ID (e.g., 0001, 0002 for sequential naming)"
|
||||
)
|
||||
auto_parser.add_argument(
|
||||
'--offline',
|
||||
action='store_true',
|
||||
help='Generate empty migration template without database connection'
|
||||
"--offline",
|
||||
action="store_true",
|
||||
help="Generate empty migration template without database connection",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Commands that don't need database connection
|
||||
if args.command == 'next':
|
||||
if args.command == "next":
|
||||
show_next_rev_id()
|
||||
return
|
||||
|
||||
# Check if offline mode is requested
|
||||
offline = getattr(args, 'offline', False)
|
||||
offline = getattr(args, "offline", False)
|
||||
|
||||
# Offline generate doesn't need database or model check
|
||||
if args.command == 'generate' and offline:
|
||||
if args.command == "generate" and offline:
|
||||
generate_migration(args.message, rev_id=args.rev_id, offline=True)
|
||||
return
|
||||
|
||||
if args.command == 'auto' and offline:
|
||||
if args.command == "auto" and offline:
|
||||
generate_migration(args.message, rev_id=args.rev_id, offline=True)
|
||||
print("\nOffline migration generated. Apply it later with:")
|
||||
print(" python migrate.py --local apply")
|
||||
@@ -423,27 +428,27 @@ def main():
|
||||
db_url = setup_database_url(args.local)
|
||||
print(f"Using database URL: {db_url}")
|
||||
|
||||
if args.command == 'generate':
|
||||
if args.command == "generate":
|
||||
check_models()
|
||||
generate_migration(args.message, rev_id=args.rev_id)
|
||||
|
||||
elif args.command == 'apply':
|
||||
elif args.command == "apply":
|
||||
apply_migration(args.revision)
|
||||
|
||||
elif args.command == 'list':
|
||||
elif args.command == "list":
|
||||
list_migrations()
|
||||
|
||||
elif args.command == 'current':
|
||||
elif args.command == "current":
|
||||
show_current()
|
||||
|
||||
elif args.command == 'check':
|
||||
elif args.command == "check":
|
||||
check_database_connection()
|
||||
check_models()
|
||||
|
||||
elif args.command == 'reset':
|
||||
elif args.command == "reset":
|
||||
reset_alembic_version()
|
||||
|
||||
elif args.command == 'auto':
|
||||
elif args.command == "auto":
|
||||
check_models()
|
||||
revision = generate_migration(args.message, rev_id=args.rev_id)
|
||||
if revision:
|
||||
|
||||
@@ -745,3 +745,230 @@ class TestAgentTypeInstanceCount:
|
||||
for agent_type in data["data"]:
|
||||
assert "instance_count" in agent_type
|
||||
assert isinstance(agent_type["instance_count"], int)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestAgentTypeCategoryFields:
|
||||
"""Tests for agent type category and display fields."""
|
||||
|
||||
async def test_create_agent_type_with_category_fields(
|
||||
self, client, superuser_token
|
||||
):
|
||||
"""Test creating agent type with all category and display fields."""
|
||||
unique_slug = f"category-type-{uuid.uuid4().hex[:8]}"
|
||||
response = await client.post(
|
||||
"/api/v1/agent-types",
|
||||
json={
|
||||
"name": "Categorized Agent Type",
|
||||
"slug": unique_slug,
|
||||
"description": "An agent type with category fields",
|
||||
"expertise": ["python"],
|
||||
"personality_prompt": "You are a helpful assistant.",
|
||||
"primary_model": "claude-opus-4-5-20251101",
|
||||
# Category and display fields
|
||||
"category": "development",
|
||||
"icon": "code",
|
||||
"color": "#3B82F6",
|
||||
"sort_order": 10,
|
||||
"typical_tasks": ["Write code", "Review PRs"],
|
||||
"collaboration_hints": ["backend-engineer", "qa-engineer"],
|
||||
},
|
||||
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
data = response.json()
|
||||
|
||||
assert data["category"] == "development"
|
||||
assert data["icon"] == "code"
|
||||
assert data["color"] == "#3B82F6"
|
||||
assert data["sort_order"] == 10
|
||||
assert data["typical_tasks"] == ["Write code", "Review PRs"]
|
||||
assert data["collaboration_hints"] == ["backend-engineer", "qa-engineer"]
|
||||
|
||||
async def test_create_agent_type_with_nullable_category(
|
||||
self, client, superuser_token
|
||||
):
|
||||
"""Test creating agent type with null category."""
|
||||
unique_slug = f"null-category-{uuid.uuid4().hex[:8]}"
|
||||
response = await client.post(
|
||||
"/api/v1/agent-types",
|
||||
json={
|
||||
"name": "Uncategorized Agent",
|
||||
"slug": unique_slug,
|
||||
"expertise": ["general"],
|
||||
"personality_prompt": "You are a helpful assistant.",
|
||||
"primary_model": "claude-opus-4-5-20251101",
|
||||
"category": None,
|
||||
},
|
||||
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
data = response.json()
|
||||
assert data["category"] is None
|
||||
|
||||
async def test_create_agent_type_invalid_color_format(
|
||||
self, client, superuser_token
|
||||
):
|
||||
"""Test that invalid color format is rejected."""
|
||||
unique_slug = f"invalid-color-{uuid.uuid4().hex[:8]}"
|
||||
response = await client.post(
|
||||
"/api/v1/agent-types",
|
||||
json={
|
||||
"name": "Invalid Color Agent",
|
||||
"slug": unique_slug,
|
||||
"expertise": ["python"],
|
||||
"personality_prompt": "You are a helpful assistant.",
|
||||
"primary_model": "claude-opus-4-5-20251101",
|
||||
"color": "not-a-hex-color",
|
||||
},
|
||||
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
async def test_create_agent_type_invalid_category(self, client, superuser_token):
|
||||
"""Test that invalid category value is rejected."""
|
||||
unique_slug = f"invalid-category-{uuid.uuid4().hex[:8]}"
|
||||
response = await client.post(
|
||||
"/api/v1/agent-types",
|
||||
json={
|
||||
"name": "Invalid Category Agent",
|
||||
"slug": unique_slug,
|
||||
"expertise": ["python"],
|
||||
"personality_prompt": "You are a helpful assistant.",
|
||||
"primary_model": "claude-opus-4-5-20251101",
|
||||
"category": "not_a_valid_category",
|
||||
},
|
||||
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
async def test_update_agent_type_category_fields(
|
||||
self, client, superuser_token, test_agent_type
|
||||
):
|
||||
"""Test updating category and display fields."""
|
||||
agent_type_id = test_agent_type["id"]
|
||||
|
||||
response = await client.patch(
|
||||
f"/api/v1/agent-types/{agent_type_id}",
|
||||
json={
|
||||
"category": "ai_ml",
|
||||
"icon": "brain",
|
||||
"color": "#8B5CF6",
|
||||
"sort_order": 50,
|
||||
"typical_tasks": ["Train models", "Analyze data"],
|
||||
"collaboration_hints": ["data-scientist"],
|
||||
},
|
||||
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
|
||||
assert data["category"] == "ai_ml"
|
||||
assert data["icon"] == "brain"
|
||||
assert data["color"] == "#8B5CF6"
|
||||
assert data["sort_order"] == 50
|
||||
assert data["typical_tasks"] == ["Train models", "Analyze data"]
|
||||
assert data["collaboration_hints"] == ["data-scientist"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestAgentTypeCategoryFilter:
|
||||
"""Tests for agent type category filtering."""
|
||||
|
||||
async def test_list_agent_types_filter_by_category(
|
||||
self, client, superuser_token, user_token
|
||||
):
|
||||
"""Test filtering agent types by category."""
|
||||
# Create agent types in different categories
|
||||
for cat in ["development", "design"]:
|
||||
unique_slug = f"filter-test-{cat}-{uuid.uuid4().hex[:8]}"
|
||||
await client.post(
|
||||
"/api/v1/agent-types",
|
||||
json={
|
||||
"name": f"Filter Test {cat.capitalize()}",
|
||||
"slug": unique_slug,
|
||||
"expertise": ["python"],
|
||||
"personality_prompt": "Test prompt",
|
||||
"primary_model": "claude-opus-4-5-20251101",
|
||||
"category": cat,
|
||||
},
|
||||
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||
)
|
||||
|
||||
# Filter by development category
|
||||
response = await client.get(
|
||||
"/api/v1/agent-types",
|
||||
params={"category": "development"},
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
|
||||
# All returned types should have development category
|
||||
for agent_type in data["data"]:
|
||||
assert agent_type["category"] == "development"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestAgentTypeGroupedEndpoint:
|
||||
"""Tests for the grouped by category endpoint."""
|
||||
|
||||
async def test_list_agent_types_grouped(self, client, superuser_token, user_token):
|
||||
"""Test getting agent types grouped by category."""
|
||||
# Create agent types in different categories
|
||||
categories = ["development", "design", "quality"]
|
||||
for cat in categories:
|
||||
unique_slug = f"grouped-test-{cat}-{uuid.uuid4().hex[:8]}"
|
||||
await client.post(
|
||||
"/api/v1/agent-types",
|
||||
json={
|
||||
"name": f"Grouped Test {cat.capitalize()}",
|
||||
"slug": unique_slug,
|
||||
"expertise": ["python"],
|
||||
"personality_prompt": "Test prompt",
|
||||
"primary_model": "claude-opus-4-5-20251101",
|
||||
"category": cat,
|
||||
"sort_order": 10,
|
||||
},
|
||||
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||
)
|
||||
|
||||
# Get grouped agent types
|
||||
response = await client.get(
|
||||
"/api/v1/agent-types/grouped",
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
|
||||
# Should be a dict with category keys
|
||||
assert isinstance(data, dict)
|
||||
|
||||
# Check that at least one of our created categories exists
|
||||
assert any(cat in data for cat in categories)
|
||||
|
||||
async def test_list_agent_types_grouped_filter_inactive(
|
||||
self, client, superuser_token, user_token
|
||||
):
|
||||
"""Test grouped endpoint with is_active filter."""
|
||||
response = await client.get(
|
||||
"/api/v1/agent-types/grouped",
|
||||
params={"is_active": False},
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert isinstance(data, dict)
|
||||
|
||||
async def test_list_agent_types_grouped_unauthenticated(self, client):
|
||||
"""Test that unauthenticated users cannot access grouped endpoint."""
|
||||
response = await client.get("/api/v1/agent-types/grouped")
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
@@ -188,13 +188,14 @@ class TestPasswordResetConfirm:
|
||||
@pytest.mark.asyncio
|
||||
async def test_password_reset_confirm_expired_token(self, client, async_test_user):
|
||||
"""Test password reset confirmation with expired token."""
|
||||
import time as time_module
|
||||
import asyncio
|
||||
|
||||
# Create token that expires immediately
|
||||
token = create_password_reset_token(async_test_user.email, expires_in=1)
|
||||
# Create token that expires at current second (expires_in=0)
|
||||
# Token expires when exp < current_time, so we need to cross a second boundary
|
||||
token = create_password_reset_token(async_test_user.email, expires_in=0)
|
||||
|
||||
# Wait for token to expire
|
||||
time_module.sleep(2)
|
||||
# Wait for token to expire (need to cross second boundary)
|
||||
await asyncio.sleep(1.1)
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/auth/password-reset/confirm",
|
||||
|
||||
@@ -368,3 +368,9 @@ async def e2e_org_with_members(e2e_client, e2e_superuser):
|
||||
"user_id": member_id,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# NOTE: Class-scoped fixtures for E2E tests were attempted but have fundamental
|
||||
# issues with pytest-asyncio + SQLAlchemy/asyncpg event loop management.
|
||||
# The function-scoped fixtures above provide proper test isolation.
|
||||
# Performance optimization would require significant infrastructure changes.
|
||||
|
||||
2
backend/tests/models/memory/__init__.py
Normal file
2
backend/tests/models/memory/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# tests/unit/models/memory/__init__.py
|
||||
"""Unit tests for memory database models."""
|
||||
71
backend/tests/models/memory/test_enums.py
Normal file
71
backend/tests/models/memory/test_enums.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# tests/unit/models/memory/test_enums.py
|
||||
"""Unit tests for memory model enums."""
|
||||
|
||||
from app.models.memory.enums import (
|
||||
ConsolidationStatus,
|
||||
ConsolidationType,
|
||||
EpisodeOutcome,
|
||||
ScopeType,
|
||||
)
|
||||
|
||||
|
||||
class TestScopeType:
|
||||
"""Tests for ScopeType enum."""
|
||||
|
||||
def test_all_values_exist(self) -> None:
|
||||
"""Test all expected scope types exist."""
|
||||
assert ScopeType.GLOBAL.value == "global"
|
||||
assert ScopeType.PROJECT.value == "project"
|
||||
assert ScopeType.AGENT_TYPE.value == "agent_type"
|
||||
assert ScopeType.AGENT_INSTANCE.value == "agent_instance"
|
||||
assert ScopeType.SESSION.value == "session"
|
||||
|
||||
def test_scope_count(self) -> None:
|
||||
"""Test we have exactly 5 scope types."""
|
||||
assert len(ScopeType) == 5
|
||||
|
||||
|
||||
class TestEpisodeOutcome:
|
||||
"""Tests for EpisodeOutcome enum."""
|
||||
|
||||
def test_all_values_exist(self) -> None:
|
||||
"""Test all expected outcome values exist."""
|
||||
assert EpisodeOutcome.SUCCESS.value == "success"
|
||||
assert EpisodeOutcome.FAILURE.value == "failure"
|
||||
assert EpisodeOutcome.PARTIAL.value == "partial"
|
||||
|
||||
def test_outcome_count(self) -> None:
|
||||
"""Test we have exactly 3 outcome types."""
|
||||
assert len(EpisodeOutcome) == 3
|
||||
|
||||
|
||||
class TestConsolidationType:
|
||||
"""Tests for ConsolidationType enum."""
|
||||
|
||||
def test_all_values_exist(self) -> None:
|
||||
"""Test all expected consolidation types exist."""
|
||||
assert ConsolidationType.WORKING_TO_EPISODIC.value == "working_to_episodic"
|
||||
assert ConsolidationType.EPISODIC_TO_SEMANTIC.value == "episodic_to_semantic"
|
||||
assert (
|
||||
ConsolidationType.EPISODIC_TO_PROCEDURAL.value == "episodic_to_procedural"
|
||||
)
|
||||
assert ConsolidationType.PRUNING.value == "pruning"
|
||||
|
||||
def test_consolidation_count(self) -> None:
|
||||
"""Test we have exactly 4 consolidation types."""
|
||||
assert len(ConsolidationType) == 4
|
||||
|
||||
|
||||
class TestConsolidationStatus:
|
||||
"""Tests for ConsolidationStatus enum."""
|
||||
|
||||
def test_all_values_exist(self) -> None:
|
||||
"""Test all expected status values exist."""
|
||||
assert ConsolidationStatus.PENDING.value == "pending"
|
||||
assert ConsolidationStatus.RUNNING.value == "running"
|
||||
assert ConsolidationStatus.COMPLETED.value == "completed"
|
||||
assert ConsolidationStatus.FAILED.value == "failed"
|
||||
|
||||
def test_status_count(self) -> None:
|
||||
"""Test we have exactly 4 status types."""
|
||||
assert len(ConsolidationStatus) == 4
|
||||
249
backend/tests/models/memory/test_models.py
Normal file
249
backend/tests/models/memory/test_models.py
Normal file
@@ -0,0 +1,249 @@
|
||||
# tests/unit/models/memory/test_models.py
|
||||
"""Unit tests for memory database models."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
from app.models.memory import (
|
||||
ConsolidationStatus,
|
||||
ConsolidationType,
|
||||
Episode,
|
||||
EpisodeOutcome,
|
||||
Fact,
|
||||
MemoryConsolidationLog,
|
||||
Procedure,
|
||||
ScopeType,
|
||||
WorkingMemory,
|
||||
)
|
||||
|
||||
|
||||
class TestWorkingMemoryModel:
|
||||
"""Tests for WorkingMemory model."""
|
||||
|
||||
def test_tablename(self) -> None:
|
||||
"""Test table name is correct."""
|
||||
assert WorkingMemory.__tablename__ == "working_memory"
|
||||
|
||||
def test_has_required_columns(self) -> None:
|
||||
"""Test all required columns exist."""
|
||||
columns = WorkingMemory.__table__.columns
|
||||
assert "id" in columns
|
||||
assert "scope_type" in columns
|
||||
assert "scope_id" in columns
|
||||
assert "key" in columns
|
||||
assert "value" in columns
|
||||
assert "expires_at" in columns
|
||||
assert "created_at" in columns
|
||||
assert "updated_at" in columns
|
||||
|
||||
def test_has_unique_constraint(self) -> None:
|
||||
"""Test unique constraint on scope+key."""
|
||||
indexes = {idx.name: idx for idx in WorkingMemory.__table__.indexes}
|
||||
assert "ix_working_memory_scope_key" in indexes
|
||||
assert indexes["ix_working_memory_scope_key"].unique
|
||||
|
||||
|
||||
class TestEpisodeModel:
|
||||
"""Tests for Episode model."""
|
||||
|
||||
def test_tablename(self) -> None:
|
||||
"""Test table name is correct."""
|
||||
assert Episode.__tablename__ == "episodes"
|
||||
|
||||
def test_has_required_columns(self) -> None:
|
||||
"""Test all required columns exist."""
|
||||
columns = Episode.__table__.columns
|
||||
required = [
|
||||
"id",
|
||||
"project_id",
|
||||
"agent_instance_id",
|
||||
"agent_type_id",
|
||||
"session_id",
|
||||
"task_type",
|
||||
"task_description",
|
||||
"actions",
|
||||
"context_summary",
|
||||
"outcome",
|
||||
"outcome_details",
|
||||
"duration_seconds",
|
||||
"tokens_used",
|
||||
"lessons_learned",
|
||||
"importance_score",
|
||||
"embedding",
|
||||
"occurred_at",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
]
|
||||
for col in required:
|
||||
assert col in columns, f"Missing column: {col}"
|
||||
|
||||
def test_has_foreign_keys(self) -> None:
|
||||
"""Test foreign key relationships exist."""
|
||||
columns = Episode.__table__.columns
|
||||
assert columns["project_id"].foreign_keys
|
||||
assert columns["agent_instance_id"].foreign_keys
|
||||
assert columns["agent_type_id"].foreign_keys
|
||||
|
||||
def test_has_relationships(self) -> None:
|
||||
"""Test ORM relationships exist."""
|
||||
mapper = Episode.__mapper__
|
||||
assert "project" in mapper.relationships
|
||||
assert "agent_instance" in mapper.relationships
|
||||
assert "agent_type" in mapper.relationships
|
||||
|
||||
|
||||
class TestFactModel:
|
||||
"""Tests for Fact model."""
|
||||
|
||||
def test_tablename(self) -> None:
|
||||
"""Test table name is correct."""
|
||||
assert Fact.__tablename__ == "facts"
|
||||
|
||||
def test_has_required_columns(self) -> None:
|
||||
"""Test all required columns exist."""
|
||||
columns = Fact.__table__.columns
|
||||
required = [
|
||||
"id",
|
||||
"project_id",
|
||||
"subject",
|
||||
"predicate",
|
||||
"object",
|
||||
"confidence",
|
||||
"source_episode_ids",
|
||||
"first_learned",
|
||||
"last_reinforced",
|
||||
"reinforcement_count",
|
||||
"embedding",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
]
|
||||
for col in required:
|
||||
assert col in columns, f"Missing column: {col}"
|
||||
|
||||
def test_project_id_nullable(self) -> None:
|
||||
"""Test project_id is nullable for global facts."""
|
||||
columns = Fact.__table__.columns
|
||||
assert columns["project_id"].nullable
|
||||
|
||||
|
||||
class TestProcedureModel:
|
||||
"""Tests for Procedure model."""
|
||||
|
||||
def test_tablename(self) -> None:
|
||||
"""Test table name is correct."""
|
||||
assert Procedure.__tablename__ == "procedures"
|
||||
|
||||
def test_has_required_columns(self) -> None:
|
||||
"""Test all required columns exist."""
|
||||
columns = Procedure.__table__.columns
|
||||
required = [
|
||||
"id",
|
||||
"project_id",
|
||||
"agent_type_id",
|
||||
"name",
|
||||
"trigger_pattern",
|
||||
"steps",
|
||||
"success_count",
|
||||
"failure_count",
|
||||
"last_used",
|
||||
"embedding",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
]
|
||||
for col in required:
|
||||
assert col in columns, f"Missing column: {col}"
|
||||
|
||||
def test_success_rate_property(self) -> None:
|
||||
"""Test success_rate calculated property."""
|
||||
proc = Procedure()
|
||||
proc.success_count = 8
|
||||
proc.failure_count = 2
|
||||
assert proc.success_rate == 0.8
|
||||
|
||||
def test_success_rate_zero_total(self) -> None:
|
||||
"""Test success_rate with zero total uses."""
|
||||
proc = Procedure()
|
||||
proc.success_count = 0
|
||||
proc.failure_count = 0
|
||||
assert proc.success_rate == 0.0
|
||||
|
||||
def test_total_uses_property(self) -> None:
|
||||
"""Test total_uses calculated property."""
|
||||
proc = Procedure()
|
||||
proc.success_count = 5
|
||||
proc.failure_count = 3
|
||||
assert proc.total_uses == 8
|
||||
|
||||
|
||||
class TestMemoryConsolidationLogModel:
|
||||
"""Tests for MemoryConsolidationLog model."""
|
||||
|
||||
def test_tablename(self) -> None:
|
||||
"""Test table name is correct."""
|
||||
assert MemoryConsolidationLog.__tablename__ == "memory_consolidation_log"
|
||||
|
||||
def test_has_required_columns(self) -> None:
|
||||
"""Test all required columns exist."""
|
||||
columns = MemoryConsolidationLog.__table__.columns
|
||||
required = [
|
||||
"id",
|
||||
"consolidation_type",
|
||||
"source_count",
|
||||
"result_count",
|
||||
"started_at",
|
||||
"completed_at",
|
||||
"status",
|
||||
"error",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
]
|
||||
for col in required:
|
||||
assert col in columns, f"Missing column: {col}"
|
||||
|
||||
def test_duration_seconds_property_completed(self) -> None:
|
||||
"""Test duration_seconds with completed job."""
|
||||
log = MemoryConsolidationLog()
|
||||
log.started_at = datetime.now(UTC)
|
||||
log.completed_at = log.started_at + timedelta(seconds=10)
|
||||
assert log.duration_seconds == pytest.approx(10.0)
|
||||
|
||||
def test_duration_seconds_property_incomplete(self) -> None:
|
||||
"""Test duration_seconds with incomplete job."""
|
||||
log = MemoryConsolidationLog()
|
||||
log.started_at = datetime.now(UTC)
|
||||
log.completed_at = None
|
||||
assert log.duration_seconds is None
|
||||
|
||||
def test_default_status(self) -> None:
|
||||
"""Test default status is PENDING."""
|
||||
columns = MemoryConsolidationLog.__table__.columns
|
||||
assert columns["status"].default.arg == ConsolidationStatus.PENDING
|
||||
|
||||
|
||||
class TestModelExports:
|
||||
"""Tests for model package exports."""
|
||||
|
||||
def test_all_models_exported(self) -> None:
|
||||
"""Test all models are exported from package."""
|
||||
from app.models.memory import (
|
||||
Episode,
|
||||
Fact,
|
||||
MemoryConsolidationLog,
|
||||
Procedure,
|
||||
WorkingMemory,
|
||||
)
|
||||
|
||||
# Verify these are the actual classes
|
||||
assert Episode.__tablename__ == "episodes"
|
||||
assert Fact.__tablename__ == "facts"
|
||||
assert Procedure.__tablename__ == "procedures"
|
||||
assert WorkingMemory.__tablename__ == "working_memory"
|
||||
assert MemoryConsolidationLog.__tablename__ == "memory_consolidation_log"
|
||||
|
||||
def test_enums_exported(self) -> None:
|
||||
"""Test all enums are exported."""
|
||||
assert ScopeType.GLOBAL.value == "global"
|
||||
assert EpisodeOutcome.SUCCESS.value == "success"
|
||||
assert ConsolidationType.WORKING_TO_EPISODIC.value == "working_to_episodic"
|
||||
assert ConsolidationStatus.PENDING.value == "pending"
|
||||
@@ -316,3 +316,325 @@ class TestAgentTypeJsonFields:
|
||||
)
|
||||
|
||||
assert agent_type.fallback_models == models
|
||||
|
||||
|
||||
class TestAgentTypeCategoryFieldsValidation:
|
||||
"""Tests for AgentType category and display field validation."""
|
||||
|
||||
def test_valid_category_values(self):
|
||||
"""Test that all valid category values are accepted."""
|
||||
valid_categories = [
|
||||
"development",
|
||||
"design",
|
||||
"quality",
|
||||
"operations",
|
||||
"ai_ml",
|
||||
"data",
|
||||
"leadership",
|
||||
"domain_expert",
|
||||
]
|
||||
|
||||
for category in valid_categories:
|
||||
agent_type = AgentTypeCreate(
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
category=category,
|
||||
)
|
||||
assert agent_type.category.value == category
|
||||
|
||||
def test_category_null_allowed(self):
|
||||
"""Test that null category is allowed."""
|
||||
agent_type = AgentTypeCreate(
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
category=None,
|
||||
)
|
||||
assert agent_type.category is None
|
||||
|
||||
def test_invalid_category_rejected(self):
|
||||
"""Test that invalid category values are rejected."""
|
||||
with pytest.raises(ValidationError):
|
||||
AgentTypeCreate(
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
category="invalid_category",
|
||||
)
|
||||
|
||||
def test_valid_hex_color(self):
|
||||
"""Test that valid hex colors are accepted."""
|
||||
valid_colors = ["#3B82F6", "#EC4899", "#10B981", "#ffffff", "#000000"]
|
||||
|
||||
for color in valid_colors:
|
||||
agent_type = AgentTypeCreate(
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
color=color,
|
||||
)
|
||||
assert agent_type.color == color
|
||||
|
||||
def test_invalid_hex_color_rejected(self):
|
||||
"""Test that invalid hex colors are rejected."""
|
||||
invalid_colors = [
|
||||
"not-a-color",
|
||||
"3B82F6", # Missing #
|
||||
"#3B82F", # Too short
|
||||
"#3B82F6A", # Too long
|
||||
"#GGGGGG", # Invalid hex chars
|
||||
"rgb(59, 130, 246)", # RGB format not supported
|
||||
]
|
||||
|
||||
for color in invalid_colors:
|
||||
with pytest.raises(ValidationError):
|
||||
AgentTypeCreate(
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
color=color,
|
||||
)
|
||||
|
||||
def test_color_null_allowed(self):
|
||||
"""Test that null color is allowed."""
|
||||
agent_type = AgentTypeCreate(
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
color=None,
|
||||
)
|
||||
assert agent_type.color is None
|
||||
|
||||
def test_sort_order_valid_range(self):
|
||||
"""Test that valid sort_order values are accepted."""
|
||||
for sort_order in [0, 1, 500, 1000]:
|
||||
agent_type = AgentTypeCreate(
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
sort_order=sort_order,
|
||||
)
|
||||
assert agent_type.sort_order == sort_order
|
||||
|
||||
def test_sort_order_default_zero(self):
|
||||
"""Test that sort_order defaults to 0."""
|
||||
agent_type = AgentTypeCreate(
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
)
|
||||
assert agent_type.sort_order == 0
|
||||
|
||||
def test_sort_order_negative_rejected(self):
|
||||
"""Test that negative sort_order is rejected."""
|
||||
with pytest.raises(ValidationError):
|
||||
AgentTypeCreate(
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
sort_order=-1,
|
||||
)
|
||||
|
||||
def test_sort_order_exceeds_max_rejected(self):
|
||||
"""Test that sort_order > 1000 is rejected."""
|
||||
with pytest.raises(ValidationError):
|
||||
AgentTypeCreate(
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
sort_order=1001,
|
||||
)
|
||||
|
||||
def test_icon_max_length(self):
|
||||
"""Test that icon field respects max length."""
|
||||
agent_type = AgentTypeCreate(
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
icon="x" * 50,
|
||||
)
|
||||
assert len(agent_type.icon) == 50
|
||||
|
||||
def test_icon_exceeds_max_length_rejected(self):
|
||||
"""Test that icon exceeding max length is rejected."""
|
||||
with pytest.raises(ValidationError):
|
||||
AgentTypeCreate(
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
icon="x" * 51,
|
||||
)
|
||||
|
||||
|
||||
class TestAgentTypeTypicalTasksValidation:
|
||||
"""Tests for typical_tasks field validation."""
|
||||
|
||||
def test_typical_tasks_list(self):
|
||||
"""Test typical_tasks as a list."""
|
||||
tasks = ["Write code", "Review PRs", "Debug issues"]
|
||||
agent_type = AgentTypeCreate(
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
typical_tasks=tasks,
|
||||
)
|
||||
assert agent_type.typical_tasks == tasks
|
||||
|
||||
def test_typical_tasks_default_empty(self):
|
||||
"""Test typical_tasks defaults to empty list."""
|
||||
agent_type = AgentTypeCreate(
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
)
|
||||
assert agent_type.typical_tasks == []
|
||||
|
||||
def test_typical_tasks_strips_whitespace(self):
|
||||
"""Test that typical_tasks items are stripped."""
|
||||
agent_type = AgentTypeCreate(
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
typical_tasks=[" Write code ", " Debug "],
|
||||
)
|
||||
assert agent_type.typical_tasks == ["Write code", "Debug"]
|
||||
|
||||
def test_typical_tasks_removes_empty_strings(self):
|
||||
"""Test that empty strings are removed from typical_tasks."""
|
||||
agent_type = AgentTypeCreate(
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
typical_tasks=["Write code", "", " ", "Debug"],
|
||||
)
|
||||
assert agent_type.typical_tasks == ["Write code", "Debug"]
|
||||
|
||||
|
||||
class TestAgentTypeCollaborationHintsValidation:
|
||||
"""Tests for collaboration_hints field validation."""
|
||||
|
||||
def test_collaboration_hints_list(self):
|
||||
"""Test collaboration_hints as a list."""
|
||||
hints = ["backend-engineer", "qa-engineer"]
|
||||
agent_type = AgentTypeCreate(
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
collaboration_hints=hints,
|
||||
)
|
||||
assert agent_type.collaboration_hints == hints
|
||||
|
||||
def test_collaboration_hints_default_empty(self):
|
||||
"""Test collaboration_hints defaults to empty list."""
|
||||
agent_type = AgentTypeCreate(
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
)
|
||||
assert agent_type.collaboration_hints == []
|
||||
|
||||
def test_collaboration_hints_normalized_lowercase(self):
|
||||
"""Test that collaboration_hints are normalized to lowercase."""
|
||||
agent_type = AgentTypeCreate(
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
collaboration_hints=["Backend-Engineer", "QA-ENGINEER"],
|
||||
)
|
||||
assert agent_type.collaboration_hints == ["backend-engineer", "qa-engineer"]
|
||||
|
||||
def test_collaboration_hints_strips_whitespace(self):
|
||||
"""Test that collaboration_hints are stripped."""
|
||||
agent_type = AgentTypeCreate(
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
collaboration_hints=[" backend-engineer ", " qa-engineer "],
|
||||
)
|
||||
assert agent_type.collaboration_hints == ["backend-engineer", "qa-engineer"]
|
||||
|
||||
def test_collaboration_hints_removes_empty_strings(self):
|
||||
"""Test that empty strings are removed from collaboration_hints."""
|
||||
agent_type = AgentTypeCreate(
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
collaboration_hints=["backend-engineer", "", " ", "qa-engineer"],
|
||||
)
|
||||
assert agent_type.collaboration_hints == ["backend-engineer", "qa-engineer"]
|
||||
|
||||
|
||||
class TestAgentTypeUpdateCategoryFields:
|
||||
"""Tests for AgentTypeUpdate category and display fields."""
|
||||
|
||||
def test_update_category_field(self):
|
||||
"""Test updating category field."""
|
||||
update = AgentTypeUpdate(category="ai_ml")
|
||||
assert update.category.value == "ai_ml"
|
||||
|
||||
def test_update_icon_field(self):
|
||||
"""Test updating icon field."""
|
||||
update = AgentTypeUpdate(icon="brain")
|
||||
assert update.icon == "brain"
|
||||
|
||||
def test_update_color_field(self):
|
||||
"""Test updating color field."""
|
||||
update = AgentTypeUpdate(color="#8B5CF6")
|
||||
assert update.color == "#8B5CF6"
|
||||
|
||||
def test_update_sort_order_field(self):
|
||||
"""Test updating sort_order field."""
|
||||
update = AgentTypeUpdate(sort_order=50)
|
||||
assert update.sort_order == 50
|
||||
|
||||
def test_update_typical_tasks_field(self):
|
||||
"""Test updating typical_tasks field."""
|
||||
update = AgentTypeUpdate(typical_tasks=["New task"])
|
||||
assert update.typical_tasks == ["New task"]
|
||||
|
||||
def test_update_typical_tasks_strips_whitespace(self):
|
||||
"""Test that typical_tasks are stripped on update."""
|
||||
update = AgentTypeUpdate(typical_tasks=[" New task "])
|
||||
assert update.typical_tasks == ["New task"]
|
||||
|
||||
def test_update_collaboration_hints_field(self):
|
||||
"""Test updating collaboration_hints field."""
|
||||
update = AgentTypeUpdate(collaboration_hints=["new-collaborator"])
|
||||
assert update.collaboration_hints == ["new-collaborator"]
|
||||
|
||||
def test_update_collaboration_hints_normalized(self):
|
||||
"""Test that collaboration_hints are normalized on update."""
|
||||
update = AgentTypeUpdate(collaboration_hints=[" New-Collaborator "])
|
||||
assert update.collaboration_hints == ["new-collaborator"]
|
||||
|
||||
def test_update_invalid_color_rejected(self):
|
||||
"""Test that invalid color is rejected on update."""
|
||||
with pytest.raises(ValidationError):
|
||||
AgentTypeUpdate(color="invalid")
|
||||
|
||||
def test_update_invalid_sort_order_rejected(self):
|
||||
"""Test that invalid sort_order is rejected on update."""
|
||||
with pytest.raises(ValidationError):
|
||||
AgentTypeUpdate(sort_order=-1)
|
||||
|
||||
@@ -304,10 +304,18 @@ class TestTaskModuleExports:
|
||||
assert hasattr(tasks, "sync")
|
||||
assert hasattr(tasks, "workflow")
|
||||
assert hasattr(tasks, "cost")
|
||||
assert hasattr(tasks, "memory_consolidation")
|
||||
|
||||
def test_tasks_all_attribute_is_correct(self):
|
||||
"""Test that __all__ contains all expected module names."""
|
||||
from app import tasks
|
||||
|
||||
expected_modules = ["agent", "git", "sync", "workflow", "cost"]
|
||||
expected_modules = [
|
||||
"agent",
|
||||
"git",
|
||||
"sync",
|
||||
"workflow",
|
||||
"cost",
|
||||
"memory_consolidation",
|
||||
]
|
||||
assert set(tasks.__all__) == set(expected_modules)
|
||||
|
||||
@@ -42,6 +42,9 @@ class TestInitDb:
|
||||
assert user.last_name == "User"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skip(
|
||||
reason="SQLite doesn't support UUID type binding - requires PostgreSQL"
|
||||
)
|
||||
async def test_init_db_returns_existing_superuser(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
|
||||
2
backend/tests/unit/models/__init__.py
Normal file
2
backend/tests/unit/models/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# tests/unit/models/__init__.py
|
||||
"""Unit tests for database models."""
|
||||
260
backend/tests/unit/services/context/types/test_memory.py
Normal file
260
backend/tests/unit/services/context/types/test_memory.py
Normal file
@@ -0,0 +1,260 @@
|
||||
# tests/unit/services/context/types/test_memory.py
|
||||
"""Tests for MemoryContext type."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
from app.services.context.types import ContextType
|
||||
from app.services.context.types.memory import MemoryContext, MemorySubtype
|
||||
|
||||
|
||||
class TestMemorySubtype:
|
||||
"""Tests for MemorySubtype enum."""
|
||||
|
||||
def test_all_types_defined(self) -> None:
|
||||
"""All memory subtypes should be defined."""
|
||||
assert MemorySubtype.WORKING == "working"
|
||||
assert MemorySubtype.EPISODIC == "episodic"
|
||||
assert MemorySubtype.SEMANTIC == "semantic"
|
||||
assert MemorySubtype.PROCEDURAL == "procedural"
|
||||
|
||||
def test_enum_values(self) -> None:
|
||||
"""Enum values should match strings."""
|
||||
assert MemorySubtype.WORKING.value == "working"
|
||||
assert MemorySubtype("episodic") == MemorySubtype.EPISODIC
|
||||
|
||||
|
||||
class TestMemoryContext:
|
||||
"""Tests for MemoryContext class."""
|
||||
|
||||
def test_get_type_returns_memory(self) -> None:
|
||||
"""get_type should return MEMORY."""
|
||||
ctx = MemoryContext(content="test", source="test_source")
|
||||
assert ctx.get_type() == ContextType.MEMORY
|
||||
|
||||
def test_default_values(self) -> None:
|
||||
"""Default values should be set correctly."""
|
||||
ctx = MemoryContext(content="test", source="test_source")
|
||||
assert ctx.memory_subtype == MemorySubtype.EPISODIC
|
||||
assert ctx.memory_id is None
|
||||
assert ctx.relevance_score == 0.0
|
||||
assert ctx.importance == 0.5
|
||||
|
||||
def test_to_dict_includes_memory_fields(self) -> None:
|
||||
"""to_dict should include memory-specific fields."""
|
||||
ctx = MemoryContext(
|
||||
content="test content",
|
||||
source="test_source",
|
||||
memory_subtype=MemorySubtype.SEMANTIC,
|
||||
memory_id="mem-123",
|
||||
relevance_score=0.8,
|
||||
subject="User",
|
||||
predicate="prefers",
|
||||
object_value="dark mode",
|
||||
)
|
||||
|
||||
data = ctx.to_dict()
|
||||
|
||||
assert data["memory_subtype"] == "semantic"
|
||||
assert data["memory_id"] == "mem-123"
|
||||
assert data["relevance_score"] == 0.8
|
||||
assert data["subject"] == "User"
|
||||
assert data["predicate"] == "prefers"
|
||||
assert data["object_value"] == "dark mode"
|
||||
|
||||
def test_from_dict(self) -> None:
|
||||
"""from_dict should create correct MemoryContext."""
|
||||
data = {
|
||||
"content": "test content",
|
||||
"source": "test_source",
|
||||
"timestamp": "2024-01-01T00:00:00+00:00",
|
||||
"memory_subtype": "semantic",
|
||||
"memory_id": "mem-123",
|
||||
"relevance_score": 0.8,
|
||||
"subject": "Test",
|
||||
}
|
||||
|
||||
ctx = MemoryContext.from_dict(data)
|
||||
|
||||
assert ctx.content == "test content"
|
||||
assert ctx.memory_subtype == MemorySubtype.SEMANTIC
|
||||
assert ctx.memory_id == "mem-123"
|
||||
assert ctx.subject == "Test"
|
||||
|
||||
|
||||
class TestMemoryContextFromWorkingMemory:
|
||||
"""Tests for MemoryContext.from_working_memory."""
|
||||
|
||||
def test_creates_working_memory_context(self) -> None:
|
||||
"""Should create working memory context from key/value."""
|
||||
ctx = MemoryContext.from_working_memory(
|
||||
key="user_preferences",
|
||||
value={"theme": "dark"},
|
||||
source="working:sess-123",
|
||||
query="preferences",
|
||||
)
|
||||
|
||||
assert ctx.memory_subtype == MemorySubtype.WORKING
|
||||
assert ctx.key == "user_preferences"
|
||||
assert "{'theme': 'dark'}" in ctx.content
|
||||
assert ctx.relevance_score == 1.0 # Working memory is always relevant
|
||||
assert ctx.importance == 0.8 # Higher importance
|
||||
|
||||
def test_string_value(self) -> None:
|
||||
"""Should handle string values."""
|
||||
ctx = MemoryContext.from_working_memory(
|
||||
key="current_task",
|
||||
value="Build authentication",
|
||||
)
|
||||
|
||||
assert ctx.content == "Build authentication"
|
||||
|
||||
|
||||
class TestMemoryContextFromEpisodicMemory:
|
||||
"""Tests for MemoryContext.from_episodic_memory."""
|
||||
|
||||
def test_creates_episodic_memory_context(self) -> None:
|
||||
"""Should create episodic memory context from episode."""
|
||||
episode = MagicMock()
|
||||
episode.id = uuid4()
|
||||
episode.task_description = "Implemented login feature"
|
||||
episode.task_type = "feature_implementation"
|
||||
episode.outcome = MagicMock(value="success")
|
||||
episode.importance_score = 0.9
|
||||
episode.session_id = "sess-123"
|
||||
episode.occurred_at = datetime.now(UTC)
|
||||
episode.lessons_learned = ["Use proper validation"]
|
||||
|
||||
ctx = MemoryContext.from_episodic_memory(episode, query="login")
|
||||
|
||||
assert ctx.memory_subtype == MemorySubtype.EPISODIC
|
||||
assert ctx.memory_id == str(episode.id)
|
||||
assert ctx.content == "Implemented login feature"
|
||||
assert ctx.task_type == "feature_implementation"
|
||||
assert ctx.outcome == "success"
|
||||
assert ctx.importance == 0.9
|
||||
|
||||
def test_handles_missing_outcome(self) -> None:
|
||||
"""Should handle episodes with no outcome."""
|
||||
episode = MagicMock()
|
||||
episode.id = uuid4()
|
||||
episode.task_description = "WIP task"
|
||||
episode.outcome = None
|
||||
episode.importance_score = 0.5
|
||||
episode.occurred_at = None
|
||||
|
||||
ctx = MemoryContext.from_episodic_memory(episode)
|
||||
|
||||
assert ctx.outcome is None
|
||||
|
||||
|
||||
class TestMemoryContextFromSemanticMemory:
|
||||
"""Tests for MemoryContext.from_semantic_memory."""
|
||||
|
||||
def test_creates_semantic_memory_context(self) -> None:
|
||||
"""Should create semantic memory context from fact."""
|
||||
fact = MagicMock()
|
||||
fact.id = uuid4()
|
||||
fact.subject = "User"
|
||||
fact.predicate = "prefers"
|
||||
fact.object = "dark mode"
|
||||
fact.confidence = 0.95
|
||||
|
||||
ctx = MemoryContext.from_semantic_memory(fact, query="user preferences")
|
||||
|
||||
assert ctx.memory_subtype == MemorySubtype.SEMANTIC
|
||||
assert ctx.memory_id == str(fact.id)
|
||||
assert ctx.content == "User prefers dark mode"
|
||||
assert ctx.subject == "User"
|
||||
assert ctx.predicate == "prefers"
|
||||
assert ctx.object_value == "dark mode"
|
||||
assert ctx.relevance_score == 0.95
|
||||
|
||||
|
||||
class TestMemoryContextFromProceduralMemory:
|
||||
"""Tests for MemoryContext.from_procedural_memory."""
|
||||
|
||||
def test_creates_procedural_memory_context(self) -> None:
|
||||
"""Should create procedural memory context from procedure."""
|
||||
procedure = MagicMock()
|
||||
procedure.id = uuid4()
|
||||
procedure.name = "Deploy to Production"
|
||||
procedure.trigger_pattern = "When deploying to production"
|
||||
procedure.steps = [
|
||||
{"action": "run_tests"},
|
||||
{"action": "build_docker"},
|
||||
{"action": "deploy"},
|
||||
]
|
||||
procedure.success_rate = 0.85
|
||||
procedure.success_count = 10
|
||||
procedure.failure_count = 2
|
||||
|
||||
ctx = MemoryContext.from_procedural_memory(procedure, query="deploy")
|
||||
|
||||
assert ctx.memory_subtype == MemorySubtype.PROCEDURAL
|
||||
assert ctx.memory_id == str(procedure.id)
|
||||
assert "Deploy to Production" in ctx.content
|
||||
assert "When deploying to production" in ctx.content
|
||||
assert ctx.trigger == "When deploying to production"
|
||||
assert ctx.success_rate == 0.85
|
||||
assert ctx.metadata["steps_count"] == 3
|
||||
assert ctx.metadata["execution_count"] == 12
|
||||
|
||||
|
||||
class TestMemoryContextHelpers:
|
||||
"""Tests for MemoryContext helper methods."""
|
||||
|
||||
def test_is_working_memory(self) -> None:
|
||||
"""is_working_memory should return True for working memory."""
|
||||
ctx = MemoryContext(
|
||||
content="test",
|
||||
source="test",
|
||||
memory_subtype=MemorySubtype.WORKING,
|
||||
)
|
||||
assert ctx.is_working_memory() is True
|
||||
assert ctx.is_episodic_memory() is False
|
||||
|
||||
def test_is_episodic_memory(self) -> None:
|
||||
"""is_episodic_memory should return True for episodic memory."""
|
||||
ctx = MemoryContext(
|
||||
content="test",
|
||||
source="test",
|
||||
memory_subtype=MemorySubtype.EPISODIC,
|
||||
)
|
||||
assert ctx.is_episodic_memory() is True
|
||||
assert ctx.is_semantic_memory() is False
|
||||
|
||||
def test_is_semantic_memory(self) -> None:
|
||||
"""is_semantic_memory should return True for semantic memory."""
|
||||
ctx = MemoryContext(
|
||||
content="test",
|
||||
source="test",
|
||||
memory_subtype=MemorySubtype.SEMANTIC,
|
||||
)
|
||||
assert ctx.is_semantic_memory() is True
|
||||
assert ctx.is_procedural_memory() is False
|
||||
|
||||
def test_is_procedural_memory(self) -> None:
|
||||
"""is_procedural_memory should return True for procedural memory."""
|
||||
ctx = MemoryContext(
|
||||
content="test",
|
||||
source="test",
|
||||
memory_subtype=MemorySubtype.PROCEDURAL,
|
||||
)
|
||||
assert ctx.is_procedural_memory() is True
|
||||
assert ctx.is_working_memory() is False
|
||||
|
||||
def test_get_formatted_source(self) -> None:
|
||||
"""get_formatted_source should return formatted string."""
|
||||
ctx = MemoryContext(
|
||||
content="test",
|
||||
source="episodic:12345678-1234-1234-1234-123456789012",
|
||||
memory_subtype=MemorySubtype.EPISODIC,
|
||||
memory_id="12345678-1234-1234-1234-123456789012",
|
||||
)
|
||||
|
||||
formatted = ctx.get_formatted_source()
|
||||
|
||||
assert "[episodic]" in formatted
|
||||
assert "12345678..." in formatted
|
||||
1
backend/tests/unit/services/memory/__init__.py
Normal file
1
backend/tests/unit/services/memory/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for the Agent Memory System."""
|
||||
2
backend/tests/unit/services/memory/cache/__init__.py
vendored
Normal file
2
backend/tests/unit/services/memory/cache/__init__.py
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
# tests/unit/services/memory/cache/__init__.py
|
||||
"""Tests for memory caching layer."""
|
||||
331
backend/tests/unit/services/memory/cache/test_cache_manager.py
vendored
Normal file
331
backend/tests/unit/services/memory/cache/test_cache_manager.py
vendored
Normal file
@@ -0,0 +1,331 @@
|
||||
# tests/unit/services/memory/cache/test_cache_manager.py
|
||||
"""Tests for CacheManager."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.memory.cache.cache_manager import (
|
||||
CacheManager,
|
||||
CacheStats,
|
||||
get_cache_manager,
|
||||
reset_cache_manager,
|
||||
)
|
||||
from app.services.memory.cache.embedding_cache import EmbeddingCache
|
||||
from app.services.memory.cache.hot_cache import HotMemoryCache
|
||||
|
||||
pytestmark = pytest.mark.asyncio(loop_scope="function")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_singleton() -> None:
|
||||
"""Reset singleton before each test."""
|
||||
reset_cache_manager()
|
||||
|
||||
|
||||
class TestCacheStats:
|
||||
"""Tests for CacheStats."""
|
||||
|
||||
def test_to_dict(self) -> None:
|
||||
"""Should convert to dictionary."""
|
||||
from datetime import UTC, datetime
|
||||
|
||||
stats = CacheStats(
|
||||
hot_cache={"hits": 10},
|
||||
embedding_cache={"hits": 20},
|
||||
overall_hit_rate=0.75,
|
||||
last_cleanup=datetime.now(UTC),
|
||||
cleanup_count=5,
|
||||
)
|
||||
|
||||
result = stats.to_dict()
|
||||
|
||||
assert result["hot_cache"] == {"hits": 10}
|
||||
assert result["overall_hit_rate"] == 0.75
|
||||
assert result["cleanup_count"] == 5
|
||||
assert result["last_cleanup"] is not None
|
||||
|
||||
|
||||
class TestCacheManager:
|
||||
"""Tests for CacheManager."""
|
||||
|
||||
@pytest.fixture
|
||||
def manager(self) -> CacheManager:
|
||||
"""Create a cache manager."""
|
||||
return CacheManager()
|
||||
|
||||
def test_is_enabled(self, manager: CacheManager) -> None:
|
||||
"""Should check if caching is enabled."""
|
||||
# Default is enabled from settings
|
||||
assert manager.is_enabled is True
|
||||
|
||||
def test_has_hot_cache(self, manager: CacheManager) -> None:
|
||||
"""Should have hot memory cache."""
|
||||
assert manager.hot_cache is not None
|
||||
assert isinstance(manager.hot_cache, HotMemoryCache)
|
||||
|
||||
def test_has_embedding_cache(self, manager: CacheManager) -> None:
|
||||
"""Should have embedding cache."""
|
||||
assert manager.embedding_cache is not None
|
||||
assert isinstance(manager.embedding_cache, EmbeddingCache)
|
||||
|
||||
def test_cache_memory(self, manager: CacheManager) -> None:
|
||||
"""Should cache memory in hot cache."""
|
||||
memory_id = uuid4()
|
||||
memory = {"task": "test", "data": "value"}
|
||||
|
||||
manager.cache_memory("episodic", memory_id, memory)
|
||||
result = manager.get_memory("episodic", memory_id)
|
||||
|
||||
assert result == memory
|
||||
|
||||
def test_cache_memory_with_scope(self, manager: CacheManager) -> None:
|
||||
"""Should cache memory with scope."""
|
||||
memory_id = uuid4()
|
||||
memory = {"task": "test"}
|
||||
|
||||
manager.cache_memory("semantic", memory_id, memory, scope="proj-123")
|
||||
result = manager.get_memory("semantic", memory_id, scope="proj-123")
|
||||
|
||||
assert result == memory
|
||||
|
||||
async def test_cache_embedding(self, manager: CacheManager) -> None:
|
||||
"""Should cache embedding."""
|
||||
content = "test content"
|
||||
embedding = [0.1, 0.2, 0.3]
|
||||
|
||||
content_hash = await manager.cache_embedding(content, embedding)
|
||||
result = await manager.get_embedding(content)
|
||||
|
||||
assert result == embedding
|
||||
assert len(content_hash) == 32
|
||||
|
||||
async def test_invalidate_memory(self, manager: CacheManager) -> None:
|
||||
"""Should invalidate memory from hot cache."""
|
||||
memory_id = uuid4()
|
||||
manager.cache_memory("episodic", memory_id, {"data": "test"})
|
||||
|
||||
count = await manager.invalidate_memory("episodic", memory_id)
|
||||
|
||||
assert count >= 1
|
||||
assert manager.get_memory("episodic", memory_id) is None
|
||||
|
||||
async def test_invalidate_by_type(self, manager: CacheManager) -> None:
|
||||
"""Should invalidate all entries of a type."""
|
||||
manager.cache_memory("episodic", uuid4(), {"data": "1"})
|
||||
manager.cache_memory("episodic", uuid4(), {"data": "2"})
|
||||
manager.cache_memory("semantic", uuid4(), {"data": "3"})
|
||||
|
||||
count = await manager.invalidate_by_type("episodic")
|
||||
|
||||
assert count >= 2
|
||||
|
||||
async def test_invalidate_by_scope(self, manager: CacheManager) -> None:
|
||||
"""Should invalidate all entries in a scope."""
|
||||
manager.cache_memory("episodic", uuid4(), {"data": "1"}, scope="proj-1")
|
||||
manager.cache_memory("semantic", uuid4(), {"data": "2"}, scope="proj-1")
|
||||
manager.cache_memory("episodic", uuid4(), {"data": "3"}, scope="proj-2")
|
||||
|
||||
count = await manager.invalidate_by_scope("proj-1")
|
||||
|
||||
assert count >= 2
|
||||
|
||||
async def test_invalidate_embedding(self, manager: CacheManager) -> None:
|
||||
"""Should invalidate cached embedding."""
|
||||
content = "test content"
|
||||
await manager.cache_embedding(content, [0.1, 0.2])
|
||||
|
||||
result = await manager.invalidate_embedding(content)
|
||||
|
||||
assert result is True
|
||||
assert await manager.get_embedding(content) is None
|
||||
|
||||
async def test_clear_all(self, manager: CacheManager) -> None:
|
||||
"""Should clear all caches."""
|
||||
manager.cache_memory("episodic", uuid4(), {"data": "test"})
|
||||
await manager.cache_embedding("content", [0.1])
|
||||
|
||||
count = await manager.clear_all()
|
||||
|
||||
assert count >= 2
|
||||
|
||||
async def test_cleanup_expired(self, manager: CacheManager) -> None:
|
||||
"""Should clean up expired entries."""
|
||||
count = await manager.cleanup_expired()
|
||||
|
||||
# May be 0 if no expired entries
|
||||
assert count >= 0
|
||||
assert manager._cleanup_count == 1
|
||||
assert manager._last_cleanup is not None
|
||||
|
||||
def test_get_stats(self, manager: CacheManager) -> None:
|
||||
"""Should return aggregated statistics."""
|
||||
manager.cache_memory("episodic", uuid4(), {"data": "test"})
|
||||
|
||||
stats = manager.get_stats()
|
||||
|
||||
assert "hot_cache" in stats.to_dict()
|
||||
assert "embedding_cache" in stats.to_dict()
|
||||
assert "overall_hit_rate" in stats.to_dict()
|
||||
|
||||
def test_get_hot_memories(self, manager: CacheManager) -> None:
|
||||
"""Should return most accessed memories."""
|
||||
id1 = uuid4()
|
||||
id2 = uuid4()
|
||||
|
||||
manager.cache_memory("episodic", id1, {"data": "1"})
|
||||
manager.cache_memory("episodic", id2, {"data": "2"})
|
||||
|
||||
# Access first multiple times
|
||||
for _ in range(5):
|
||||
manager.get_memory("episodic", id1)
|
||||
|
||||
hot = manager.get_hot_memories(limit=2)
|
||||
|
||||
assert len(hot) == 2
|
||||
|
||||
def test_reset_stats(self, manager: CacheManager) -> None:
|
||||
"""Should reset all statistics."""
|
||||
manager.cache_memory("episodic", uuid4(), {"data": "test"})
|
||||
manager.get_memory("episodic", uuid4()) # Miss
|
||||
|
||||
manager.reset_stats()
|
||||
|
||||
stats = manager.get_stats()
|
||||
assert stats.hot_cache.get("hits", 0) == 0
|
||||
|
||||
async def test_warmup(self, manager: CacheManager) -> None:
|
||||
"""Should warm up cache with memories."""
|
||||
memories = [
|
||||
("episodic", uuid4(), {"data": "1"}),
|
||||
("episodic", uuid4(), {"data": "2"}),
|
||||
("semantic", uuid4(), {"data": "3"}),
|
||||
]
|
||||
|
||||
count = await manager.warmup(memories)
|
||||
|
||||
assert count == 3
|
||||
|
||||
|
||||
class TestCacheManagerWithRetrieval:
|
||||
"""Tests for CacheManager with retrieval cache."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_retrieval_cache(self) -> MagicMock:
|
||||
"""Create mock retrieval cache."""
|
||||
cache = MagicMock()
|
||||
cache.invalidate_by_memory = MagicMock(return_value=1)
|
||||
cache.clear = MagicMock(return_value=5)
|
||||
cache.get_stats = MagicMock(return_value={"entries": 10})
|
||||
return cache
|
||||
|
||||
@pytest.fixture
|
||||
def manager_with_retrieval(
|
||||
self,
|
||||
mock_retrieval_cache: MagicMock,
|
||||
) -> CacheManager:
|
||||
"""Create manager with retrieval cache."""
|
||||
manager = CacheManager()
|
||||
manager.set_retrieval_cache(mock_retrieval_cache)
|
||||
return manager
|
||||
|
||||
async def test_invalidate_clears_retrieval(
|
||||
self,
|
||||
manager_with_retrieval: CacheManager,
|
||||
mock_retrieval_cache: MagicMock,
|
||||
) -> None:
|
||||
"""Should invalidate retrieval cache entries."""
|
||||
memory_id = uuid4()
|
||||
|
||||
await manager_with_retrieval.invalidate_memory("episodic", memory_id)
|
||||
|
||||
mock_retrieval_cache.invalidate_by_memory.assert_called_once_with(memory_id)
|
||||
|
||||
def test_stats_includes_retrieval(
|
||||
self,
|
||||
manager_with_retrieval: CacheManager,
|
||||
) -> None:
|
||||
"""Should include retrieval cache stats."""
|
||||
stats = manager_with_retrieval.get_stats()
|
||||
|
||||
assert "retrieval_cache" in stats.to_dict()
|
||||
|
||||
|
||||
class TestCacheManagerDisabled:
|
||||
"""Tests for CacheManager when disabled."""
|
||||
|
||||
@pytest.fixture
|
||||
def disabled_manager(self) -> CacheManager:
|
||||
"""Create a disabled cache manager."""
|
||||
with patch(
|
||||
"app.services.memory.cache.cache_manager.get_memory_settings"
|
||||
) as mock_settings:
|
||||
settings = MagicMock()
|
||||
settings.cache_enabled = False
|
||||
settings.cache_max_items = 1000
|
||||
settings.cache_ttl_seconds = 300
|
||||
mock_settings.return_value = settings
|
||||
|
||||
return CacheManager()
|
||||
|
||||
def test_get_memory_returns_none(self, disabled_manager: CacheManager) -> None:
|
||||
"""Should return None when disabled."""
|
||||
disabled_manager.cache_memory("episodic", uuid4(), {"data": "test"})
|
||||
result = disabled_manager.get_memory("episodic", uuid4())
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_get_embedding_returns_none(
|
||||
self,
|
||||
disabled_manager: CacheManager,
|
||||
) -> None:
|
||||
"""Should return None for embeddings when disabled."""
|
||||
result = await disabled_manager.get_embedding("content")
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_warmup_returns_zero(self, disabled_manager: CacheManager) -> None:
|
||||
"""Should return 0 from warmup when disabled."""
|
||||
count = await disabled_manager.warmup([("episodic", uuid4(), {})])
|
||||
|
||||
assert count == 0
|
||||
|
||||
|
||||
class TestGetCacheManager:
|
||||
"""Tests for get_cache_manager factory."""
|
||||
|
||||
def test_returns_singleton(self) -> None:
|
||||
"""Should return same instance."""
|
||||
manager1 = get_cache_manager()
|
||||
manager2 = get_cache_manager()
|
||||
|
||||
assert manager1 is manager2
|
||||
|
||||
def test_reset_creates_new(self) -> None:
|
||||
"""Should create new instance after reset."""
|
||||
manager1 = get_cache_manager()
|
||||
reset_cache_manager()
|
||||
manager2 = get_cache_manager()
|
||||
|
||||
assert manager1 is not manager2
|
||||
|
||||
def test_reset_parameter(self) -> None:
|
||||
"""Should create new instance with reset=True."""
|
||||
manager1 = get_cache_manager()
|
||||
manager2 = get_cache_manager(reset=True)
|
||||
|
||||
assert manager1 is not manager2
|
||||
|
||||
|
||||
class TestResetCacheManager:
|
||||
"""Tests for reset_cache_manager."""
|
||||
|
||||
def test_resets_singleton(self) -> None:
|
||||
"""Should reset the singleton."""
|
||||
get_cache_manager()
|
||||
reset_cache_manager()
|
||||
|
||||
# Next call should create new instance
|
||||
manager = get_cache_manager()
|
||||
assert manager is not None
|
||||
391
backend/tests/unit/services/memory/cache/test_embedding_cache.py
vendored
Normal file
391
backend/tests/unit/services/memory/cache/test_embedding_cache.py
vendored
Normal file
@@ -0,0 +1,391 @@
|
||||
# tests/unit/services/memory/cache/test_embedding_cache.py
|
||||
"""Tests for EmbeddingCache."""
|
||||
|
||||
import time
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.memory.cache.embedding_cache import (
|
||||
CachedEmbeddingGenerator,
|
||||
EmbeddingCache,
|
||||
EmbeddingCacheStats,
|
||||
EmbeddingEntry,
|
||||
create_embedding_cache,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.asyncio(loop_scope="function")
|
||||
|
||||
|
||||
class TestEmbeddingEntry:
|
||||
"""Tests for EmbeddingEntry."""
|
||||
|
||||
def test_creates_entry(self) -> None:
|
||||
"""Should create entry with embedding."""
|
||||
from datetime import UTC, datetime
|
||||
|
||||
entry = EmbeddingEntry(
|
||||
embedding=[0.1, 0.2, 0.3],
|
||||
content_hash="abc123",
|
||||
model="text-embedding-3-small",
|
||||
created_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
assert entry.embedding == [0.1, 0.2, 0.3]
|
||||
assert entry.content_hash == "abc123"
|
||||
assert entry.ttl_seconds == 3600.0
|
||||
|
||||
def test_is_expired(self) -> None:
|
||||
"""Should detect expired entries."""
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
old_time = datetime.now(UTC) - timedelta(seconds=4000)
|
||||
entry = EmbeddingEntry(
|
||||
embedding=[0.1],
|
||||
content_hash="abc",
|
||||
model="default",
|
||||
created_at=old_time,
|
||||
ttl_seconds=3600.0,
|
||||
)
|
||||
|
||||
assert entry.is_expired() is True
|
||||
|
||||
def test_not_expired(self) -> None:
|
||||
"""Should detect non-expired entries."""
|
||||
from datetime import UTC, datetime
|
||||
|
||||
entry = EmbeddingEntry(
|
||||
embedding=[0.1],
|
||||
content_hash="abc",
|
||||
model="default",
|
||||
created_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
assert entry.is_expired() is False
|
||||
|
||||
|
||||
class TestEmbeddingCacheStats:
|
||||
"""Tests for EmbeddingCacheStats."""
|
||||
|
||||
def test_hit_rate_calculation(self) -> None:
|
||||
"""Should calculate hit rate correctly."""
|
||||
stats = EmbeddingCacheStats(hits=90, misses=10)
|
||||
|
||||
assert stats.hit_rate == 0.9
|
||||
|
||||
def test_hit_rate_zero_requests(self) -> None:
|
||||
"""Should return 0 for no requests."""
|
||||
stats = EmbeddingCacheStats()
|
||||
|
||||
assert stats.hit_rate == 0.0
|
||||
|
||||
def test_to_dict(self) -> None:
|
||||
"""Should convert to dictionary."""
|
||||
stats = EmbeddingCacheStats(hits=10, misses=5, bytes_saved=1000)
|
||||
|
||||
result = stats.to_dict()
|
||||
|
||||
assert result["hits"] == 10
|
||||
assert result["bytes_saved"] == 1000
|
||||
|
||||
|
||||
class TestEmbeddingCache:
|
||||
"""Tests for EmbeddingCache."""
|
||||
|
||||
@pytest.fixture
|
||||
def cache(self) -> EmbeddingCache:
|
||||
"""Create an embedding cache."""
|
||||
return EmbeddingCache(max_size=100, default_ttl_seconds=300.0)
|
||||
|
||||
async def test_put_and_get(self, cache: EmbeddingCache) -> None:
|
||||
"""Should store and retrieve embeddings."""
|
||||
content = "Hello world"
|
||||
embedding = [0.1, 0.2, 0.3, 0.4]
|
||||
|
||||
content_hash = await cache.put(content, embedding)
|
||||
result = await cache.get(content)
|
||||
|
||||
assert result == embedding
|
||||
assert len(content_hash) == 32
|
||||
|
||||
async def test_get_missing(self, cache: EmbeddingCache) -> None:
|
||||
"""Should return None for missing content."""
|
||||
result = await cache.get("nonexistent content")
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_get_by_hash(self, cache: EmbeddingCache) -> None:
|
||||
"""Should get by content hash."""
|
||||
content = "Test content"
|
||||
embedding = [0.1, 0.2]
|
||||
|
||||
content_hash = await cache.put(content, embedding)
|
||||
result = await cache.get_by_hash(content_hash)
|
||||
|
||||
assert result == embedding
|
||||
|
||||
async def test_model_separation(self, cache: EmbeddingCache) -> None:
|
||||
"""Should separate embeddings by model."""
|
||||
content = "Same content"
|
||||
emb1 = [0.1, 0.2]
|
||||
emb2 = [0.3, 0.4]
|
||||
|
||||
await cache.put(content, emb1, model="model-a")
|
||||
await cache.put(content, emb2, model="model-b")
|
||||
|
||||
result1 = await cache.get(content, model="model-a")
|
||||
result2 = await cache.get(content, model="model-b")
|
||||
|
||||
assert result1 == emb1
|
||||
assert result2 == emb2
|
||||
|
||||
async def test_lru_eviction(self) -> None:
|
||||
"""Should evict LRU entries when at capacity."""
|
||||
cache = EmbeddingCache(max_size=3)
|
||||
|
||||
await cache.put("content1", [0.1])
|
||||
await cache.put("content2", [0.2])
|
||||
await cache.put("content3", [0.3])
|
||||
|
||||
# Access first to make it recent
|
||||
await cache.get("content1")
|
||||
|
||||
# Add fourth, should evict second (LRU)
|
||||
await cache.put("content4", [0.4])
|
||||
|
||||
assert await cache.get("content1") is not None
|
||||
assert await cache.get("content2") is None # Evicted
|
||||
assert await cache.get("content3") is not None
|
||||
assert await cache.get("content4") is not None
|
||||
|
||||
async def test_ttl_expiration(self) -> None:
|
||||
"""Should expire entries after TTL."""
|
||||
cache = EmbeddingCache(max_size=100, default_ttl_seconds=0.05)
|
||||
|
||||
await cache.put("content", [0.1, 0.2])
|
||||
|
||||
time.sleep(0.06)
|
||||
|
||||
result = await cache.get("content")
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_put_batch(self, cache: EmbeddingCache) -> None:
|
||||
"""Should cache multiple embeddings."""
|
||||
items = [
|
||||
("content1", [0.1]),
|
||||
("content2", [0.2]),
|
||||
("content3", [0.3]),
|
||||
]
|
||||
|
||||
hashes = await cache.put_batch(items)
|
||||
|
||||
assert len(hashes) == 3
|
||||
assert await cache.get("content1") == [0.1]
|
||||
assert await cache.get("content2") == [0.2]
|
||||
|
||||
async def test_invalidate(self, cache: EmbeddingCache) -> None:
|
||||
"""Should invalidate cached embedding."""
|
||||
await cache.put("content", [0.1, 0.2])
|
||||
|
||||
result = await cache.invalidate("content")
|
||||
|
||||
assert result is True
|
||||
assert await cache.get("content") is None
|
||||
|
||||
async def test_invalidate_by_hash(self, cache: EmbeddingCache) -> None:
|
||||
"""Should invalidate by hash."""
|
||||
content_hash = await cache.put("content", [0.1, 0.2])
|
||||
|
||||
result = await cache.invalidate_by_hash(content_hash)
|
||||
|
||||
assert result is True
|
||||
assert await cache.get("content") is None
|
||||
|
||||
async def test_invalidate_by_model(self, cache: EmbeddingCache) -> None:
|
||||
"""Should invalidate all embeddings for a model."""
|
||||
await cache.put("content1", [0.1], model="model-a")
|
||||
await cache.put("content2", [0.2], model="model-a")
|
||||
await cache.put("content3", [0.3], model="model-b")
|
||||
|
||||
count = await cache.invalidate_by_model("model-a")
|
||||
|
||||
assert count == 2
|
||||
assert await cache.get("content1", model="model-a") is None
|
||||
assert await cache.get("content3", model="model-b") is not None
|
||||
|
||||
async def test_clear(self, cache: EmbeddingCache) -> None:
|
||||
"""Should clear all entries."""
|
||||
await cache.put("content1", [0.1])
|
||||
await cache.put("content2", [0.2])
|
||||
|
||||
count = await cache.clear()
|
||||
|
||||
assert count == 2
|
||||
assert cache.size == 0
|
||||
|
||||
def test_cleanup_expired(self) -> None:
|
||||
"""Should remove expired entries."""
|
||||
cache = EmbeddingCache(max_size=100, default_ttl_seconds=0.05)
|
||||
|
||||
# Use synchronous put for setup
|
||||
cache._put_memory("hash1", "default", [0.1])
|
||||
cache._put_memory("hash2", "default", [0.2], ttl_seconds=10)
|
||||
|
||||
time.sleep(0.06)
|
||||
|
||||
count = cache.cleanup_expired()
|
||||
|
||||
assert count == 1
|
||||
|
||||
def test_get_stats(self, cache: EmbeddingCache) -> None:
|
||||
"""Should return accurate statistics."""
|
||||
# Put synchronously for setup
|
||||
cache._put_memory("hash1", "default", [0.1])
|
||||
|
||||
stats = cache.get_stats()
|
||||
|
||||
assert stats.current_size == 1
|
||||
|
||||
def test_hash_content(self) -> None:
|
||||
"""Should produce consistent hashes."""
|
||||
hash1 = EmbeddingCache.hash_content("test content")
|
||||
hash2 = EmbeddingCache.hash_content("test content")
|
||||
hash3 = EmbeddingCache.hash_content("different content")
|
||||
|
||||
assert hash1 == hash2
|
||||
assert hash1 != hash3
|
||||
assert len(hash1) == 32
|
||||
|
||||
|
||||
class TestEmbeddingCacheWithRedis:
|
||||
"""Tests for EmbeddingCache with Redis."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis(self) -> MagicMock:
|
||||
"""Create mock Redis."""
|
||||
redis = MagicMock()
|
||||
redis.get = AsyncMock(return_value=None)
|
||||
redis.setex = AsyncMock()
|
||||
redis.delete = AsyncMock()
|
||||
redis.scan_iter = MagicMock(return_value=iter([]))
|
||||
return redis
|
||||
|
||||
@pytest.fixture
|
||||
def cache_with_redis(self, mock_redis: MagicMock) -> EmbeddingCache:
|
||||
"""Create cache with mock Redis."""
|
||||
return EmbeddingCache(
|
||||
max_size=100,
|
||||
default_ttl_seconds=300.0,
|
||||
redis=mock_redis,
|
||||
)
|
||||
|
||||
async def test_put_stores_in_redis(
|
||||
self,
|
||||
cache_with_redis: EmbeddingCache,
|
||||
mock_redis: MagicMock,
|
||||
) -> None:
|
||||
"""Should store in Redis when available."""
|
||||
await cache_with_redis.put("content", [0.1, 0.2])
|
||||
|
||||
mock_redis.setex.assert_called_once()
|
||||
|
||||
async def test_get_checks_redis_on_miss(
|
||||
self,
|
||||
cache_with_redis: EmbeddingCache,
|
||||
mock_redis: MagicMock,
|
||||
) -> None:
|
||||
"""Should check Redis when memory cache misses."""
|
||||
import json
|
||||
|
||||
mock_redis.get.return_value = json.dumps([0.1, 0.2])
|
||||
|
||||
result = await cache_with_redis.get("content")
|
||||
|
||||
assert result == [0.1, 0.2]
|
||||
mock_redis.get.assert_called_once()
|
||||
|
||||
|
||||
class TestCachedEmbeddingGenerator:
|
||||
"""Tests for CachedEmbeddingGenerator."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_generator(self) -> MagicMock:
|
||||
"""Create mock embedding generator."""
|
||||
gen = MagicMock()
|
||||
gen.generate = AsyncMock(return_value=[0.1, 0.2, 0.3])
|
||||
gen.generate_batch = AsyncMock(return_value=[[0.1], [0.2], [0.3]])
|
||||
return gen
|
||||
|
||||
@pytest.fixture
|
||||
def cache(self) -> EmbeddingCache:
|
||||
"""Create embedding cache."""
|
||||
return EmbeddingCache(max_size=100)
|
||||
|
||||
@pytest.fixture
|
||||
def cached_gen(
|
||||
self,
|
||||
mock_generator: MagicMock,
|
||||
cache: EmbeddingCache,
|
||||
) -> CachedEmbeddingGenerator:
|
||||
"""Create cached generator."""
|
||||
return CachedEmbeddingGenerator(mock_generator, cache)
|
||||
|
||||
async def test_generate_caches_result(
|
||||
self,
|
||||
cached_gen: CachedEmbeddingGenerator,
|
||||
mock_generator: MagicMock,
|
||||
) -> None:
|
||||
"""Should cache generated embedding."""
|
||||
result1 = await cached_gen.generate("test text")
|
||||
result2 = await cached_gen.generate("test text")
|
||||
|
||||
assert result1 == [0.1, 0.2, 0.3]
|
||||
assert result2 == [0.1, 0.2, 0.3]
|
||||
mock_generator.generate.assert_called_once() # Only called once
|
||||
|
||||
async def test_generate_batch_uses_cache(
|
||||
self,
|
||||
cached_gen: CachedEmbeddingGenerator,
|
||||
mock_generator: MagicMock,
|
||||
cache: EmbeddingCache,
|
||||
) -> None:
|
||||
"""Should use cache for batch generation."""
|
||||
# Pre-cache one embedding
|
||||
await cache.put("text1", [0.5])
|
||||
|
||||
# Mock returns 2 embeddings for the 2 uncached texts
|
||||
mock_generator.generate_batch = AsyncMock(return_value=[[0.2], [0.3]])
|
||||
|
||||
results = await cached_gen.generate_batch(["text1", "text2", "text3"])
|
||||
|
||||
assert len(results) == 3
|
||||
assert results[0] == [0.5] # From cache
|
||||
assert results[1] == [0.2] # Generated
|
||||
assert results[2] == [0.3] # Generated
|
||||
|
||||
async def test_get_stats(self, cached_gen: CachedEmbeddingGenerator) -> None:
|
||||
"""Should return generator statistics."""
|
||||
await cached_gen.generate("text1")
|
||||
await cached_gen.generate("text1") # Cache hit
|
||||
|
||||
stats = cached_gen.get_stats()
|
||||
|
||||
assert stats["call_count"] == 2
|
||||
assert stats["cache_hit_count"] == 1
|
||||
|
||||
|
||||
class TestCreateEmbeddingCache:
|
||||
"""Tests for factory function."""
|
||||
|
||||
def test_creates_cache(self) -> None:
|
||||
"""Should create cache with defaults."""
|
||||
cache = create_embedding_cache()
|
||||
|
||||
assert cache.max_size == 50000
|
||||
|
||||
def test_creates_cache_with_options(self) -> None:
|
||||
"""Should create cache with custom options."""
|
||||
cache = create_embedding_cache(max_size=1000, default_ttl_seconds=600.0)
|
||||
|
||||
assert cache.max_size == 1000
|
||||
355
backend/tests/unit/services/memory/cache/test_hot_cache.py
vendored
Normal file
355
backend/tests/unit/services/memory/cache/test_hot_cache.py
vendored
Normal file
@@ -0,0 +1,355 @@
|
||||
# tests/unit/services/memory/cache/test_hot_cache.py
|
||||
"""Tests for HotMemoryCache."""
|
||||
|
||||
import time
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.memory.cache.hot_cache import (
|
||||
CacheEntry,
|
||||
CacheKey,
|
||||
HotCacheStats,
|
||||
HotMemoryCache,
|
||||
create_hot_cache,
|
||||
)
|
||||
|
||||
|
||||
class TestCacheKey:
|
||||
"""Tests for CacheKey."""
|
||||
|
||||
def test_creates_key(self) -> None:
|
||||
"""Should create key with required fields."""
|
||||
key = CacheKey(memory_type="episodic", memory_id="123")
|
||||
|
||||
assert key.memory_type == "episodic"
|
||||
assert key.memory_id == "123"
|
||||
assert key.scope is None
|
||||
|
||||
def test_creates_key_with_scope(self) -> None:
|
||||
"""Should create key with scope."""
|
||||
key = CacheKey(memory_type="semantic", memory_id="456", scope="proj-123")
|
||||
|
||||
assert key.scope == "proj-123"
|
||||
|
||||
def test_hash_and_equality(self) -> None:
|
||||
"""Keys with same values should be equal and have same hash."""
|
||||
key1 = CacheKey(memory_type="episodic", memory_id="123", scope="proj-1")
|
||||
key2 = CacheKey(memory_type="episodic", memory_id="123", scope="proj-1")
|
||||
|
||||
assert key1 == key2
|
||||
assert hash(key1) == hash(key2)
|
||||
|
||||
def test_str_representation(self) -> None:
|
||||
"""Should produce readable string."""
|
||||
key = CacheKey(memory_type="episodic", memory_id="123", scope="proj-1")
|
||||
|
||||
assert str(key) == "episodic:proj-1:123"
|
||||
|
||||
def test_str_without_scope(self) -> None:
|
||||
"""Should produce string without scope."""
|
||||
key = CacheKey(memory_type="episodic", memory_id="123")
|
||||
|
||||
assert str(key) == "episodic:123"
|
||||
|
||||
|
||||
class TestCacheEntry:
|
||||
"""Tests for CacheEntry."""
|
||||
|
||||
def test_creates_entry(self) -> None:
|
||||
"""Should create entry with value."""
|
||||
entry = CacheEntry(
|
||||
value={"data": "test"},
|
||||
created_at=pytest.importorskip("datetime").datetime.now(
|
||||
pytest.importorskip("datetime").UTC
|
||||
),
|
||||
last_accessed_at=pytest.importorskip("datetime").datetime.now(
|
||||
pytest.importorskip("datetime").UTC
|
||||
),
|
||||
)
|
||||
|
||||
assert entry.value == {"data": "test"}
|
||||
assert entry.access_count == 1
|
||||
assert entry.ttl_seconds == 300.0
|
||||
|
||||
def test_is_expired(self) -> None:
|
||||
"""Should detect expired entries."""
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
old_time = datetime.now(UTC) - timedelta(seconds=400)
|
||||
entry = CacheEntry(
|
||||
value="test",
|
||||
created_at=old_time,
|
||||
last_accessed_at=old_time,
|
||||
ttl_seconds=300.0,
|
||||
)
|
||||
|
||||
assert entry.is_expired() is True
|
||||
|
||||
def test_not_expired(self) -> None:
|
||||
"""Should detect non-expired entries."""
|
||||
from datetime import UTC, datetime
|
||||
|
||||
entry = CacheEntry(
|
||||
value="test",
|
||||
created_at=datetime.now(UTC),
|
||||
last_accessed_at=datetime.now(UTC),
|
||||
ttl_seconds=300.0,
|
||||
)
|
||||
|
||||
assert entry.is_expired() is False
|
||||
|
||||
def test_touch_updates_access(self) -> None:
|
||||
"""Touch should update access time and count."""
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
old_time = datetime.now(UTC) - timedelta(seconds=10)
|
||||
entry = CacheEntry(
|
||||
value="test",
|
||||
created_at=old_time,
|
||||
last_accessed_at=old_time,
|
||||
access_count=5,
|
||||
)
|
||||
|
||||
entry.touch()
|
||||
|
||||
assert entry.access_count == 6
|
||||
assert entry.last_accessed_at > old_time
|
||||
|
||||
|
||||
class TestHotCacheStats:
|
||||
"""Tests for HotCacheStats."""
|
||||
|
||||
def test_hit_rate_calculation(self) -> None:
|
||||
"""Should calculate hit rate correctly."""
|
||||
stats = HotCacheStats(hits=80, misses=20)
|
||||
|
||||
assert stats.hit_rate == 0.8
|
||||
|
||||
def test_hit_rate_zero_requests(self) -> None:
|
||||
"""Should return 0 for no requests."""
|
||||
stats = HotCacheStats()
|
||||
|
||||
assert stats.hit_rate == 0.0
|
||||
|
||||
def test_to_dict(self) -> None:
|
||||
"""Should convert to dictionary."""
|
||||
stats = HotCacheStats(hits=10, misses=5, evictions=2)
|
||||
|
||||
result = stats.to_dict()
|
||||
|
||||
assert result["hits"] == 10
|
||||
assert result["misses"] == 5
|
||||
assert result["evictions"] == 2
|
||||
assert "hit_rate" in result
|
||||
|
||||
|
||||
class TestHotMemoryCache:
|
||||
"""Tests for HotMemoryCache."""
|
||||
|
||||
@pytest.fixture
|
||||
def cache(self) -> HotMemoryCache[dict]:
|
||||
"""Create a hot memory cache."""
|
||||
return HotMemoryCache[dict](max_size=100, default_ttl_seconds=300.0)
|
||||
|
||||
def test_put_and_get(self, cache: HotMemoryCache[dict]) -> None:
|
||||
"""Should store and retrieve values."""
|
||||
key = CacheKey(memory_type="episodic", memory_id="123")
|
||||
value = {"data": "test"}
|
||||
|
||||
cache.put(key, value)
|
||||
result = cache.get(key)
|
||||
|
||||
assert result == value
|
||||
|
||||
def test_get_missing_key(self, cache: HotMemoryCache[dict]) -> None:
|
||||
"""Should return None for missing keys."""
|
||||
key = CacheKey(memory_type="episodic", memory_id="nonexistent")
|
||||
|
||||
result = cache.get(key)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_put_by_id(self, cache: HotMemoryCache[dict]) -> None:
|
||||
"""Should store by type and ID."""
|
||||
memory_id = uuid4()
|
||||
value = {"data": "test"}
|
||||
|
||||
cache.put_by_id("episodic", memory_id, value)
|
||||
result = cache.get_by_id("episodic", memory_id)
|
||||
|
||||
assert result == value
|
||||
|
||||
def test_put_by_id_with_scope(self, cache: HotMemoryCache[dict]) -> None:
|
||||
"""Should store with scope."""
|
||||
memory_id = uuid4()
|
||||
value = {"data": "test"}
|
||||
|
||||
cache.put_by_id("semantic", memory_id, value, scope="proj-123")
|
||||
result = cache.get_by_id("semantic", memory_id, scope="proj-123")
|
||||
|
||||
assert result == value
|
||||
|
||||
def test_lru_eviction(self) -> None:
|
||||
"""Should evict LRU entries when at capacity."""
|
||||
cache = HotMemoryCache[str](max_size=3)
|
||||
|
||||
# Fill cache
|
||||
cache.put_by_id("test", "1", "first")
|
||||
cache.put_by_id("test", "2", "second")
|
||||
cache.put_by_id("test", "3", "third")
|
||||
|
||||
# Access first to make it recent
|
||||
cache.get_by_id("test", "1")
|
||||
|
||||
# Add fourth, should evict second (LRU)
|
||||
cache.put_by_id("test", "4", "fourth")
|
||||
|
||||
assert cache.get_by_id("test", "1") is not None # Accessed, kept
|
||||
assert cache.get_by_id("test", "2") is None # Evicted (LRU)
|
||||
assert cache.get_by_id("test", "3") is not None
|
||||
assert cache.get_by_id("test", "4") is not None
|
||||
|
||||
def test_ttl_expiration(self) -> None:
|
||||
"""Should expire entries after TTL."""
|
||||
cache = HotMemoryCache[str](max_size=100, default_ttl_seconds=0.05)
|
||||
|
||||
cache.put_by_id("test", "1", "value")
|
||||
|
||||
# Wait for expiration
|
||||
time.sleep(0.06)
|
||||
|
||||
result = cache.get_by_id("test", "1")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_invalidate(self, cache: HotMemoryCache[dict]) -> None:
|
||||
"""Should invalidate specific entry."""
|
||||
key = CacheKey(memory_type="episodic", memory_id="123")
|
||||
cache.put(key, {"data": "test"})
|
||||
|
||||
result = cache.invalidate(key)
|
||||
|
||||
assert result is True
|
||||
assert cache.get(key) is None
|
||||
|
||||
def test_invalidate_by_id(self, cache: HotMemoryCache[dict]) -> None:
|
||||
"""Should invalidate by ID."""
|
||||
memory_id = uuid4()
|
||||
cache.put_by_id("episodic", memory_id, {"data": "test"})
|
||||
|
||||
result = cache.invalidate_by_id("episodic", memory_id)
|
||||
|
||||
assert result is True
|
||||
assert cache.get_by_id("episodic", memory_id) is None
|
||||
|
||||
def test_invalidate_by_type(self, cache: HotMemoryCache[dict]) -> None:
|
||||
"""Should invalidate all entries of a type."""
|
||||
cache.put_by_id("episodic", "1", {"data": "1"})
|
||||
cache.put_by_id("episodic", "2", {"data": "2"})
|
||||
cache.put_by_id("semantic", "3", {"data": "3"})
|
||||
|
||||
count = cache.invalidate_by_type("episodic")
|
||||
|
||||
assert count == 2
|
||||
assert cache.get_by_id("episodic", "1") is None
|
||||
assert cache.get_by_id("episodic", "2") is None
|
||||
assert cache.get_by_id("semantic", "3") is not None
|
||||
|
||||
def test_invalidate_by_scope(self, cache: HotMemoryCache[dict]) -> None:
|
||||
"""Should invalidate all entries in a scope."""
|
||||
cache.put_by_id("episodic", "1", {"data": "1"}, scope="proj-1")
|
||||
cache.put_by_id("semantic", "2", {"data": "2"}, scope="proj-1")
|
||||
cache.put_by_id("episodic", "3", {"data": "3"}, scope="proj-2")
|
||||
|
||||
count = cache.invalidate_by_scope("proj-1")
|
||||
|
||||
assert count == 2
|
||||
assert cache.get_by_id("episodic", "3", scope="proj-2") is not None
|
||||
|
||||
def test_invalidate_pattern(self, cache: HotMemoryCache[dict]) -> None:
|
||||
"""Should invalidate entries matching pattern."""
|
||||
cache.put_by_id("episodic", "123", {"data": "1"})
|
||||
cache.put_by_id("episodic", "124", {"data": "2"})
|
||||
cache.put_by_id("semantic", "125", {"data": "3"})
|
||||
|
||||
count = cache.invalidate_pattern("episodic:*")
|
||||
|
||||
assert count == 2
|
||||
|
||||
def test_clear(self, cache: HotMemoryCache[dict]) -> None:
|
||||
"""Should clear all entries."""
|
||||
cache.put_by_id("episodic", "1", {"data": "1"})
|
||||
cache.put_by_id("semantic", "2", {"data": "2"})
|
||||
|
||||
count = cache.clear()
|
||||
|
||||
assert count == 2
|
||||
assert cache.size == 0
|
||||
|
||||
def test_cleanup_expired(self) -> None:
|
||||
"""Should remove expired entries."""
|
||||
cache = HotMemoryCache[str](max_size=100, default_ttl_seconds=0.05)
|
||||
|
||||
cache.put_by_id("test", "1", "value1")
|
||||
cache.put_by_id("test", "2", "value2", ttl_seconds=10)
|
||||
|
||||
time.sleep(0.06)
|
||||
|
||||
count = cache.cleanup_expired()
|
||||
|
||||
assert count == 1 # Only the first one expired
|
||||
assert cache.size == 1
|
||||
|
||||
def test_get_hot_memories(self, cache: HotMemoryCache[dict]) -> None:
|
||||
"""Should return most accessed memories."""
|
||||
cache.put_by_id("episodic", "1", {"data": "1"})
|
||||
cache.put_by_id("episodic", "2", {"data": "2"})
|
||||
|
||||
# Access first one multiple times
|
||||
for _ in range(5):
|
||||
cache.get_by_id("episodic", "1")
|
||||
|
||||
hot = cache.get_hot_memories(limit=2)
|
||||
|
||||
assert len(hot) == 2
|
||||
assert hot[0][1] >= hot[1][1] # Sorted by access count
|
||||
|
||||
def test_get_stats(self, cache: HotMemoryCache[dict]) -> None:
|
||||
"""Should return accurate statistics."""
|
||||
cache.put_by_id("episodic", "1", {"data": "1"})
|
||||
cache.get_by_id("episodic", "1") # Hit
|
||||
cache.get_by_id("episodic", "2") # Miss
|
||||
|
||||
stats = cache.get_stats()
|
||||
|
||||
assert stats.hits == 1
|
||||
assert stats.misses == 1
|
||||
assert stats.current_size == 1
|
||||
|
||||
def test_reset_stats(self, cache: HotMemoryCache[dict]) -> None:
|
||||
"""Should reset statistics."""
|
||||
cache.put_by_id("episodic", "1", {"data": "1"})
|
||||
cache.get_by_id("episodic", "1")
|
||||
|
||||
cache.reset_stats()
|
||||
stats = cache.get_stats()
|
||||
|
||||
assert stats.hits == 0
|
||||
assert stats.misses == 0
|
||||
|
||||
|
||||
class TestCreateHotCache:
|
||||
"""Tests for factory function."""
|
||||
|
||||
def test_creates_cache(self) -> None:
|
||||
"""Should create cache with defaults."""
|
||||
cache = create_hot_cache()
|
||||
|
||||
assert cache.max_size == 10000
|
||||
|
||||
def test_creates_cache_with_options(self) -> None:
|
||||
"""Should create cache with custom options."""
|
||||
cache = create_hot_cache(max_size=500, default_ttl_seconds=60.0)
|
||||
|
||||
assert cache.max_size == 500
|
||||
@@ -0,0 +1,2 @@
|
||||
# tests/unit/services/memory/consolidation/__init__.py
|
||||
"""Tests for memory consolidation."""
|
||||
736
backend/tests/unit/services/memory/consolidation/test_service.py
Normal file
736
backend/tests/unit/services/memory/consolidation/test_service.py
Normal file
@@ -0,0 +1,736 @@
|
||||
# tests/unit/services/memory/consolidation/test_service.py
|
||||
"""Unit tests for memory consolidation service."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.memory.consolidation.service import (
|
||||
ConsolidationConfig,
|
||||
ConsolidationResult,
|
||||
MemoryConsolidationService,
|
||||
NightlyConsolidationResult,
|
||||
SessionConsolidationResult,
|
||||
)
|
||||
from app.services.memory.types import Episode, Outcome, TaskState
|
||||
|
||||
|
||||
def _utcnow() -> datetime:
|
||||
"""Get current UTC time."""
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
def make_episode(
|
||||
outcome: Outcome = Outcome.SUCCESS,
|
||||
occurred_at: datetime | None = None,
|
||||
task_type: str = "test_task",
|
||||
lessons_learned: list[str] | None = None,
|
||||
importance_score: float = 0.5,
|
||||
actions: list[dict] | None = None,
|
||||
) -> Episode:
|
||||
"""Create a test episode."""
|
||||
return Episode(
|
||||
id=uuid4(),
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=uuid4(),
|
||||
agent_type_id=uuid4(),
|
||||
session_id="test-session",
|
||||
task_type=task_type,
|
||||
task_description="Test task description",
|
||||
actions=actions or [{"action": "test"}],
|
||||
context_summary="Test context",
|
||||
outcome=outcome,
|
||||
outcome_details="Test outcome",
|
||||
duration_seconds=10.0,
|
||||
tokens_used=100,
|
||||
lessons_learned=lessons_learned or [],
|
||||
importance_score=importance_score,
|
||||
embedding=None,
|
||||
occurred_at=occurred_at or _utcnow(),
|
||||
created_at=_utcnow(),
|
||||
updated_at=_utcnow(),
|
||||
)
|
||||
|
||||
|
||||
def make_task_state(
|
||||
current_step: int = 5,
|
||||
total_steps: int = 10,
|
||||
progress_percent: float = 50.0,
|
||||
status: str = "in_progress",
|
||||
description: str = "Test Task",
|
||||
) -> TaskState:
|
||||
"""Create a test task state."""
|
||||
now = _utcnow()
|
||||
return TaskState(
|
||||
task_id="test-task-id",
|
||||
task_type="test_task",
|
||||
description=description,
|
||||
current_step=current_step,
|
||||
total_steps=total_steps,
|
||||
status=status,
|
||||
progress_percent=progress_percent,
|
||||
started_at=now - timedelta(hours=1),
|
||||
updated_at=now,
|
||||
)
|
||||
|
||||
|
||||
class TestConsolidationConfig:
|
||||
"""Tests for ConsolidationConfig."""
|
||||
|
||||
def test_default_values(self) -> None:
|
||||
"""Test default configuration values."""
|
||||
config = ConsolidationConfig()
|
||||
|
||||
assert config.min_steps_for_episode == 2
|
||||
assert config.min_duration_seconds == 5.0
|
||||
assert config.min_confidence_for_fact == 0.6
|
||||
assert config.max_facts_per_episode == 10
|
||||
assert config.min_episodes_for_procedure == 3
|
||||
assert config.max_episode_age_days == 90
|
||||
assert config.batch_size == 100
|
||||
|
||||
def test_custom_values(self) -> None:
|
||||
"""Test custom configuration values."""
|
||||
config = ConsolidationConfig(
|
||||
min_steps_for_episode=5,
|
||||
batch_size=50,
|
||||
)
|
||||
|
||||
assert config.min_steps_for_episode == 5
|
||||
assert config.batch_size == 50
|
||||
|
||||
|
||||
class TestConsolidationResult:
|
||||
"""Tests for ConsolidationResult."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test creating a consolidation result."""
|
||||
result = ConsolidationResult(
|
||||
source_type="episodic",
|
||||
target_type="semantic",
|
||||
items_processed=10,
|
||||
items_created=5,
|
||||
)
|
||||
|
||||
assert result.source_type == "episodic"
|
||||
assert result.target_type == "semantic"
|
||||
assert result.items_processed == 10
|
||||
assert result.items_created == 5
|
||||
assert result.items_skipped == 0
|
||||
assert result.errors == []
|
||||
|
||||
def test_to_dict(self) -> None:
|
||||
"""Test converting to dictionary."""
|
||||
result = ConsolidationResult(
|
||||
source_type="episodic",
|
||||
target_type="semantic",
|
||||
items_processed=10,
|
||||
items_created=5,
|
||||
errors=["test error"],
|
||||
)
|
||||
|
||||
d = result.to_dict()
|
||||
|
||||
assert d["source_type"] == "episodic"
|
||||
assert d["target_type"] == "semantic"
|
||||
assert d["items_processed"] == 10
|
||||
assert d["items_created"] == 5
|
||||
assert "test error" in d["errors"]
|
||||
|
||||
|
||||
class TestSessionConsolidationResult:
|
||||
"""Tests for SessionConsolidationResult."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test creating a session consolidation result."""
|
||||
result = SessionConsolidationResult(
|
||||
session_id="test-session",
|
||||
episode_created=True,
|
||||
episode_id=uuid4(),
|
||||
scratchpad_entries=5,
|
||||
)
|
||||
|
||||
assert result.session_id == "test-session"
|
||||
assert result.episode_created is True
|
||||
assert result.episode_id is not None
|
||||
|
||||
|
||||
class TestNightlyConsolidationResult:
|
||||
"""Tests for NightlyConsolidationResult."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test creating a nightly consolidation result."""
|
||||
result = NightlyConsolidationResult(
|
||||
started_at=_utcnow(),
|
||||
)
|
||||
|
||||
assert result.started_at is not None
|
||||
assert result.completed_at is None
|
||||
assert result.total_episodes_processed == 0
|
||||
|
||||
def test_to_dict(self) -> None:
|
||||
"""Test converting to dictionary."""
|
||||
result = NightlyConsolidationResult(
|
||||
started_at=_utcnow(),
|
||||
completed_at=_utcnow(),
|
||||
total_facts_created=5,
|
||||
total_procedures_created=2,
|
||||
)
|
||||
|
||||
d = result.to_dict()
|
||||
|
||||
assert "started_at" in d
|
||||
assert "completed_at" in d
|
||||
assert d["total_facts_created"] == 5
|
||||
assert d["total_procedures_created"] == 2
|
||||
|
||||
|
||||
class TestMemoryConsolidationService:
|
||||
"""Tests for MemoryConsolidationService."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
return AsyncMock()
|
||||
|
||||
@pytest.fixture
|
||||
def service(self, mock_session: AsyncMock) -> MemoryConsolidationService:
|
||||
"""Create a consolidation service with mocked dependencies."""
|
||||
return MemoryConsolidationService(
|
||||
session=mock_session,
|
||||
config=ConsolidationConfig(),
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Session Consolidation Tests
|
||||
# =========================================================================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consolidate_session_insufficient_steps(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test session not consolidated when insufficient steps."""
|
||||
mock_working_memory = AsyncMock()
|
||||
task_state = make_task_state(current_step=1) # Less than min_steps_for_episode
|
||||
mock_working_memory.get_task_state.return_value = task_state
|
||||
|
||||
result = await service.consolidate_session(
|
||||
working_memory=mock_working_memory,
|
||||
project_id=uuid4(),
|
||||
session_id="test-session",
|
||||
)
|
||||
|
||||
assert result.episode_created is False
|
||||
assert result.episode_id is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consolidate_session_no_task_state(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test session not consolidated when no task state."""
|
||||
mock_working_memory = AsyncMock()
|
||||
mock_working_memory.get_task_state.return_value = None
|
||||
|
||||
result = await service.consolidate_session(
|
||||
working_memory=mock_working_memory,
|
||||
project_id=uuid4(),
|
||||
session_id="test-session",
|
||||
)
|
||||
|
||||
assert result.episode_created is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consolidate_session_success(
|
||||
self, service: MemoryConsolidationService, mock_session: AsyncMock
|
||||
) -> None:
|
||||
"""Test successful session consolidation."""
|
||||
mock_working_memory = AsyncMock()
|
||||
task_state = make_task_state(
|
||||
current_step=5,
|
||||
progress_percent=100.0,
|
||||
status="complete",
|
||||
)
|
||||
mock_working_memory.get_task_state.return_value = task_state
|
||||
mock_working_memory.get_scratchpad.return_value = ["step1", "step2"]
|
||||
mock_working_memory.get_all.return_value = {"key1": "value1"}
|
||||
|
||||
# Mock episodic memory
|
||||
mock_episode = make_episode()
|
||||
with patch.object(
|
||||
service, "_get_episodic", new_callable=AsyncMock
|
||||
) as mock_get_episodic:
|
||||
mock_episodic = AsyncMock()
|
||||
mock_episodic.record_episode.return_value = mock_episode
|
||||
mock_get_episodic.return_value = mock_episodic
|
||||
|
||||
result = await service.consolidate_session(
|
||||
working_memory=mock_working_memory,
|
||||
project_id=uuid4(),
|
||||
session_id="test-session",
|
||||
)
|
||||
|
||||
assert result.episode_created is True
|
||||
assert result.episode_id == mock_episode.id
|
||||
assert result.scratchpad_entries == 2
|
||||
|
||||
# =========================================================================
|
||||
# Outcome Determination Tests
|
||||
# =========================================================================
|
||||
|
||||
def test_determine_session_outcome_success(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test outcome determination for successful session."""
|
||||
task_state = make_task_state(status="complete", progress_percent=100.0)
|
||||
outcome = service._determine_session_outcome(task_state)
|
||||
assert outcome == Outcome.SUCCESS
|
||||
|
||||
def test_determine_session_outcome_failure(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test outcome determination for failed session."""
|
||||
task_state = make_task_state(status="error", progress_percent=25.0)
|
||||
outcome = service._determine_session_outcome(task_state)
|
||||
assert outcome == Outcome.FAILURE
|
||||
|
||||
def test_determine_session_outcome_partial(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test outcome determination for partial session."""
|
||||
task_state = make_task_state(status="stopped", progress_percent=60.0)
|
||||
outcome = service._determine_session_outcome(task_state)
|
||||
assert outcome == Outcome.PARTIAL
|
||||
|
||||
def test_determine_session_outcome_none(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test outcome determination with no task state."""
|
||||
outcome = service._determine_session_outcome(None)
|
||||
assert outcome == Outcome.PARTIAL
|
||||
|
||||
# =========================================================================
|
||||
# Action Building Tests
|
||||
# =========================================================================
|
||||
|
||||
def test_build_actions_from_session(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test building actions from session data."""
|
||||
scratchpad = ["thought 1", "thought 2"]
|
||||
variables = {"var1": "value1"}
|
||||
task_state = make_task_state()
|
||||
|
||||
actions = service._build_actions_from_session(scratchpad, variables, task_state)
|
||||
|
||||
assert len(actions) == 3 # 2 scratchpad + 1 final state
|
||||
assert actions[0]["type"] == "reasoning"
|
||||
assert actions[2]["type"] == "final_state"
|
||||
|
||||
def test_build_context_summary(self, service: MemoryConsolidationService) -> None:
|
||||
"""Test building context summary."""
|
||||
task_state = make_task_state(
|
||||
description="Test Task",
|
||||
progress_percent=75.0,
|
||||
)
|
||||
variables = {"key": "value"}
|
||||
|
||||
summary = service._build_context_summary(task_state, variables)
|
||||
|
||||
assert "Test Task" in summary
|
||||
assert "75.0%" in summary
|
||||
|
||||
# =========================================================================
|
||||
# Importance Calculation Tests
|
||||
# =========================================================================
|
||||
|
||||
def test_calculate_session_importance_base(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test base importance calculation."""
|
||||
task_state = make_task_state(total_steps=3) # Below threshold
|
||||
importance = service._calculate_session_importance(
|
||||
task_state, Outcome.SUCCESS, []
|
||||
)
|
||||
|
||||
assert importance == 0.5 # Base score
|
||||
|
||||
def test_calculate_session_importance_failure(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test importance boost for failures."""
|
||||
task_state = make_task_state(total_steps=3) # Below threshold
|
||||
importance = service._calculate_session_importance(
|
||||
task_state, Outcome.FAILURE, []
|
||||
)
|
||||
|
||||
assert importance == 0.8 # Base (0.5) + failure boost (0.3)
|
||||
|
||||
def test_calculate_session_importance_complex(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test importance for complex session."""
|
||||
task_state = make_task_state(total_steps=10)
|
||||
actions = [{"step": i} for i in range(6)]
|
||||
importance = service._calculate_session_importance(
|
||||
task_state, Outcome.SUCCESS, actions
|
||||
)
|
||||
|
||||
# Base (0.5) + many steps (0.1) + many actions (0.1)
|
||||
assert importance == 0.7
|
||||
|
||||
# =========================================================================
|
||||
# Episode to Fact Consolidation Tests
|
||||
# =========================================================================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consolidate_episodes_to_facts_empty(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test consolidation with no episodes."""
|
||||
with patch.object(
|
||||
service, "_get_episodic", new_callable=AsyncMock
|
||||
) as mock_get_episodic:
|
||||
mock_episodic = AsyncMock()
|
||||
mock_episodic.get_recent.return_value = []
|
||||
mock_get_episodic.return_value = mock_episodic
|
||||
|
||||
result = await service.consolidate_episodes_to_facts(
|
||||
project_id=uuid4(),
|
||||
)
|
||||
|
||||
assert result.items_processed == 0
|
||||
assert result.items_created == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consolidate_episodes_to_facts_success(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test successful fact extraction."""
|
||||
episode = make_episode(
|
||||
lessons_learned=["Always check return values"],
|
||||
)
|
||||
|
||||
mock_fact = MagicMock()
|
||||
mock_fact.reinforcement_count = 1 # New fact
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
service, "_get_episodic", new_callable=AsyncMock
|
||||
) as mock_get_episodic,
|
||||
patch.object(
|
||||
service, "_get_semantic", new_callable=AsyncMock
|
||||
) as mock_get_semantic,
|
||||
):
|
||||
mock_episodic = AsyncMock()
|
||||
mock_episodic.get_recent.return_value = [episode]
|
||||
mock_get_episodic.return_value = mock_episodic
|
||||
|
||||
mock_semantic = AsyncMock()
|
||||
mock_semantic.store_fact.return_value = mock_fact
|
||||
mock_get_semantic.return_value = mock_semantic
|
||||
|
||||
result = await service.consolidate_episodes_to_facts(
|
||||
project_id=uuid4(),
|
||||
)
|
||||
|
||||
assert result.items_processed == 1
|
||||
# At least one fact should be created from lesson
|
||||
assert result.items_created >= 0
|
||||
|
||||
# =========================================================================
|
||||
# Episode to Procedure Consolidation Tests
|
||||
# =========================================================================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consolidate_episodes_to_procedures_insufficient(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test consolidation with insufficient episodes."""
|
||||
# Only 1 episode - less than min_episodes_for_procedure (3)
|
||||
episode = make_episode()
|
||||
|
||||
with patch.object(
|
||||
service, "_get_episodic", new_callable=AsyncMock
|
||||
) as mock_get_episodic:
|
||||
mock_episodic = AsyncMock()
|
||||
mock_episodic.get_by_outcome.return_value = [episode]
|
||||
mock_get_episodic.return_value = mock_episodic
|
||||
|
||||
result = await service.consolidate_episodes_to_procedures(
|
||||
project_id=uuid4(),
|
||||
)
|
||||
|
||||
assert result.items_processed == 1
|
||||
assert result.items_created == 0
|
||||
assert result.items_skipped == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consolidate_episodes_to_procedures_success(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test successful procedure creation."""
|
||||
# Create enough episodes for a procedure
|
||||
episodes = [
|
||||
make_episode(
|
||||
task_type="deploy",
|
||||
actions=[{"type": "step1"}, {"type": "step2"}, {"type": "step3"}],
|
||||
)
|
||||
for _ in range(5)
|
||||
]
|
||||
|
||||
mock_procedure = MagicMock()
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
service, "_get_episodic", new_callable=AsyncMock
|
||||
) as mock_get_episodic,
|
||||
patch.object(
|
||||
service, "_get_procedural", new_callable=AsyncMock
|
||||
) as mock_get_procedural,
|
||||
):
|
||||
mock_episodic = AsyncMock()
|
||||
mock_episodic.get_by_outcome.return_value = episodes
|
||||
mock_get_episodic.return_value = mock_episodic
|
||||
|
||||
mock_procedural = AsyncMock()
|
||||
mock_procedural.find_matching.return_value = [] # No existing procedure
|
||||
mock_procedural.record_procedure.return_value = mock_procedure
|
||||
mock_get_procedural.return_value = mock_procedural
|
||||
|
||||
result = await service.consolidate_episodes_to_procedures(
|
||||
project_id=uuid4(),
|
||||
)
|
||||
|
||||
assert result.items_processed == 5
|
||||
assert result.items_created == 1
|
||||
|
||||
# =========================================================================
|
||||
# Common Steps Extraction Tests
|
||||
# =========================================================================
|
||||
|
||||
def test_extract_common_steps(self, service: MemoryConsolidationService) -> None:
|
||||
"""Test extracting steps from episodes."""
|
||||
episodes = [
|
||||
make_episode(
|
||||
outcome=Outcome.SUCCESS,
|
||||
importance_score=0.8,
|
||||
actions=[
|
||||
{"type": "step1", "content": "First step"},
|
||||
{"type": "step2", "content": "Second step"},
|
||||
],
|
||||
),
|
||||
make_episode(
|
||||
outcome=Outcome.SUCCESS,
|
||||
importance_score=0.5,
|
||||
actions=[{"type": "simple"}],
|
||||
),
|
||||
]
|
||||
|
||||
steps = service._extract_common_steps(episodes)
|
||||
|
||||
assert len(steps) == 2
|
||||
assert steps[0]["order"] == 1
|
||||
assert steps[0]["action"] == "step1"
|
||||
|
||||
# =========================================================================
|
||||
# Pruning Tests
|
||||
# =========================================================================
|
||||
|
||||
def test_should_prune_episode_old_low_importance(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test pruning old, low-importance episode."""
|
||||
old_date = _utcnow() - timedelta(days=100)
|
||||
episode = make_episode(
|
||||
occurred_at=old_date,
|
||||
importance_score=0.1,
|
||||
outcome=Outcome.SUCCESS,
|
||||
)
|
||||
cutoff = _utcnow() - timedelta(days=90)
|
||||
|
||||
should_prune = service._should_prune_episode(episode, cutoff, 0.2)
|
||||
|
||||
assert should_prune is True
|
||||
|
||||
def test_should_prune_episode_recent(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test not pruning recent episode."""
|
||||
recent_date = _utcnow() - timedelta(days=30)
|
||||
episode = make_episode(
|
||||
occurred_at=recent_date,
|
||||
importance_score=0.1,
|
||||
)
|
||||
cutoff = _utcnow() - timedelta(days=90)
|
||||
|
||||
should_prune = service._should_prune_episode(episode, cutoff, 0.2)
|
||||
|
||||
assert should_prune is False
|
||||
|
||||
def test_should_prune_episode_failure_protected(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test not pruning failure (with keep_all_failures=True)."""
|
||||
old_date = _utcnow() - timedelta(days=100)
|
||||
episode = make_episode(
|
||||
occurred_at=old_date,
|
||||
importance_score=0.1,
|
||||
outcome=Outcome.FAILURE,
|
||||
)
|
||||
cutoff = _utcnow() - timedelta(days=90)
|
||||
|
||||
should_prune = service._should_prune_episode(episode, cutoff, 0.2)
|
||||
|
||||
# Config has keep_all_failures=True by default
|
||||
assert should_prune is False
|
||||
|
||||
def test_should_prune_episode_with_lessons_protected(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test not pruning episode with lessons."""
|
||||
old_date = _utcnow() - timedelta(days=100)
|
||||
episode = make_episode(
|
||||
occurred_at=old_date,
|
||||
importance_score=0.1,
|
||||
lessons_learned=["Important lesson"],
|
||||
)
|
||||
cutoff = _utcnow() - timedelta(days=90)
|
||||
|
||||
should_prune = service._should_prune_episode(episode, cutoff, 0.2)
|
||||
|
||||
# Config has keep_all_with_lessons=True by default
|
||||
assert should_prune is False
|
||||
|
||||
def test_should_prune_episode_high_importance_protected(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test not pruning high importance episode."""
|
||||
old_date = _utcnow() - timedelta(days=100)
|
||||
episode = make_episode(
|
||||
occurred_at=old_date,
|
||||
importance_score=0.8,
|
||||
)
|
||||
cutoff = _utcnow() - timedelta(days=90)
|
||||
|
||||
should_prune = service._should_prune_episode(episode, cutoff, 0.2)
|
||||
|
||||
assert should_prune is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prune_old_episodes(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test episode pruning."""
|
||||
old_episode = make_episode(
|
||||
occurred_at=_utcnow() - timedelta(days=100),
|
||||
importance_score=0.1,
|
||||
outcome=Outcome.SUCCESS,
|
||||
lessons_learned=[],
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
service, "_get_episodic", new_callable=AsyncMock
|
||||
) as mock_get_episodic:
|
||||
mock_episodic = AsyncMock()
|
||||
mock_episodic.get_recent.return_value = [old_episode]
|
||||
mock_episodic.delete.return_value = True
|
||||
mock_get_episodic.return_value = mock_episodic
|
||||
|
||||
result = await service.prune_old_episodes(project_id=uuid4())
|
||||
|
||||
assert result.items_processed == 1
|
||||
assert result.items_pruned == 1
|
||||
|
||||
# =========================================================================
|
||||
# Nightly Consolidation Tests
|
||||
# =========================================================================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_nightly_consolidation(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test nightly consolidation workflow."""
|
||||
with (
|
||||
patch.object(
|
||||
service,
|
||||
"consolidate_episodes_to_facts",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_facts,
|
||||
patch.object(
|
||||
service,
|
||||
"consolidate_episodes_to_procedures",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_procedures,
|
||||
patch.object(
|
||||
service,
|
||||
"prune_old_episodes",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_prune,
|
||||
):
|
||||
mock_facts.return_value = ConsolidationResult(
|
||||
source_type="episodic",
|
||||
target_type="semantic",
|
||||
items_processed=10,
|
||||
items_created=5,
|
||||
)
|
||||
mock_procedures.return_value = ConsolidationResult(
|
||||
source_type="episodic",
|
||||
target_type="procedural",
|
||||
items_processed=10,
|
||||
items_created=2,
|
||||
)
|
||||
mock_prune.return_value = ConsolidationResult(
|
||||
source_type="episodic",
|
||||
target_type="pruned",
|
||||
items_pruned=3,
|
||||
)
|
||||
|
||||
result = await service.run_nightly_consolidation(project_id=uuid4())
|
||||
|
||||
assert result.completed_at is not None
|
||||
assert result.total_facts_created == 5
|
||||
assert result.total_procedures_created == 2
|
||||
assert result.total_pruned == 3
|
||||
assert result.total_episodes_processed == 20
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_nightly_consolidation_with_errors(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test nightly consolidation handles errors."""
|
||||
with (
|
||||
patch.object(
|
||||
service,
|
||||
"consolidate_episodes_to_facts",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_facts,
|
||||
patch.object(
|
||||
service,
|
||||
"consolidate_episodes_to_procedures",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_procedures,
|
||||
patch.object(
|
||||
service,
|
||||
"prune_old_episodes",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_prune,
|
||||
):
|
||||
mock_facts.return_value = ConsolidationResult(
|
||||
source_type="episodic",
|
||||
target_type="semantic",
|
||||
errors=["fact error"],
|
||||
)
|
||||
mock_procedures.return_value = ConsolidationResult(
|
||||
source_type="episodic",
|
||||
target_type="procedural",
|
||||
)
|
||||
mock_prune.return_value = ConsolidationResult(
|
||||
source_type="episodic",
|
||||
target_type="pruned",
|
||||
)
|
||||
|
||||
result = await service.run_nightly_consolidation(project_id=uuid4())
|
||||
|
||||
assert "fact error" in result.errors
|
||||
2
backend/tests/unit/services/memory/episodic/__init__.py
Normal file
2
backend/tests/unit/services/memory/episodic/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# tests/unit/services/memory/episodic/__init__.py
|
||||
"""Unit tests for episodic memory service."""
|
||||
359
backend/tests/unit/services/memory/episodic/test_memory.py
Normal file
359
backend/tests/unit/services/memory/episodic/test_memory.py
Normal file
@@ -0,0 +1,359 @@
|
||||
# tests/unit/services/memory/episodic/test_memory.py
|
||||
"""Unit tests for EpisodicMemory class."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.memory.episodic.memory import EpisodicMemory
|
||||
from app.services.memory.episodic.retrieval import RetrievalStrategy
|
||||
from app.services.memory.types import EpisodeCreate, Outcome, RetrievalResult
|
||||
|
||||
|
||||
class TestEpisodicMemoryInit:
|
||||
"""Tests for EpisodicMemory initialization."""
|
||||
|
||||
def test_init_creates_recorder_and_retriever(self) -> None:
|
||||
"""Test that init creates recorder and retriever."""
|
||||
mock_session = AsyncMock()
|
||||
memory = EpisodicMemory(session=mock_session)
|
||||
|
||||
assert memory._recorder is not None
|
||||
assert memory._retriever is not None
|
||||
assert memory._session is mock_session
|
||||
|
||||
def test_init_with_embedding_generator(self) -> None:
|
||||
"""Test init with embedding generator."""
|
||||
mock_session = AsyncMock()
|
||||
mock_embedding_gen = AsyncMock()
|
||||
memory = EpisodicMemory(
|
||||
session=mock_session, embedding_generator=mock_embedding_gen
|
||||
)
|
||||
|
||||
assert memory._embedding_generator is mock_embedding_gen
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_factory_method(self) -> None:
|
||||
"""Test create factory method."""
|
||||
mock_session = AsyncMock()
|
||||
memory = await EpisodicMemory.create(session=mock_session)
|
||||
|
||||
assert memory is not None
|
||||
assert memory._session is mock_session
|
||||
|
||||
|
||||
class TestEpisodicMemoryRecording:
|
||||
"""Tests for episode recording methods."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
session = AsyncMock()
|
||||
session.add = MagicMock()
|
||||
session.flush = AsyncMock()
|
||||
session.refresh = AsyncMock()
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def memory(self, mock_session: AsyncMock) -> EpisodicMemory:
|
||||
"""Create an EpisodicMemory instance."""
|
||||
return EpisodicMemory(session=mock_session)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_episode(
|
||||
self,
|
||||
memory: EpisodicMemory,
|
||||
) -> None:
|
||||
"""Test recording an episode."""
|
||||
episode_data = EpisodeCreate(
|
||||
project_id=uuid4(),
|
||||
session_id="test-session",
|
||||
task_type="test_task",
|
||||
task_description="Test description",
|
||||
actions=[{"action": "test"}],
|
||||
context_summary="Test context",
|
||||
outcome=Outcome.SUCCESS,
|
||||
outcome_details="Success",
|
||||
duration_seconds=30.0,
|
||||
tokens_used=100,
|
||||
)
|
||||
|
||||
result = await memory.record_episode(episode_data)
|
||||
|
||||
assert result.project_id == episode_data.project_id
|
||||
assert result.task_type == "test_task"
|
||||
assert result.outcome == Outcome.SUCCESS
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_success(
|
||||
self,
|
||||
memory: EpisodicMemory,
|
||||
) -> None:
|
||||
"""Test convenience method for recording success."""
|
||||
project_id = uuid4()
|
||||
result = await memory.record_success(
|
||||
project_id=project_id,
|
||||
session_id="test-session",
|
||||
task_type="deployment",
|
||||
task_description="Deploy to production",
|
||||
actions=[{"step": "deploy"}],
|
||||
context_summary="Deploying v1.0",
|
||||
outcome_details="Deployed successfully",
|
||||
duration_seconds=60.0,
|
||||
tokens_used=200,
|
||||
)
|
||||
|
||||
assert result.outcome == Outcome.SUCCESS
|
||||
assert result.task_type == "deployment"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_failure(
|
||||
self,
|
||||
memory: EpisodicMemory,
|
||||
) -> None:
|
||||
"""Test convenience method for recording failure."""
|
||||
project_id = uuid4()
|
||||
result = await memory.record_failure(
|
||||
project_id=project_id,
|
||||
session_id="test-session",
|
||||
task_type="deployment",
|
||||
task_description="Deploy to production",
|
||||
actions=[{"step": "deploy"}],
|
||||
context_summary="Deploying v1.0",
|
||||
error_details="Connection timeout",
|
||||
duration_seconds=30.0,
|
||||
tokens_used=100,
|
||||
)
|
||||
|
||||
assert result.outcome == Outcome.FAILURE
|
||||
assert result.outcome_details == "Connection timeout"
|
||||
|
||||
|
||||
class TestEpisodicMemoryRetrieval:
|
||||
"""Tests for episode retrieval methods."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
session = AsyncMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = []
|
||||
session.execute.return_value = mock_result
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def memory(self, mock_session: AsyncMock) -> EpisodicMemory:
|
||||
"""Create an EpisodicMemory instance."""
|
||||
return EpisodicMemory(session=mock_session)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_similar(
|
||||
self,
|
||||
memory: EpisodicMemory,
|
||||
) -> None:
|
||||
"""Test semantic search."""
|
||||
project_id = uuid4()
|
||||
results = await memory.search_similar(project_id, "authentication bug")
|
||||
|
||||
assert isinstance(results, list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_recent(
|
||||
self,
|
||||
memory: EpisodicMemory,
|
||||
) -> None:
|
||||
"""Test getting recent episodes."""
|
||||
project_id = uuid4()
|
||||
results = await memory.get_recent(project_id, limit=5)
|
||||
|
||||
assert isinstance(results, list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_outcome(
|
||||
self,
|
||||
memory: EpisodicMemory,
|
||||
) -> None:
|
||||
"""Test getting episodes by outcome."""
|
||||
project_id = uuid4()
|
||||
results = await memory.get_by_outcome(project_id, Outcome.FAILURE, limit=5)
|
||||
|
||||
assert isinstance(results, list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_task_type(
|
||||
self,
|
||||
memory: EpisodicMemory,
|
||||
) -> None:
|
||||
"""Test getting episodes by task type."""
|
||||
project_id = uuid4()
|
||||
results = await memory.get_by_task_type(project_id, "code_review", limit=5)
|
||||
|
||||
assert isinstance(results, list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_important(
|
||||
self,
|
||||
memory: EpisodicMemory,
|
||||
) -> None:
|
||||
"""Test getting important episodes."""
|
||||
project_id = uuid4()
|
||||
results = await memory.get_important(project_id, limit=5, min_importance=0.8)
|
||||
|
||||
assert isinstance(results, list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_with_full_result(
|
||||
self,
|
||||
memory: EpisodicMemory,
|
||||
) -> None:
|
||||
"""Test retrieve with full result metadata."""
|
||||
project_id = uuid4()
|
||||
result = await memory.retrieve(project_id, RetrievalStrategy.RECENCY, limit=10)
|
||||
|
||||
assert isinstance(result, RetrievalResult)
|
||||
assert result.retrieval_type == "recency"
|
||||
|
||||
|
||||
class TestEpisodicMemorySummarization:
|
||||
"""Tests for episode summarization."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
session = AsyncMock()
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def memory(self, mock_session: AsyncMock) -> EpisodicMemory:
|
||||
"""Create an EpisodicMemory instance."""
|
||||
return EpisodicMemory(session=mock_session)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summarize_empty_list(
|
||||
self,
|
||||
memory: EpisodicMemory,
|
||||
) -> None:
|
||||
"""Test summarizing empty episode list."""
|
||||
summary = await memory.summarize_episodes([])
|
||||
assert "No episodes to summarize" in summary
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summarize_not_found(
|
||||
self,
|
||||
memory: EpisodicMemory,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test summarizing when episodes not found."""
|
||||
# Mock get_by_id to return None
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = None
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
summary = await memory.summarize_episodes([uuid4(), uuid4()])
|
||||
assert "No episodes found" in summary
|
||||
|
||||
|
||||
class TestEpisodicMemoryStats:
|
||||
"""Tests for episode statistics."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
session = AsyncMock()
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def memory(self, mock_session: AsyncMock) -> EpisodicMemory:
|
||||
"""Create an EpisodicMemory instance."""
|
||||
return EpisodicMemory(session=mock_session)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_stats(
|
||||
self,
|
||||
memory: EpisodicMemory,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test getting episode statistics."""
|
||||
# Mock empty result
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = []
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
stats = await memory.get_stats(uuid4())
|
||||
|
||||
assert "total_count" in stats
|
||||
assert "success_count" in stats
|
||||
assert "failure_count" in stats
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_count(
|
||||
self,
|
||||
memory: EpisodicMemory,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test counting episodes."""
|
||||
# Mock result with 3 episodes
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = [1, 2, 3]
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
count = await memory.count(uuid4())
|
||||
assert count == 3
|
||||
|
||||
|
||||
class TestEpisodicMemoryModification:
|
||||
"""Tests for episode modification methods."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
session = AsyncMock()
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def memory(self, mock_session: AsyncMock) -> EpisodicMemory:
|
||||
"""Create an EpisodicMemory instance."""
|
||||
return EpisodicMemory(session=mock_session)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_id_not_found(
|
||||
self,
|
||||
memory: EpisodicMemory,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test get_by_id returns None when not found."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = None
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = await memory.get_by_id(uuid4())
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_importance_not_found(
|
||||
self,
|
||||
memory: EpisodicMemory,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test update_importance returns None when not found."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = None
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = await memory.update_importance(uuid4(), 0.9)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_not_found(
|
||||
self,
|
||||
memory: EpisodicMemory,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test delete returns False when not found."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = None
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = await memory.delete(uuid4())
|
||||
assert result is False
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user