forked from cardosofelipe/fast-next-template
Compare commits
42 Commits
main
...
3c6b14d2bf
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3c6b14d2bf | ||
|
|
6b21a6fadd | ||
|
|
600657adc4 | ||
|
|
c9d0d079b3 | ||
|
|
4c8f81368c | ||
|
|
efbe91ce14 | ||
|
|
5d646779c9 | ||
|
|
5a4d93df26 | ||
|
|
7ef217be39 | ||
|
|
20159c5865 | ||
|
|
f9a72fcb34 | ||
|
|
fcb0a5f86a | ||
|
|
92782bcb05 | ||
|
|
1dcf99ee38 | ||
|
|
70009676a3 | ||
|
|
192237e69b | ||
|
|
3edce9cd26 | ||
|
|
35aea2d73a | ||
|
|
d0f32d04f7 | ||
|
|
da85a8aba8 | ||
|
|
f8bd1011e9 | ||
|
|
f057c2f0b6 | ||
|
|
33ec889fc4 | ||
|
|
74b8c65741 | ||
|
|
b232298c61 | ||
|
|
cf6291ac8e | ||
|
|
e3fe0439fd | ||
|
|
57680c3772 | ||
|
|
997cfaa03a | ||
|
|
6954774e36 | ||
|
|
30e5c68304 | ||
|
|
0b24d4c6cc | ||
|
|
1670e05e0d | ||
|
|
999b7ac03f | ||
|
|
48ecb40f18 | ||
|
|
b818f17418 | ||
|
|
e946787a61 | ||
|
|
3554efe66a | ||
|
|
bd988f76b0 | ||
|
|
4974233169 | ||
|
|
c9d8c0835c | ||
|
|
085a748929 |
@@ -80,7 +80,7 @@ test:
|
||||
|
||||
test-cov:
|
||||
@echo "🧪 Running tests with coverage..."
|
||||
@IS_TEST=True PYTHONPATH=. uv run pytest --cov=app --cov-report=term-missing --cov-report=html -n 16
|
||||
@IS_TEST=True PYTHONPATH=. uv run pytest --cov=app --cov-report=term-missing --cov-report=html -n 20
|
||||
@echo "📊 Coverage report generated in htmlcov/index.html"
|
||||
|
||||
# ============================================================================
|
||||
|
||||
512
backend/app/alembic/versions/0005_add_memory_system_tables.py
Normal file
512
backend/app/alembic/versions/0005_add_memory_system_tables.py
Normal file
@@ -0,0 +1,512 @@
|
||||
"""Add Agent Memory System tables
|
||||
|
||||
Revision ID: 0005
|
||||
Revises: 0004
|
||||
Create Date: 2025-01-05
|
||||
|
||||
This migration creates the Agent Memory System tables:
|
||||
- working_memory: Key-value storage with TTL for active sessions
|
||||
- episodes: Experiential memories from task executions
|
||||
- facts: Semantic knowledge triples with confidence scores
|
||||
- procedures: Learned skills and procedures
|
||||
- memory_consolidation_log: Tracks consolidation jobs
|
||||
|
||||
See Issue #88: Database Schema & Storage Layer
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "0005"
|
||||
down_revision: str | None = "0004"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Create Agent Memory System tables."""
|
||||
|
||||
# =========================================================================
|
||||
# Create ENUM types for memory system
|
||||
# =========================================================================
|
||||
|
||||
# Scope type enum
|
||||
scope_type_enum = postgresql.ENUM(
|
||||
"global",
|
||||
"project",
|
||||
"agent_type",
|
||||
"agent_instance",
|
||||
"session",
|
||||
name="scope_type",
|
||||
create_type=False,
|
||||
)
|
||||
scope_type_enum.create(op.get_bind(), checkfirst=True)
|
||||
|
||||
# Episode outcome enum
|
||||
episode_outcome_enum = postgresql.ENUM(
|
||||
"success",
|
||||
"failure",
|
||||
"partial",
|
||||
name="episode_outcome",
|
||||
create_type=False,
|
||||
)
|
||||
episode_outcome_enum.create(op.get_bind(), checkfirst=True)
|
||||
|
||||
# Consolidation type enum
|
||||
consolidation_type_enum = postgresql.ENUM(
|
||||
"working_to_episodic",
|
||||
"episodic_to_semantic",
|
||||
"episodic_to_procedural",
|
||||
"pruning",
|
||||
name="consolidation_type",
|
||||
create_type=False,
|
||||
)
|
||||
consolidation_type_enum.create(op.get_bind(), checkfirst=True)
|
||||
|
||||
# Consolidation status enum
|
||||
consolidation_status_enum = postgresql.ENUM(
|
||||
"pending",
|
||||
"running",
|
||||
"completed",
|
||||
"failed",
|
||||
name="consolidation_status",
|
||||
create_type=False,
|
||||
)
|
||||
consolidation_status_enum.create(op.get_bind(), checkfirst=True)
|
||||
|
||||
# =========================================================================
|
||||
# Create working_memory table
|
||||
# Key-value storage with TTL for active sessions
|
||||
# =========================================================================
|
||||
op.create_table(
|
||||
"working_memory",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column(
|
||||
"scope_type",
|
||||
scope_type_enum,
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("scope_id", sa.String(255), nullable=False),
|
||||
sa.Column("key", sa.String(255), nullable=False),
|
||||
sa.Column("value", postgresql.JSONB(astext_type=sa.Text()), nullable=False),
|
||||
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
# Working memory indexes
|
||||
op.create_index(
|
||||
"ix_working_memory_scope_type",
|
||||
"working_memory",
|
||||
["scope_type"],
|
||||
)
|
||||
op.create_index(
|
||||
"ix_working_memory_scope_id",
|
||||
"working_memory",
|
||||
["scope_id"],
|
||||
)
|
||||
op.create_index(
|
||||
"ix_working_memory_scope_key",
|
||||
"working_memory",
|
||||
["scope_type", "scope_id", "key"],
|
||||
unique=True,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_working_memory_expires",
|
||||
"working_memory",
|
||||
["expires_at"],
|
||||
)
|
||||
op.create_index(
|
||||
"ix_working_memory_scope_list",
|
||||
"working_memory",
|
||||
["scope_type", "scope_id"],
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Create episodes table
|
||||
# Experiential memories from task executions
|
||||
# =========================================================================
|
||||
op.create_table(
|
||||
"episodes",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("project_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("agent_instance_id", postgresql.UUID(as_uuid=True), nullable=True),
|
||||
sa.Column("agent_type_id", postgresql.UUID(as_uuid=True), nullable=True),
|
||||
sa.Column("session_id", sa.String(255), nullable=False),
|
||||
sa.Column("task_type", sa.String(100), nullable=False),
|
||||
sa.Column("task_description", sa.Text(), nullable=False),
|
||||
sa.Column(
|
||||
"actions",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=False,
|
||||
server_default="[]",
|
||||
),
|
||||
sa.Column("context_summary", sa.Text(), nullable=False),
|
||||
sa.Column(
|
||||
"outcome",
|
||||
episode_outcome_enum,
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("outcome_details", sa.Text(), nullable=True),
|
||||
sa.Column("duration_seconds", sa.Float(), nullable=False, server_default="0.0"),
|
||||
sa.Column("tokens_used", sa.BigInteger(), nullable=False, server_default="0"),
|
||||
sa.Column(
|
||||
"lessons_learned",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=False,
|
||||
server_default="[]",
|
||||
),
|
||||
sa.Column("importance_score", sa.Float(), nullable=False, server_default="0.5"),
|
||||
# Vector embedding - using TEXT as fallback, will be VECTOR(1536) when pgvector is available
|
||||
sa.Column("embedding", sa.Text(), nullable=True),
|
||||
sa.Column("occurred_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["project_id"],
|
||||
["projects.id"],
|
||||
name="fk_episodes_project",
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["agent_instance_id"],
|
||||
["agent_instances.id"],
|
||||
name="fk_episodes_agent_instance",
|
||||
ondelete="SET NULL",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["agent_type_id"],
|
||||
["agent_types.id"],
|
||||
name="fk_episodes_agent_type",
|
||||
ondelete="SET NULL",
|
||||
),
|
||||
)
|
||||
|
||||
# Episode indexes
|
||||
op.create_index("ix_episodes_project_id", "episodes", ["project_id"])
|
||||
op.create_index("ix_episodes_agent_instance_id", "episodes", ["agent_instance_id"])
|
||||
op.create_index("ix_episodes_agent_type_id", "episodes", ["agent_type_id"])
|
||||
op.create_index("ix_episodes_session_id", "episodes", ["session_id"])
|
||||
op.create_index("ix_episodes_task_type", "episodes", ["task_type"])
|
||||
op.create_index("ix_episodes_outcome", "episodes", ["outcome"])
|
||||
op.create_index("ix_episodes_importance_score", "episodes", ["importance_score"])
|
||||
op.create_index("ix_episodes_occurred_at", "episodes", ["occurred_at"])
|
||||
op.create_index("ix_episodes_project_task", "episodes", ["project_id", "task_type"])
|
||||
op.create_index(
|
||||
"ix_episodes_project_outcome", "episodes", ["project_id", "outcome"]
|
||||
)
|
||||
op.create_index(
|
||||
"ix_episodes_agent_task", "episodes", ["agent_instance_id", "task_type"]
|
||||
)
|
||||
op.create_index(
|
||||
"ix_episodes_project_time", "episodes", ["project_id", "occurred_at"]
|
||||
)
|
||||
op.create_index(
|
||||
"ix_episodes_importance_time",
|
||||
"episodes",
|
||||
["importance_score", "occurred_at"],
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Create facts table
|
||||
# Semantic knowledge triples with confidence scores
|
||||
# =========================================================================
|
||||
op.create_table(
|
||||
"facts",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column(
|
||||
"project_id", postgresql.UUID(as_uuid=True), nullable=True
|
||||
), # NULL for global facts
|
||||
sa.Column("subject", sa.String(500), nullable=False),
|
||||
sa.Column("predicate", sa.String(255), nullable=False),
|
||||
sa.Column("object", sa.Text(), nullable=False),
|
||||
sa.Column("confidence", sa.Float(), nullable=False, server_default="0.8"),
|
||||
# Source episode IDs stored as JSON array of UUID strings for cross-db compatibility
|
||||
sa.Column(
|
||||
"source_episode_ids",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=False,
|
||||
server_default="[]",
|
||||
),
|
||||
sa.Column("first_learned", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("last_reinforced", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column(
|
||||
"reinforcement_count", sa.Integer(), nullable=False, server_default="1"
|
||||
),
|
||||
# Vector embedding
|
||||
sa.Column("embedding", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["project_id"],
|
||||
["projects.id"],
|
||||
name="fk_facts_project",
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
)
|
||||
|
||||
# Fact indexes
|
||||
op.create_index("ix_facts_project_id", "facts", ["project_id"])
|
||||
op.create_index("ix_facts_subject", "facts", ["subject"])
|
||||
op.create_index("ix_facts_predicate", "facts", ["predicate"])
|
||||
op.create_index("ix_facts_confidence", "facts", ["confidence"])
|
||||
op.create_index("ix_facts_subject_predicate", "facts", ["subject", "predicate"])
|
||||
op.create_index("ix_facts_project_subject", "facts", ["project_id", "subject"])
|
||||
op.create_index(
|
||||
"ix_facts_confidence_time", "facts", ["confidence", "last_reinforced"]
|
||||
)
|
||||
# Unique constraint for triples within project scope
|
||||
op.create_index(
|
||||
"ix_facts_unique_triple",
|
||||
"facts",
|
||||
["project_id", "subject", "predicate", "object"],
|
||||
unique=True,
|
||||
postgresql_where=sa.text("project_id IS NOT NULL"),
|
||||
)
|
||||
# Unique constraint for global facts (project_id IS NULL)
|
||||
op.create_index(
|
||||
"ix_facts_unique_triple_global",
|
||||
"facts",
|
||||
["subject", "predicate", "object"],
|
||||
unique=True,
|
||||
postgresql_where=sa.text("project_id IS NULL"),
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Create procedures table
|
||||
# Learned skills and procedures
|
||||
# =========================================================================
|
||||
op.create_table(
|
||||
"procedures",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("project_id", postgresql.UUID(as_uuid=True), nullable=True),
|
||||
sa.Column("agent_type_id", postgresql.UUID(as_uuid=True), nullable=True),
|
||||
sa.Column("name", sa.String(255), nullable=False),
|
||||
sa.Column("trigger_pattern", sa.Text(), nullable=False),
|
||||
sa.Column(
|
||||
"steps",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=False,
|
||||
server_default="[]",
|
||||
),
|
||||
sa.Column("success_count", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.Column("failure_count", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.Column("last_used", sa.DateTime(timezone=True), nullable=True),
|
||||
# Vector embedding
|
||||
sa.Column("embedding", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["project_id"],
|
||||
["projects.id"],
|
||||
name="fk_procedures_project",
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["agent_type_id"],
|
||||
["agent_types.id"],
|
||||
name="fk_procedures_agent_type",
|
||||
ondelete="SET NULL",
|
||||
),
|
||||
)
|
||||
|
||||
# Procedure indexes
|
||||
op.create_index("ix_procedures_project_id", "procedures", ["project_id"])
|
||||
op.create_index("ix_procedures_agent_type_id", "procedures", ["agent_type_id"])
|
||||
op.create_index("ix_procedures_name", "procedures", ["name"])
|
||||
op.create_index("ix_procedures_last_used", "procedures", ["last_used"])
|
||||
op.create_index(
|
||||
"ix_procedures_unique_name",
|
||||
"procedures",
|
||||
["project_id", "agent_type_id", "name"],
|
||||
unique=True,
|
||||
)
|
||||
op.create_index("ix_procedures_project_name", "procedures", ["project_id", "name"])
|
||||
# Note: agent_type_id already indexed via ix_procedures_agent_type_id (line 354)
|
||||
op.create_index(
|
||||
"ix_procedures_success_rate",
|
||||
"procedures",
|
||||
["success_count", "failure_count"],
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Add check constraints for data integrity
|
||||
# =========================================================================
|
||||
|
||||
# Episode constraints
|
||||
op.create_check_constraint(
|
||||
"ck_episodes_importance_range",
|
||||
"episodes",
|
||||
"importance_score >= 0.0 AND importance_score <= 1.0",
|
||||
)
|
||||
op.create_check_constraint(
|
||||
"ck_episodes_duration_positive",
|
||||
"episodes",
|
||||
"duration_seconds >= 0.0",
|
||||
)
|
||||
op.create_check_constraint(
|
||||
"ck_episodes_tokens_positive",
|
||||
"episodes",
|
||||
"tokens_used >= 0",
|
||||
)
|
||||
|
||||
# Fact constraints
|
||||
op.create_check_constraint(
|
||||
"ck_facts_confidence_range",
|
||||
"facts",
|
||||
"confidence >= 0.0 AND confidence <= 1.0",
|
||||
)
|
||||
op.create_check_constraint(
|
||||
"ck_facts_reinforcement_positive",
|
||||
"facts",
|
||||
"reinforcement_count >= 1",
|
||||
)
|
||||
|
||||
# Procedure constraints
|
||||
op.create_check_constraint(
|
||||
"ck_procedures_success_positive",
|
||||
"procedures",
|
||||
"success_count >= 0",
|
||||
)
|
||||
op.create_check_constraint(
|
||||
"ck_procedures_failure_positive",
|
||||
"procedures",
|
||||
"failure_count >= 0",
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Create memory_consolidation_log table
|
||||
# Tracks consolidation jobs
|
||||
# =========================================================================
|
||||
op.create_table(
|
||||
"memory_consolidation_log",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column(
|
||||
"consolidation_type",
|
||||
consolidation_type_enum,
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("source_count", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.Column("result_count", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.Column("started_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
"status",
|
||||
consolidation_status_enum,
|
||||
nullable=False,
|
||||
server_default="pending",
|
||||
),
|
||||
sa.Column("error", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
# Consolidation log indexes
|
||||
op.create_index(
|
||||
"ix_consolidation_type",
|
||||
"memory_consolidation_log",
|
||||
["consolidation_type"],
|
||||
)
|
||||
op.create_index(
|
||||
"ix_consolidation_status",
|
||||
"memory_consolidation_log",
|
||||
["status"],
|
||||
)
|
||||
op.create_index(
|
||||
"ix_consolidation_type_status",
|
||||
"memory_consolidation_log",
|
||||
["consolidation_type", "status"],
|
||||
)
|
||||
op.create_index(
|
||||
"ix_consolidation_started",
|
||||
"memory_consolidation_log",
|
||||
["started_at"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Drop Agent Memory System tables."""
|
||||
|
||||
# Drop check constraints first
|
||||
op.drop_constraint("ck_procedures_failure_positive", "procedures", type_="check")
|
||||
op.drop_constraint("ck_procedures_success_positive", "procedures", type_="check")
|
||||
op.drop_constraint("ck_facts_reinforcement_positive", "facts", type_="check")
|
||||
op.drop_constraint("ck_facts_confidence_range", "facts", type_="check")
|
||||
op.drop_constraint("ck_episodes_tokens_positive", "episodes", type_="check")
|
||||
op.drop_constraint("ck_episodes_duration_positive", "episodes", type_="check")
|
||||
op.drop_constraint("ck_episodes_importance_range", "episodes", type_="check")
|
||||
|
||||
# Drop unique indexes for global facts
|
||||
op.drop_index("ix_facts_unique_triple_global", "facts")
|
||||
|
||||
# Drop tables in reverse order (dependencies first)
|
||||
op.drop_table("memory_consolidation_log")
|
||||
op.drop_table("procedures")
|
||||
op.drop_table("facts")
|
||||
op.drop_table("episodes")
|
||||
op.drop_table("working_memory")
|
||||
|
||||
# Drop ENUM types
|
||||
op.execute("DROP TYPE IF EXISTS consolidation_status")
|
||||
op.execute("DROP TYPE IF EXISTS consolidation_type")
|
||||
op.execute("DROP TYPE IF EXISTS episode_outcome")
|
||||
op.execute("DROP TYPE IF EXISTS scope_type")
|
||||
52
backend/app/alembic/versions/0006_add_abandoned_outcome.py
Normal file
52
backend/app/alembic/versions/0006_add_abandoned_outcome.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""Add ABANDONED to episode_outcome enum
|
||||
|
||||
Revision ID: 0006
|
||||
Revises: 0005
|
||||
Create Date: 2025-01-06
|
||||
|
||||
This migration adds the 'abandoned' value to the episode_outcome enum type.
|
||||
This allows episodes to track when a task was abandoned (not completed,
|
||||
but not necessarily a failure either - e.g., user cancelled, session timeout).
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "0006"
|
||||
down_revision: str | None = "0005"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Add 'abandoned' value to episode_outcome enum."""
|
||||
# PostgreSQL ALTER TYPE ADD VALUE is safe and non-blocking
|
||||
op.execute("ALTER TYPE episode_outcome ADD VALUE IF NOT EXISTS 'abandoned'")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Remove 'abandoned' from episode_outcome enum.
|
||||
|
||||
Note: PostgreSQL doesn't support removing values from enums directly.
|
||||
This downgrade converts any 'abandoned' episodes to 'failure' and
|
||||
recreates the enum without 'abandoned'.
|
||||
"""
|
||||
# Convert any abandoned episodes to failure first
|
||||
op.execute("""
|
||||
UPDATE episodes
|
||||
SET outcome = 'failure'
|
||||
WHERE outcome = 'abandoned'
|
||||
""")
|
||||
|
||||
# Recreate the enum without abandoned
|
||||
# This is complex in PostgreSQL - requires creating new type, updating columns, dropping old
|
||||
op.execute("ALTER TYPE episode_outcome RENAME TO episode_outcome_old")
|
||||
op.execute("CREATE TYPE episode_outcome AS ENUM ('success', 'failure', 'partial')")
|
||||
op.execute("""
|
||||
ALTER TABLE episodes
|
||||
ALTER COLUMN outcome TYPE episode_outcome
|
||||
USING outcome::text::episode_outcome
|
||||
""")
|
||||
op.execute("DROP TYPE episode_outcome_old")
|
||||
@@ -1,366 +0,0 @@
|
||||
{
|
||||
"organizations": [
|
||||
{
|
||||
"name": "Acme Corp",
|
||||
"slug": "acme-corp",
|
||||
"description": "A leading provider of coyote-catching equipment."
|
||||
},
|
||||
{
|
||||
"name": "Globex Corporation",
|
||||
"slug": "globex",
|
||||
"description": "We own the East Coast."
|
||||
},
|
||||
{
|
||||
"name": "Soylent Corp",
|
||||
"slug": "soylent",
|
||||
"description": "Making food for the future."
|
||||
},
|
||||
{
|
||||
"name": "Initech",
|
||||
"slug": "initech",
|
||||
"description": "Software for the soul."
|
||||
},
|
||||
{
|
||||
"name": "Umbrella Corporation",
|
||||
"slug": "umbrella",
|
||||
"description": "Our business is life itself."
|
||||
},
|
||||
{
|
||||
"name": "Massive Dynamic",
|
||||
"slug": "massive-dynamic",
|
||||
"description": "What don't we do?"
|
||||
}
|
||||
],
|
||||
"users": [
|
||||
{
|
||||
"email": "demo@example.com",
|
||||
"password": "DemoPass1234!",
|
||||
"first_name": "Demo",
|
||||
"last_name": "User",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "acme-corp",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "alice@acme.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Alice",
|
||||
"last_name": "Smith",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "acme-corp",
|
||||
"role": "admin",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "bob@acme.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Bob",
|
||||
"last_name": "Jones",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "acme-corp",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "charlie@acme.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Charlie",
|
||||
"last_name": "Brown",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "acme-corp",
|
||||
"role": "member",
|
||||
"is_active": false
|
||||
},
|
||||
{
|
||||
"email": "diana@acme.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Diana",
|
||||
"last_name": "Prince",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "acme-corp",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "carol@globex.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Carol",
|
||||
"last_name": "Williams",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "globex",
|
||||
"role": "owner",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "dan@globex.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Dan",
|
||||
"last_name": "Miller",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "globex",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "ellen@globex.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Ellen",
|
||||
"last_name": "Ripley",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "globex",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "fred@globex.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Fred",
|
||||
"last_name": "Flintstone",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "globex",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "dave@soylent.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Dave",
|
||||
"last_name": "Brown",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "soylent",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "gina@soylent.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Gina",
|
||||
"last_name": "Torres",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "soylent",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "harry@soylent.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Harry",
|
||||
"last_name": "Potter",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "soylent",
|
||||
"role": "admin",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "eve@initech.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Eve",
|
||||
"last_name": "Davis",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "initech",
|
||||
"role": "admin",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "iris@initech.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Iris",
|
||||
"last_name": "West",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "initech",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "jack@initech.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Jack",
|
||||
"last_name": "Sparrow",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "initech",
|
||||
"role": "member",
|
||||
"is_active": false
|
||||
},
|
||||
{
|
||||
"email": "frank@umbrella.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Frank",
|
||||
"last_name": "Miller",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "umbrella",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "george@umbrella.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "George",
|
||||
"last_name": "Costanza",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "umbrella",
|
||||
"role": "member",
|
||||
"is_active": false
|
||||
},
|
||||
{
|
||||
"email": "kate@umbrella.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Kate",
|
||||
"last_name": "Bishop",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "umbrella",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "leo@massive.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Leo",
|
||||
"last_name": "Messi",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "massive-dynamic",
|
||||
"role": "owner",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "mary@massive.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Mary",
|
||||
"last_name": "Jane",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "massive-dynamic",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "nathan@massive.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Nathan",
|
||||
"last_name": "Drake",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "massive-dynamic",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "olivia@massive.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Olivia",
|
||||
"last_name": "Dunham",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "massive-dynamic",
|
||||
"role": "admin",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "peter@massive.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Peter",
|
||||
"last_name": "Parker",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "massive-dynamic",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "quinn@massive.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Quinn",
|
||||
"last_name": "Mallory",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "massive-dynamic",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "grace@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Grace",
|
||||
"last_name": "Hopper",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "heidi@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Heidi",
|
||||
"last_name": "Klum",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "ivan@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Ivan",
|
||||
"last_name": "Drago",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": false
|
||||
},
|
||||
{
|
||||
"email": "rachel@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Rachel",
|
||||
"last_name": "Green",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "sam@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Sam",
|
||||
"last_name": "Wilson",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "tony@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Tony",
|
||||
"last_name": "Stark",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "una@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Una",
|
||||
"last_name": "Chin-Riley",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": false
|
||||
},
|
||||
{
|
||||
"email": "victor@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Victor",
|
||||
"last_name": "Von Doom",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "wanda@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Wanda",
|
||||
"last_name": "Maximoff",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": true
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -3,27 +3,48 @@
|
||||
Async database initialization script.
|
||||
|
||||
Creates the first superuser if configured and doesn't already exist.
|
||||
Seeds default agent types (production data) and demo data (when DEMO_MODE is enabled).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from datetime import UTC, date, datetime, timedelta
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy import select, text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.database import SessionLocal, engine
|
||||
from app.crud.syndarix.agent_type import agent_type as agent_type_crud
|
||||
from app.crud.user import user as user_crud
|
||||
from app.models.organization import Organization
|
||||
from app.models.syndarix import AgentInstance, AgentType, Issue, Project, Sprint
|
||||
from app.models.syndarix.enums import (
|
||||
AgentStatus,
|
||||
AutonomyLevel,
|
||||
ClientMode,
|
||||
IssuePriority,
|
||||
IssueStatus,
|
||||
IssueType,
|
||||
ProjectComplexity,
|
||||
ProjectStatus,
|
||||
SprintStatus,
|
||||
)
|
||||
from app.models.user import User
|
||||
from app.models.user_organization import UserOrganization
|
||||
from app.schemas.syndarix import AgentTypeCreate
|
||||
from app.schemas.users import UserCreate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Data file paths
|
||||
DATA_DIR = Path(__file__).parent.parent / "data"
|
||||
DEFAULT_AGENT_TYPES_PATH = DATA_DIR / "default_agent_types.json"
|
||||
DEMO_DATA_PATH = DATA_DIR / "demo_data.json"
|
||||
|
||||
|
||||
async def init_db() -> User | None:
|
||||
"""
|
||||
@@ -54,28 +75,29 @@ async def init_db() -> User | None:
|
||||
|
||||
if existing_user:
|
||||
logger.info(f"Superuser already exists: {existing_user.email}")
|
||||
return existing_user
|
||||
else:
|
||||
# Create superuser if doesn't exist
|
||||
user_in = UserCreate(
|
||||
email=superuser_email,
|
||||
password=superuser_password,
|
||||
first_name="Admin",
|
||||
last_name="User",
|
||||
is_superuser=True,
|
||||
)
|
||||
|
||||
# Create superuser if doesn't exist
|
||||
user_in = UserCreate(
|
||||
email=superuser_email,
|
||||
password=superuser_password,
|
||||
first_name="Admin",
|
||||
last_name="User",
|
||||
is_superuser=True,
|
||||
)
|
||||
existing_user = await user_crud.create(session, obj_in=user_in)
|
||||
await session.commit()
|
||||
await session.refresh(existing_user)
|
||||
logger.info(f"Created first superuser: {existing_user.email}")
|
||||
|
||||
user = await user_crud.create(session, obj_in=user_in)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
# ALWAYS load default agent types (production data)
|
||||
await load_default_agent_types(session)
|
||||
|
||||
logger.info(f"Created first superuser: {user.email}")
|
||||
|
||||
# Create demo data if in demo mode
|
||||
# Only load demo data if in demo mode
|
||||
if settings.DEMO_MODE:
|
||||
await load_demo_data(session)
|
||||
|
||||
return user
|
||||
return existing_user
|
||||
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
@@ -88,26 +110,89 @@ def _load_json_file(path: Path):
|
||||
return json.load(f)
|
||||
|
||||
|
||||
async def load_demo_data(session):
|
||||
"""Load demo data from JSON file."""
|
||||
demo_data_path = Path(__file__).parent / "core" / "demo_data.json"
|
||||
if not demo_data_path.exists():
|
||||
logger.warning(f"Demo data file not found: {demo_data_path}")
|
||||
async def load_default_agent_types(session: AsyncSession) -> None:
|
||||
"""
|
||||
Load default agent types from JSON file.
|
||||
|
||||
These are production defaults - created only if they don't exist, never overwritten.
|
||||
This allows users to customize agent types without worrying about server restarts.
|
||||
"""
|
||||
if not DEFAULT_AGENT_TYPES_PATH.exists():
|
||||
logger.warning(
|
||||
f"Default agent types file not found: {DEFAULT_AGENT_TYPES_PATH}"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
# Use asyncio.to_thread to avoid blocking the event loop
|
||||
data = await asyncio.to_thread(_load_json_file, demo_data_path)
|
||||
data = await asyncio.to_thread(_load_json_file, DEFAULT_AGENT_TYPES_PATH)
|
||||
|
||||
# Create Organizations
|
||||
org_map = {}
|
||||
for org_data in data.get("organizations", []):
|
||||
# Check if org exists
|
||||
result = await session.execute(
|
||||
text("SELECT * FROM organizations WHERE slug = :slug"),
|
||||
{"slug": org_data["slug"]},
|
||||
for agent_type_data in data:
|
||||
slug = agent_type_data["slug"]
|
||||
|
||||
# Check if agent type already exists
|
||||
existing = await agent_type_crud.get_by_slug(session, slug=slug)
|
||||
|
||||
if existing:
|
||||
logger.debug(f"Agent type already exists: {agent_type_data['name']}")
|
||||
continue
|
||||
|
||||
# Create the agent type
|
||||
agent_type_in = AgentTypeCreate(
|
||||
name=agent_type_data["name"],
|
||||
slug=slug,
|
||||
description=agent_type_data.get("description"),
|
||||
expertise=agent_type_data.get("expertise", []),
|
||||
personality_prompt=agent_type_data["personality_prompt"],
|
||||
primary_model=agent_type_data["primary_model"],
|
||||
fallback_models=agent_type_data.get("fallback_models", []),
|
||||
model_params=agent_type_data.get("model_params", {}),
|
||||
mcp_servers=agent_type_data.get("mcp_servers", []),
|
||||
tool_permissions=agent_type_data.get("tool_permissions", {}),
|
||||
is_active=agent_type_data.get("is_active", True),
|
||||
)
|
||||
existing_org = result.first()
|
||||
|
||||
await agent_type_crud.create(session, obj_in=agent_type_in)
|
||||
logger.info(f"Created default agent type: {agent_type_data['name']}")
|
||||
|
||||
logger.info("Default agent types loaded successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading default agent types: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def load_demo_data(session: AsyncSession) -> None:
|
||||
"""
|
||||
Load demo data from JSON file.
|
||||
|
||||
Only runs when DEMO_MODE is enabled. Creates demo organizations, users,
|
||||
projects, sprints, agent instances, and issues.
|
||||
"""
|
||||
if not DEMO_DATA_PATH.exists():
|
||||
logger.warning(f"Demo data file not found: {DEMO_DATA_PATH}")
|
||||
return
|
||||
|
||||
try:
|
||||
data = await asyncio.to_thread(_load_json_file, DEMO_DATA_PATH)
|
||||
|
||||
# Build lookup maps for FK resolution
|
||||
org_map: dict[str, Organization] = {}
|
||||
user_map: dict[str, User] = {}
|
||||
project_map: dict[str, Project] = {}
|
||||
sprint_map: dict[str, Sprint] = {} # key: "project_slug:sprint_number"
|
||||
agent_type_map: dict[str, AgentType] = {}
|
||||
agent_instance_map: dict[
|
||||
str, AgentInstance
|
||||
] = {} # key: "project_slug:agent_name"
|
||||
|
||||
# ========================
|
||||
# 1. Create Organizations
|
||||
# ========================
|
||||
for org_data in data.get("organizations", []):
|
||||
org_result = await session.execute(
|
||||
select(Organization).where(Organization.slug == org_data["slug"])
|
||||
)
|
||||
existing_org = org_result.scalar_one_or_none()
|
||||
|
||||
if not existing_org:
|
||||
org = Organization(
|
||||
@@ -117,29 +202,20 @@ async def load_demo_data(session):
|
||||
is_active=True,
|
||||
)
|
||||
session.add(org)
|
||||
await session.flush() # Flush to get ID
|
||||
org_map[org.slug] = org
|
||||
await session.flush()
|
||||
org_map[str(org.slug)] = org
|
||||
logger.info(f"Created demo organization: {org.name}")
|
||||
else:
|
||||
# We can't easily get the ORM object from raw SQL result for map without querying again or mapping
|
||||
# So let's just query it properly if we need it for relationships
|
||||
# But for simplicity in this script, let's just assume we created it or it exists.
|
||||
# To properly map for users, we need the ID.
|
||||
# Let's use a simpler approach: just try to create, if slug conflict, skip.
|
||||
pass
|
||||
org_map[str(existing_org.slug)] = existing_org
|
||||
|
||||
# Re-query all orgs to build map for users
|
||||
result = await session.execute(select(Organization))
|
||||
orgs = result.scalars().all()
|
||||
org_map = {org.slug: org for org in orgs}
|
||||
|
||||
# Create Users
|
||||
# ========================
|
||||
# 2. Create Users
|
||||
# ========================
|
||||
for user_data in data.get("users", []):
|
||||
existing_user = await user_crud.get_by_email(
|
||||
session, email=user_data["email"]
|
||||
)
|
||||
if not existing_user:
|
||||
# Create user
|
||||
user_in = UserCreate(
|
||||
email=user_data["email"],
|
||||
password=user_data["password"],
|
||||
@@ -151,17 +227,13 @@ async def load_demo_data(session):
|
||||
user = await user_crud.create(session, obj_in=user_in)
|
||||
|
||||
# Randomize created_at for demo data (last 30 days)
|
||||
# This makes the charts look more realistic
|
||||
days_ago = random.randint(0, 30) # noqa: S311
|
||||
random_time = datetime.now(UTC) - timedelta(days=days_ago)
|
||||
# Add some random hours/minutes variation
|
||||
random_time = random_time.replace(
|
||||
hour=random.randint(0, 23), # noqa: S311
|
||||
minute=random.randint(0, 59), # noqa: S311
|
||||
)
|
||||
|
||||
# Update the timestamp and is_active directly in the database
|
||||
# We do this to ensure the values are persisted correctly
|
||||
await session.execute(
|
||||
text(
|
||||
"UPDATE users SET created_at = :created_at, is_active = :is_active WHERE id = :user_id"
|
||||
@@ -174,7 +246,7 @@ async def load_demo_data(session):
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Created demo user: {user.email} (created {days_ago} days ago, active={user_data.get('is_active', True)})"
|
||||
f"Created demo user: {user.email} (created {days_ago} days ago)"
|
||||
)
|
||||
|
||||
# Add to organization if specified
|
||||
@@ -182,19 +254,228 @@ async def load_demo_data(session):
|
||||
role = user_data.get("role")
|
||||
if org_slug and org_slug in org_map and role:
|
||||
org = org_map[org_slug]
|
||||
# Check if membership exists (it shouldn't for new user)
|
||||
member = UserOrganization(
|
||||
user_id=user.id, organization_id=org.id, role=role
|
||||
)
|
||||
session.add(member)
|
||||
logger.info(f"Added {user.email} to {org.name} as {role}")
|
||||
|
||||
user_map[str(user.email)] = user
|
||||
else:
|
||||
logger.info(f"Demo user already exists: {existing_user.email}")
|
||||
user_map[str(existing_user.email)] = existing_user
|
||||
logger.debug(f"Demo user already exists: {existing_user.email}")
|
||||
|
||||
await session.flush()
|
||||
|
||||
# Add admin user to map with special "__admin__" key
|
||||
# This allows demo data to reference the admin user as owner
|
||||
superuser_email = settings.FIRST_SUPERUSER_EMAIL or "admin@example.com"
|
||||
admin_user = await user_crud.get_by_email(session, email=superuser_email)
|
||||
if admin_user:
|
||||
user_map["__admin__"] = admin_user
|
||||
user_map[str(admin_user.email)] = admin_user
|
||||
logger.debug(f"Added admin user to map: {admin_user.email}")
|
||||
|
||||
# ========================
|
||||
# 3. Load Agent Types Map (for FK resolution)
|
||||
# ========================
|
||||
agent_types_result = await session.execute(select(AgentType))
|
||||
for at in agent_types_result.scalars().all():
|
||||
agent_type_map[str(at.slug)] = at
|
||||
|
||||
# ========================
|
||||
# 4. Create Projects
|
||||
# ========================
|
||||
for project_data in data.get("projects", []):
|
||||
project_result = await session.execute(
|
||||
select(Project).where(Project.slug == project_data["slug"])
|
||||
)
|
||||
existing_project = project_result.scalar_one_or_none()
|
||||
|
||||
if not existing_project:
|
||||
# Resolve owner email to user ID
|
||||
owner_id = None
|
||||
owner_email = project_data.get("owner_email")
|
||||
if owner_email and owner_email in user_map:
|
||||
owner_id = user_map[owner_email].id
|
||||
|
||||
project = Project(
|
||||
name=project_data["name"],
|
||||
slug=project_data["slug"],
|
||||
description=project_data.get("description"),
|
||||
owner_id=owner_id,
|
||||
autonomy_level=AutonomyLevel(
|
||||
project_data.get("autonomy_level", "milestone")
|
||||
),
|
||||
status=ProjectStatus(project_data.get("status", "active")),
|
||||
complexity=ProjectComplexity(
|
||||
project_data.get("complexity", "medium")
|
||||
),
|
||||
client_mode=ClientMode(project_data.get("client_mode", "auto")),
|
||||
settings=project_data.get("settings", {}),
|
||||
)
|
||||
session.add(project)
|
||||
await session.flush()
|
||||
project_map[str(project.slug)] = project
|
||||
logger.info(f"Created demo project: {project.name}")
|
||||
else:
|
||||
project_map[str(existing_project.slug)] = existing_project
|
||||
logger.debug(f"Demo project already exists: {existing_project.name}")
|
||||
|
||||
# ========================
|
||||
# 5. Create Sprints
|
||||
# ========================
|
||||
for sprint_data in data.get("sprints", []):
|
||||
project_slug = sprint_data["project_slug"]
|
||||
sprint_number = sprint_data["number"]
|
||||
sprint_key = f"{project_slug}:{sprint_number}"
|
||||
|
||||
if project_slug not in project_map:
|
||||
logger.warning(f"Project not found for sprint: {project_slug}")
|
||||
continue
|
||||
|
||||
sprint_project = project_map[project_slug]
|
||||
|
||||
# Check if sprint exists
|
||||
sprint_result = await session.execute(
|
||||
select(Sprint).where(
|
||||
Sprint.project_id == sprint_project.id,
|
||||
Sprint.number == sprint_number,
|
||||
)
|
||||
)
|
||||
existing_sprint = sprint_result.scalar_one_or_none()
|
||||
|
||||
if not existing_sprint:
|
||||
sprint = Sprint(
|
||||
project_id=sprint_project.id,
|
||||
name=sprint_data["name"],
|
||||
number=sprint_number,
|
||||
goal=sprint_data.get("goal"),
|
||||
start_date=date.fromisoformat(sprint_data["start_date"]),
|
||||
end_date=date.fromisoformat(sprint_data["end_date"]),
|
||||
status=SprintStatus(sprint_data.get("status", "planned")),
|
||||
planned_points=sprint_data.get("planned_points"),
|
||||
)
|
||||
session.add(sprint)
|
||||
await session.flush()
|
||||
sprint_map[sprint_key] = sprint
|
||||
logger.info(
|
||||
f"Created demo sprint: {sprint.name} for {sprint_project.name}"
|
||||
)
|
||||
else:
|
||||
sprint_map[sprint_key] = existing_sprint
|
||||
logger.debug(f"Demo sprint already exists: {existing_sprint.name}")
|
||||
|
||||
# ========================
|
||||
# 6. Create Agent Instances
|
||||
# ========================
|
||||
for agent_data in data.get("agent_instances", []):
|
||||
project_slug = agent_data["project_slug"]
|
||||
agent_type_slug = agent_data["agent_type_slug"]
|
||||
agent_name = agent_data["name"]
|
||||
agent_key = f"{project_slug}:{agent_name}"
|
||||
|
||||
if project_slug not in project_map:
|
||||
logger.warning(f"Project not found for agent: {project_slug}")
|
||||
continue
|
||||
|
||||
if agent_type_slug not in agent_type_map:
|
||||
logger.warning(f"Agent type not found: {agent_type_slug}")
|
||||
continue
|
||||
|
||||
agent_project = project_map[project_slug]
|
||||
agent_type = agent_type_map[agent_type_slug]
|
||||
|
||||
# Check if agent instance exists (by name within project)
|
||||
agent_result = await session.execute(
|
||||
select(AgentInstance).where(
|
||||
AgentInstance.project_id == agent_project.id,
|
||||
AgentInstance.name == agent_name,
|
||||
)
|
||||
)
|
||||
existing_agent = agent_result.scalar_one_or_none()
|
||||
|
||||
if not existing_agent:
|
||||
agent_instance = AgentInstance(
|
||||
project_id=agent_project.id,
|
||||
agent_type_id=agent_type.id,
|
||||
name=agent_name,
|
||||
status=AgentStatus(agent_data.get("status", "idle")),
|
||||
current_task=agent_data.get("current_task"),
|
||||
)
|
||||
session.add(agent_instance)
|
||||
await session.flush()
|
||||
agent_instance_map[agent_key] = agent_instance
|
||||
logger.info(
|
||||
f"Created demo agent: {agent_name} ({agent_type.name}) "
|
||||
f"for {agent_project.name}"
|
||||
)
|
||||
else:
|
||||
agent_instance_map[agent_key] = existing_agent
|
||||
logger.debug(f"Demo agent already exists: {existing_agent.name}")
|
||||
|
||||
# ========================
|
||||
# 7. Create Issues
|
||||
# ========================
|
||||
for issue_data in data.get("issues", []):
|
||||
project_slug = issue_data["project_slug"]
|
||||
|
||||
if project_slug not in project_map:
|
||||
logger.warning(f"Project not found for issue: {project_slug}")
|
||||
continue
|
||||
|
||||
issue_project = project_map[project_slug]
|
||||
|
||||
# Check if issue exists (by title within project - simple heuristic)
|
||||
issue_result = await session.execute(
|
||||
select(Issue).where(
|
||||
Issue.project_id == issue_project.id,
|
||||
Issue.title == issue_data["title"],
|
||||
)
|
||||
)
|
||||
existing_issue = issue_result.scalar_one_or_none()
|
||||
|
||||
if not existing_issue:
|
||||
# Resolve sprint
|
||||
sprint_id = None
|
||||
sprint_number = issue_data.get("sprint_number")
|
||||
if sprint_number:
|
||||
sprint_key = f"{project_slug}:{sprint_number}"
|
||||
if sprint_key in sprint_map:
|
||||
sprint_id = sprint_map[sprint_key].id
|
||||
|
||||
# Resolve assigned agent
|
||||
assigned_agent_id = None
|
||||
assigned_agent_name = issue_data.get("assigned_agent_name")
|
||||
if assigned_agent_name:
|
||||
agent_key = f"{project_slug}:{assigned_agent_name}"
|
||||
if agent_key in agent_instance_map:
|
||||
assigned_agent_id = agent_instance_map[agent_key].id
|
||||
|
||||
issue = Issue(
|
||||
project_id=issue_project.id,
|
||||
sprint_id=sprint_id,
|
||||
type=IssueType(issue_data.get("type", "task")),
|
||||
title=issue_data["title"],
|
||||
body=issue_data.get("body", ""),
|
||||
status=IssueStatus(issue_data.get("status", "open")),
|
||||
priority=IssuePriority(issue_data.get("priority", "medium")),
|
||||
labels=issue_data.get("labels", []),
|
||||
story_points=issue_data.get("story_points"),
|
||||
assigned_agent_id=assigned_agent_id,
|
||||
)
|
||||
session.add(issue)
|
||||
logger.info(f"Created demo issue: {issue.title[:50]}...")
|
||||
else:
|
||||
logger.debug(
|
||||
f"Demo issue already exists: {existing_issue.title[:50]}..."
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
logger.info("Demo data loaded successfully")
|
||||
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.error(f"Error loading demo data: {e}")
|
||||
raise
|
||||
|
||||
@@ -210,12 +491,12 @@ async def main():
|
||||
try:
|
||||
user = await init_db()
|
||||
if user:
|
||||
print("✓ Database initialized successfully")
|
||||
print(f"✓ Superuser: {user.email}")
|
||||
print("Database initialized successfully")
|
||||
print(f"Superuser: {user.email}")
|
||||
else:
|
||||
print("✗ Failed to initialize database")
|
||||
print("Failed to initialize database")
|
||||
except Exception as e:
|
||||
print(f"✗ Error initializing database: {e}")
|
||||
print(f"Error initializing database: {e}")
|
||||
raise
|
||||
finally:
|
||||
# Close the engine
|
||||
|
||||
@@ -8,6 +8,19 @@ from app.core.database import Base
|
||||
|
||||
from .base import TimestampMixin, UUIDMixin
|
||||
|
||||
# Memory system models
|
||||
from .memory import (
|
||||
ConsolidationStatus,
|
||||
ConsolidationType,
|
||||
Episode,
|
||||
EpisodeOutcome,
|
||||
Fact,
|
||||
MemoryConsolidationLog,
|
||||
Procedure,
|
||||
ScopeType,
|
||||
WorkingMemory,
|
||||
)
|
||||
|
||||
# OAuth models (client mode - authenticate via Google/GitHub)
|
||||
from .oauth_account import OAuthAccount
|
||||
|
||||
@@ -37,7 +50,14 @@ __all__ = [
|
||||
"AgentInstance",
|
||||
"AgentType",
|
||||
"Base",
|
||||
# Memory models
|
||||
"ConsolidationStatus",
|
||||
"ConsolidationType",
|
||||
"Episode",
|
||||
"EpisodeOutcome",
|
||||
"Fact",
|
||||
"Issue",
|
||||
"MemoryConsolidationLog",
|
||||
"OAuthAccount",
|
||||
"OAuthAuthorizationCode",
|
||||
"OAuthClient",
|
||||
@@ -46,11 +66,14 @@ __all__ = [
|
||||
"OAuthState",
|
||||
"Organization",
|
||||
"OrganizationRole",
|
||||
"Procedure",
|
||||
"Project",
|
||||
"ScopeType",
|
||||
"Sprint",
|
||||
"TimestampMixin",
|
||||
"UUIDMixin",
|
||||
"User",
|
||||
"UserOrganization",
|
||||
"UserSession",
|
||||
"WorkingMemory",
|
||||
]
|
||||
|
||||
32
backend/app/models/memory/__init__.py
Normal file
32
backend/app/models/memory/__init__.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# app/models/memory/__init__.py
|
||||
"""
|
||||
Memory System Database Models.
|
||||
|
||||
Provides SQLAlchemy models for the Agent Memory System:
|
||||
- WorkingMemory: Key-value storage with TTL
|
||||
- Episode: Experiential memories
|
||||
- Fact: Semantic knowledge triples
|
||||
- Procedure: Learned skills
|
||||
- MemoryConsolidationLog: Consolidation job tracking
|
||||
"""
|
||||
|
||||
from .consolidation import MemoryConsolidationLog
|
||||
from .enums import ConsolidationStatus, ConsolidationType, EpisodeOutcome, ScopeType
|
||||
from .episode import Episode
|
||||
from .fact import Fact
|
||||
from .procedure import Procedure
|
||||
from .working_memory import WorkingMemory
|
||||
|
||||
__all__ = [
|
||||
# Enums
|
||||
"ConsolidationStatus",
|
||||
"ConsolidationType",
|
||||
# Models
|
||||
"Episode",
|
||||
"EpisodeOutcome",
|
||||
"Fact",
|
||||
"MemoryConsolidationLog",
|
||||
"Procedure",
|
||||
"ScopeType",
|
||||
"WorkingMemory",
|
||||
]
|
||||
72
backend/app/models/memory/consolidation.py
Normal file
72
backend/app/models/memory/consolidation.py
Normal file
@@ -0,0 +1,72 @@
|
||||
# app/models/memory/consolidation.py
|
||||
"""
|
||||
Memory Consolidation Log database model.
|
||||
|
||||
Tracks memory consolidation jobs that transfer knowledge
|
||||
between memory tiers.
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, DateTime, Enum, Index, Integer, Text
|
||||
|
||||
from app.models.base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
from .enums import ConsolidationStatus, ConsolidationType
|
||||
|
||||
|
||||
class MemoryConsolidationLog(Base, UUIDMixin, TimestampMixin):
|
||||
"""
|
||||
Memory consolidation job log.
|
||||
|
||||
Tracks consolidation operations:
|
||||
- Working -> Episodic (session end)
|
||||
- Episodic -> Semantic (fact extraction)
|
||||
- Episodic -> Procedural (procedure learning)
|
||||
- Pruning (removing low-value memories)
|
||||
"""
|
||||
|
||||
__tablename__ = "memory_consolidation_log"
|
||||
|
||||
# Consolidation type
|
||||
consolidation_type: Column[ConsolidationType] = Column(
|
||||
Enum(ConsolidationType),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Counts
|
||||
source_count = Column(Integer, nullable=False, default=0)
|
||||
result_count = Column(Integer, nullable=False, default=0)
|
||||
|
||||
# Timing
|
||||
started_at = Column(DateTime(timezone=True), nullable=False)
|
||||
completed_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Status
|
||||
status: Column[ConsolidationStatus] = Column(
|
||||
Enum(ConsolidationStatus),
|
||||
nullable=False,
|
||||
default=ConsolidationStatus.PENDING,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Error details if failed
|
||||
error = Column(Text, nullable=True)
|
||||
|
||||
__table_args__ = (
|
||||
# Query patterns
|
||||
Index("ix_consolidation_type_status", "consolidation_type", "status"),
|
||||
Index("ix_consolidation_started", "started_at"),
|
||||
)
|
||||
|
||||
@property
|
||||
def duration_seconds(self) -> float | None:
|
||||
"""Calculate duration of the consolidation job."""
|
||||
if self.completed_at is None or self.started_at is None:
|
||||
return None
|
||||
return (self.completed_at - self.started_at).total_seconds()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<MemoryConsolidationLog {self.id} "
|
||||
f"type={self.consolidation_type.value} status={self.status.value}>"
|
||||
)
|
||||
73
backend/app/models/memory/enums.py
Normal file
73
backend/app/models/memory/enums.py
Normal file
@@ -0,0 +1,73 @@
|
||||
# app/models/memory/enums.py
|
||||
"""
|
||||
Enums for Memory System database models.
|
||||
|
||||
These enums define the database-level constraints for memory types
|
||||
and scoping levels.
|
||||
"""
|
||||
|
||||
from enum import Enum as PyEnum
|
||||
|
||||
|
||||
class ScopeType(str, PyEnum):
|
||||
"""
|
||||
Memory scope levels matching the memory service types.
|
||||
|
||||
GLOBAL: System-wide memories accessible by all
|
||||
PROJECT: Project-scoped memories
|
||||
AGENT_TYPE: Type-specific memories (shared by instances of same type)
|
||||
AGENT_INSTANCE: Instance-specific memories
|
||||
SESSION: Session-scoped ephemeral memories
|
||||
"""
|
||||
|
||||
GLOBAL = "global"
|
||||
PROJECT = "project"
|
||||
AGENT_TYPE = "agent_type"
|
||||
AGENT_INSTANCE = "agent_instance"
|
||||
SESSION = "session"
|
||||
|
||||
|
||||
class EpisodeOutcome(str, PyEnum):
|
||||
"""
|
||||
Outcome of an episode (task execution).
|
||||
|
||||
SUCCESS: Task completed successfully
|
||||
FAILURE: Task failed
|
||||
PARTIAL: Task partially completed
|
||||
"""
|
||||
|
||||
SUCCESS = "success"
|
||||
FAILURE = "failure"
|
||||
PARTIAL = "partial"
|
||||
|
||||
|
||||
class ConsolidationType(str, PyEnum):
|
||||
"""
|
||||
Types of memory consolidation operations.
|
||||
|
||||
WORKING_TO_EPISODIC: Transfer session state to episodic
|
||||
EPISODIC_TO_SEMANTIC: Extract facts from episodes
|
||||
EPISODIC_TO_PROCEDURAL: Extract procedures from episodes
|
||||
PRUNING: Remove low-value memories
|
||||
"""
|
||||
|
||||
WORKING_TO_EPISODIC = "working_to_episodic"
|
||||
EPISODIC_TO_SEMANTIC = "episodic_to_semantic"
|
||||
EPISODIC_TO_PROCEDURAL = "episodic_to_procedural"
|
||||
PRUNING = "pruning"
|
||||
|
||||
|
||||
class ConsolidationStatus(str, PyEnum):
|
||||
"""
|
||||
Status of a consolidation job.
|
||||
|
||||
PENDING: Job is queued
|
||||
RUNNING: Job is currently executing
|
||||
COMPLETED: Job finished successfully
|
||||
FAILED: Job failed with errors
|
||||
"""
|
||||
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
139
backend/app/models/memory/episode.py
Normal file
139
backend/app/models/memory/episode.py
Normal file
@@ -0,0 +1,139 @@
|
||||
# app/models/memory/episode.py
|
||||
"""
|
||||
Episode database model.
|
||||
|
||||
Stores experiential memories - records of past task executions
|
||||
with context, actions, outcomes, and lessons learned.
|
||||
"""
|
||||
|
||||
from sqlalchemy import (
|
||||
BigInteger,
|
||||
CheckConstraint,
|
||||
Column,
|
||||
DateTime,
|
||||
Enum,
|
||||
Float,
|
||||
ForeignKey,
|
||||
Index,
|
||||
String,
|
||||
Text,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import (
|
||||
JSONB,
|
||||
UUID as PGUUID,
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.models.base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
from .enums import EpisodeOutcome
|
||||
|
||||
# Import pgvector type - will be available after migration enables extension
|
||||
try:
|
||||
from pgvector.sqlalchemy import Vector # type: ignore[import-not-found]
|
||||
except ImportError:
|
||||
# Fallback for environments without pgvector
|
||||
Vector = None
|
||||
|
||||
|
||||
class Episode(Base, UUIDMixin, TimestampMixin):
|
||||
"""
|
||||
Episodic memory model.
|
||||
|
||||
Records experiential memories from agent task execution:
|
||||
- What task was performed
|
||||
- What actions were taken
|
||||
- What was the outcome
|
||||
- What lessons were learned
|
||||
"""
|
||||
|
||||
__tablename__ = "episodes"
|
||||
|
||||
# Foreign keys
|
||||
project_id = Column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("projects.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
agent_instance_id = Column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("agent_instances.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
|
||||
agent_type_id = Column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("agent_types.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Session reference
|
||||
session_id = Column(String(255), nullable=False, index=True)
|
||||
|
||||
# Task information
|
||||
task_type = Column(String(100), nullable=False, index=True)
|
||||
task_description = Column(Text, nullable=False)
|
||||
|
||||
# Actions taken (list of action dictionaries)
|
||||
actions = Column(JSONB, default=list, nullable=False)
|
||||
|
||||
# Context summary
|
||||
context_summary = Column(Text, nullable=False)
|
||||
|
||||
# Outcome
|
||||
outcome: Column[EpisodeOutcome] = Column(
|
||||
Enum(EpisodeOutcome),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
outcome_details = Column(Text, nullable=True)
|
||||
|
||||
# Metrics
|
||||
duration_seconds = Column(Float, nullable=False, default=0.0)
|
||||
tokens_used = Column(BigInteger, nullable=False, default=0)
|
||||
|
||||
# Learning
|
||||
lessons_learned = Column(JSONB, default=list, nullable=False)
|
||||
importance_score = Column(Float, nullable=False, default=0.5, index=True)
|
||||
|
||||
# Vector embedding for semantic search
|
||||
# Using 1536 dimensions for OpenAI text-embedding-3-small
|
||||
embedding = Column(Vector(1536) if Vector else Text, nullable=True)
|
||||
|
||||
# When the episode occurred
|
||||
occurred_at = Column(DateTime(timezone=True), nullable=False, index=True)
|
||||
|
||||
# Relationships
|
||||
project = relationship("Project", foreign_keys=[project_id])
|
||||
agent_instance = relationship("AgentInstance", foreign_keys=[agent_instance_id])
|
||||
agent_type = relationship("AgentType", foreign_keys=[agent_type_id])
|
||||
|
||||
__table_args__ = (
|
||||
# Primary query patterns
|
||||
Index("ix_episodes_project_task", "project_id", "task_type"),
|
||||
Index("ix_episodes_project_outcome", "project_id", "outcome"),
|
||||
Index("ix_episodes_agent_task", "agent_instance_id", "task_type"),
|
||||
Index("ix_episodes_project_time", "project_id", "occurred_at"),
|
||||
# For importance-based pruning
|
||||
Index("ix_episodes_importance_time", "importance_score", "occurred_at"),
|
||||
# Data integrity constraints
|
||||
CheckConstraint(
|
||||
"importance_score >= 0.0 AND importance_score <= 1.0",
|
||||
name="ck_episodes_importance_range",
|
||||
),
|
||||
CheckConstraint(
|
||||
"duration_seconds >= 0.0",
|
||||
name="ck_episodes_duration_positive",
|
||||
),
|
||||
CheckConstraint(
|
||||
"tokens_used >= 0",
|
||||
name="ck_episodes_tokens_positive",
|
||||
),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Episode {self.id} task={self.task_type} outcome={self.outcome.value}>"
|
||||
120
backend/app/models/memory/fact.py
Normal file
120
backend/app/models/memory/fact.py
Normal file
@@ -0,0 +1,120 @@
|
||||
# app/models/memory/fact.py
|
||||
"""
|
||||
Fact database model.
|
||||
|
||||
Stores semantic memories - learned facts in subject-predicate-object
|
||||
triple format with confidence scores and source tracking.
|
||||
"""
|
||||
|
||||
from sqlalchemy import (
|
||||
CheckConstraint,
|
||||
Column,
|
||||
DateTime,
|
||||
Float,
|
||||
ForeignKey,
|
||||
Index,
|
||||
Integer,
|
||||
String,
|
||||
Text,
|
||||
text,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import (
|
||||
JSONB,
|
||||
UUID as PGUUID,
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.models.base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
# Import pgvector type
|
||||
try:
|
||||
from pgvector.sqlalchemy import Vector # type: ignore[import-not-found]
|
||||
except ImportError:
|
||||
Vector = None
|
||||
|
||||
|
||||
class Fact(Base, UUIDMixin, TimestampMixin):
|
||||
"""
|
||||
Semantic memory model.
|
||||
|
||||
Stores learned facts as subject-predicate-object triples:
|
||||
- "FastAPI" - "uses" - "Starlette framework"
|
||||
- "Project Alpha" - "requires" - "OAuth authentication"
|
||||
|
||||
Facts have confidence scores that decay over time and can be
|
||||
reinforced when the same fact is learned again.
|
||||
"""
|
||||
|
||||
__tablename__ = "facts"
|
||||
|
||||
# Scoping: project_id is NULL for global facts
|
||||
project_id = Column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("projects.id", ondelete="CASCADE"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Triple format
|
||||
subject = Column(String(500), nullable=False, index=True)
|
||||
predicate = Column(String(255), nullable=False, index=True)
|
||||
object = Column(Text, nullable=False)
|
||||
|
||||
# Confidence score (0.0 to 1.0)
|
||||
confidence = Column(Float, nullable=False, default=0.8, index=True)
|
||||
|
||||
# Source tracking: which episodes contributed to this fact (stored as JSONB array of UUID strings)
|
||||
source_episode_ids: Column[list] = Column(JSONB, default=list, nullable=False)
|
||||
|
||||
# Learning history
|
||||
first_learned = Column(DateTime(timezone=True), nullable=False)
|
||||
last_reinforced = Column(DateTime(timezone=True), nullable=False)
|
||||
reinforcement_count = Column(Integer, nullable=False, default=1)
|
||||
|
||||
# Vector embedding for semantic search
|
||||
embedding = Column(Vector(1536) if Vector else Text, nullable=True)
|
||||
|
||||
# Relationships
|
||||
project = relationship("Project", foreign_keys=[project_id])
|
||||
|
||||
__table_args__ = (
|
||||
# Unique constraint on triple within project scope
|
||||
Index(
|
||||
"ix_facts_unique_triple",
|
||||
"project_id",
|
||||
"subject",
|
||||
"predicate",
|
||||
"object",
|
||||
unique=True,
|
||||
postgresql_where=text("project_id IS NOT NULL"),
|
||||
),
|
||||
# Unique constraint on triple for global facts (project_id IS NULL)
|
||||
Index(
|
||||
"ix_facts_unique_triple_global",
|
||||
"subject",
|
||||
"predicate",
|
||||
"object",
|
||||
unique=True,
|
||||
postgresql_where=text("project_id IS NULL"),
|
||||
),
|
||||
# Query patterns
|
||||
Index("ix_facts_subject_predicate", "subject", "predicate"),
|
||||
Index("ix_facts_project_subject", "project_id", "subject"),
|
||||
Index("ix_facts_confidence_time", "confidence", "last_reinforced"),
|
||||
# Note: subject already has index=True on Column definition, no need for explicit index
|
||||
# Data integrity constraints
|
||||
CheckConstraint(
|
||||
"confidence >= 0.0 AND confidence <= 1.0",
|
||||
name="ck_facts_confidence_range",
|
||||
),
|
||||
CheckConstraint(
|
||||
"reinforcement_count >= 1",
|
||||
name="ck_facts_reinforcement_positive",
|
||||
),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<Fact {self.id} '{self.subject}' - '{self.predicate}' - "
|
||||
f"'{self.object[:50]}...' conf={self.confidence:.2f}>"
|
||||
)
|
||||
129
backend/app/models/memory/procedure.py
Normal file
129
backend/app/models/memory/procedure.py
Normal file
@@ -0,0 +1,129 @@
|
||||
# app/models/memory/procedure.py
|
||||
"""
|
||||
Procedure database model.
|
||||
|
||||
Stores procedural memories - learned skills and procedures
|
||||
derived from successful task execution patterns.
|
||||
"""
|
||||
|
||||
from sqlalchemy import (
|
||||
CheckConstraint,
|
||||
Column,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Index,
|
||||
Integer,
|
||||
String,
|
||||
Text,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import (
|
||||
JSONB,
|
||||
UUID as PGUUID,
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.models.base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
# Import pgvector type
|
||||
try:
|
||||
from pgvector.sqlalchemy import Vector # type: ignore[import-not-found]
|
||||
except ImportError:
|
||||
Vector = None
|
||||
|
||||
|
||||
class Procedure(Base, UUIDMixin, TimestampMixin):
|
||||
"""
|
||||
Procedural memory model.
|
||||
|
||||
Stores learned procedures (skills) extracted from successful
|
||||
task execution patterns:
|
||||
- Name and trigger pattern for matching
|
||||
- Step-by-step actions
|
||||
- Success/failure tracking
|
||||
"""
|
||||
|
||||
__tablename__ = "procedures"
|
||||
|
||||
# Scoping
|
||||
project_id = Column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("projects.id", ondelete="CASCADE"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
|
||||
agent_type_id = Column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("agent_types.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Procedure identification
|
||||
name = Column(String(255), nullable=False, index=True)
|
||||
trigger_pattern = Column(Text, nullable=False)
|
||||
|
||||
# Steps as JSON array of step objects
|
||||
# Each step: {order, action, parameters, expected_outcome, fallback_action}
|
||||
steps = Column(JSONB, default=list, nullable=False)
|
||||
|
||||
# Success tracking
|
||||
success_count = Column(Integer, nullable=False, default=0)
|
||||
failure_count = Column(Integer, nullable=False, default=0)
|
||||
|
||||
# Usage tracking
|
||||
last_used = Column(DateTime(timezone=True), nullable=True, index=True)
|
||||
|
||||
# Vector embedding for semantic matching
|
||||
embedding = Column(Vector(1536) if Vector else Text, nullable=True)
|
||||
|
||||
# Relationships
|
||||
project = relationship("Project", foreign_keys=[project_id])
|
||||
agent_type = relationship("AgentType", foreign_keys=[agent_type_id])
|
||||
|
||||
__table_args__ = (
|
||||
# Unique procedure name within scope
|
||||
Index(
|
||||
"ix_procedures_unique_name",
|
||||
"project_id",
|
||||
"agent_type_id",
|
||||
"name",
|
||||
unique=True,
|
||||
),
|
||||
# Query patterns
|
||||
Index("ix_procedures_project_name", "project_id", "name"),
|
||||
# Note: agent_type_id already has index=True on Column definition
|
||||
# For finding best procedures
|
||||
Index("ix_procedures_success_rate", "success_count", "failure_count"),
|
||||
# Data integrity constraints
|
||||
CheckConstraint(
|
||||
"success_count >= 0",
|
||||
name="ck_procedures_success_positive",
|
||||
),
|
||||
CheckConstraint(
|
||||
"failure_count >= 0",
|
||||
name="ck_procedures_failure_positive",
|
||||
),
|
||||
)
|
||||
|
||||
@property
|
||||
def success_rate(self) -> float:
|
||||
"""Calculate the success rate of this procedure."""
|
||||
# Snapshot values to avoid race conditions in concurrent access
|
||||
success = self.success_count
|
||||
failure = self.failure_count
|
||||
total = success + failure
|
||||
if total == 0:
|
||||
return 0.0
|
||||
return success / total
|
||||
|
||||
@property
|
||||
def total_uses(self) -> int:
|
||||
"""Get total number of times this procedure was used."""
|
||||
# Snapshot values for consistency
|
||||
return self.success_count + self.failure_count
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<Procedure {self.name} ({self.id}) success_rate={self.success_rate:.2%}>"
|
||||
)
|
||||
58
backend/app/models/memory/working_memory.py
Normal file
58
backend/app/models/memory/working_memory.py
Normal file
@@ -0,0 +1,58 @@
|
||||
# app/models/memory/working_memory.py
|
||||
"""
|
||||
Working Memory database model.
|
||||
|
||||
Stores ephemeral key-value data for active sessions with TTL support.
|
||||
Used as database backup when Redis is unavailable.
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, DateTime, Enum, Index, String
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
|
||||
from app.models.base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
from .enums import ScopeType
|
||||
|
||||
|
||||
class WorkingMemory(Base, UUIDMixin, TimestampMixin):
|
||||
"""
|
||||
Working memory storage table.
|
||||
|
||||
Provides database-backed working memory as fallback when
|
||||
Redis is unavailable. Supports TTL-based expiration.
|
||||
"""
|
||||
|
||||
__tablename__ = "working_memory"
|
||||
|
||||
# Scoping
|
||||
scope_type: Column[ScopeType] = Column(
|
||||
Enum(ScopeType),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
scope_id = Column(String(255), nullable=False, index=True)
|
||||
|
||||
# Key-value storage
|
||||
key = Column(String(255), nullable=False)
|
||||
value = Column(JSONB, nullable=False)
|
||||
|
||||
# TTL support
|
||||
expires_at = Column(DateTime(timezone=True), nullable=True, index=True)
|
||||
|
||||
__table_args__ = (
|
||||
# Primary lookup: scope + key
|
||||
Index(
|
||||
"ix_working_memory_scope_key",
|
||||
"scope_type",
|
||||
"scope_id",
|
||||
"key",
|
||||
unique=True,
|
||||
),
|
||||
# For cleanup of expired entries
|
||||
Index("ix_working_memory_expires", "expires_at"),
|
||||
# For listing all keys in a scope
|
||||
Index("ix_working_memory_scope_list", "scope_type", "scope_id"),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<WorkingMemory {self.scope_type.value}:{self.scope_id}:{self.key}>"
|
||||
@@ -62,7 +62,11 @@ class AgentInstance(Base, UUIDMixin, TimestampMixin):
|
||||
|
||||
# Status tracking
|
||||
status: Column[AgentStatus] = Column(
|
||||
Enum(AgentStatus),
|
||||
Enum(
|
||||
AgentStatus,
|
||||
name="agent_status",
|
||||
values_callable=lambda x: [e.value for e in x],
|
||||
),
|
||||
default=AgentStatus.IDLE,
|
||||
nullable=False,
|
||||
index=True,
|
||||
|
||||
@@ -59,7 +59,9 @@ class Issue(Base, UUIDMixin, TimestampMixin):
|
||||
|
||||
# Issue type (Epic, Story, Task, Bug)
|
||||
type: Column[IssueType] = Column(
|
||||
Enum(IssueType),
|
||||
Enum(
|
||||
IssueType, name="issue_type", values_callable=lambda x: [e.value for e in x]
|
||||
),
|
||||
default=IssueType.TASK,
|
||||
nullable=False,
|
||||
index=True,
|
||||
@@ -78,14 +80,22 @@ class Issue(Base, UUIDMixin, TimestampMixin):
|
||||
|
||||
# Status and priority
|
||||
status: Column[IssueStatus] = Column(
|
||||
Enum(IssueStatus),
|
||||
Enum(
|
||||
IssueStatus,
|
||||
name="issue_status",
|
||||
values_callable=lambda x: [e.value for e in x],
|
||||
),
|
||||
default=IssueStatus.OPEN,
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
priority: Column[IssuePriority] = Column(
|
||||
Enum(IssuePriority),
|
||||
Enum(
|
||||
IssuePriority,
|
||||
name="issue_priority",
|
||||
values_callable=lambda x: [e.value for e in x],
|
||||
),
|
||||
default=IssuePriority.MEDIUM,
|
||||
nullable=False,
|
||||
index=True,
|
||||
@@ -132,7 +142,11 @@ class Issue(Base, UUIDMixin, TimestampMixin):
|
||||
|
||||
# Sync status with external tracker
|
||||
sync_status: Column[SyncStatus] = Column(
|
||||
Enum(SyncStatus),
|
||||
Enum(
|
||||
SyncStatus,
|
||||
name="sync_status",
|
||||
values_callable=lambda x: [e.value for e in x],
|
||||
),
|
||||
default=SyncStatus.SYNCED,
|
||||
nullable=False,
|
||||
# Note: Index defined in __table_args__ as ix_issues_sync_status
|
||||
|
||||
@@ -35,28 +35,44 @@ class Project(Base, UUIDMixin, TimestampMixin):
|
||||
description = Column(Text, nullable=True)
|
||||
|
||||
autonomy_level: Column[AutonomyLevel] = Column(
|
||||
Enum(AutonomyLevel),
|
||||
Enum(
|
||||
AutonomyLevel,
|
||||
name="autonomy_level",
|
||||
values_callable=lambda x: [e.value for e in x],
|
||||
),
|
||||
default=AutonomyLevel.MILESTONE,
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
status: Column[ProjectStatus] = Column(
|
||||
Enum(ProjectStatus),
|
||||
Enum(
|
||||
ProjectStatus,
|
||||
name="project_status",
|
||||
values_callable=lambda x: [e.value for e in x],
|
||||
),
|
||||
default=ProjectStatus.ACTIVE,
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
complexity: Column[ProjectComplexity] = Column(
|
||||
Enum(ProjectComplexity),
|
||||
Enum(
|
||||
ProjectComplexity,
|
||||
name="project_complexity",
|
||||
values_callable=lambda x: [e.value for e in x],
|
||||
),
|
||||
default=ProjectComplexity.MEDIUM,
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
client_mode: Column[ClientMode] = Column(
|
||||
Enum(ClientMode),
|
||||
Enum(
|
||||
ClientMode,
|
||||
name="client_mode",
|
||||
values_callable=lambda x: [e.value for e in x],
|
||||
),
|
||||
default=ClientMode.AUTO,
|
||||
nullable=False,
|
||||
index=True,
|
||||
|
||||
@@ -57,7 +57,11 @@ class Sprint(Base, UUIDMixin, TimestampMixin):
|
||||
|
||||
# Status
|
||||
status: Column[SprintStatus] = Column(
|
||||
Enum(SprintStatus),
|
||||
Enum(
|
||||
SprintStatus,
|
||||
name="sprint_status",
|
||||
values_callable=lambda x: [e.value for e in x],
|
||||
),
|
||||
default=SprintStatus.PLANNED,
|
||||
nullable=False,
|
||||
index=True,
|
||||
|
||||
@@ -114,6 +114,8 @@ from .types import (
|
||||
ContextType,
|
||||
ConversationContext,
|
||||
KnowledgeContext,
|
||||
MemoryContext,
|
||||
MemorySubtype,
|
||||
MessageRole,
|
||||
SystemContext,
|
||||
TaskComplexity,
|
||||
@@ -149,6 +151,8 @@ __all__ = [
|
||||
"FormattingError",
|
||||
"InvalidContextError",
|
||||
"KnowledgeContext",
|
||||
"MemoryContext",
|
||||
"MemorySubtype",
|
||||
"MessageRole",
|
||||
"ModelAdapter",
|
||||
"OpenAIAdapter",
|
||||
|
||||
@@ -30,6 +30,7 @@ class TokenBudget:
|
||||
knowledge: int = 0
|
||||
conversation: int = 0
|
||||
tools: int = 0
|
||||
memory: int = 0 # Agent memory (working, episodic, semantic, procedural)
|
||||
response_reserve: int = 0
|
||||
buffer: int = 0
|
||||
|
||||
@@ -60,6 +61,7 @@ class TokenBudget:
|
||||
"knowledge": self.knowledge,
|
||||
"conversation": self.conversation,
|
||||
"tool": self.tools,
|
||||
"memory": self.memory,
|
||||
}
|
||||
return allocation_map.get(context_type, 0)
|
||||
|
||||
@@ -211,6 +213,7 @@ class TokenBudget:
|
||||
"knowledge": self.knowledge,
|
||||
"conversation": self.conversation,
|
||||
"tools": self.tools,
|
||||
"memory": self.memory,
|
||||
"response_reserve": self.response_reserve,
|
||||
"buffer": self.buffer,
|
||||
},
|
||||
@@ -264,9 +267,10 @@ class BudgetAllocator:
|
||||
total=total_tokens,
|
||||
system=int(total_tokens * alloc.get("system", 0.05)),
|
||||
task=int(total_tokens * alloc.get("task", 0.10)),
|
||||
knowledge=int(total_tokens * alloc.get("knowledge", 0.40)),
|
||||
conversation=int(total_tokens * alloc.get("conversation", 0.20)),
|
||||
knowledge=int(total_tokens * alloc.get("knowledge", 0.30)),
|
||||
conversation=int(total_tokens * alloc.get("conversation", 0.15)),
|
||||
tools=int(total_tokens * alloc.get("tools", 0.05)),
|
||||
memory=int(total_tokens * alloc.get("memory", 0.15)),
|
||||
response_reserve=int(total_tokens * alloc.get("response", 0.15)),
|
||||
buffer=int(total_tokens * alloc.get("buffer", 0.05)),
|
||||
)
|
||||
@@ -317,6 +321,8 @@ class BudgetAllocator:
|
||||
budget.conversation = max(0, budget.conversation + actual_adjustment)
|
||||
elif context_type == "tool":
|
||||
budget.tools = max(0, budget.tools + actual_adjustment)
|
||||
elif context_type == "memory":
|
||||
budget.memory = max(0, budget.memory + actual_adjustment)
|
||||
|
||||
return budget
|
||||
|
||||
@@ -338,7 +344,12 @@ class BudgetAllocator:
|
||||
Rebalanced budget
|
||||
"""
|
||||
if prioritize is None:
|
||||
prioritize = [ContextType.KNOWLEDGE, ContextType.TASK, ContextType.SYSTEM]
|
||||
prioritize = [
|
||||
ContextType.KNOWLEDGE,
|
||||
ContextType.MEMORY,
|
||||
ContextType.TASK,
|
||||
ContextType.SYSTEM,
|
||||
]
|
||||
|
||||
# Calculate unused tokens per type
|
||||
unused: dict[str, int] = {}
|
||||
|
||||
@@ -7,6 +7,7 @@ Provides a high-level API for assembling optimized context for LLM requests.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from uuid import UUID
|
||||
|
||||
from .assembly import ContextPipeline
|
||||
from .budget import BudgetAllocator, TokenBudget, TokenCalculator
|
||||
@@ -20,6 +21,7 @@ from .types import (
|
||||
BaseContext,
|
||||
ConversationContext,
|
||||
KnowledgeContext,
|
||||
MemoryContext,
|
||||
MessageRole,
|
||||
SystemContext,
|
||||
TaskContext,
|
||||
@@ -30,6 +32,7 @@ if TYPE_CHECKING:
|
||||
from redis.asyncio import Redis
|
||||
|
||||
from app.services.mcp.client_manager import MCPClientManager
|
||||
from app.services.memory.integration import MemoryContextSource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -64,6 +67,7 @@ class ContextEngine:
|
||||
mcp_manager: "MCPClientManager | None" = None,
|
||||
redis: "Redis | None" = None,
|
||||
settings: ContextSettings | None = None,
|
||||
memory_source: "MemoryContextSource | None" = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the context engine.
|
||||
@@ -72,9 +76,11 @@ class ContextEngine:
|
||||
mcp_manager: MCP client manager for LLM Gateway/Knowledge Base
|
||||
redis: Redis connection for caching
|
||||
settings: Context settings
|
||||
memory_source: Optional memory context source for agent memory
|
||||
"""
|
||||
self._mcp = mcp_manager
|
||||
self._settings = settings or get_context_settings()
|
||||
self._memory_source = memory_source
|
||||
|
||||
# Initialize components
|
||||
self._calculator = TokenCalculator(mcp_manager=mcp_manager)
|
||||
@@ -115,6 +121,15 @@ class ContextEngine:
|
||||
"""
|
||||
self._cache.set_redis(redis)
|
||||
|
||||
def set_memory_source(self, memory_source: "MemoryContextSource") -> None:
|
||||
"""
|
||||
Set memory context source for agent memory integration.
|
||||
|
||||
Args:
|
||||
memory_source: Memory context source
|
||||
"""
|
||||
self._memory_source = memory_source
|
||||
|
||||
async def assemble_context(
|
||||
self,
|
||||
project_id: str,
|
||||
@@ -126,6 +141,10 @@ class ContextEngine:
|
||||
task_description: str | None = None,
|
||||
knowledge_query: str | None = None,
|
||||
knowledge_limit: int = 10,
|
||||
memory_query: str | None = None,
|
||||
memory_limit: int = 20,
|
||||
session_id: str | None = None,
|
||||
agent_type_id: str | None = None,
|
||||
conversation_history: list[dict[str, str]] | None = None,
|
||||
tool_results: list[dict[str, Any]] | None = None,
|
||||
custom_contexts: list[BaseContext] | None = None,
|
||||
@@ -151,6 +170,10 @@ class ContextEngine:
|
||||
task_description: Current task description
|
||||
knowledge_query: Query for knowledge base search
|
||||
knowledge_limit: Max number of knowledge results
|
||||
memory_query: Query for agent memory search
|
||||
memory_limit: Max number of memory results
|
||||
session_id: Session ID for working memory access
|
||||
agent_type_id: Agent type ID for procedural memory
|
||||
conversation_history: List of {"role": str, "content": str}
|
||||
tool_results: List of tool results to include
|
||||
custom_contexts: Additional custom contexts
|
||||
@@ -197,15 +220,27 @@ class ContextEngine:
|
||||
)
|
||||
contexts.extend(knowledge_contexts)
|
||||
|
||||
# 4. Conversation history
|
||||
# 4. Memory context from Agent Memory System
|
||||
if memory_query and self._memory_source:
|
||||
memory_contexts = await self._fetch_memory(
|
||||
project_id=project_id,
|
||||
agent_id=agent_id,
|
||||
query=memory_query,
|
||||
limit=memory_limit,
|
||||
session_id=session_id,
|
||||
agent_type_id=agent_type_id,
|
||||
)
|
||||
contexts.extend(memory_contexts)
|
||||
|
||||
# 5. Conversation history
|
||||
if conversation_history:
|
||||
contexts.extend(self._convert_conversation(conversation_history))
|
||||
|
||||
# 5. Tool results
|
||||
# 6. Tool results
|
||||
if tool_results:
|
||||
contexts.extend(self._convert_tool_results(tool_results))
|
||||
|
||||
# 6. Custom contexts
|
||||
# 7. Custom contexts
|
||||
if custom_contexts:
|
||||
contexts.extend(custom_contexts)
|
||||
|
||||
@@ -308,6 +343,65 @@ class ContextEngine:
|
||||
logger.warning(f"Failed to fetch knowledge: {e}")
|
||||
return []
|
||||
|
||||
async def _fetch_memory(
|
||||
self,
|
||||
project_id: str,
|
||||
agent_id: str,
|
||||
query: str,
|
||||
limit: int = 20,
|
||||
session_id: str | None = None,
|
||||
agent_type_id: str | None = None,
|
||||
) -> list[MemoryContext]:
|
||||
"""
|
||||
Fetch relevant memories from Agent Memory System.
|
||||
|
||||
Args:
|
||||
project_id: Project identifier
|
||||
agent_id: Agent identifier
|
||||
query: Search query
|
||||
limit: Maximum results
|
||||
session_id: Session ID for working memory
|
||||
agent_type_id: Agent type ID for procedural memory
|
||||
|
||||
Returns:
|
||||
List of MemoryContext instances
|
||||
"""
|
||||
if not self._memory_source:
|
||||
return []
|
||||
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
|
||||
# Configure fetch limits
|
||||
from app.services.memory.integration.context_source import MemoryFetchConfig
|
||||
|
||||
config = MemoryFetchConfig(
|
||||
working_limit=min(limit // 4, 5),
|
||||
episodic_limit=min(limit // 2, 10),
|
||||
semantic_limit=min(limit // 2, 10),
|
||||
procedural_limit=min(limit // 4, 5),
|
||||
include_working=session_id is not None,
|
||||
)
|
||||
|
||||
result = await self._memory_source.fetch_context(
|
||||
query=query,
|
||||
project_id=UUID(project_id),
|
||||
agent_instance_id=UUID(agent_id) if agent_id else None,
|
||||
agent_type_id=UUID(agent_type_id) if agent_type_id else None,
|
||||
session_id=session_id,
|
||||
config=config,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Fetched {len(result.contexts)} memory contexts for query: {query}, "
|
||||
f"by_type: {result.by_type}"
|
||||
)
|
||||
return result.contexts[:limit]
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch memory: {e}")
|
||||
return []
|
||||
|
||||
def _convert_conversation(
|
||||
self,
|
||||
history: list[dict[str, str]],
|
||||
@@ -466,6 +560,7 @@ def create_context_engine(
|
||||
mcp_manager: "MCPClientManager | None" = None,
|
||||
redis: "Redis | None" = None,
|
||||
settings: ContextSettings | None = None,
|
||||
memory_source: "MemoryContextSource | None" = None,
|
||||
) -> ContextEngine:
|
||||
"""
|
||||
Create a context engine instance.
|
||||
@@ -474,6 +569,7 @@ def create_context_engine(
|
||||
mcp_manager: MCP client manager
|
||||
redis: Redis connection
|
||||
settings: Context settings
|
||||
memory_source: Optional memory context source
|
||||
|
||||
Returns:
|
||||
Configured ContextEngine instance
|
||||
@@ -482,4 +578,5 @@ def create_context_engine(
|
||||
mcp_manager=mcp_manager,
|
||||
redis=redis,
|
||||
settings=settings,
|
||||
memory_source=memory_source,
|
||||
)
|
||||
|
||||
@@ -15,6 +15,10 @@ from .conversation import (
|
||||
MessageRole,
|
||||
)
|
||||
from .knowledge import KnowledgeContext
|
||||
from .memory import (
|
||||
MemoryContext,
|
||||
MemorySubtype,
|
||||
)
|
||||
from .system import SystemContext
|
||||
from .task import (
|
||||
TaskComplexity,
|
||||
@@ -33,6 +37,8 @@ __all__ = [
|
||||
"ContextType",
|
||||
"ConversationContext",
|
||||
"KnowledgeContext",
|
||||
"MemoryContext",
|
||||
"MemorySubtype",
|
||||
"MessageRole",
|
||||
"SystemContext",
|
||||
"TaskComplexity",
|
||||
|
||||
@@ -26,6 +26,7 @@ class ContextType(str, Enum):
|
||||
KNOWLEDGE = "knowledge"
|
||||
CONVERSATION = "conversation"
|
||||
TOOL = "tool"
|
||||
MEMORY = "memory" # Agent memory (working, episodic, semantic, procedural)
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, value: str) -> "ContextType":
|
||||
|
||||
282
backend/app/services/context/types/memory.py
Normal file
282
backend/app/services/context/types/memory.py
Normal file
@@ -0,0 +1,282 @@
|
||||
"""
|
||||
Memory Context Type.
|
||||
|
||||
Represents agent memory as context for LLM requests.
|
||||
Includes working, episodic, semantic, and procedural memories.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from .base import BaseContext, ContextPriority, ContextType
|
||||
|
||||
|
||||
class MemorySubtype(str, Enum):
|
||||
"""Types of agent memory."""
|
||||
|
||||
WORKING = "working" # Session-scoped temporary data
|
||||
EPISODIC = "episodic" # Task history and outcomes
|
||||
SEMANTIC = "semantic" # Facts and knowledge
|
||||
PROCEDURAL = "procedural" # Learned procedures
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class MemoryContext(BaseContext):
|
||||
"""
|
||||
Context from agent memory system.
|
||||
|
||||
Memory context represents data retrieved from the agent
|
||||
memory system, including:
|
||||
- Working memory: Current session state
|
||||
- Episodic memory: Past task experiences
|
||||
- Semantic memory: Learned facts and knowledge
|
||||
- Procedural memory: Known procedures and workflows
|
||||
|
||||
Each memory item includes relevance scoring from search.
|
||||
"""
|
||||
|
||||
# Memory-specific fields
|
||||
memory_subtype: MemorySubtype = field(default=MemorySubtype.EPISODIC)
|
||||
memory_id: str | None = field(default=None)
|
||||
relevance_score: float = field(default=0.0)
|
||||
importance: float = field(default=0.5)
|
||||
search_query: str = field(default="")
|
||||
|
||||
# Type-specific fields (populated based on memory_subtype)
|
||||
key: str | None = field(default=None) # For working memory
|
||||
task_type: str | None = field(default=None) # For episodic
|
||||
outcome: str | None = field(default=None) # For episodic
|
||||
subject: str | None = field(default=None) # For semantic
|
||||
predicate: str | None = field(default=None) # For semantic
|
||||
object_value: str | None = field(default=None) # For semantic
|
||||
trigger: str | None = field(default=None) # For procedural
|
||||
success_rate: float | None = field(default=None) # For procedural
|
||||
|
||||
def get_type(self) -> ContextType:
|
||||
"""Return MEMORY context type."""
|
||||
return ContextType.MEMORY
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary with memory-specific fields."""
|
||||
base = super().to_dict()
|
||||
base.update(
|
||||
{
|
||||
"memory_subtype": self.memory_subtype.value,
|
||||
"memory_id": self.memory_id,
|
||||
"relevance_score": self.relevance_score,
|
||||
"importance": self.importance,
|
||||
"search_query": self.search_query,
|
||||
"key": self.key,
|
||||
"task_type": self.task_type,
|
||||
"outcome": self.outcome,
|
||||
"subject": self.subject,
|
||||
"predicate": self.predicate,
|
||||
"object_value": self.object_value,
|
||||
"trigger": self.trigger,
|
||||
"success_rate": self.success_rate,
|
||||
}
|
||||
)
|
||||
return base
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "MemoryContext":
|
||||
"""Create MemoryContext from dictionary."""
|
||||
return cls(
|
||||
id=data.get("id", ""),
|
||||
content=data["content"],
|
||||
source=data["source"],
|
||||
timestamp=datetime.fromisoformat(data["timestamp"])
|
||||
if isinstance(data.get("timestamp"), str)
|
||||
else data.get("timestamp", datetime.now(UTC)),
|
||||
priority=data.get("priority", ContextPriority.NORMAL.value),
|
||||
metadata=data.get("metadata", {}),
|
||||
memory_subtype=MemorySubtype(data.get("memory_subtype", "episodic")),
|
||||
memory_id=data.get("memory_id"),
|
||||
relevance_score=data.get("relevance_score", 0.0),
|
||||
importance=data.get("importance", 0.5),
|
||||
search_query=data.get("search_query", ""),
|
||||
key=data.get("key"),
|
||||
task_type=data.get("task_type"),
|
||||
outcome=data.get("outcome"),
|
||||
subject=data.get("subject"),
|
||||
predicate=data.get("predicate"),
|
||||
object_value=data.get("object_value"),
|
||||
trigger=data.get("trigger"),
|
||||
success_rate=data.get("success_rate"),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_working_memory(
|
||||
cls,
|
||||
key: str,
|
||||
value: Any,
|
||||
source: str = "working_memory",
|
||||
query: str = "",
|
||||
) -> "MemoryContext":
|
||||
"""
|
||||
Create MemoryContext from working memory entry.
|
||||
|
||||
Args:
|
||||
key: Working memory key
|
||||
value: Value stored at key
|
||||
source: Source identifier
|
||||
query: Search query used
|
||||
|
||||
Returns:
|
||||
MemoryContext instance
|
||||
"""
|
||||
return cls(
|
||||
content=str(value),
|
||||
source=source,
|
||||
memory_subtype=MemorySubtype.WORKING,
|
||||
key=key,
|
||||
relevance_score=1.0, # Working memory is always relevant
|
||||
importance=0.8, # Higher importance for current session state
|
||||
search_query=query,
|
||||
priority=ContextPriority.HIGH.value,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_episodic_memory(
|
||||
cls,
|
||||
episode: Any,
|
||||
query: str = "",
|
||||
) -> "MemoryContext":
|
||||
"""
|
||||
Create MemoryContext from episodic memory episode.
|
||||
|
||||
Args:
|
||||
episode: Episode object from episodic memory
|
||||
query: Search query used
|
||||
|
||||
Returns:
|
||||
MemoryContext instance
|
||||
"""
|
||||
outcome_val = None
|
||||
if hasattr(episode, "outcome") and episode.outcome:
|
||||
outcome_val = (
|
||||
episode.outcome.value
|
||||
if hasattr(episode.outcome, "value")
|
||||
else str(episode.outcome)
|
||||
)
|
||||
|
||||
return cls(
|
||||
content=episode.task_description,
|
||||
source=f"episodic:{episode.id}",
|
||||
memory_subtype=MemorySubtype.EPISODIC,
|
||||
memory_id=str(episode.id),
|
||||
relevance_score=getattr(episode, "importance_score", 0.5),
|
||||
importance=getattr(episode, "importance_score", 0.5),
|
||||
search_query=query,
|
||||
task_type=getattr(episode, "task_type", None),
|
||||
outcome=outcome_val,
|
||||
metadata={
|
||||
"session_id": getattr(episode, "session_id", None),
|
||||
"occurred_at": episode.occurred_at.isoformat()
|
||||
if hasattr(episode, "occurred_at") and episode.occurred_at
|
||||
else None,
|
||||
"lessons_learned": getattr(episode, "lessons_learned", []),
|
||||
},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_semantic_memory(
|
||||
cls,
|
||||
fact: Any,
|
||||
query: str = "",
|
||||
) -> "MemoryContext":
|
||||
"""
|
||||
Create MemoryContext from semantic memory fact.
|
||||
|
||||
Args:
|
||||
fact: Fact object from semantic memory
|
||||
query: Search query used
|
||||
|
||||
Returns:
|
||||
MemoryContext instance
|
||||
"""
|
||||
triple = f"{fact.subject} {fact.predicate} {fact.object}"
|
||||
return cls(
|
||||
content=triple,
|
||||
source=f"semantic:{fact.id}",
|
||||
memory_subtype=MemorySubtype.SEMANTIC,
|
||||
memory_id=str(fact.id),
|
||||
relevance_score=getattr(fact, "confidence", 0.5),
|
||||
importance=getattr(fact, "confidence", 0.5),
|
||||
search_query=query,
|
||||
subject=fact.subject,
|
||||
predicate=fact.predicate,
|
||||
object_value=fact.object,
|
||||
priority=ContextPriority.NORMAL.value,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_procedural_memory(
|
||||
cls,
|
||||
procedure: Any,
|
||||
query: str = "",
|
||||
) -> "MemoryContext":
|
||||
"""
|
||||
Create MemoryContext from procedural memory procedure.
|
||||
|
||||
Args:
|
||||
procedure: Procedure object from procedural memory
|
||||
query: Search query used
|
||||
|
||||
Returns:
|
||||
MemoryContext instance
|
||||
"""
|
||||
# Format steps as content
|
||||
steps = getattr(procedure, "steps", [])
|
||||
steps_content = "\n".join(
|
||||
f" {i + 1}. {step.get('action', step) if isinstance(step, dict) else step}"
|
||||
for i, step in enumerate(steps)
|
||||
)
|
||||
content = f"Procedure: {procedure.name}\nTrigger: {procedure.trigger_pattern}\nSteps:\n{steps_content}"
|
||||
|
||||
return cls(
|
||||
content=content,
|
||||
source=f"procedural:{procedure.id}",
|
||||
memory_subtype=MemorySubtype.PROCEDURAL,
|
||||
memory_id=str(procedure.id),
|
||||
relevance_score=getattr(procedure, "success_rate", 0.5),
|
||||
importance=0.7, # Procedures are moderately important
|
||||
search_query=query,
|
||||
trigger=procedure.trigger_pattern,
|
||||
success_rate=getattr(procedure, "success_rate", None),
|
||||
metadata={
|
||||
"steps_count": len(steps),
|
||||
"execution_count": getattr(procedure, "success_count", 0)
|
||||
+ getattr(procedure, "failure_count", 0),
|
||||
},
|
||||
)
|
||||
|
||||
def is_working_memory(self) -> bool:
|
||||
"""Check if this is working memory."""
|
||||
return self.memory_subtype == MemorySubtype.WORKING
|
||||
|
||||
def is_episodic_memory(self) -> bool:
|
||||
"""Check if this is episodic memory."""
|
||||
return self.memory_subtype == MemorySubtype.EPISODIC
|
||||
|
||||
def is_semantic_memory(self) -> bool:
|
||||
"""Check if this is semantic memory."""
|
||||
return self.memory_subtype == MemorySubtype.SEMANTIC
|
||||
|
||||
def is_procedural_memory(self) -> bool:
|
||||
"""Check if this is procedural memory."""
|
||||
return self.memory_subtype == MemorySubtype.PROCEDURAL
|
||||
|
||||
def get_formatted_source(self) -> str:
|
||||
"""
|
||||
Get a formatted source string for display.
|
||||
|
||||
Returns:
|
||||
Formatted source string
|
||||
"""
|
||||
parts = [f"[{self.memory_subtype.value}]", self.source]
|
||||
if self.memory_id:
|
||||
parts.append(f"({self.memory_id[:8]}...)")
|
||||
return " ".join(parts)
|
||||
@@ -122,16 +122,24 @@ class MCPClientManager:
|
||||
)
|
||||
|
||||
async def _connect_all_servers(self) -> None:
|
||||
"""Connect to all enabled MCP servers."""
|
||||
"""Connect to all enabled MCP servers concurrently."""
|
||||
import asyncio
|
||||
|
||||
enabled_servers = self._registry.get_enabled_configs()
|
||||
|
||||
for name, config in enabled_servers.items():
|
||||
async def connect_server(name: str, config: "MCPServerConfig") -> None:
|
||||
try:
|
||||
await self._pool.get_connection(name, config)
|
||||
logger.info("Connected to MCP server: %s", name)
|
||||
except Exception as e:
|
||||
logger.error("Failed to connect to MCP server %s: %s", name, e)
|
||||
|
||||
# Connect to all servers concurrently for faster startup
|
||||
await asyncio.gather(
|
||||
*(connect_server(name, config) for name, config in enabled_servers.items()),
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""
|
||||
Shutdown the MCP client manager.
|
||||
|
||||
@@ -179,6 +179,8 @@ def load_mcp_config(path: str | Path | None = None) -> MCPConfig:
|
||||
2. MCP_CONFIG_PATH environment variable
|
||||
3. Default path (backend/mcp_servers.yaml)
|
||||
4. Empty config if no file exists
|
||||
|
||||
In test mode (IS_TEST=True), retry settings are reduced for faster tests.
|
||||
"""
|
||||
if path is None:
|
||||
path = os.environ.get("MCP_CONFIG_PATH", str(DEFAULT_CONFIG_PATH))
|
||||
@@ -189,7 +191,18 @@ def load_mcp_config(path: str | Path | None = None) -> MCPConfig:
|
||||
# Return empty config if no file exists (allows runtime registration)
|
||||
return MCPConfig()
|
||||
|
||||
return MCPConfig.from_yaml(path)
|
||||
config = MCPConfig.from_yaml(path)
|
||||
|
||||
# In test mode, reduce retry settings to speed up tests
|
||||
is_test = os.environ.get("IS_TEST", "").lower() in ("true", "1", "yes")
|
||||
if is_test:
|
||||
for server_config in config.mcp_servers.values():
|
||||
server_config.retry_attempts = 1 # Single attempt
|
||||
server_config.retry_delay = 0.1 # 100ms instead of 1s
|
||||
server_config.retry_max_delay = 0.5 # 500ms max
|
||||
server_config.timeout = 2 # 2s timeout instead of 30-120s
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def create_default_config() -> MCPConfig:
|
||||
|
||||
141
backend/app/services/memory/__init__.py
Normal file
141
backend/app/services/memory/__init__.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""
|
||||
Agent Memory System
|
||||
|
||||
Multi-tier cognitive memory for AI agents, providing:
|
||||
- Working Memory: Session-scoped ephemeral state (Redis/In-memory)
|
||||
- Episodic Memory: Experiential records of past tasks (PostgreSQL)
|
||||
- Semantic Memory: Learned facts and knowledge (PostgreSQL + pgvector)
|
||||
- Procedural Memory: Learned skills and procedures (PostgreSQL)
|
||||
|
||||
Usage:
|
||||
from app.services.memory import (
|
||||
MemoryManager,
|
||||
MemorySettings,
|
||||
get_memory_settings,
|
||||
MemoryType,
|
||||
ScopeLevel,
|
||||
)
|
||||
|
||||
# Create a manager for a session
|
||||
manager = MemoryManager.for_session(
|
||||
session_id="sess-123",
|
||||
project_id=uuid,
|
||||
)
|
||||
|
||||
async with manager:
|
||||
# Working memory
|
||||
await manager.set_working("key", {"data": "value"})
|
||||
value = await manager.get_working("key")
|
||||
|
||||
# Episodic memory
|
||||
episode = await manager.record_episode(episode_data)
|
||||
similar = await manager.search_episodes("query")
|
||||
|
||||
# Semantic memory
|
||||
fact = await manager.store_fact(fact_data)
|
||||
facts = await manager.search_facts("query")
|
||||
|
||||
# Procedural memory
|
||||
procedure = await manager.record_procedure(procedure_data)
|
||||
procedures = await manager.find_procedures("context")
|
||||
"""
|
||||
|
||||
# Configuration
|
||||
from .config import (
|
||||
MemorySettings,
|
||||
get_default_settings,
|
||||
get_memory_settings,
|
||||
reset_memory_settings,
|
||||
)
|
||||
|
||||
# Exceptions
|
||||
from .exceptions import (
|
||||
CheckpointError,
|
||||
EmbeddingError,
|
||||
MemoryCapacityError,
|
||||
MemoryConflictError,
|
||||
MemoryConsolidationError,
|
||||
MemoryError,
|
||||
MemoryExpiredError,
|
||||
MemoryNotFoundError,
|
||||
MemoryRetrievalError,
|
||||
MemoryScopeError,
|
||||
MemorySerializationError,
|
||||
MemoryStorageError,
|
||||
)
|
||||
|
||||
# Manager
|
||||
from .manager import MemoryManager
|
||||
|
||||
# Types
|
||||
from .types import (
|
||||
ConsolidationStatus,
|
||||
ConsolidationType,
|
||||
Episode,
|
||||
EpisodeCreate,
|
||||
Fact,
|
||||
FactCreate,
|
||||
MemoryItem,
|
||||
MemoryStats,
|
||||
MemoryStore,
|
||||
MemoryType,
|
||||
Outcome,
|
||||
Procedure,
|
||||
ProcedureCreate,
|
||||
RetrievalResult,
|
||||
ScopeContext,
|
||||
ScopeLevel,
|
||||
Step,
|
||||
TaskState,
|
||||
WorkingMemoryItem,
|
||||
)
|
||||
|
||||
# Reflection (lazy import available)
|
||||
# Import directly: from app.services.memory.reflection import MemoryReflection
|
||||
|
||||
__all__ = [
|
||||
"CheckpointError",
|
||||
"ConsolidationStatus",
|
||||
"ConsolidationType",
|
||||
"EmbeddingError",
|
||||
"Episode",
|
||||
"EpisodeCreate",
|
||||
"Fact",
|
||||
"FactCreate",
|
||||
"MemoryCapacityError",
|
||||
"MemoryConflictError",
|
||||
"MemoryConsolidationError",
|
||||
# Exceptions
|
||||
"MemoryError",
|
||||
"MemoryExpiredError",
|
||||
"MemoryItem",
|
||||
# Manager
|
||||
"MemoryManager",
|
||||
"MemoryNotFoundError",
|
||||
"MemoryRetrievalError",
|
||||
"MemoryScopeError",
|
||||
"MemorySerializationError",
|
||||
# Configuration
|
||||
"MemorySettings",
|
||||
"MemoryStats",
|
||||
"MemoryStorageError",
|
||||
# Types - Abstract
|
||||
"MemoryStore",
|
||||
# Types - Enums
|
||||
"MemoryType",
|
||||
"Outcome",
|
||||
"Procedure",
|
||||
"ProcedureCreate",
|
||||
"RetrievalResult",
|
||||
# Types - Data Classes
|
||||
"ScopeContext",
|
||||
"ScopeLevel",
|
||||
"Step",
|
||||
"TaskState",
|
||||
"WorkingMemoryItem",
|
||||
"get_default_settings",
|
||||
"get_memory_settings",
|
||||
"reset_memory_settings",
|
||||
# MCP Tools - lazy import to avoid circular dependencies
|
||||
# Import directly: from app.services.memory.mcp import MemoryToolService
|
||||
]
|
||||
21
backend/app/services/memory/cache/__init__.py
vendored
Normal file
21
backend/app/services/memory/cache/__init__.py
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
# app/services/memory/cache/__init__.py
|
||||
"""
|
||||
Memory Caching Layer.
|
||||
|
||||
Provides caching for memory operations:
|
||||
- Hot Memory Cache: LRU cache for frequently accessed memories
|
||||
- Embedding Cache: Cache embeddings by content hash
|
||||
- Cache Manager: Unified cache management with invalidation
|
||||
"""
|
||||
|
||||
from .cache_manager import CacheManager, CacheStats, get_cache_manager
|
||||
from .embedding_cache import EmbeddingCache
|
||||
from .hot_cache import HotMemoryCache
|
||||
|
||||
__all__ = [
|
||||
"CacheManager",
|
||||
"CacheStats",
|
||||
"EmbeddingCache",
|
||||
"HotMemoryCache",
|
||||
"get_cache_manager",
|
||||
]
|
||||
505
backend/app/services/memory/cache/cache_manager.py
vendored
Normal file
505
backend/app/services/memory/cache/cache_manager.py
vendored
Normal file
@@ -0,0 +1,505 @@
|
||||
# app/services/memory/cache/cache_manager.py
|
||||
"""
|
||||
Cache Manager.
|
||||
|
||||
Unified cache management for memory operations.
|
||||
Coordinates hot cache, embedding cache, and retrieval cache.
|
||||
Provides centralized invalidation and statistics.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from uuid import UUID
|
||||
|
||||
from app.services.memory.config import get_memory_settings
|
||||
|
||||
from .embedding_cache import EmbeddingCache, create_embedding_cache
|
||||
from .hot_cache import CacheKey, HotMemoryCache, create_hot_cache
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis.asyncio import Redis
|
||||
|
||||
from app.services.memory.indexing.retrieval import RetrievalCache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _utcnow() -> datetime:
|
||||
"""Get current UTC time as timezone-aware datetime."""
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheStats:
|
||||
"""Aggregated cache statistics."""
|
||||
|
||||
hot_cache: dict[str, Any] = field(default_factory=dict)
|
||||
embedding_cache: dict[str, Any] = field(default_factory=dict)
|
||||
retrieval_cache: dict[str, Any] = field(default_factory=dict)
|
||||
overall_hit_rate: float = 0.0
|
||||
last_cleanup: datetime | None = None
|
||||
cleanup_count: int = 0
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"hot_cache": self.hot_cache,
|
||||
"embedding_cache": self.embedding_cache,
|
||||
"retrieval_cache": self.retrieval_cache,
|
||||
"overall_hit_rate": self.overall_hit_rate,
|
||||
"last_cleanup": self.last_cleanup.isoformat()
|
||||
if self.last_cleanup
|
||||
else None,
|
||||
"cleanup_count": self.cleanup_count,
|
||||
}
|
||||
|
||||
|
||||
class CacheManager:
|
||||
"""
|
||||
Unified cache manager for memory operations.
|
||||
|
||||
Provides:
|
||||
- Centralized cache configuration
|
||||
- Coordinated invalidation across caches
|
||||
- Aggregated statistics
|
||||
- Automatic cleanup scheduling
|
||||
|
||||
Performance targets:
|
||||
- Overall cache hit rate > 80%
|
||||
- Cache operations < 1ms (memory), < 5ms (Redis)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hot_cache: HotMemoryCache[Any] | None = None,
|
||||
embedding_cache: EmbeddingCache | None = None,
|
||||
retrieval_cache: "RetrievalCache | None" = None,
|
||||
redis: "Redis | None" = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the cache manager.
|
||||
|
||||
Args:
|
||||
hot_cache: Optional pre-configured hot cache
|
||||
embedding_cache: Optional pre-configured embedding cache
|
||||
retrieval_cache: Optional pre-configured retrieval cache
|
||||
redis: Optional Redis connection for persistence
|
||||
"""
|
||||
self._settings = get_memory_settings()
|
||||
self._redis = redis
|
||||
self._enabled = self._settings.cache_enabled
|
||||
|
||||
# Initialize caches
|
||||
if hot_cache:
|
||||
self._hot_cache = hot_cache
|
||||
else:
|
||||
self._hot_cache = create_hot_cache(
|
||||
max_size=self._settings.cache_max_items,
|
||||
default_ttl_seconds=self._settings.cache_ttl_seconds,
|
||||
)
|
||||
|
||||
if embedding_cache:
|
||||
self._embedding_cache = embedding_cache
|
||||
else:
|
||||
self._embedding_cache = create_embedding_cache(
|
||||
max_size=self._settings.cache_max_items,
|
||||
default_ttl_seconds=self._settings.cache_ttl_seconds
|
||||
* 12, # 1hr for embeddings
|
||||
redis=redis,
|
||||
)
|
||||
|
||||
self._retrieval_cache = retrieval_cache
|
||||
|
||||
# Stats tracking
|
||||
self._last_cleanup: datetime | None = None
|
||||
self._cleanup_count = 0
|
||||
self._lock = threading.RLock()
|
||||
|
||||
logger.info(
|
||||
f"Initialized CacheManager: enabled={self._enabled}, "
|
||||
f"redis={'connected' if redis else 'disabled'}"
|
||||
)
|
||||
|
||||
def set_redis(self, redis: "Redis") -> None:
|
||||
"""Set Redis connection for all caches."""
|
||||
self._redis = redis
|
||||
self._embedding_cache.set_redis(redis)
|
||||
|
||||
def set_retrieval_cache(self, cache: "RetrievalCache") -> None:
|
||||
"""Set retrieval cache instance."""
|
||||
self._retrieval_cache = cache
|
||||
|
||||
@property
|
||||
def is_enabled(self) -> bool:
|
||||
"""Check if caching is enabled."""
|
||||
return self._enabled
|
||||
|
||||
@property
|
||||
def hot_cache(self) -> HotMemoryCache[Any]:
|
||||
"""Get the hot memory cache."""
|
||||
return self._hot_cache
|
||||
|
||||
@property
|
||||
def embedding_cache(self) -> EmbeddingCache:
|
||||
"""Get the embedding cache."""
|
||||
return self._embedding_cache
|
||||
|
||||
@property
|
||||
def retrieval_cache(self) -> "RetrievalCache | None":
|
||||
"""Get the retrieval cache."""
|
||||
return self._retrieval_cache
|
||||
|
||||
# =========================================================================
|
||||
# Hot Memory Cache Operations
|
||||
# =========================================================================
|
||||
|
||||
def get_memory(
|
||||
self,
|
||||
memory_type: str,
|
||||
memory_id: UUID | str,
|
||||
scope: str | None = None,
|
||||
) -> Any | None:
|
||||
"""
|
||||
Get a memory from hot cache.
|
||||
|
||||
Args:
|
||||
memory_type: Type of memory
|
||||
memory_id: Memory ID
|
||||
scope: Optional scope
|
||||
|
||||
Returns:
|
||||
Cached memory or None
|
||||
"""
|
||||
if not self._enabled:
|
||||
return None
|
||||
return self._hot_cache.get_by_id(memory_type, memory_id, scope)
|
||||
|
||||
def cache_memory(
|
||||
self,
|
||||
memory_type: str,
|
||||
memory_id: UUID | str,
|
||||
memory: Any,
|
||||
scope: str | None = None,
|
||||
ttl_seconds: float | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Cache a memory in hot cache.
|
||||
|
||||
Args:
|
||||
memory_type: Type of memory
|
||||
memory_id: Memory ID
|
||||
memory: Memory object
|
||||
scope: Optional scope
|
||||
ttl_seconds: Optional TTL override
|
||||
"""
|
||||
if not self._enabled:
|
||||
return
|
||||
self._hot_cache.put_by_id(memory_type, memory_id, memory, scope, ttl_seconds)
|
||||
|
||||
# =========================================================================
|
||||
# Embedding Cache Operations
|
||||
# =========================================================================
|
||||
|
||||
async def get_embedding(
|
||||
self,
|
||||
content: str,
|
||||
model: str = "default",
|
||||
) -> list[float] | None:
|
||||
"""
|
||||
Get a cached embedding.
|
||||
|
||||
Args:
|
||||
content: Content text
|
||||
model: Model name
|
||||
|
||||
Returns:
|
||||
Cached embedding or None
|
||||
"""
|
||||
if not self._enabled:
|
||||
return None
|
||||
return await self._embedding_cache.get(content, model)
|
||||
|
||||
async def cache_embedding(
|
||||
self,
|
||||
content: str,
|
||||
embedding: list[float],
|
||||
model: str = "default",
|
||||
ttl_seconds: float | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Cache an embedding.
|
||||
|
||||
Args:
|
||||
content: Content text
|
||||
embedding: Embedding vector
|
||||
model: Model name
|
||||
ttl_seconds: Optional TTL override
|
||||
|
||||
Returns:
|
||||
Content hash
|
||||
"""
|
||||
if not self._enabled:
|
||||
return EmbeddingCache.hash_content(content)
|
||||
return await self._embedding_cache.put(content, embedding, model, ttl_seconds)
|
||||
|
||||
# =========================================================================
|
||||
# Invalidation
|
||||
# =========================================================================
|
||||
|
||||
async def invalidate_memory(
|
||||
self,
|
||||
memory_type: str,
|
||||
memory_id: UUID | str,
|
||||
scope: str | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Invalidate a memory across all caches.
|
||||
|
||||
Args:
|
||||
memory_type: Type of memory
|
||||
memory_id: Memory ID
|
||||
scope: Optional scope
|
||||
|
||||
Returns:
|
||||
Number of entries invalidated
|
||||
"""
|
||||
count = 0
|
||||
|
||||
# Invalidate hot cache
|
||||
if self._hot_cache.invalidate_by_id(memory_type, memory_id, scope):
|
||||
count += 1
|
||||
|
||||
# Invalidate retrieval cache
|
||||
if self._retrieval_cache:
|
||||
uuid_id = (
|
||||
UUID(str(memory_id)) if not isinstance(memory_id, UUID) else memory_id
|
||||
)
|
||||
count += self._retrieval_cache.invalidate_by_memory(uuid_id)
|
||||
|
||||
logger.debug(f"Invalidated {count} cache entries for {memory_type}:{memory_id}")
|
||||
return count
|
||||
|
||||
async def invalidate_by_type(self, memory_type: str) -> int:
|
||||
"""
|
||||
Invalidate all entries of a memory type.
|
||||
|
||||
Args:
|
||||
memory_type: Type of memory
|
||||
|
||||
Returns:
|
||||
Number of entries invalidated
|
||||
"""
|
||||
count = self._hot_cache.invalidate_by_type(memory_type)
|
||||
|
||||
if self._retrieval_cache:
|
||||
count += self._retrieval_cache.clear()
|
||||
|
||||
logger.info(f"Invalidated {count} cache entries for type {memory_type}")
|
||||
return count
|
||||
|
||||
async def invalidate_by_scope(self, scope: str) -> int:
|
||||
"""
|
||||
Invalidate all entries in a scope.
|
||||
|
||||
Args:
|
||||
scope: Scope to invalidate (e.g., project_id)
|
||||
|
||||
Returns:
|
||||
Number of entries invalidated
|
||||
"""
|
||||
count = self._hot_cache.invalidate_by_scope(scope)
|
||||
|
||||
# Retrieval cache doesn't support scope-based invalidation
|
||||
# so we clear it entirely for safety
|
||||
if self._retrieval_cache:
|
||||
count += self._retrieval_cache.clear()
|
||||
|
||||
logger.info(f"Invalidated {count} cache entries for scope {scope}")
|
||||
return count
|
||||
|
||||
async def invalidate_embedding(
|
||||
self,
|
||||
content: str,
|
||||
model: str = "default",
|
||||
) -> bool:
|
||||
"""
|
||||
Invalidate a cached embedding.
|
||||
|
||||
Args:
|
||||
content: Content text
|
||||
model: Model name
|
||||
|
||||
Returns:
|
||||
True if entry was found and removed
|
||||
"""
|
||||
return await self._embedding_cache.invalidate(content, model)
|
||||
|
||||
async def clear_all(self) -> int:
|
||||
"""
|
||||
Clear all caches.
|
||||
|
||||
Returns:
|
||||
Total number of entries cleared
|
||||
"""
|
||||
count = 0
|
||||
|
||||
count += self._hot_cache.clear()
|
||||
count += await self._embedding_cache.clear()
|
||||
|
||||
if self._retrieval_cache:
|
||||
count += self._retrieval_cache.clear()
|
||||
|
||||
logger.info(f"Cleared {count} entries from all caches")
|
||||
return count
|
||||
|
||||
# =========================================================================
|
||||
# Cleanup
|
||||
# =========================================================================
|
||||
|
||||
async def cleanup_expired(self) -> int:
|
||||
"""
|
||||
Clean up expired entries from all caches.
|
||||
|
||||
Returns:
|
||||
Number of entries cleaned up
|
||||
"""
|
||||
with self._lock:
|
||||
count = 0
|
||||
|
||||
count += self._hot_cache.cleanup_expired()
|
||||
count += self._embedding_cache.cleanup_expired()
|
||||
|
||||
# Retrieval cache doesn't have a cleanup method,
|
||||
# but entries expire on access
|
||||
|
||||
self._last_cleanup = _utcnow()
|
||||
self._cleanup_count += 1
|
||||
|
||||
if count > 0:
|
||||
logger.info(f"Cleaned up {count} expired cache entries")
|
||||
|
||||
return count
|
||||
|
||||
# =========================================================================
|
||||
# Statistics
|
||||
# =========================================================================
|
||||
|
||||
def get_stats(self) -> CacheStats:
|
||||
"""
|
||||
Get aggregated cache statistics.
|
||||
|
||||
Returns:
|
||||
CacheStats with all cache metrics
|
||||
"""
|
||||
hot_stats = self._hot_cache.get_stats().to_dict()
|
||||
emb_stats = self._embedding_cache.get_stats().to_dict()
|
||||
|
||||
retrieval_stats: dict[str, Any] = {}
|
||||
if self._retrieval_cache:
|
||||
retrieval_stats = self._retrieval_cache.get_stats()
|
||||
|
||||
# Calculate overall hit rate
|
||||
total_hits = hot_stats.get("hits", 0) + emb_stats.get("hits", 0)
|
||||
total_misses = hot_stats.get("misses", 0) + emb_stats.get("misses", 0)
|
||||
|
||||
if retrieval_stats:
|
||||
# Retrieval cache doesn't track hits/misses the same way
|
||||
pass
|
||||
|
||||
total_requests = total_hits + total_misses
|
||||
overall_hit_rate = total_hits / total_requests if total_requests > 0 else 0.0
|
||||
|
||||
return CacheStats(
|
||||
hot_cache=hot_stats,
|
||||
embedding_cache=emb_stats,
|
||||
retrieval_cache=retrieval_stats,
|
||||
overall_hit_rate=overall_hit_rate,
|
||||
last_cleanup=self._last_cleanup,
|
||||
cleanup_count=self._cleanup_count,
|
||||
)
|
||||
|
||||
def get_hot_memories(self, limit: int = 10) -> list[tuple[CacheKey, int]]:
|
||||
"""
|
||||
Get the most frequently accessed memories.
|
||||
|
||||
Args:
|
||||
limit: Maximum number to return
|
||||
|
||||
Returns:
|
||||
List of (key, access_count) tuples
|
||||
"""
|
||||
return self._hot_cache.get_hot_memories(limit)
|
||||
|
||||
def reset_stats(self) -> None:
|
||||
"""Reset all cache statistics."""
|
||||
self._hot_cache.reset_stats()
|
||||
self._embedding_cache.reset_stats()
|
||||
|
||||
# =========================================================================
|
||||
# Warmup
|
||||
# =========================================================================
|
||||
|
||||
async def warmup(
|
||||
self,
|
||||
memories: list[tuple[str, UUID | str, Any]],
|
||||
scope: str | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Warm up the hot cache with memories.
|
||||
|
||||
Args:
|
||||
memories: List of (memory_type, memory_id, memory) tuples
|
||||
scope: Optional scope for all memories
|
||||
|
||||
Returns:
|
||||
Number of memories cached
|
||||
"""
|
||||
if not self._enabled:
|
||||
return 0
|
||||
|
||||
for memory_type, memory_id, memory in memories:
|
||||
self._hot_cache.put_by_id(memory_type, memory_id, memory, scope)
|
||||
|
||||
logger.info(f"Warmed up cache with {len(memories)} memories")
|
||||
return len(memories)
|
||||
|
||||
|
||||
# Singleton instance
|
||||
_cache_manager: CacheManager | None = None
|
||||
_cache_manager_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_cache_manager(
|
||||
redis: "Redis | None" = None,
|
||||
reset: bool = False,
|
||||
) -> CacheManager:
|
||||
"""
|
||||
Get the global CacheManager instance.
|
||||
|
||||
Thread-safe with double-checked locking pattern.
|
||||
|
||||
Args:
|
||||
redis: Optional Redis connection
|
||||
reset: Force create a new instance
|
||||
|
||||
Returns:
|
||||
CacheManager instance
|
||||
"""
|
||||
global _cache_manager
|
||||
|
||||
if reset or _cache_manager is None:
|
||||
with _cache_manager_lock:
|
||||
if reset or _cache_manager is None:
|
||||
_cache_manager = CacheManager(redis=redis)
|
||||
|
||||
return _cache_manager
|
||||
|
||||
|
||||
def reset_cache_manager() -> None:
|
||||
"""Reset the global cache manager instance."""
|
||||
global _cache_manager
|
||||
with _cache_manager_lock:
|
||||
_cache_manager = None
|
||||
623
backend/app/services/memory/cache/embedding_cache.py
vendored
Normal file
623
backend/app/services/memory/cache/embedding_cache.py
vendored
Normal file
@@ -0,0 +1,623 @@
|
||||
# app/services/memory/cache/embedding_cache.py
|
||||
"""
|
||||
Embedding Cache.
|
||||
|
||||
Caches embeddings by content hash to avoid recomputing.
|
||||
Provides significant performance improvement for repeated content.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import threading
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis.asyncio import Redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _utcnow() -> datetime:
|
||||
"""Get current UTC time as timezone-aware datetime."""
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingEntry:
|
||||
"""A cached embedding entry."""
|
||||
|
||||
embedding: list[float]
|
||||
content_hash: str
|
||||
model: str
|
||||
created_at: datetime
|
||||
ttl_seconds: float = 3600.0 # 1 hour default
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if this entry has expired."""
|
||||
age = (_utcnow() - self.created_at).total_seconds()
|
||||
return age > self.ttl_seconds
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingCacheStats:
|
||||
"""Statistics for the embedding cache."""
|
||||
|
||||
hits: int = 0
|
||||
misses: int = 0
|
||||
evictions: int = 0
|
||||
expirations: int = 0
|
||||
current_size: int = 0
|
||||
max_size: int = 0
|
||||
bytes_saved: int = 0 # Estimated bytes saved by caching
|
||||
|
||||
@property
|
||||
def hit_rate(self) -> float:
|
||||
"""Calculate cache hit rate."""
|
||||
total = self.hits + self.misses
|
||||
if total == 0:
|
||||
return 0.0
|
||||
return self.hits / total
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"hits": self.hits,
|
||||
"misses": self.misses,
|
||||
"evictions": self.evictions,
|
||||
"expirations": self.expirations,
|
||||
"current_size": self.current_size,
|
||||
"max_size": self.max_size,
|
||||
"hit_rate": self.hit_rate,
|
||||
"bytes_saved": self.bytes_saved,
|
||||
}
|
||||
|
||||
|
||||
class EmbeddingCache:
|
||||
"""
|
||||
Cache for embeddings by content hash.
|
||||
|
||||
Features:
|
||||
- Content-hash based deduplication
|
||||
- LRU eviction
|
||||
- TTL-based expiration
|
||||
- Optional Redis backing for persistence
|
||||
- Thread-safe operations
|
||||
|
||||
Performance targets:
|
||||
- Cache hit rate > 90% for repeated content
|
||||
- Get/put operations < 1ms (memory), < 5ms (Redis)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_size: int = 50000,
|
||||
default_ttl_seconds: float = 3600.0,
|
||||
redis: "Redis | None" = None,
|
||||
redis_prefix: str = "mem:emb",
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the embedding cache.
|
||||
|
||||
Args:
|
||||
max_size: Maximum number of entries in memory cache
|
||||
default_ttl_seconds: Default TTL for entries (1 hour)
|
||||
redis: Optional Redis connection for persistence
|
||||
redis_prefix: Prefix for Redis keys
|
||||
"""
|
||||
self._max_size = max_size
|
||||
self._default_ttl = default_ttl_seconds
|
||||
self._cache: OrderedDict[str, EmbeddingEntry] = OrderedDict()
|
||||
self._lock = threading.RLock()
|
||||
self._stats = EmbeddingCacheStats(max_size=max_size)
|
||||
self._redis = redis
|
||||
self._redis_prefix = redis_prefix
|
||||
|
||||
logger.info(
|
||||
f"Initialized EmbeddingCache with max_size={max_size}, "
|
||||
f"ttl={default_ttl_seconds}s, redis={'enabled' if redis else 'disabled'}"
|
||||
)
|
||||
|
||||
def set_redis(self, redis: "Redis") -> None:
|
||||
"""Set Redis connection for persistence."""
|
||||
self._redis = redis
|
||||
|
||||
@staticmethod
|
||||
def hash_content(content: str) -> str:
|
||||
"""
|
||||
Compute hash of content for cache key.
|
||||
|
||||
Args:
|
||||
content: Content to hash
|
||||
|
||||
Returns:
|
||||
32-character hex hash
|
||||
"""
|
||||
return hashlib.sha256(content.encode()).hexdigest()[:32]
|
||||
|
||||
def _cache_key(self, content_hash: str, model: str) -> str:
|
||||
"""Build cache key from content hash and model."""
|
||||
return f"{content_hash}:{model}"
|
||||
|
||||
def _redis_key(self, content_hash: str, model: str) -> str:
|
||||
"""Build Redis key from content hash and model."""
|
||||
return f"{self._redis_prefix}:{content_hash}:{model}"
|
||||
|
||||
async def get(
|
||||
self,
|
||||
content: str,
|
||||
model: str = "default",
|
||||
) -> list[float] | None:
|
||||
"""
|
||||
Get a cached embedding.
|
||||
|
||||
Args:
|
||||
content: Content text
|
||||
model: Model name
|
||||
|
||||
Returns:
|
||||
Cached embedding or None if not found/expired
|
||||
"""
|
||||
content_hash = self.hash_content(content)
|
||||
cache_key = self._cache_key(content_hash, model)
|
||||
|
||||
# Check memory cache first
|
||||
with self._lock:
|
||||
if cache_key in self._cache:
|
||||
entry = self._cache[cache_key]
|
||||
if entry.is_expired():
|
||||
del self._cache[cache_key]
|
||||
self._stats.expirations += 1
|
||||
self._stats.current_size = len(self._cache)
|
||||
else:
|
||||
# Move to end (most recently used)
|
||||
self._cache.move_to_end(cache_key)
|
||||
self._stats.hits += 1
|
||||
return entry.embedding
|
||||
|
||||
# Check Redis if available
|
||||
if self._redis:
|
||||
try:
|
||||
redis_key = self._redis_key(content_hash, model)
|
||||
data = await self._redis.get(redis_key)
|
||||
if data:
|
||||
import json
|
||||
|
||||
embedding = json.loads(data)
|
||||
# Store in memory cache for faster access
|
||||
self._put_memory(content_hash, model, embedding)
|
||||
self._stats.hits += 1
|
||||
return embedding
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis get error: {e}")
|
||||
|
||||
self._stats.misses += 1
|
||||
return None
|
||||
|
||||
async def get_by_hash(
|
||||
self,
|
||||
content_hash: str,
|
||||
model: str = "default",
|
||||
) -> list[float] | None:
|
||||
"""
|
||||
Get a cached embedding by hash.
|
||||
|
||||
Args:
|
||||
content_hash: Content hash
|
||||
model: Model name
|
||||
|
||||
Returns:
|
||||
Cached embedding or None if not found/expired
|
||||
"""
|
||||
cache_key = self._cache_key(content_hash, model)
|
||||
|
||||
with self._lock:
|
||||
if cache_key in self._cache:
|
||||
entry = self._cache[cache_key]
|
||||
if entry.is_expired():
|
||||
del self._cache[cache_key]
|
||||
self._stats.expirations += 1
|
||||
self._stats.current_size = len(self._cache)
|
||||
else:
|
||||
self._cache.move_to_end(cache_key)
|
||||
self._stats.hits += 1
|
||||
return entry.embedding
|
||||
|
||||
# Check Redis
|
||||
if self._redis:
|
||||
try:
|
||||
redis_key = self._redis_key(content_hash, model)
|
||||
data = await self._redis.get(redis_key)
|
||||
if data:
|
||||
import json
|
||||
|
||||
embedding = json.loads(data)
|
||||
self._put_memory(content_hash, model, embedding)
|
||||
self._stats.hits += 1
|
||||
return embedding
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis get error: {e}")
|
||||
|
||||
self._stats.misses += 1
|
||||
return None
|
||||
|
||||
async def put(
|
||||
self,
|
||||
content: str,
|
||||
embedding: list[float],
|
||||
model: str = "default",
|
||||
ttl_seconds: float | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Cache an embedding.
|
||||
|
||||
Args:
|
||||
content: Content text
|
||||
embedding: Embedding vector
|
||||
model: Model name
|
||||
ttl_seconds: Optional TTL override
|
||||
|
||||
Returns:
|
||||
Content hash
|
||||
"""
|
||||
content_hash = self.hash_content(content)
|
||||
ttl = ttl_seconds or self._default_ttl
|
||||
|
||||
# Store in memory
|
||||
self._put_memory(content_hash, model, embedding, ttl)
|
||||
|
||||
# Store in Redis if available
|
||||
if self._redis:
|
||||
try:
|
||||
import json
|
||||
|
||||
redis_key = self._redis_key(content_hash, model)
|
||||
await self._redis.setex(
|
||||
redis_key,
|
||||
int(ttl),
|
||||
json.dumps(embedding),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis put error: {e}")
|
||||
|
||||
return content_hash
|
||||
|
||||
def _put_memory(
|
||||
self,
|
||||
content_hash: str,
|
||||
model: str,
|
||||
embedding: list[float],
|
||||
ttl_seconds: float | None = None,
|
||||
) -> None:
|
||||
"""Store in memory cache."""
|
||||
with self._lock:
|
||||
# Evict if at capacity
|
||||
self._evict_if_needed()
|
||||
|
||||
cache_key = self._cache_key(content_hash, model)
|
||||
entry = EmbeddingEntry(
|
||||
embedding=embedding,
|
||||
content_hash=content_hash,
|
||||
model=model,
|
||||
created_at=_utcnow(),
|
||||
ttl_seconds=ttl_seconds or self._default_ttl,
|
||||
)
|
||||
|
||||
self._cache[cache_key] = entry
|
||||
self._cache.move_to_end(cache_key)
|
||||
self._stats.current_size = len(self._cache)
|
||||
|
||||
def _evict_if_needed(self) -> None:
|
||||
"""Evict entries if cache is at capacity."""
|
||||
while len(self._cache) >= self._max_size:
|
||||
if self._cache:
|
||||
self._cache.popitem(last=False)
|
||||
self._stats.evictions += 1
|
||||
|
||||
async def put_batch(
|
||||
self,
|
||||
items: list[tuple[str, list[float]]],
|
||||
model: str = "default",
|
||||
ttl_seconds: float | None = None,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Cache multiple embeddings.
|
||||
|
||||
Args:
|
||||
items: List of (content, embedding) tuples
|
||||
model: Model name
|
||||
ttl_seconds: Optional TTL override
|
||||
|
||||
Returns:
|
||||
List of content hashes
|
||||
"""
|
||||
hashes = []
|
||||
for content, embedding in items:
|
||||
content_hash = await self.put(content, embedding, model, ttl_seconds)
|
||||
hashes.append(content_hash)
|
||||
return hashes
|
||||
|
||||
async def invalidate(
|
||||
self,
|
||||
content: str,
|
||||
model: str = "default",
|
||||
) -> bool:
|
||||
"""
|
||||
Invalidate a cached embedding.
|
||||
|
||||
Args:
|
||||
content: Content text
|
||||
model: Model name
|
||||
|
||||
Returns:
|
||||
True if entry was found and removed
|
||||
"""
|
||||
content_hash = self.hash_content(content)
|
||||
return await self.invalidate_by_hash(content_hash, model)
|
||||
|
||||
async def invalidate_by_hash(
|
||||
self,
|
||||
content_hash: str,
|
||||
model: str = "default",
|
||||
) -> bool:
|
||||
"""
|
||||
Invalidate a cached embedding by hash.
|
||||
|
||||
Args:
|
||||
content_hash: Content hash
|
||||
model: Model name
|
||||
|
||||
Returns:
|
||||
True if entry was found and removed
|
||||
"""
|
||||
cache_key = self._cache_key(content_hash, model)
|
||||
removed = False
|
||||
|
||||
with self._lock:
|
||||
if cache_key in self._cache:
|
||||
del self._cache[cache_key]
|
||||
self._stats.current_size = len(self._cache)
|
||||
removed = True
|
||||
|
||||
# Remove from Redis
|
||||
if self._redis:
|
||||
try:
|
||||
redis_key = self._redis_key(content_hash, model)
|
||||
await self._redis.delete(redis_key)
|
||||
removed = True
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis delete error: {e}")
|
||||
|
||||
return removed
|
||||
|
||||
async def invalidate_by_model(self, model: str) -> int:
|
||||
"""
|
||||
Invalidate all embeddings for a model.
|
||||
|
||||
Args:
|
||||
model: Model name
|
||||
|
||||
Returns:
|
||||
Number of entries invalidated
|
||||
"""
|
||||
count = 0
|
||||
|
||||
with self._lock:
|
||||
keys_to_remove = [k for k, v in self._cache.items() if v.model == model]
|
||||
for key in keys_to_remove:
|
||||
del self._cache[key]
|
||||
count += 1
|
||||
|
||||
self._stats.current_size = len(self._cache)
|
||||
|
||||
# Note: Redis pattern deletion would require SCAN which is expensive
|
||||
# For now, we only clear memory cache for model-based invalidation
|
||||
|
||||
return count
|
||||
|
||||
async def clear(self) -> int:
|
||||
"""
|
||||
Clear all cache entries.
|
||||
|
||||
Returns:
|
||||
Number of entries cleared
|
||||
"""
|
||||
with self._lock:
|
||||
count = len(self._cache)
|
||||
self._cache.clear()
|
||||
self._stats.current_size = 0
|
||||
|
||||
# Clear Redis entries
|
||||
if self._redis:
|
||||
try:
|
||||
pattern = f"{self._redis_prefix}:*"
|
||||
deleted = 0
|
||||
async for key in self._redis.scan_iter(match=pattern):
|
||||
await self._redis.delete(key)
|
||||
deleted += 1
|
||||
count = max(count, deleted)
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis clear error: {e}")
|
||||
|
||||
logger.info(f"Cleared {count} entries from embedding cache")
|
||||
return count
|
||||
|
||||
def cleanup_expired(self) -> int:
|
||||
"""
|
||||
Remove all expired entries from memory cache.
|
||||
|
||||
Returns:
|
||||
Number of entries removed
|
||||
"""
|
||||
with self._lock:
|
||||
keys_to_remove = [k for k, v in self._cache.items() if v.is_expired()]
|
||||
for key in keys_to_remove:
|
||||
del self._cache[key]
|
||||
self._stats.expirations += 1
|
||||
|
||||
self._stats.current_size = len(self._cache)
|
||||
|
||||
if keys_to_remove:
|
||||
logger.debug(f"Cleaned up {len(keys_to_remove)} expired embeddings")
|
||||
|
||||
return len(keys_to_remove)
|
||||
|
||||
def get_stats(self) -> EmbeddingCacheStats:
|
||||
"""Get cache statistics."""
|
||||
with self._lock:
|
||||
self._stats.current_size = len(self._cache)
|
||||
return self._stats
|
||||
|
||||
def reset_stats(self) -> None:
|
||||
"""Reset cache statistics."""
|
||||
with self._lock:
|
||||
self._stats = EmbeddingCacheStats(
|
||||
max_size=self._max_size,
|
||||
current_size=len(self._cache),
|
||||
)
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
"""Get current cache size."""
|
||||
return len(self._cache)
|
||||
|
||||
@property
|
||||
def max_size(self) -> int:
|
||||
"""Get maximum cache size."""
|
||||
return self._max_size
|
||||
|
||||
|
||||
class CachedEmbeddingGenerator:
|
||||
"""
|
||||
Wrapper for embedding generators with caching.
|
||||
|
||||
Wraps an embedding generator to cache results.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
generator: Any,
|
||||
cache: EmbeddingCache,
|
||||
model: str = "default",
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the cached embedding generator.
|
||||
|
||||
Args:
|
||||
generator: Underlying embedding generator
|
||||
cache: Embedding cache
|
||||
model: Model name for cache keys
|
||||
"""
|
||||
self._generator = generator
|
||||
self._cache = cache
|
||||
self._model = model
|
||||
self._call_count = 0
|
||||
self._cache_hit_count = 0
|
||||
|
||||
async def generate(self, text: str) -> list[float]:
|
||||
"""
|
||||
Generate embedding with caching.
|
||||
|
||||
Args:
|
||||
text: Text to embed
|
||||
|
||||
Returns:
|
||||
Embedding vector
|
||||
"""
|
||||
self._call_count += 1
|
||||
|
||||
# Check cache first
|
||||
cached = await self._cache.get(text, self._model)
|
||||
if cached is not None:
|
||||
self._cache_hit_count += 1
|
||||
return cached
|
||||
|
||||
# Generate and cache
|
||||
embedding = await self._generator.generate(text)
|
||||
await self._cache.put(text, embedding, self._model)
|
||||
|
||||
return embedding
|
||||
|
||||
async def generate_batch(
|
||||
self,
|
||||
texts: list[str],
|
||||
) -> list[list[float]]:
|
||||
"""
|
||||
Generate embeddings for multiple texts with caching.
|
||||
|
||||
Args:
|
||||
texts: Texts to embed
|
||||
|
||||
Returns:
|
||||
List of embedding vectors
|
||||
"""
|
||||
results: list[list[float] | None] = [None] * len(texts)
|
||||
to_generate: list[tuple[int, str]] = []
|
||||
|
||||
# Check cache for each text
|
||||
for i, text in enumerate(texts):
|
||||
cached = await self._cache.get(text, self._model)
|
||||
if cached is not None:
|
||||
results[i] = cached
|
||||
self._cache_hit_count += 1
|
||||
else:
|
||||
to_generate.append((i, text))
|
||||
|
||||
self._call_count += len(texts)
|
||||
|
||||
# Generate missing embeddings
|
||||
if to_generate:
|
||||
if hasattr(self._generator, "generate_batch"):
|
||||
texts_to_gen = [t for _, t in to_generate]
|
||||
embeddings = await self._generator.generate_batch(texts_to_gen)
|
||||
|
||||
for (idx, text), embedding in zip(to_generate, embeddings, strict=True):
|
||||
results[idx] = embedding
|
||||
await self._cache.put(text, embedding, self._model)
|
||||
else:
|
||||
# Fallback to individual generation
|
||||
for idx, text in to_generate:
|
||||
embedding = await self._generator.generate(text)
|
||||
results[idx] = embedding
|
||||
await self._cache.put(text, embedding, self._model)
|
||||
|
||||
return results # type: ignore[return-value]
|
||||
|
||||
def get_stats(self) -> dict[str, Any]:
|
||||
"""Get generator statistics."""
|
||||
return {
|
||||
"call_count": self._call_count,
|
||||
"cache_hit_count": self._cache_hit_count,
|
||||
"cache_hit_rate": (
|
||||
self._cache_hit_count / self._call_count
|
||||
if self._call_count > 0
|
||||
else 0.0
|
||||
),
|
||||
"cache_stats": self._cache.get_stats().to_dict(),
|
||||
}
|
||||
|
||||
|
||||
# Factory function
|
||||
def create_embedding_cache(
|
||||
max_size: int = 50000,
|
||||
default_ttl_seconds: float = 3600.0,
|
||||
redis: "Redis | None" = None,
|
||||
) -> EmbeddingCache:
|
||||
"""
|
||||
Create an embedding cache.
|
||||
|
||||
Args:
|
||||
max_size: Maximum number of entries
|
||||
default_ttl_seconds: Default TTL for entries
|
||||
redis: Optional Redis connection
|
||||
|
||||
Returns:
|
||||
Configured EmbeddingCache instance
|
||||
"""
|
||||
return EmbeddingCache(
|
||||
max_size=max_size,
|
||||
default_ttl_seconds=default_ttl_seconds,
|
||||
redis=redis,
|
||||
)
|
||||
461
backend/app/services/memory/cache/hot_cache.py
vendored
Normal file
461
backend/app/services/memory/cache/hot_cache.py
vendored
Normal file
@@ -0,0 +1,461 @@
|
||||
# app/services/memory/cache/hot_cache.py
|
||||
"""
|
||||
Hot Memory Cache.
|
||||
|
||||
LRU cache for frequently accessed memories.
|
||||
Provides fast access to recently used memories without database queries.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _utcnow() -> datetime:
|
||||
"""Get current UTC time as timezone-aware datetime."""
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheEntry[T]:
|
||||
"""A cached memory entry with metadata."""
|
||||
|
||||
value: T
|
||||
created_at: datetime
|
||||
last_accessed_at: datetime
|
||||
access_count: int = 1
|
||||
ttl_seconds: float = 300.0
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if this entry has expired."""
|
||||
age = (_utcnow() - self.created_at).total_seconds()
|
||||
return age > self.ttl_seconds
|
||||
|
||||
def touch(self) -> None:
|
||||
"""Update access time and count."""
|
||||
self.last_accessed_at = _utcnow()
|
||||
self.access_count += 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheKey:
|
||||
"""A structured cache key with components."""
|
||||
|
||||
memory_type: str
|
||||
memory_id: str
|
||||
scope: str | None = None
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.memory_type, self.memory_id, self.scope))
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, CacheKey):
|
||||
return False
|
||||
return (
|
||||
self.memory_type == other.memory_type
|
||||
and self.memory_id == other.memory_id
|
||||
and self.scope == other.scope
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
if self.scope:
|
||||
return f"{self.memory_type}:{self.scope}:{self.memory_id}"
|
||||
return f"{self.memory_type}:{self.memory_id}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class HotCacheStats:
|
||||
"""Statistics for the hot memory cache."""
|
||||
|
||||
hits: int = 0
|
||||
misses: int = 0
|
||||
evictions: int = 0
|
||||
expirations: int = 0
|
||||
current_size: int = 0
|
||||
max_size: int = 0
|
||||
|
||||
@property
|
||||
def hit_rate(self) -> float:
|
||||
"""Calculate cache hit rate."""
|
||||
total = self.hits + self.misses
|
||||
if total == 0:
|
||||
return 0.0
|
||||
return self.hits / total
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"hits": self.hits,
|
||||
"misses": self.misses,
|
||||
"evictions": self.evictions,
|
||||
"expirations": self.expirations,
|
||||
"current_size": self.current_size,
|
||||
"max_size": self.max_size,
|
||||
"hit_rate": self.hit_rate,
|
||||
}
|
||||
|
||||
|
||||
class HotMemoryCache[T]:
|
||||
"""
|
||||
LRU cache for frequently accessed memories.
|
||||
|
||||
Features:
|
||||
- LRU eviction when capacity is reached
|
||||
- TTL-based expiration
|
||||
- Access count tracking for hot memory identification
|
||||
- Thread-safe operations
|
||||
- Scoped invalidation
|
||||
|
||||
Performance targets:
|
||||
- Cache hit rate > 80% for hot memories
|
||||
- Get/put operations < 1ms
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_size: int = 10000,
|
||||
default_ttl_seconds: float = 300.0,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the hot memory cache.
|
||||
|
||||
Args:
|
||||
max_size: Maximum number of entries
|
||||
default_ttl_seconds: Default TTL for entries (5 minutes)
|
||||
"""
|
||||
self._max_size = max_size
|
||||
self._default_ttl = default_ttl_seconds
|
||||
self._cache: OrderedDict[CacheKey, CacheEntry[T]] = OrderedDict()
|
||||
self._lock = threading.RLock()
|
||||
self._stats = HotCacheStats(max_size=max_size)
|
||||
logger.info(
|
||||
f"Initialized HotMemoryCache with max_size={max_size}, "
|
||||
f"ttl={default_ttl_seconds}s"
|
||||
)
|
||||
|
||||
def get(self, key: CacheKey) -> T | None:
|
||||
"""
|
||||
Get a memory from cache.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
|
||||
Returns:
|
||||
Cached value or None if not found/expired
|
||||
"""
|
||||
with self._lock:
|
||||
if key not in self._cache:
|
||||
self._stats.misses += 1
|
||||
return None
|
||||
|
||||
entry = self._cache[key]
|
||||
|
||||
# Check expiration
|
||||
if entry.is_expired():
|
||||
del self._cache[key]
|
||||
self._stats.expirations += 1
|
||||
self._stats.misses += 1
|
||||
self._stats.current_size = len(self._cache)
|
||||
return None
|
||||
|
||||
# Move to end (most recently used)
|
||||
self._cache.move_to_end(key)
|
||||
entry.touch()
|
||||
|
||||
self._stats.hits += 1
|
||||
return entry.value
|
||||
|
||||
def get_by_id(
|
||||
self,
|
||||
memory_type: str,
|
||||
memory_id: UUID | str,
|
||||
scope: str | None = None,
|
||||
) -> T | None:
|
||||
"""
|
||||
Get a memory by type and ID.
|
||||
|
||||
Args:
|
||||
memory_type: Type of memory (episodic, semantic, procedural)
|
||||
memory_id: Memory ID
|
||||
scope: Optional scope (project_id, agent_id)
|
||||
|
||||
Returns:
|
||||
Cached value or None if not found/expired
|
||||
"""
|
||||
key = CacheKey(
|
||||
memory_type=memory_type,
|
||||
memory_id=str(memory_id),
|
||||
scope=scope,
|
||||
)
|
||||
return self.get(key)
|
||||
|
||||
def put(
|
||||
self,
|
||||
key: CacheKey,
|
||||
value: T,
|
||||
ttl_seconds: float | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Put a memory into cache.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
value: Value to cache
|
||||
ttl_seconds: Optional TTL override
|
||||
"""
|
||||
with self._lock:
|
||||
# Evict if at capacity
|
||||
self._evict_if_needed()
|
||||
|
||||
now = _utcnow()
|
||||
entry = CacheEntry(
|
||||
value=value,
|
||||
created_at=now,
|
||||
last_accessed_at=now,
|
||||
access_count=1,
|
||||
ttl_seconds=ttl_seconds or self._default_ttl,
|
||||
)
|
||||
|
||||
self._cache[key] = entry
|
||||
self._cache.move_to_end(key)
|
||||
self._stats.current_size = len(self._cache)
|
||||
|
||||
def put_by_id(
|
||||
self,
|
||||
memory_type: str,
|
||||
memory_id: UUID | str,
|
||||
value: T,
|
||||
scope: str | None = None,
|
||||
ttl_seconds: float | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Put a memory by type and ID.
|
||||
|
||||
Args:
|
||||
memory_type: Type of memory
|
||||
memory_id: Memory ID
|
||||
value: Value to cache
|
||||
scope: Optional scope
|
||||
ttl_seconds: Optional TTL override
|
||||
"""
|
||||
key = CacheKey(
|
||||
memory_type=memory_type,
|
||||
memory_id=str(memory_id),
|
||||
scope=scope,
|
||||
)
|
||||
self.put(key, value, ttl_seconds)
|
||||
|
||||
def _evict_if_needed(self) -> None:
|
||||
"""Evict entries if cache is at capacity."""
|
||||
while len(self._cache) >= self._max_size:
|
||||
# Remove least recently used (first item)
|
||||
if self._cache:
|
||||
self._cache.popitem(last=False)
|
||||
self._stats.evictions += 1
|
||||
|
||||
def invalidate(self, key: CacheKey) -> bool:
|
||||
"""
|
||||
Invalidate a specific cache entry.
|
||||
|
||||
Args:
|
||||
key: Cache key to invalidate
|
||||
|
||||
Returns:
|
||||
True if entry was found and removed
|
||||
"""
|
||||
with self._lock:
|
||||
if key in self._cache:
|
||||
del self._cache[key]
|
||||
self._stats.current_size = len(self._cache)
|
||||
return True
|
||||
return False
|
||||
|
||||
def invalidate_by_id(
|
||||
self,
|
||||
memory_type: str,
|
||||
memory_id: UUID | str,
|
||||
scope: str | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Invalidate a memory by type and ID.
|
||||
|
||||
Args:
|
||||
memory_type: Type of memory
|
||||
memory_id: Memory ID
|
||||
scope: Optional scope
|
||||
|
||||
Returns:
|
||||
True if entry was found and removed
|
||||
"""
|
||||
key = CacheKey(
|
||||
memory_type=memory_type,
|
||||
memory_id=str(memory_id),
|
||||
scope=scope,
|
||||
)
|
||||
return self.invalidate(key)
|
||||
|
||||
def invalidate_by_type(self, memory_type: str) -> int:
|
||||
"""
|
||||
Invalidate all entries of a memory type.
|
||||
|
||||
Args:
|
||||
memory_type: Type of memory to invalidate
|
||||
|
||||
Returns:
|
||||
Number of entries invalidated
|
||||
"""
|
||||
with self._lock:
|
||||
keys_to_remove = [
|
||||
k for k in self._cache.keys() if k.memory_type == memory_type
|
||||
]
|
||||
for key in keys_to_remove:
|
||||
del self._cache[key]
|
||||
|
||||
self._stats.current_size = len(self._cache)
|
||||
return len(keys_to_remove)
|
||||
|
||||
def invalidate_by_scope(self, scope: str) -> int:
|
||||
"""
|
||||
Invalidate all entries in a scope.
|
||||
|
||||
Args:
|
||||
scope: Scope to invalidate (e.g., project_id)
|
||||
|
||||
Returns:
|
||||
Number of entries invalidated
|
||||
"""
|
||||
with self._lock:
|
||||
keys_to_remove = [k for k in self._cache.keys() if k.scope == scope]
|
||||
for key in keys_to_remove:
|
||||
del self._cache[key]
|
||||
|
||||
self._stats.current_size = len(self._cache)
|
||||
return len(keys_to_remove)
|
||||
|
||||
def invalidate_pattern(self, pattern: str) -> int:
|
||||
"""
|
||||
Invalidate entries matching a pattern.
|
||||
|
||||
Pattern can include * as wildcard.
|
||||
|
||||
Args:
|
||||
pattern: Pattern to match (e.g., "episodic:*")
|
||||
|
||||
Returns:
|
||||
Number of entries invalidated
|
||||
"""
|
||||
import fnmatch
|
||||
|
||||
with self._lock:
|
||||
keys_to_remove = [
|
||||
k for k in self._cache.keys() if fnmatch.fnmatch(str(k), pattern)
|
||||
]
|
||||
for key in keys_to_remove:
|
||||
del self._cache[key]
|
||||
|
||||
self._stats.current_size = len(self._cache)
|
||||
return len(keys_to_remove)
|
||||
|
||||
def clear(self) -> int:
|
||||
"""
|
||||
Clear all cache entries.
|
||||
|
||||
Returns:
|
||||
Number of entries cleared
|
||||
"""
|
||||
with self._lock:
|
||||
count = len(self._cache)
|
||||
self._cache.clear()
|
||||
self._stats.current_size = 0
|
||||
logger.info(f"Cleared {count} entries from hot cache")
|
||||
return count
|
||||
|
||||
def cleanup_expired(self) -> int:
|
||||
"""
|
||||
Remove all expired entries.
|
||||
|
||||
Returns:
|
||||
Number of entries removed
|
||||
"""
|
||||
with self._lock:
|
||||
keys_to_remove = [k for k, v in self._cache.items() if v.is_expired()]
|
||||
for key in keys_to_remove:
|
||||
del self._cache[key]
|
||||
self._stats.expirations += 1
|
||||
|
||||
self._stats.current_size = len(self._cache)
|
||||
|
||||
if keys_to_remove:
|
||||
logger.debug(f"Cleaned up {len(keys_to_remove)} expired entries")
|
||||
|
||||
return len(keys_to_remove)
|
||||
|
||||
def get_hot_memories(self, limit: int = 10) -> list[tuple[CacheKey, int]]:
|
||||
"""
|
||||
Get the most frequently accessed memories.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of memories to return
|
||||
|
||||
Returns:
|
||||
List of (key, access_count) tuples sorted by access count
|
||||
"""
|
||||
with self._lock:
|
||||
entries = [
|
||||
(k, v.access_count)
|
||||
for k, v in self._cache.items()
|
||||
if not v.is_expired()
|
||||
]
|
||||
entries.sort(key=lambda x: x[1], reverse=True)
|
||||
return entries[:limit]
|
||||
|
||||
def get_stats(self) -> HotCacheStats:
|
||||
"""Get cache statistics."""
|
||||
with self._lock:
|
||||
self._stats.current_size = len(self._cache)
|
||||
return self._stats
|
||||
|
||||
def reset_stats(self) -> None:
|
||||
"""Reset cache statistics."""
|
||||
with self._lock:
|
||||
self._stats = HotCacheStats(
|
||||
max_size=self._max_size,
|
||||
current_size=len(self._cache),
|
||||
)
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
"""Get current cache size."""
|
||||
return len(self._cache)
|
||||
|
||||
@property
|
||||
def max_size(self) -> int:
|
||||
"""Get maximum cache size."""
|
||||
return self._max_size
|
||||
|
||||
|
||||
# Factory function for typed caches
|
||||
def create_hot_cache(
|
||||
max_size: int = 10000,
|
||||
default_ttl_seconds: float = 300.0,
|
||||
) -> HotMemoryCache[Any]:
|
||||
"""
|
||||
Create a hot memory cache.
|
||||
|
||||
Args:
|
||||
max_size: Maximum number of entries
|
||||
default_ttl_seconds: Default TTL for entries
|
||||
|
||||
Returns:
|
||||
Configured HotMemoryCache instance
|
||||
"""
|
||||
return HotMemoryCache(
|
||||
max_size=max_size,
|
||||
default_ttl_seconds=default_ttl_seconds,
|
||||
)
|
||||
410
backend/app/services/memory/config.py
Normal file
410
backend/app/services/memory/config.py
Normal file
@@ -0,0 +1,410 @@
|
||||
"""
|
||||
Memory System Configuration.
|
||||
|
||||
Provides Pydantic settings for the Agent Memory System,
|
||||
including storage backends, capacity limits, and consolidation policies.
|
||||
"""
|
||||
|
||||
import threading
|
||||
from functools import lru_cache
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, field_validator, model_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class MemorySettings(BaseSettings):
|
||||
"""
|
||||
Configuration for the Agent Memory System.
|
||||
|
||||
All settings can be overridden via environment variables
|
||||
with the MEM_ prefix.
|
||||
"""
|
||||
|
||||
# Working Memory Settings
|
||||
working_memory_backend: str = Field(
|
||||
default="redis",
|
||||
description="Backend for working memory: 'redis' or 'memory'",
|
||||
)
|
||||
working_memory_default_ttl_seconds: int = Field(
|
||||
default=3600,
|
||||
ge=60,
|
||||
le=86400,
|
||||
description="Default TTL for working memory items (1 hour default)",
|
||||
)
|
||||
working_memory_max_items_per_session: int = Field(
|
||||
default=1000,
|
||||
ge=100,
|
||||
le=100000,
|
||||
description="Maximum items per session in working memory",
|
||||
)
|
||||
working_memory_max_value_size_bytes: int = Field(
|
||||
default=1048576, # 1MB
|
||||
ge=1024,
|
||||
le=104857600, # 100MB
|
||||
description="Maximum size of a single value in working memory",
|
||||
)
|
||||
working_memory_checkpoint_enabled: bool = Field(
|
||||
default=True,
|
||||
description="Enable checkpointing for working memory recovery",
|
||||
)
|
||||
|
||||
# Redis Settings (for working memory)
|
||||
redis_url: str = Field(
|
||||
default="redis://localhost:6379/0",
|
||||
description="Redis connection URL",
|
||||
)
|
||||
redis_prefix: str = Field(
|
||||
default="mem",
|
||||
description="Redis key prefix for memory items",
|
||||
)
|
||||
redis_connection_timeout_seconds: int = Field(
|
||||
default=5,
|
||||
ge=1,
|
||||
le=60,
|
||||
description="Redis connection timeout",
|
||||
)
|
||||
|
||||
# Episodic Memory Settings
|
||||
episodic_max_episodes_per_project: int = Field(
|
||||
default=10000,
|
||||
ge=100,
|
||||
le=1000000,
|
||||
description="Maximum episodes to retain per project",
|
||||
)
|
||||
episodic_default_importance: float = Field(
|
||||
default=0.5,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Default importance score for new episodes",
|
||||
)
|
||||
episodic_retention_days: int = Field(
|
||||
default=365,
|
||||
ge=7,
|
||||
le=3650,
|
||||
description="Days to retain episodes before archival",
|
||||
)
|
||||
|
||||
# Semantic Memory Settings
|
||||
semantic_max_facts_per_project: int = Field(
|
||||
default=50000,
|
||||
ge=1000,
|
||||
le=10000000,
|
||||
description="Maximum facts to retain per project",
|
||||
)
|
||||
semantic_confidence_decay_days: int = Field(
|
||||
default=90,
|
||||
ge=7,
|
||||
le=365,
|
||||
description="Days until confidence decays by 50%",
|
||||
)
|
||||
semantic_min_confidence: float = Field(
|
||||
default=0.1,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Minimum confidence before fact is pruned",
|
||||
)
|
||||
|
||||
# Procedural Memory Settings
|
||||
procedural_max_procedures_per_project: int = Field(
|
||||
default=1000,
|
||||
ge=10,
|
||||
le=100000,
|
||||
description="Maximum procedures per project",
|
||||
)
|
||||
procedural_min_success_rate: float = Field(
|
||||
default=0.3,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Minimum success rate before procedure is pruned",
|
||||
)
|
||||
procedural_min_uses_before_suggest: int = Field(
|
||||
default=3,
|
||||
ge=1,
|
||||
le=100,
|
||||
description="Minimum uses before procedure is suggested",
|
||||
)
|
||||
|
||||
# Embedding Settings
|
||||
embedding_model: str = Field(
|
||||
default="text-embedding-3-small",
|
||||
description="Model to use for embeddings",
|
||||
)
|
||||
embedding_dimensions: int = Field(
|
||||
default=1536,
|
||||
ge=256,
|
||||
le=4096,
|
||||
description="Embedding vector dimensions",
|
||||
)
|
||||
embedding_batch_size: int = Field(
|
||||
default=100,
|
||||
ge=1,
|
||||
le=1000,
|
||||
description="Batch size for embedding generation",
|
||||
)
|
||||
embedding_cache_enabled: bool = Field(
|
||||
default=True,
|
||||
description="Enable caching of embeddings",
|
||||
)
|
||||
|
||||
# Retrieval Settings
|
||||
retrieval_default_limit: int = Field(
|
||||
default=10,
|
||||
ge=1,
|
||||
le=100,
|
||||
description="Default limit for retrieval queries",
|
||||
)
|
||||
retrieval_max_limit: int = Field(
|
||||
default=100,
|
||||
ge=10,
|
||||
le=1000,
|
||||
description="Maximum limit for retrieval queries",
|
||||
)
|
||||
retrieval_min_similarity: float = Field(
|
||||
default=0.5,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Minimum similarity score for retrieval",
|
||||
)
|
||||
|
||||
# Consolidation Settings
|
||||
consolidation_enabled: bool = Field(
|
||||
default=True,
|
||||
description="Enable automatic memory consolidation",
|
||||
)
|
||||
consolidation_batch_size: int = Field(
|
||||
default=100,
|
||||
ge=10,
|
||||
le=1000,
|
||||
description="Batch size for consolidation jobs",
|
||||
)
|
||||
consolidation_schedule_cron: str = Field(
|
||||
default="0 3 * * *",
|
||||
description="Cron expression for nightly consolidation (3 AM)",
|
||||
)
|
||||
consolidation_working_to_episodic_delay_minutes: int = Field(
|
||||
default=30,
|
||||
ge=5,
|
||||
le=1440,
|
||||
description="Minutes after session end before consolidating to episodic",
|
||||
)
|
||||
|
||||
# Pruning Settings
|
||||
pruning_enabled: bool = Field(
|
||||
default=True,
|
||||
description="Enable automatic memory pruning",
|
||||
)
|
||||
pruning_min_age_days: int = Field(
|
||||
default=7,
|
||||
ge=1,
|
||||
le=365,
|
||||
description="Minimum age before memory can be pruned",
|
||||
)
|
||||
pruning_importance_threshold: float = Field(
|
||||
default=0.2,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Importance threshold below which memory can be pruned",
|
||||
)
|
||||
|
||||
# Caching Settings
|
||||
cache_enabled: bool = Field(
|
||||
default=True,
|
||||
description="Enable caching for memory retrieval",
|
||||
)
|
||||
cache_ttl_seconds: int = Field(
|
||||
default=300,
|
||||
ge=10,
|
||||
le=3600,
|
||||
description="Cache TTL for retrieval results",
|
||||
)
|
||||
cache_max_items: int = Field(
|
||||
default=10000,
|
||||
ge=100,
|
||||
le=1000000,
|
||||
description="Maximum items in memory cache",
|
||||
)
|
||||
|
||||
# Performance Settings
|
||||
max_retrieval_time_ms: int = Field(
|
||||
default=100,
|
||||
ge=10,
|
||||
le=5000,
|
||||
description="Target maximum retrieval time in milliseconds",
|
||||
)
|
||||
parallel_retrieval: bool = Field(
|
||||
default=True,
|
||||
description="Enable parallel retrieval from multiple memory types",
|
||||
)
|
||||
max_parallel_retrievals: int = Field(
|
||||
default=4,
|
||||
ge=1,
|
||||
le=10,
|
||||
description="Maximum concurrent retrieval operations",
|
||||
)
|
||||
|
||||
@field_validator("working_memory_backend")
|
||||
@classmethod
|
||||
def validate_backend(cls, v: str) -> str:
|
||||
"""Validate working memory backend."""
|
||||
valid_backends = {"redis", "memory"}
|
||||
if v not in valid_backends:
|
||||
raise ValueError(f"backend must be one of: {valid_backends}")
|
||||
return v
|
||||
|
||||
@field_validator("embedding_model")
|
||||
@classmethod
|
||||
def validate_embedding_model(cls, v: str) -> str:
|
||||
"""Validate embedding model name."""
|
||||
valid_models = {
|
||||
"text-embedding-3-small",
|
||||
"text-embedding-3-large",
|
||||
"text-embedding-ada-002",
|
||||
}
|
||||
if v not in valid_models:
|
||||
raise ValueError(f"embedding_model must be one of: {valid_models}")
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_limits(self) -> "MemorySettings":
|
||||
"""Validate that limits are consistent."""
|
||||
if self.retrieval_default_limit > self.retrieval_max_limit:
|
||||
raise ValueError(
|
||||
f"retrieval_default_limit ({self.retrieval_default_limit}) "
|
||||
f"cannot exceed retrieval_max_limit ({self.retrieval_max_limit})"
|
||||
)
|
||||
return self
|
||||
|
||||
def get_working_memory_config(self) -> dict[str, Any]:
|
||||
"""Get working memory configuration as a dictionary."""
|
||||
return {
|
||||
"backend": self.working_memory_backend,
|
||||
"default_ttl_seconds": self.working_memory_default_ttl_seconds,
|
||||
"max_items_per_session": self.working_memory_max_items_per_session,
|
||||
"max_value_size_bytes": self.working_memory_max_value_size_bytes,
|
||||
"checkpoint_enabled": self.working_memory_checkpoint_enabled,
|
||||
}
|
||||
|
||||
def get_redis_config(self) -> dict[str, Any]:
|
||||
"""Get Redis configuration as a dictionary."""
|
||||
return {
|
||||
"url": self.redis_url,
|
||||
"prefix": self.redis_prefix,
|
||||
"connection_timeout_seconds": self.redis_connection_timeout_seconds,
|
||||
}
|
||||
|
||||
def get_embedding_config(self) -> dict[str, Any]:
|
||||
"""Get embedding configuration as a dictionary."""
|
||||
return {
|
||||
"model": self.embedding_model,
|
||||
"dimensions": self.embedding_dimensions,
|
||||
"batch_size": self.embedding_batch_size,
|
||||
"cache_enabled": self.embedding_cache_enabled,
|
||||
}
|
||||
|
||||
def get_consolidation_config(self) -> dict[str, Any]:
|
||||
"""Get consolidation configuration as a dictionary."""
|
||||
return {
|
||||
"enabled": self.consolidation_enabled,
|
||||
"batch_size": self.consolidation_batch_size,
|
||||
"schedule_cron": self.consolidation_schedule_cron,
|
||||
"working_to_episodic_delay_minutes": (
|
||||
self.consolidation_working_to_episodic_delay_minutes
|
||||
),
|
||||
}
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert settings to dictionary for logging/debugging."""
|
||||
return {
|
||||
"working_memory": self.get_working_memory_config(),
|
||||
"redis": self.get_redis_config(),
|
||||
"episodic": {
|
||||
"max_episodes_per_project": self.episodic_max_episodes_per_project,
|
||||
"default_importance": self.episodic_default_importance,
|
||||
"retention_days": self.episodic_retention_days,
|
||||
},
|
||||
"semantic": {
|
||||
"max_facts_per_project": self.semantic_max_facts_per_project,
|
||||
"confidence_decay_days": self.semantic_confidence_decay_days,
|
||||
"min_confidence": self.semantic_min_confidence,
|
||||
},
|
||||
"procedural": {
|
||||
"max_procedures_per_project": self.procedural_max_procedures_per_project,
|
||||
"min_success_rate": self.procedural_min_success_rate,
|
||||
"min_uses_before_suggest": self.procedural_min_uses_before_suggest,
|
||||
},
|
||||
"embedding": self.get_embedding_config(),
|
||||
"retrieval": {
|
||||
"default_limit": self.retrieval_default_limit,
|
||||
"max_limit": self.retrieval_max_limit,
|
||||
"min_similarity": self.retrieval_min_similarity,
|
||||
},
|
||||
"consolidation": self.get_consolidation_config(),
|
||||
"pruning": {
|
||||
"enabled": self.pruning_enabled,
|
||||
"min_age_days": self.pruning_min_age_days,
|
||||
"importance_threshold": self.pruning_importance_threshold,
|
||||
},
|
||||
"cache": {
|
||||
"enabled": self.cache_enabled,
|
||||
"ttl_seconds": self.cache_ttl_seconds,
|
||||
"max_items": self.cache_max_items,
|
||||
},
|
||||
"performance": {
|
||||
"max_retrieval_time_ms": self.max_retrieval_time_ms,
|
||||
"parallel_retrieval": self.parallel_retrieval,
|
||||
"max_parallel_retrievals": self.max_parallel_retrievals,
|
||||
},
|
||||
}
|
||||
|
||||
model_config = {
|
||||
"env_prefix": "MEM_",
|
||||
"env_file": ".env",
|
||||
"env_file_encoding": "utf-8",
|
||||
"case_sensitive": False,
|
||||
"extra": "ignore",
|
||||
}
|
||||
|
||||
|
||||
# Thread-safe singleton pattern
|
||||
_settings: MemorySettings | None = None
|
||||
_settings_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_memory_settings() -> MemorySettings:
|
||||
"""
|
||||
Get the global MemorySettings instance.
|
||||
|
||||
Thread-safe with double-checked locking pattern.
|
||||
|
||||
Returns:
|
||||
MemorySettings instance
|
||||
"""
|
||||
global _settings
|
||||
if _settings is None:
|
||||
with _settings_lock:
|
||||
if _settings is None:
|
||||
_settings = MemorySettings()
|
||||
return _settings
|
||||
|
||||
|
||||
def reset_memory_settings() -> None:
|
||||
"""
|
||||
Reset the global settings instance.
|
||||
|
||||
Primarily used for testing.
|
||||
"""
|
||||
global _settings
|
||||
with _settings_lock:
|
||||
_settings = None
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_default_settings() -> MemorySettings:
|
||||
"""
|
||||
Get default settings (cached).
|
||||
|
||||
Use this for read-only access to defaults.
|
||||
For mutable access, use get_memory_settings().
|
||||
"""
|
||||
return MemorySettings()
|
||||
29
backend/app/services/memory/consolidation/__init__.py
Normal file
29
backend/app/services/memory/consolidation/__init__.py
Normal file
@@ -0,0 +1,29 @@
|
||||
# app/services/memory/consolidation/__init__.py
|
||||
"""
|
||||
Memory Consolidation.
|
||||
|
||||
Transfers and extracts knowledge between memory tiers:
|
||||
- Working -> Episodic (session end)
|
||||
- Episodic -> Semantic (learn facts)
|
||||
- Episodic -> Procedural (learn procedures)
|
||||
|
||||
Also handles memory pruning and importance-based retention.
|
||||
"""
|
||||
|
||||
from .service import (
|
||||
ConsolidationConfig,
|
||||
ConsolidationResult,
|
||||
MemoryConsolidationService,
|
||||
NightlyConsolidationResult,
|
||||
SessionConsolidationResult,
|
||||
get_consolidation_service,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ConsolidationConfig",
|
||||
"ConsolidationResult",
|
||||
"MemoryConsolidationService",
|
||||
"NightlyConsolidationResult",
|
||||
"SessionConsolidationResult",
|
||||
"get_consolidation_service",
|
||||
]
|
||||
913
backend/app/services/memory/consolidation/service.py
Normal file
913
backend/app/services/memory/consolidation/service.py
Normal file
@@ -0,0 +1,913 @@
|
||||
# app/services/memory/consolidation/service.py
|
||||
"""
|
||||
Memory Consolidation Service.
|
||||
|
||||
Transfers and extracts knowledge between memory tiers:
|
||||
- Working -> Episodic (session end)
|
||||
- Episodic -> Semantic (learn facts)
|
||||
- Episodic -> Procedural (learn procedures)
|
||||
|
||||
Also handles memory pruning and importance-based retention.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.memory.episodic.memory import EpisodicMemory
|
||||
from app.services.memory.procedural.memory import ProceduralMemory
|
||||
from app.services.memory.semantic.extraction import FactExtractor, get_fact_extractor
|
||||
from app.services.memory.semantic.memory import SemanticMemory
|
||||
from app.services.memory.types import (
|
||||
Episode,
|
||||
EpisodeCreate,
|
||||
Outcome,
|
||||
ProcedureCreate,
|
||||
TaskState,
|
||||
)
|
||||
from app.services.memory.working.memory import WorkingMemory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConsolidationConfig:
|
||||
"""Configuration for memory consolidation."""
|
||||
|
||||
# Working -> Episodic thresholds
|
||||
min_steps_for_episode: int = 2
|
||||
min_duration_seconds: float = 5.0
|
||||
|
||||
# Episodic -> Semantic thresholds
|
||||
min_confidence_for_fact: float = 0.6
|
||||
max_facts_per_episode: int = 10
|
||||
reinforce_existing_facts: bool = True
|
||||
|
||||
# Episodic -> Procedural thresholds
|
||||
min_episodes_for_procedure: int = 3
|
||||
min_success_rate_for_procedure: float = 0.7
|
||||
min_steps_for_procedure: int = 2
|
||||
|
||||
# Pruning thresholds
|
||||
max_episode_age_days: int = 90
|
||||
min_importance_to_keep: float = 0.2
|
||||
keep_all_failures: bool = True
|
||||
keep_all_with_lessons: bool = True
|
||||
|
||||
# Batch sizes
|
||||
batch_size: int = 100
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConsolidationResult:
|
||||
"""Result of a consolidation operation."""
|
||||
|
||||
source_type: str
|
||||
target_type: str
|
||||
items_processed: int = 0
|
||||
items_created: int = 0
|
||||
items_updated: int = 0
|
||||
items_skipped: int = 0
|
||||
items_pruned: int = 0
|
||||
errors: list[str] = field(default_factory=list)
|
||||
duration_seconds: float = 0.0
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"source_type": self.source_type,
|
||||
"target_type": self.target_type,
|
||||
"items_processed": self.items_processed,
|
||||
"items_created": self.items_created,
|
||||
"items_updated": self.items_updated,
|
||||
"items_skipped": self.items_skipped,
|
||||
"items_pruned": self.items_pruned,
|
||||
"errors": self.errors,
|
||||
"duration_seconds": self.duration_seconds,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionConsolidationResult:
|
||||
"""Result of consolidating a session's working memory to episodic."""
|
||||
|
||||
session_id: str
|
||||
episode_created: bool = False
|
||||
episode_id: UUID | None = None
|
||||
scratchpad_entries: int = 0
|
||||
variables_captured: int = 0
|
||||
error: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class NightlyConsolidationResult:
|
||||
"""Result of nightly consolidation run."""
|
||||
|
||||
started_at: datetime
|
||||
completed_at: datetime | None = None
|
||||
episodic_to_semantic: ConsolidationResult | None = None
|
||||
episodic_to_procedural: ConsolidationResult | None = None
|
||||
pruning: ConsolidationResult | None = None
|
||||
total_episodes_processed: int = 0
|
||||
total_facts_created: int = 0
|
||||
total_procedures_created: int = 0
|
||||
total_pruned: int = 0
|
||||
errors: list[str] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"started_at": self.started_at.isoformat(),
|
||||
"completed_at": self.completed_at.isoformat()
|
||||
if self.completed_at
|
||||
else None,
|
||||
"episodic_to_semantic": (
|
||||
self.episodic_to_semantic.to_dict()
|
||||
if self.episodic_to_semantic
|
||||
else None
|
||||
),
|
||||
"episodic_to_procedural": (
|
||||
self.episodic_to_procedural.to_dict()
|
||||
if self.episodic_to_procedural
|
||||
else None
|
||||
),
|
||||
"pruning": self.pruning.to_dict() if self.pruning else None,
|
||||
"total_episodes_processed": self.total_episodes_processed,
|
||||
"total_facts_created": self.total_facts_created,
|
||||
"total_procedures_created": self.total_procedures_created,
|
||||
"total_pruned": self.total_pruned,
|
||||
"errors": self.errors,
|
||||
}
|
||||
|
||||
|
||||
class MemoryConsolidationService:
|
||||
"""
|
||||
Service for consolidating memories between tiers.
|
||||
|
||||
Responsibilities:
|
||||
- Transfer working memory to episodic at session end
|
||||
- Extract facts from episodes to semantic memory
|
||||
- Learn procedures from successful episode patterns
|
||||
- Prune old/low-value memories
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
config: ConsolidationConfig | None = None,
|
||||
embedding_generator: Any | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize consolidation service.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
config: Consolidation configuration
|
||||
embedding_generator: Optional embedding generator
|
||||
"""
|
||||
self._session = session
|
||||
self._config = config or ConsolidationConfig()
|
||||
self._embedding_generator = embedding_generator
|
||||
self._fact_extractor: FactExtractor = get_fact_extractor()
|
||||
|
||||
# Memory services (lazy initialized)
|
||||
self._episodic: EpisodicMemory | None = None
|
||||
self._semantic: SemanticMemory | None = None
|
||||
self._procedural: ProceduralMemory | None = None
|
||||
|
||||
async def _get_episodic(self) -> EpisodicMemory:
|
||||
"""Get or create episodic memory service."""
|
||||
if self._episodic is None:
|
||||
self._episodic = await EpisodicMemory.create(
|
||||
self._session, self._embedding_generator
|
||||
)
|
||||
return self._episodic
|
||||
|
||||
async def _get_semantic(self) -> SemanticMemory:
|
||||
"""Get or create semantic memory service."""
|
||||
if self._semantic is None:
|
||||
self._semantic = await SemanticMemory.create(
|
||||
self._session, self._embedding_generator
|
||||
)
|
||||
return self._semantic
|
||||
|
||||
async def _get_procedural(self) -> ProceduralMemory:
|
||||
"""Get or create procedural memory service."""
|
||||
if self._procedural is None:
|
||||
self._procedural = await ProceduralMemory.create(
|
||||
self._session, self._embedding_generator
|
||||
)
|
||||
return self._procedural
|
||||
|
||||
# =========================================================================
|
||||
# Working -> Episodic Consolidation
|
||||
# =========================================================================
|
||||
|
||||
async def consolidate_session(
|
||||
self,
|
||||
working_memory: WorkingMemory,
|
||||
project_id: UUID,
|
||||
session_id: str,
|
||||
task_type: str = "session_task",
|
||||
agent_instance_id: UUID | None = None,
|
||||
agent_type_id: UUID | None = None,
|
||||
) -> SessionConsolidationResult:
|
||||
"""
|
||||
Consolidate a session's working memory to episodic memory.
|
||||
|
||||
Called at session end to transfer relevant session data
|
||||
into a persistent episode.
|
||||
|
||||
Args:
|
||||
working_memory: The session's working memory
|
||||
project_id: Project ID
|
||||
session_id: Session ID
|
||||
task_type: Type of task performed
|
||||
agent_instance_id: Optional agent instance
|
||||
agent_type_id: Optional agent type
|
||||
|
||||
Returns:
|
||||
SessionConsolidationResult with outcome details
|
||||
"""
|
||||
result = SessionConsolidationResult(session_id=session_id)
|
||||
|
||||
try:
|
||||
# Get task state
|
||||
task_state = await working_memory.get_task_state()
|
||||
|
||||
# Check if there's enough content to consolidate
|
||||
if not self._should_consolidate_session(task_state):
|
||||
logger.debug(
|
||||
f"Skipping consolidation for session {session_id}: insufficient content"
|
||||
)
|
||||
return result
|
||||
|
||||
# Gather scratchpad entries
|
||||
scratchpad = await working_memory.get_scratchpad()
|
||||
result.scratchpad_entries = len(scratchpad)
|
||||
|
||||
# Gather user variables
|
||||
all_data = await working_memory.get_all()
|
||||
result.variables_captured = len(all_data)
|
||||
|
||||
# Determine outcome
|
||||
outcome = self._determine_session_outcome(task_state)
|
||||
|
||||
# Build actions from scratchpad and variables
|
||||
actions = self._build_actions_from_session(scratchpad, all_data, task_state)
|
||||
|
||||
# Build context summary
|
||||
context_summary = self._build_context_summary(task_state, all_data)
|
||||
|
||||
# Extract lessons learned
|
||||
lessons = self._extract_session_lessons(task_state, outcome)
|
||||
|
||||
# Calculate importance
|
||||
importance = self._calculate_session_importance(
|
||||
task_state, outcome, actions
|
||||
)
|
||||
|
||||
# Create episode
|
||||
episode_data = EpisodeCreate(
|
||||
project_id=project_id,
|
||||
session_id=session_id,
|
||||
task_type=task_type,
|
||||
task_description=task_state.description
|
||||
if task_state
|
||||
else "Session task",
|
||||
actions=actions,
|
||||
context_summary=context_summary,
|
||||
outcome=outcome,
|
||||
outcome_details=task_state.status if task_state else "",
|
||||
duration_seconds=self._calculate_duration(task_state),
|
||||
tokens_used=0, # Would need to track this in working memory
|
||||
lessons_learned=lessons,
|
||||
importance_score=importance,
|
||||
agent_instance_id=agent_instance_id,
|
||||
agent_type_id=agent_type_id,
|
||||
)
|
||||
|
||||
episodic = await self._get_episodic()
|
||||
episode = await episodic.record_episode(episode_data)
|
||||
|
||||
result.episode_created = True
|
||||
result.episode_id = episode.id
|
||||
|
||||
logger.info(
|
||||
f"Consolidated session {session_id} to episode {episode.id} "
|
||||
f"({len(actions)} actions, outcome={outcome.value})"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
result.error = str(e)
|
||||
logger.exception(f"Failed to consolidate session {session_id}")
|
||||
|
||||
return result
|
||||
|
||||
def _should_consolidate_session(self, task_state: TaskState | None) -> bool:
|
||||
"""Check if session has enough content to consolidate."""
|
||||
if task_state is None:
|
||||
return False
|
||||
|
||||
# Check minimum steps
|
||||
if task_state.current_step < self._config.min_steps_for_episode:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _determine_session_outcome(self, task_state: TaskState | None) -> Outcome:
|
||||
"""Determine outcome from task state."""
|
||||
if task_state is None:
|
||||
return Outcome.PARTIAL
|
||||
|
||||
status = task_state.status.lower() if task_state.status else ""
|
||||
progress = task_state.progress_percent
|
||||
|
||||
if "success" in status or "complete" in status or progress >= 100:
|
||||
return Outcome.SUCCESS
|
||||
if "fail" in status or "error" in status:
|
||||
return Outcome.FAILURE
|
||||
if progress >= 50:
|
||||
return Outcome.PARTIAL
|
||||
return Outcome.FAILURE
|
||||
|
||||
def _build_actions_from_session(
|
||||
self,
|
||||
scratchpad: list[str],
|
||||
variables: dict[str, Any],
|
||||
task_state: TaskState | None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Build action list from session data."""
|
||||
actions: list[dict[str, Any]] = []
|
||||
|
||||
# Add scratchpad entries as actions
|
||||
for i, entry in enumerate(scratchpad):
|
||||
actions.append(
|
||||
{
|
||||
"step": i + 1,
|
||||
"type": "reasoning",
|
||||
"content": entry[:500], # Truncate long entries
|
||||
}
|
||||
)
|
||||
|
||||
# Add final state
|
||||
if task_state:
|
||||
actions.append(
|
||||
{
|
||||
"step": len(scratchpad) + 1,
|
||||
"type": "final_state",
|
||||
"current_step": task_state.current_step,
|
||||
"total_steps": task_state.total_steps,
|
||||
"progress": task_state.progress_percent,
|
||||
"status": task_state.status,
|
||||
}
|
||||
)
|
||||
|
||||
return actions
|
||||
|
||||
def _build_context_summary(
|
||||
self,
|
||||
task_state: TaskState | None,
|
||||
variables: dict[str, Any],
|
||||
) -> str:
|
||||
"""Build context summary from session data."""
|
||||
parts = []
|
||||
|
||||
if task_state:
|
||||
parts.append(f"Task: {task_state.description}")
|
||||
parts.append(f"Progress: {task_state.progress_percent:.1f}%")
|
||||
parts.append(f"Steps: {task_state.current_step}/{task_state.total_steps}")
|
||||
|
||||
# Include key variables
|
||||
key_vars = {k: v for k, v in variables.items() if len(str(v)) < 100}
|
||||
if key_vars:
|
||||
var_str = ", ".join(f"{k}={v}" for k, v in list(key_vars.items())[:5])
|
||||
parts.append(f"Variables: {var_str}")
|
||||
|
||||
return "; ".join(parts) if parts else "Session completed"
|
||||
|
||||
def _extract_session_lessons(
|
||||
self,
|
||||
task_state: TaskState | None,
|
||||
outcome: Outcome,
|
||||
) -> list[str]:
|
||||
"""Extract lessons from session."""
|
||||
lessons: list[str] = []
|
||||
|
||||
if task_state and task_state.status:
|
||||
if outcome == Outcome.FAILURE:
|
||||
lessons.append(
|
||||
f"Task failed at step {task_state.current_step}: {task_state.status}"
|
||||
)
|
||||
elif outcome == Outcome.SUCCESS:
|
||||
lessons.append(
|
||||
f"Successfully completed in {task_state.current_step} steps"
|
||||
)
|
||||
|
||||
return lessons
|
||||
|
||||
def _calculate_session_importance(
|
||||
self,
|
||||
task_state: TaskState | None,
|
||||
outcome: Outcome,
|
||||
actions: list[dict[str, Any]],
|
||||
) -> float:
|
||||
"""Calculate importance score for session."""
|
||||
score = 0.5 # Base score
|
||||
|
||||
# Failures are important to learn from
|
||||
if outcome == Outcome.FAILURE:
|
||||
score += 0.3
|
||||
|
||||
# Many steps means complex task
|
||||
if task_state and task_state.total_steps >= 5:
|
||||
score += 0.1
|
||||
|
||||
# Many actions means detailed reasoning
|
||||
if len(actions) >= 5:
|
||||
score += 0.1
|
||||
|
||||
return min(1.0, score)
|
||||
|
||||
def _calculate_duration(self, task_state: TaskState | None) -> float:
|
||||
"""Calculate session duration."""
|
||||
if task_state is None:
|
||||
return 0.0
|
||||
|
||||
if task_state.started_at and task_state.updated_at:
|
||||
delta = task_state.updated_at - task_state.started_at
|
||||
return delta.total_seconds()
|
||||
|
||||
return 0.0
|
||||
|
||||
# =========================================================================
|
||||
# Episodic -> Semantic Consolidation
|
||||
# =========================================================================
|
||||
|
||||
async def consolidate_episodes_to_facts(
|
||||
self,
|
||||
project_id: UUID,
|
||||
since: datetime | None = None,
|
||||
limit: int | None = None,
|
||||
) -> ConsolidationResult:
|
||||
"""
|
||||
Extract facts from episodic memories to semantic memory.
|
||||
|
||||
Args:
|
||||
project_id: Project to consolidate
|
||||
since: Only process episodes since this time
|
||||
limit: Maximum episodes to process
|
||||
|
||||
Returns:
|
||||
ConsolidationResult with extraction statistics
|
||||
"""
|
||||
start_time = datetime.now(UTC)
|
||||
result = ConsolidationResult(
|
||||
source_type="episodic",
|
||||
target_type="semantic",
|
||||
)
|
||||
|
||||
try:
|
||||
episodic = await self._get_episodic()
|
||||
semantic = await self._get_semantic()
|
||||
|
||||
# Get episodes to process
|
||||
since_time = since or datetime.now(UTC) - timedelta(days=1)
|
||||
episodes = await episodic.get_recent(
|
||||
project_id,
|
||||
limit=limit or self._config.batch_size,
|
||||
since=since_time,
|
||||
)
|
||||
|
||||
for episode in episodes:
|
||||
result.items_processed += 1
|
||||
|
||||
try:
|
||||
# Extract facts using the extractor
|
||||
extracted_facts = self._fact_extractor.extract_from_episode(episode)
|
||||
|
||||
for extracted_fact in extracted_facts:
|
||||
if (
|
||||
extracted_fact.confidence
|
||||
< self._config.min_confidence_for_fact
|
||||
):
|
||||
result.items_skipped += 1
|
||||
continue
|
||||
|
||||
# Create fact (store_fact handles deduplication/reinforcement)
|
||||
fact_create = extracted_fact.to_fact_create(
|
||||
project_id=project_id,
|
||||
source_episode_ids=[episode.id],
|
||||
)
|
||||
|
||||
# store_fact automatically reinforces if fact already exists
|
||||
fact = await semantic.store_fact(fact_create)
|
||||
|
||||
# Check if this was a new fact or reinforced existing
|
||||
if fact.reinforcement_count == 1:
|
||||
result.items_created += 1
|
||||
else:
|
||||
result.items_updated += 1
|
||||
|
||||
except Exception as e:
|
||||
result.errors.append(f"Episode {episode.id}: {e}")
|
||||
logger.warning(
|
||||
f"Failed to extract facts from episode {episode.id}: {e}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
result.errors.append(f"Consolidation failed: {e}")
|
||||
logger.exception("Failed episodic -> semantic consolidation")
|
||||
|
||||
result.duration_seconds = (datetime.now(UTC) - start_time).total_seconds()
|
||||
|
||||
logger.info(
|
||||
f"Episodic -> Semantic consolidation: "
|
||||
f"{result.items_processed} processed, "
|
||||
f"{result.items_created} created, "
|
||||
f"{result.items_updated} reinforced"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
# =========================================================================
|
||||
# Episodic -> Procedural Consolidation
|
||||
# =========================================================================
|
||||
|
||||
async def consolidate_episodes_to_procedures(
|
||||
self,
|
||||
project_id: UUID,
|
||||
agent_type_id: UUID | None = None,
|
||||
since: datetime | None = None,
|
||||
) -> ConsolidationResult:
|
||||
"""
|
||||
Learn procedures from patterns in episodic memories.
|
||||
|
||||
Identifies recurring successful patterns and creates/updates
|
||||
procedures to capture them.
|
||||
|
||||
Args:
|
||||
project_id: Project to consolidate
|
||||
agent_type_id: Optional filter by agent type
|
||||
since: Only process episodes since this time
|
||||
|
||||
Returns:
|
||||
ConsolidationResult with procedure statistics
|
||||
"""
|
||||
start_time = datetime.now(UTC)
|
||||
result = ConsolidationResult(
|
||||
source_type="episodic",
|
||||
target_type="procedural",
|
||||
)
|
||||
|
||||
try:
|
||||
episodic = await self._get_episodic()
|
||||
procedural = await self._get_procedural()
|
||||
|
||||
# Get successful episodes
|
||||
since_time = since or datetime.now(UTC) - timedelta(days=7)
|
||||
episodes = await episodic.get_by_outcome(
|
||||
project_id,
|
||||
outcome=Outcome.SUCCESS,
|
||||
limit=self._config.batch_size,
|
||||
agent_instance_id=None, # Get all agent instances
|
||||
)
|
||||
|
||||
# Group by task type
|
||||
task_groups: dict[str, list[Episode]] = {}
|
||||
for episode in episodes:
|
||||
if episode.occurred_at >= since_time:
|
||||
if episode.task_type not in task_groups:
|
||||
task_groups[episode.task_type] = []
|
||||
task_groups[episode.task_type].append(episode)
|
||||
|
||||
result.items_processed = len(episodes)
|
||||
|
||||
# Process each task type group
|
||||
for task_type, group in task_groups.items():
|
||||
if len(group) < self._config.min_episodes_for_procedure:
|
||||
result.items_skipped += len(group)
|
||||
continue
|
||||
|
||||
try:
|
||||
procedure_result = await self._learn_procedure_from_episodes(
|
||||
procedural,
|
||||
project_id,
|
||||
agent_type_id,
|
||||
task_type,
|
||||
group,
|
||||
)
|
||||
|
||||
if procedure_result == "created":
|
||||
result.items_created += 1
|
||||
elif procedure_result == "updated":
|
||||
result.items_updated += 1
|
||||
else:
|
||||
result.items_skipped += 1
|
||||
|
||||
except Exception as e:
|
||||
result.errors.append(f"Task type '{task_type}': {e}")
|
||||
logger.warning(f"Failed to learn procedure for '{task_type}': {e}")
|
||||
|
||||
except Exception as e:
|
||||
result.errors.append(f"Consolidation failed: {e}")
|
||||
logger.exception("Failed episodic -> procedural consolidation")
|
||||
|
||||
result.duration_seconds = (datetime.now(UTC) - start_time).total_seconds()
|
||||
|
||||
logger.info(
|
||||
f"Episodic -> Procedural consolidation: "
|
||||
f"{result.items_processed} processed, "
|
||||
f"{result.items_created} created, "
|
||||
f"{result.items_updated} updated"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def _learn_procedure_from_episodes(
|
||||
self,
|
||||
procedural: ProceduralMemory,
|
||||
project_id: UUID,
|
||||
agent_type_id: UUID | None,
|
||||
task_type: str,
|
||||
episodes: list[Episode],
|
||||
) -> str:
|
||||
"""Learn or update a procedure from a set of episodes."""
|
||||
# Calculate success rate for this pattern
|
||||
success_count = sum(1 for e in episodes if e.outcome == Outcome.SUCCESS)
|
||||
total_count = len(episodes)
|
||||
success_rate = success_count / total_count if total_count > 0 else 0
|
||||
|
||||
if success_rate < self._config.min_success_rate_for_procedure:
|
||||
return "skipped"
|
||||
|
||||
# Extract common steps from episodes
|
||||
steps = self._extract_common_steps(episodes)
|
||||
|
||||
if len(steps) < self._config.min_steps_for_procedure:
|
||||
return "skipped"
|
||||
|
||||
# Check for existing procedure
|
||||
matching = await procedural.find_matching(
|
||||
context=task_type,
|
||||
project_id=project_id,
|
||||
agent_type_id=agent_type_id,
|
||||
limit=1,
|
||||
)
|
||||
|
||||
if matching:
|
||||
# Update existing procedure with new success
|
||||
await procedural.record_outcome(
|
||||
matching[0].id,
|
||||
success=True,
|
||||
)
|
||||
return "updated"
|
||||
else:
|
||||
# Create new procedure
|
||||
# Note: success_count starts at 1 in record_procedure
|
||||
procedure_data = ProcedureCreate(
|
||||
project_id=project_id,
|
||||
agent_type_id=agent_type_id,
|
||||
name=f"Procedure for {task_type}",
|
||||
trigger_pattern=task_type,
|
||||
steps=steps,
|
||||
)
|
||||
await procedural.record_procedure(procedure_data)
|
||||
return "created"
|
||||
|
||||
def _extract_common_steps(self, episodes: list[Episode]) -> list[dict[str, Any]]:
|
||||
"""Extract common action steps from multiple episodes."""
|
||||
# Simple heuristic: take the steps from the most successful episode
|
||||
# with the most detailed actions
|
||||
|
||||
best_episode = max(
|
||||
episodes,
|
||||
key=lambda e: (
|
||||
e.outcome == Outcome.SUCCESS,
|
||||
len(e.actions),
|
||||
e.importance_score,
|
||||
),
|
||||
)
|
||||
|
||||
steps: list[dict[str, Any]] = []
|
||||
for i, action in enumerate(best_episode.actions):
|
||||
step = {
|
||||
"order": i + 1,
|
||||
"action": action.get("type", "action"),
|
||||
"description": action.get("content", str(action))[:500],
|
||||
"parameters": action,
|
||||
}
|
||||
steps.append(step)
|
||||
|
||||
return steps
|
||||
|
||||
# =========================================================================
|
||||
# Memory Pruning
|
||||
# =========================================================================
|
||||
|
||||
async def prune_old_episodes(
|
||||
self,
|
||||
project_id: UUID,
|
||||
max_age_days: int | None = None,
|
||||
min_importance: float | None = None,
|
||||
) -> ConsolidationResult:
|
||||
"""
|
||||
Prune old, low-value episodes.
|
||||
|
||||
Args:
|
||||
project_id: Project to prune
|
||||
max_age_days: Maximum age in days (default from config)
|
||||
min_importance: Minimum importance to keep (default from config)
|
||||
|
||||
Returns:
|
||||
ConsolidationResult with pruning statistics
|
||||
"""
|
||||
start_time = datetime.now(UTC)
|
||||
result = ConsolidationResult(
|
||||
source_type="episodic",
|
||||
target_type="pruned",
|
||||
)
|
||||
|
||||
max_age = max_age_days or self._config.max_episode_age_days
|
||||
min_imp = min_importance or self._config.min_importance_to_keep
|
||||
cutoff_date = datetime.now(UTC) - timedelta(days=max_age)
|
||||
|
||||
try:
|
||||
episodic = await self._get_episodic()
|
||||
|
||||
# Get old episodes
|
||||
# Note: In production, this would use a more efficient query
|
||||
all_episodes = await episodic.get_recent(
|
||||
project_id,
|
||||
limit=self._config.batch_size * 10,
|
||||
since=cutoff_date - timedelta(days=365), # Search past year
|
||||
)
|
||||
|
||||
for episode in all_episodes:
|
||||
result.items_processed += 1
|
||||
|
||||
# Check if should be pruned
|
||||
if not self._should_prune_episode(episode, cutoff_date, min_imp):
|
||||
result.items_skipped += 1
|
||||
continue
|
||||
|
||||
try:
|
||||
deleted = await episodic.delete(episode.id)
|
||||
if deleted:
|
||||
result.items_pruned += 1
|
||||
else:
|
||||
result.items_skipped += 1
|
||||
except Exception as e:
|
||||
result.errors.append(f"Episode {episode.id}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
result.errors.append(f"Pruning failed: {e}")
|
||||
logger.exception("Failed episode pruning")
|
||||
|
||||
result.duration_seconds = (datetime.now(UTC) - start_time).total_seconds()
|
||||
|
||||
logger.info(
|
||||
f"Episode pruning: {result.items_processed} processed, "
|
||||
f"{result.items_pruned} pruned"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _should_prune_episode(
|
||||
self,
|
||||
episode: Episode,
|
||||
cutoff_date: datetime,
|
||||
min_importance: float,
|
||||
) -> bool:
|
||||
"""Determine if an episode should be pruned."""
|
||||
# Keep recent episodes
|
||||
if episode.occurred_at >= cutoff_date:
|
||||
return False
|
||||
|
||||
# Keep failures if configured
|
||||
if self._config.keep_all_failures and episode.outcome == Outcome.FAILURE:
|
||||
return False
|
||||
|
||||
# Keep episodes with lessons if configured
|
||||
if self._config.keep_all_with_lessons and episode.lessons_learned:
|
||||
return False
|
||||
|
||||
# Keep high-importance episodes
|
||||
if episode.importance_score >= min_importance:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
# =========================================================================
|
||||
# Nightly Consolidation
|
||||
# =========================================================================
|
||||
|
||||
async def run_nightly_consolidation(
|
||||
self,
|
||||
project_id: UUID,
|
||||
agent_type_id: UUID | None = None,
|
||||
) -> NightlyConsolidationResult:
|
||||
"""
|
||||
Run full nightly consolidation workflow.
|
||||
|
||||
This includes:
|
||||
1. Extract facts from recent episodes
|
||||
2. Learn procedures from successful patterns
|
||||
3. Prune old, low-value memories
|
||||
|
||||
Args:
|
||||
project_id: Project to consolidate
|
||||
agent_type_id: Optional agent type filter
|
||||
|
||||
Returns:
|
||||
NightlyConsolidationResult with all outcomes
|
||||
"""
|
||||
result = NightlyConsolidationResult(started_at=datetime.now(UTC))
|
||||
|
||||
logger.info(f"Starting nightly consolidation for project {project_id}")
|
||||
|
||||
try:
|
||||
# Step 1: Episodic -> Semantic (last 24 hours)
|
||||
since_yesterday = datetime.now(UTC) - timedelta(days=1)
|
||||
result.episodic_to_semantic = await self.consolidate_episodes_to_facts(
|
||||
project_id=project_id,
|
||||
since=since_yesterday,
|
||||
)
|
||||
result.total_facts_created = result.episodic_to_semantic.items_created
|
||||
|
||||
# Step 2: Episodic -> Procedural (last 7 days)
|
||||
since_week = datetime.now(UTC) - timedelta(days=7)
|
||||
result.episodic_to_procedural = (
|
||||
await self.consolidate_episodes_to_procedures(
|
||||
project_id=project_id,
|
||||
agent_type_id=agent_type_id,
|
||||
since=since_week,
|
||||
)
|
||||
)
|
||||
result.total_procedures_created = (
|
||||
result.episodic_to_procedural.items_created
|
||||
)
|
||||
|
||||
# Step 3: Prune old memories
|
||||
result.pruning = await self.prune_old_episodes(project_id=project_id)
|
||||
result.total_pruned = result.pruning.items_pruned
|
||||
|
||||
# Calculate totals
|
||||
result.total_episodes_processed = (
|
||||
result.episodic_to_semantic.items_processed
|
||||
if result.episodic_to_semantic
|
||||
else 0
|
||||
) + (
|
||||
result.episodic_to_procedural.items_processed
|
||||
if result.episodic_to_procedural
|
||||
else 0
|
||||
)
|
||||
|
||||
# Collect all errors
|
||||
if result.episodic_to_semantic and result.episodic_to_semantic.errors:
|
||||
result.errors.extend(result.episodic_to_semantic.errors)
|
||||
if result.episodic_to_procedural and result.episodic_to_procedural.errors:
|
||||
result.errors.extend(result.episodic_to_procedural.errors)
|
||||
if result.pruning and result.pruning.errors:
|
||||
result.errors.extend(result.pruning.errors)
|
||||
|
||||
except Exception as e:
|
||||
result.errors.append(f"Nightly consolidation failed: {e}")
|
||||
logger.exception("Nightly consolidation failed")
|
||||
|
||||
result.completed_at = datetime.now(UTC)
|
||||
|
||||
duration = (result.completed_at - result.started_at).total_seconds()
|
||||
logger.info(
|
||||
f"Nightly consolidation completed in {duration:.1f}s: "
|
||||
f"{result.total_facts_created} facts, "
|
||||
f"{result.total_procedures_created} procedures, "
|
||||
f"{result.total_pruned} pruned"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# Factory function - no singleton to avoid stale session issues
|
||||
async def get_consolidation_service(
|
||||
session: AsyncSession,
|
||||
config: ConsolidationConfig | None = None,
|
||||
) -> MemoryConsolidationService:
|
||||
"""
|
||||
Create a memory consolidation service for the given session.
|
||||
|
||||
Note: This creates a new instance each time to avoid stale session issues.
|
||||
The service is lightweight and safe to recreate per-request.
|
||||
|
||||
Args:
|
||||
session: Database session (must be active)
|
||||
config: Optional configuration
|
||||
|
||||
Returns:
|
||||
MemoryConsolidationService instance
|
||||
"""
|
||||
return MemoryConsolidationService(session=session, config=config)
|
||||
17
backend/app/services/memory/episodic/__init__.py
Normal file
17
backend/app/services/memory/episodic/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
# app/services/memory/episodic/__init__.py
|
||||
"""
|
||||
Episodic Memory Package.
|
||||
|
||||
Provides experiential memory storage and retrieval for agent learning.
|
||||
"""
|
||||
|
||||
from .memory import EpisodicMemory
|
||||
from .recorder import EpisodeRecorder
|
||||
from .retrieval import EpisodeRetriever, RetrievalStrategy
|
||||
|
||||
__all__ = [
|
||||
"EpisodeRecorder",
|
||||
"EpisodeRetriever",
|
||||
"EpisodicMemory",
|
||||
"RetrievalStrategy",
|
||||
]
|
||||
490
backend/app/services/memory/episodic/memory.py
Normal file
490
backend/app/services/memory/episodic/memory.py
Normal file
@@ -0,0 +1,490 @@
|
||||
# app/services/memory/episodic/memory.py
|
||||
"""
|
||||
Episodic Memory Implementation.
|
||||
|
||||
Provides experiential memory storage and retrieval for agent learning.
|
||||
Combines episode recording and retrieval into a unified interface.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.memory.types import Episode, EpisodeCreate, Outcome, RetrievalResult
|
||||
|
||||
from .recorder import EpisodeRecorder
|
||||
from .retrieval import EpisodeRetriever, RetrievalStrategy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EpisodicMemory:
|
||||
"""
|
||||
Episodic Memory Service.
|
||||
|
||||
Provides experiential memory for agent learning:
|
||||
- Record task completions with context
|
||||
- Store failures with error context
|
||||
- Retrieve by semantic similarity
|
||||
- Retrieve by recency, outcome, task type
|
||||
- Track importance scores
|
||||
- Extract lessons learned
|
||||
|
||||
Performance target: <100ms P95 for retrieval
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
embedding_generator: Any | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize episodic memory.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
embedding_generator: Optional embedding generator for semantic search
|
||||
"""
|
||||
self._session = session
|
||||
self._embedding_generator = embedding_generator
|
||||
self._recorder = EpisodeRecorder(session, embedding_generator)
|
||||
self._retriever = EpisodeRetriever(session, embedding_generator)
|
||||
|
||||
@classmethod
|
||||
async def create(
|
||||
cls,
|
||||
session: AsyncSession,
|
||||
embedding_generator: Any | None = None,
|
||||
) -> "EpisodicMemory":
|
||||
"""
|
||||
Factory method to create EpisodicMemory.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
embedding_generator: Optional embedding generator
|
||||
|
||||
Returns:
|
||||
Configured EpisodicMemory instance
|
||||
"""
|
||||
return cls(session=session, embedding_generator=embedding_generator)
|
||||
|
||||
# =========================================================================
|
||||
# Recording Operations
|
||||
# =========================================================================
|
||||
|
||||
async def record_episode(self, episode: EpisodeCreate) -> Episode:
|
||||
"""
|
||||
Record a new episode.
|
||||
|
||||
Args:
|
||||
episode: Episode data to record
|
||||
|
||||
Returns:
|
||||
The created episode with assigned ID
|
||||
"""
|
||||
return await self._recorder.record(episode)
|
||||
|
||||
async def record_success(
|
||||
self,
|
||||
project_id: UUID,
|
||||
session_id: str,
|
||||
task_type: str,
|
||||
task_description: str,
|
||||
actions: list[dict[str, Any]],
|
||||
context_summary: str,
|
||||
outcome_details: str = "",
|
||||
duration_seconds: float = 0.0,
|
||||
tokens_used: int = 0,
|
||||
lessons_learned: list[str] | None = None,
|
||||
agent_instance_id: UUID | None = None,
|
||||
agent_type_id: UUID | None = None,
|
||||
) -> Episode:
|
||||
"""
|
||||
Convenience method to record a successful episode.
|
||||
|
||||
Args:
|
||||
project_id: Project ID
|
||||
session_id: Session ID
|
||||
task_type: Type of task
|
||||
task_description: Task description
|
||||
actions: Actions taken
|
||||
context_summary: Context summary
|
||||
outcome_details: Optional outcome details
|
||||
duration_seconds: Task duration
|
||||
tokens_used: Tokens consumed
|
||||
lessons_learned: Optional lessons
|
||||
agent_instance_id: Optional agent instance
|
||||
agent_type_id: Optional agent type
|
||||
|
||||
Returns:
|
||||
The created episode
|
||||
"""
|
||||
episode_data = EpisodeCreate(
|
||||
project_id=project_id,
|
||||
session_id=session_id,
|
||||
task_type=task_type,
|
||||
task_description=task_description,
|
||||
actions=actions,
|
||||
context_summary=context_summary,
|
||||
outcome=Outcome.SUCCESS,
|
||||
outcome_details=outcome_details,
|
||||
duration_seconds=duration_seconds,
|
||||
tokens_used=tokens_used,
|
||||
lessons_learned=lessons_learned or [],
|
||||
agent_instance_id=agent_instance_id,
|
||||
agent_type_id=agent_type_id,
|
||||
)
|
||||
return await self.record_episode(episode_data)
|
||||
|
||||
async def record_failure(
|
||||
self,
|
||||
project_id: UUID,
|
||||
session_id: str,
|
||||
task_type: str,
|
||||
task_description: str,
|
||||
actions: list[dict[str, Any]],
|
||||
context_summary: str,
|
||||
error_details: str,
|
||||
duration_seconds: float = 0.0,
|
||||
tokens_used: int = 0,
|
||||
lessons_learned: list[str] | None = None,
|
||||
agent_instance_id: UUID | None = None,
|
||||
agent_type_id: UUID | None = None,
|
||||
) -> Episode:
|
||||
"""
|
||||
Convenience method to record a failed episode.
|
||||
|
||||
Args:
|
||||
project_id: Project ID
|
||||
session_id: Session ID
|
||||
task_type: Type of task
|
||||
task_description: Task description
|
||||
actions: Actions taken before failure
|
||||
context_summary: Context summary
|
||||
error_details: Error details
|
||||
duration_seconds: Task duration
|
||||
tokens_used: Tokens consumed
|
||||
lessons_learned: Optional lessons from failure
|
||||
agent_instance_id: Optional agent instance
|
||||
agent_type_id: Optional agent type
|
||||
|
||||
Returns:
|
||||
The created episode
|
||||
"""
|
||||
episode_data = EpisodeCreate(
|
||||
project_id=project_id,
|
||||
session_id=session_id,
|
||||
task_type=task_type,
|
||||
task_description=task_description,
|
||||
actions=actions,
|
||||
context_summary=context_summary,
|
||||
outcome=Outcome.FAILURE,
|
||||
outcome_details=error_details,
|
||||
duration_seconds=duration_seconds,
|
||||
tokens_used=tokens_used,
|
||||
lessons_learned=lessons_learned or [],
|
||||
agent_instance_id=agent_instance_id,
|
||||
agent_type_id=agent_type_id,
|
||||
)
|
||||
return await self.record_episode(episode_data)
|
||||
|
||||
# =========================================================================
|
||||
# Retrieval Operations
|
||||
# =========================================================================
|
||||
|
||||
async def search_similar(
|
||||
self,
|
||||
project_id: UUID,
|
||||
query: str,
|
||||
limit: int = 10,
|
||||
agent_instance_id: UUID | None = None,
|
||||
) -> list[Episode]:
|
||||
"""
|
||||
Search for semantically similar episodes.
|
||||
|
||||
Args:
|
||||
project_id: Project to search within
|
||||
query: Search query
|
||||
limit: Maximum results
|
||||
agent_instance_id: Optional filter by agent instance
|
||||
|
||||
Returns:
|
||||
List of similar episodes
|
||||
"""
|
||||
result = await self._retriever.search_similar(
|
||||
project_id, query, limit, agent_instance_id
|
||||
)
|
||||
return result.items
|
||||
|
||||
async def get_recent(
|
||||
self,
|
||||
project_id: UUID,
|
||||
limit: int = 10,
|
||||
since: datetime | None = None,
|
||||
agent_instance_id: UUID | None = None,
|
||||
) -> list[Episode]:
|
||||
"""
|
||||
Get recent episodes.
|
||||
|
||||
Args:
|
||||
project_id: Project to search within
|
||||
limit: Maximum results
|
||||
since: Optional time filter
|
||||
agent_instance_id: Optional filter by agent instance
|
||||
|
||||
Returns:
|
||||
List of recent episodes
|
||||
"""
|
||||
result = await self._retriever.get_recent(
|
||||
project_id, limit, since, agent_instance_id
|
||||
)
|
||||
return result.items
|
||||
|
||||
async def get_by_outcome(
|
||||
self,
|
||||
project_id: UUID,
|
||||
outcome: Outcome,
|
||||
limit: int = 10,
|
||||
agent_instance_id: UUID | None = None,
|
||||
) -> list[Episode]:
|
||||
"""
|
||||
Get episodes by outcome.
|
||||
|
||||
Args:
|
||||
project_id: Project to search within
|
||||
outcome: Outcome to filter by
|
||||
limit: Maximum results
|
||||
agent_instance_id: Optional filter by agent instance
|
||||
|
||||
Returns:
|
||||
List of episodes with specified outcome
|
||||
"""
|
||||
result = await self._retriever.get_by_outcome(
|
||||
project_id, outcome, limit, agent_instance_id
|
||||
)
|
||||
return result.items
|
||||
|
||||
async def get_by_task_type(
|
||||
self,
|
||||
project_id: UUID,
|
||||
task_type: str,
|
||||
limit: int = 10,
|
||||
agent_instance_id: UUID | None = None,
|
||||
) -> list[Episode]:
|
||||
"""
|
||||
Get episodes by task type.
|
||||
|
||||
Args:
|
||||
project_id: Project to search within
|
||||
task_type: Task type to filter by
|
||||
limit: Maximum results
|
||||
agent_instance_id: Optional filter by agent instance
|
||||
|
||||
Returns:
|
||||
List of episodes with specified task type
|
||||
"""
|
||||
result = await self._retriever.get_by_task_type(
|
||||
project_id, task_type, limit, agent_instance_id
|
||||
)
|
||||
return result.items
|
||||
|
||||
async def get_important(
|
||||
self,
|
||||
project_id: UUID,
|
||||
limit: int = 10,
|
||||
min_importance: float = 0.7,
|
||||
agent_instance_id: UUID | None = None,
|
||||
) -> list[Episode]:
|
||||
"""
|
||||
Get high-importance episodes.
|
||||
|
||||
Args:
|
||||
project_id: Project to search within
|
||||
limit: Maximum results
|
||||
min_importance: Minimum importance score
|
||||
agent_instance_id: Optional filter by agent instance
|
||||
|
||||
Returns:
|
||||
List of important episodes
|
||||
"""
|
||||
result = await self._retriever.get_important(
|
||||
project_id, limit, min_importance, agent_instance_id
|
||||
)
|
||||
return result.items
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
project_id: UUID,
|
||||
strategy: RetrievalStrategy = RetrievalStrategy.RECENCY,
|
||||
limit: int = 10,
|
||||
**kwargs: Any,
|
||||
) -> RetrievalResult[Episode]:
|
||||
"""
|
||||
Retrieve episodes with full result metadata.
|
||||
|
||||
Args:
|
||||
project_id: Project to search within
|
||||
strategy: Retrieval strategy
|
||||
limit: Maximum results
|
||||
**kwargs: Strategy-specific parameters
|
||||
|
||||
Returns:
|
||||
RetrievalResult with episodes and metadata
|
||||
"""
|
||||
return await self._retriever.retrieve(project_id, strategy, limit, **kwargs)
|
||||
|
||||
# =========================================================================
|
||||
# Modification Operations
|
||||
# =========================================================================
|
||||
|
||||
async def get_by_id(self, episode_id: UUID) -> Episode | None:
|
||||
"""Get an episode by ID."""
|
||||
return await self._recorder.get_by_id(episode_id)
|
||||
|
||||
async def update_importance(
|
||||
self,
|
||||
episode_id: UUID,
|
||||
importance_score: float,
|
||||
) -> Episode | None:
|
||||
"""
|
||||
Update an episode's importance score.
|
||||
|
||||
Args:
|
||||
episode_id: Episode to update
|
||||
importance_score: New importance score (0.0 to 1.0)
|
||||
|
||||
Returns:
|
||||
Updated episode or None if not found
|
||||
"""
|
||||
return await self._recorder.update_importance(episode_id, importance_score)
|
||||
|
||||
async def add_lessons(
|
||||
self,
|
||||
episode_id: UUID,
|
||||
lessons: list[str],
|
||||
) -> Episode | None:
|
||||
"""
|
||||
Add lessons learned to an episode.
|
||||
|
||||
Args:
|
||||
episode_id: Episode to update
|
||||
lessons: Lessons to add
|
||||
|
||||
Returns:
|
||||
Updated episode or None if not found
|
||||
"""
|
||||
return await self._recorder.add_lessons(episode_id, lessons)
|
||||
|
||||
async def delete(self, episode_id: UUID) -> bool:
|
||||
"""
|
||||
Delete an episode.
|
||||
|
||||
Args:
|
||||
episode_id: Episode to delete
|
||||
|
||||
Returns:
|
||||
True if deleted
|
||||
"""
|
||||
return await self._recorder.delete(episode_id)
|
||||
|
||||
# =========================================================================
|
||||
# Summarization
|
||||
# =========================================================================
|
||||
|
||||
async def summarize_episodes(
|
||||
self,
|
||||
episode_ids: list[UUID],
|
||||
) -> str:
|
||||
"""
|
||||
Summarize multiple episodes into a consolidated view.
|
||||
|
||||
Args:
|
||||
episode_ids: Episodes to summarize
|
||||
|
||||
Returns:
|
||||
Summary text
|
||||
"""
|
||||
if not episode_ids:
|
||||
return "No episodes to summarize."
|
||||
|
||||
episodes: list[Episode] = []
|
||||
for episode_id in episode_ids:
|
||||
episode = await self.get_by_id(episode_id)
|
||||
if episode:
|
||||
episodes.append(episode)
|
||||
|
||||
if not episodes:
|
||||
return "No episodes found."
|
||||
|
||||
# Build summary
|
||||
lines = [f"Summary of {len(episodes)} episodes:", ""]
|
||||
|
||||
# Outcome breakdown
|
||||
success = sum(1 for e in episodes if e.outcome == Outcome.SUCCESS)
|
||||
failure = sum(1 for e in episodes if e.outcome == Outcome.FAILURE)
|
||||
partial = sum(1 for e in episodes if e.outcome == Outcome.PARTIAL)
|
||||
lines.append(
|
||||
f"Outcomes: {success} success, {failure} failure, {partial} partial"
|
||||
)
|
||||
|
||||
# Task types
|
||||
task_types = {e.task_type for e in episodes}
|
||||
lines.append(f"Task types: {', '.join(sorted(task_types))}")
|
||||
|
||||
# Aggregate lessons
|
||||
all_lessons: list[str] = []
|
||||
for e in episodes:
|
||||
all_lessons.extend(e.lessons_learned)
|
||||
|
||||
if all_lessons:
|
||||
lines.append("")
|
||||
lines.append("Key lessons learned:")
|
||||
# Deduplicate lessons
|
||||
unique_lessons = list(dict.fromkeys(all_lessons))
|
||||
for lesson in unique_lessons[:10]: # Top 10
|
||||
lines.append(f" - {lesson}")
|
||||
|
||||
# Duration and tokens
|
||||
total_duration = sum(e.duration_seconds for e in episodes)
|
||||
total_tokens = sum(e.tokens_used for e in episodes)
|
||||
lines.append("")
|
||||
lines.append(f"Total duration: {total_duration:.1f}s")
|
||||
lines.append(f"Total tokens: {total_tokens:,}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
# =========================================================================
|
||||
# Statistics
|
||||
# =========================================================================
|
||||
|
||||
async def get_stats(self, project_id: UUID) -> dict[str, Any]:
|
||||
"""
|
||||
Get episode statistics for a project.
|
||||
|
||||
Args:
|
||||
project_id: Project to get stats for
|
||||
|
||||
Returns:
|
||||
Dictionary with episode statistics
|
||||
"""
|
||||
return await self._recorder.get_stats(project_id)
|
||||
|
||||
async def count(
|
||||
self,
|
||||
project_id: UUID,
|
||||
since: datetime | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Count episodes for a project.
|
||||
|
||||
Args:
|
||||
project_id: Project to count for
|
||||
since: Optional time filter
|
||||
|
||||
Returns:
|
||||
Number of episodes
|
||||
"""
|
||||
return await self._recorder.count_by_project(project_id, since)
|
||||
357
backend/app/services/memory/episodic/recorder.py
Normal file
357
backend/app/services/memory/episodic/recorder.py
Normal file
@@ -0,0 +1,357 @@
|
||||
# app/services/memory/episodic/recorder.py
|
||||
"""
|
||||
Episode Recording.
|
||||
|
||||
Handles the creation and storage of episodic memories
|
||||
during agent task execution.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.memory.enums import EpisodeOutcome
|
||||
from app.models.memory.episode import Episode as EpisodeModel
|
||||
from app.services.memory.config import get_memory_settings
|
||||
from app.services.memory.types import Episode, EpisodeCreate, Outcome
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _outcome_to_db(outcome: Outcome) -> EpisodeOutcome:
|
||||
"""Convert service Outcome to database EpisodeOutcome."""
|
||||
return EpisodeOutcome(outcome.value)
|
||||
|
||||
|
||||
def _db_to_outcome(db_outcome: EpisodeOutcome) -> Outcome:
|
||||
"""Convert database EpisodeOutcome to service Outcome."""
|
||||
return Outcome(db_outcome.value)
|
||||
|
||||
|
||||
def _model_to_episode(model: EpisodeModel) -> Episode:
|
||||
"""Convert SQLAlchemy model to Episode dataclass."""
|
||||
# SQLAlchemy Column types are inferred as Column[T] by mypy, but at runtime
|
||||
# they return actual values. We use type: ignore to handle this mismatch.
|
||||
return Episode(
|
||||
id=model.id, # type: ignore[arg-type]
|
||||
project_id=model.project_id, # type: ignore[arg-type]
|
||||
agent_instance_id=model.agent_instance_id, # type: ignore[arg-type]
|
||||
agent_type_id=model.agent_type_id, # type: ignore[arg-type]
|
||||
session_id=model.session_id, # type: ignore[arg-type]
|
||||
task_type=model.task_type, # type: ignore[arg-type]
|
||||
task_description=model.task_description, # type: ignore[arg-type]
|
||||
actions=model.actions or [], # type: ignore[arg-type]
|
||||
context_summary=model.context_summary, # type: ignore[arg-type]
|
||||
outcome=_db_to_outcome(model.outcome), # type: ignore[arg-type]
|
||||
outcome_details=model.outcome_details or "", # type: ignore[arg-type]
|
||||
duration_seconds=model.duration_seconds, # type: ignore[arg-type]
|
||||
tokens_used=model.tokens_used, # type: ignore[arg-type]
|
||||
lessons_learned=model.lessons_learned or [], # type: ignore[arg-type]
|
||||
importance_score=model.importance_score, # type: ignore[arg-type]
|
||||
embedding=None, # Don't expose raw embedding
|
||||
occurred_at=model.occurred_at, # type: ignore[arg-type]
|
||||
created_at=model.created_at, # type: ignore[arg-type]
|
||||
updated_at=model.updated_at, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
|
||||
class EpisodeRecorder:
|
||||
"""
|
||||
Records episodes to the database.
|
||||
|
||||
Handles episode creation, importance scoring,
|
||||
and lesson extraction.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
embedding_generator: Any | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize recorder.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
embedding_generator: Optional embedding generator for semantic indexing
|
||||
"""
|
||||
self._session = session
|
||||
self._embedding_generator = embedding_generator
|
||||
self._settings = get_memory_settings()
|
||||
|
||||
async def record(self, episode_data: EpisodeCreate) -> Episode:
|
||||
"""
|
||||
Record a new episode.
|
||||
|
||||
Args:
|
||||
episode_data: Episode data to record
|
||||
|
||||
Returns:
|
||||
The created episode
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
|
||||
# Calculate importance score if not provided
|
||||
importance = episode_data.importance_score
|
||||
if importance == 0.5: # Default value, calculate
|
||||
importance = self._calculate_importance(episode_data)
|
||||
|
||||
# Create the model
|
||||
model = EpisodeModel(
|
||||
id=uuid4(),
|
||||
project_id=episode_data.project_id,
|
||||
agent_instance_id=episode_data.agent_instance_id,
|
||||
agent_type_id=episode_data.agent_type_id,
|
||||
session_id=episode_data.session_id,
|
||||
task_type=episode_data.task_type,
|
||||
task_description=episode_data.task_description,
|
||||
actions=episode_data.actions,
|
||||
context_summary=episode_data.context_summary,
|
||||
outcome=_outcome_to_db(episode_data.outcome),
|
||||
outcome_details=episode_data.outcome_details,
|
||||
duration_seconds=episode_data.duration_seconds,
|
||||
tokens_used=episode_data.tokens_used,
|
||||
lessons_learned=episode_data.lessons_learned,
|
||||
importance_score=importance,
|
||||
occurred_at=now,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
|
||||
# Generate embedding if generator available
|
||||
if self._embedding_generator is not None:
|
||||
try:
|
||||
text_for_embedding = self._create_embedding_text(episode_data)
|
||||
embedding = await self._embedding_generator.generate(text_for_embedding)
|
||||
model.embedding = embedding
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to generate embedding: {e}")
|
||||
|
||||
self._session.add(model)
|
||||
await self._session.flush()
|
||||
await self._session.refresh(model)
|
||||
|
||||
logger.debug(f"Recorded episode {model.id} for task {model.task_type}")
|
||||
return _model_to_episode(model)
|
||||
|
||||
def _calculate_importance(self, episode_data: EpisodeCreate) -> float:
|
||||
"""
|
||||
Calculate importance score for an episode.
|
||||
|
||||
Factors:
|
||||
- Outcome: Failures are more important to learn from
|
||||
- Duration: Longer tasks may be more significant
|
||||
- Token usage: Higher usage may indicate complexity
|
||||
- Lessons learned: Episodes with lessons are more valuable
|
||||
"""
|
||||
score = 0.5 # Base score
|
||||
|
||||
# Outcome factor
|
||||
if episode_data.outcome == Outcome.FAILURE:
|
||||
score += 0.2 # Failures are important for learning
|
||||
elif episode_data.outcome == Outcome.PARTIAL:
|
||||
score += 0.1
|
||||
# Success is default, no adjustment
|
||||
|
||||
# Lessons learned factor
|
||||
if episode_data.lessons_learned:
|
||||
score += min(0.15, len(episode_data.lessons_learned) * 0.05)
|
||||
|
||||
# Duration factor (longer tasks may be more significant)
|
||||
if episode_data.duration_seconds > 60:
|
||||
score += 0.05
|
||||
if episode_data.duration_seconds > 300:
|
||||
score += 0.05
|
||||
|
||||
# Token usage factor (complex tasks)
|
||||
if episode_data.tokens_used > 1000:
|
||||
score += 0.05
|
||||
|
||||
# Clamp to valid range
|
||||
return min(1.0, max(0.0, score))
|
||||
|
||||
def _create_embedding_text(self, episode_data: EpisodeCreate) -> str:
|
||||
"""Create text representation for embedding generation."""
|
||||
parts = [
|
||||
f"Task: {episode_data.task_type}",
|
||||
f"Description: {episode_data.task_description}",
|
||||
f"Context: {episode_data.context_summary}",
|
||||
f"Outcome: {episode_data.outcome.value}",
|
||||
]
|
||||
|
||||
if episode_data.outcome_details:
|
||||
parts.append(f"Details: {episode_data.outcome_details}")
|
||||
|
||||
if episode_data.lessons_learned:
|
||||
parts.append(f"Lessons: {', '.join(episode_data.lessons_learned)}")
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
async def get_by_id(self, episode_id: UUID) -> Episode | None:
|
||||
"""Get an episode by ID."""
|
||||
query = select(EpisodeModel).where(EpisodeModel.id == episode_id)
|
||||
result = await self._session.execute(query)
|
||||
model = result.scalar_one_or_none()
|
||||
|
||||
if model is None:
|
||||
return None
|
||||
|
||||
return _model_to_episode(model)
|
||||
|
||||
async def update_importance(
|
||||
self,
|
||||
episode_id: UUID,
|
||||
importance_score: float,
|
||||
) -> Episode | None:
|
||||
"""
|
||||
Update the importance score of an episode.
|
||||
|
||||
Args:
|
||||
episode_id: Episode to update
|
||||
importance_score: New importance score (0.0 to 1.0)
|
||||
|
||||
Returns:
|
||||
Updated episode or None if not found
|
||||
"""
|
||||
# Validate score
|
||||
importance_score = min(1.0, max(0.0, importance_score))
|
||||
|
||||
stmt = (
|
||||
update(EpisodeModel)
|
||||
.where(EpisodeModel.id == episode_id)
|
||||
.values(
|
||||
importance_score=importance_score,
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
.returning(EpisodeModel)
|
||||
)
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
model = result.scalar_one_or_none()
|
||||
|
||||
if model is None:
|
||||
return None
|
||||
|
||||
await self._session.flush()
|
||||
return _model_to_episode(model)
|
||||
|
||||
async def add_lessons(
|
||||
self,
|
||||
episode_id: UUID,
|
||||
lessons: list[str],
|
||||
) -> Episode | None:
|
||||
"""
|
||||
Add lessons learned to an episode.
|
||||
|
||||
Args:
|
||||
episode_id: Episode to update
|
||||
lessons: New lessons to add
|
||||
|
||||
Returns:
|
||||
Updated episode or None if not found
|
||||
"""
|
||||
# Get current episode
|
||||
query = select(EpisodeModel).where(EpisodeModel.id == episode_id)
|
||||
result = await self._session.execute(query)
|
||||
model = result.scalar_one_or_none()
|
||||
|
||||
if model is None:
|
||||
return None
|
||||
|
||||
# Append lessons
|
||||
current_lessons: list[str] = model.lessons_learned or [] # type: ignore[assignment]
|
||||
updated_lessons = current_lessons + lessons
|
||||
|
||||
stmt = (
|
||||
update(EpisodeModel)
|
||||
.where(EpisodeModel.id == episode_id)
|
||||
.values(
|
||||
lessons_learned=updated_lessons,
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
.returning(EpisodeModel)
|
||||
)
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
model = result.scalar_one_or_none()
|
||||
await self._session.flush()
|
||||
|
||||
return _model_to_episode(model) if model else None
|
||||
|
||||
async def delete(self, episode_id: UUID) -> bool:
|
||||
"""
|
||||
Delete an episode.
|
||||
|
||||
Args:
|
||||
episode_id: Episode to delete
|
||||
|
||||
Returns:
|
||||
True if deleted
|
||||
"""
|
||||
query = select(EpisodeModel).where(EpisodeModel.id == episode_id)
|
||||
result = await self._session.execute(query)
|
||||
model = result.scalar_one_or_none()
|
||||
|
||||
if model is None:
|
||||
return False
|
||||
|
||||
await self._session.delete(model)
|
||||
await self._session.flush()
|
||||
return True
|
||||
|
||||
async def count_by_project(
|
||||
self,
|
||||
project_id: UUID,
|
||||
since: datetime | None = None,
|
||||
) -> int:
|
||||
"""Count episodes for a project."""
|
||||
query = select(EpisodeModel).where(EpisodeModel.project_id == project_id)
|
||||
if since is not None:
|
||||
query = query.where(EpisodeModel.occurred_at >= since)
|
||||
|
||||
result = await self._session.execute(query)
|
||||
return len(list(result.scalars().all()))
|
||||
|
||||
async def get_stats(self, project_id: UUID) -> dict[str, Any]:
|
||||
"""
|
||||
Get statistics for a project's episodes.
|
||||
|
||||
Returns:
|
||||
Dictionary with episode statistics
|
||||
"""
|
||||
query = select(EpisodeModel).where(EpisodeModel.project_id == project_id)
|
||||
result = await self._session.execute(query)
|
||||
episodes = list(result.scalars().all())
|
||||
|
||||
if not episodes:
|
||||
return {
|
||||
"total_count": 0,
|
||||
"success_count": 0,
|
||||
"failure_count": 0,
|
||||
"partial_count": 0,
|
||||
"avg_importance": 0.0,
|
||||
"avg_duration": 0.0,
|
||||
"total_tokens": 0,
|
||||
}
|
||||
|
||||
success_count = sum(1 for e in episodes if e.outcome == EpisodeOutcome.SUCCESS)
|
||||
failure_count = sum(1 for e in episodes if e.outcome == EpisodeOutcome.FAILURE)
|
||||
partial_count = sum(1 for e in episodes if e.outcome == EpisodeOutcome.PARTIAL)
|
||||
|
||||
avg_importance = sum(e.importance_score for e in episodes) / len(episodes)
|
||||
avg_duration = sum(e.duration_seconds for e in episodes) / len(episodes)
|
||||
total_tokens = sum(e.tokens_used for e in episodes)
|
||||
|
||||
return {
|
||||
"total_count": len(episodes),
|
||||
"success_count": success_count,
|
||||
"failure_count": failure_count,
|
||||
"partial_count": partial_count,
|
||||
"avg_importance": avg_importance,
|
||||
"avg_duration": avg_duration,
|
||||
"total_tokens": total_tokens,
|
||||
}
|
||||
503
backend/app/services/memory/episodic/retrieval.py
Normal file
503
backend/app/services/memory/episodic/retrieval.py
Normal file
@@ -0,0 +1,503 @@
|
||||
# app/services/memory/episodic/retrieval.py
|
||||
"""
|
||||
Episode Retrieval Strategies.
|
||||
|
||||
Provides different retrieval strategies for finding relevant episodes:
|
||||
- Semantic similarity (vector search)
|
||||
- Recency-based
|
||||
- Outcome-based filtering
|
||||
- Importance-based ranking
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_, desc, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.memory.enums import EpisodeOutcome
|
||||
from app.models.memory.episode import Episode as EpisodeModel
|
||||
from app.services.memory.types import Episode, Outcome, RetrievalResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RetrievalStrategy(str, Enum):
|
||||
"""Retrieval strategy types."""
|
||||
|
||||
SEMANTIC = "semantic"
|
||||
RECENCY = "recency"
|
||||
OUTCOME = "outcome"
|
||||
IMPORTANCE = "importance"
|
||||
HYBRID = "hybrid"
|
||||
|
||||
|
||||
def _model_to_episode(model: EpisodeModel) -> Episode:
|
||||
"""Convert SQLAlchemy model to Episode dataclass."""
|
||||
# SQLAlchemy Column types are inferred as Column[T] by mypy, but at runtime
|
||||
# they return actual values. We use type: ignore to handle this mismatch.
|
||||
return Episode(
|
||||
id=model.id, # type: ignore[arg-type]
|
||||
project_id=model.project_id, # type: ignore[arg-type]
|
||||
agent_instance_id=model.agent_instance_id, # type: ignore[arg-type]
|
||||
agent_type_id=model.agent_type_id, # type: ignore[arg-type]
|
||||
session_id=model.session_id, # type: ignore[arg-type]
|
||||
task_type=model.task_type, # type: ignore[arg-type]
|
||||
task_description=model.task_description, # type: ignore[arg-type]
|
||||
actions=model.actions or [], # type: ignore[arg-type]
|
||||
context_summary=model.context_summary, # type: ignore[arg-type]
|
||||
outcome=Outcome(model.outcome.value),
|
||||
outcome_details=model.outcome_details or "", # type: ignore[arg-type]
|
||||
duration_seconds=model.duration_seconds, # type: ignore[arg-type]
|
||||
tokens_used=model.tokens_used, # type: ignore[arg-type]
|
||||
lessons_learned=model.lessons_learned or [], # type: ignore[arg-type]
|
||||
importance_score=model.importance_score, # type: ignore[arg-type]
|
||||
embedding=None, # Don't expose raw embedding
|
||||
occurred_at=model.occurred_at, # type: ignore[arg-type]
|
||||
created_at=model.created_at, # type: ignore[arg-type]
|
||||
updated_at=model.updated_at, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
|
||||
class BaseRetriever(ABC):
|
||||
"""Abstract base class for episode retrieval strategies."""
|
||||
|
||||
@abstractmethod
|
||||
async def retrieve(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
project_id: UUID,
|
||||
limit: int = 10,
|
||||
**kwargs: Any,
|
||||
) -> RetrievalResult[Episode]:
|
||||
"""Retrieve episodes based on the strategy."""
|
||||
...
|
||||
|
||||
|
||||
class RecencyRetriever(BaseRetriever):
|
||||
"""Retrieves episodes by recency (most recent first)."""
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
project_id: UUID,
|
||||
limit: int = 10,
|
||||
*,
|
||||
since: datetime | None = None,
|
||||
agent_instance_id: UUID | None = None,
|
||||
**kwargs: Any,
|
||||
) -> RetrievalResult[Episode]:
|
||||
"""Retrieve most recent episodes."""
|
||||
start_time = time.perf_counter()
|
||||
|
||||
query = (
|
||||
select(EpisodeModel)
|
||||
.where(EpisodeModel.project_id == project_id)
|
||||
.order_by(desc(EpisodeModel.occurred_at))
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
if since is not None:
|
||||
query = query.where(EpisodeModel.occurred_at >= since)
|
||||
|
||||
if agent_instance_id is not None:
|
||||
query = query.where(EpisodeModel.agent_instance_id == agent_instance_id)
|
||||
|
||||
result = await session.execute(query)
|
||||
models = list(result.scalars().all())
|
||||
|
||||
# Get total count
|
||||
count_query = select(EpisodeModel).where(EpisodeModel.project_id == project_id)
|
||||
if since is not None:
|
||||
count_query = count_query.where(EpisodeModel.occurred_at >= since)
|
||||
count_result = await session.execute(count_query)
|
||||
total_count = len(list(count_result.scalars().all()))
|
||||
|
||||
latency_ms = (time.perf_counter() - start_time) * 1000
|
||||
|
||||
return RetrievalResult(
|
||||
items=[_model_to_episode(m) for m in models],
|
||||
total_count=total_count,
|
||||
query="recency",
|
||||
retrieval_type=RetrievalStrategy.RECENCY.value,
|
||||
latency_ms=latency_ms,
|
||||
metadata={"since": since.isoformat() if since else None},
|
||||
)
|
||||
|
||||
|
||||
class OutcomeRetriever(BaseRetriever):
|
||||
"""Retrieves episodes filtered by outcome."""
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
project_id: UUID,
|
||||
limit: int = 10,
|
||||
*,
|
||||
outcome: Outcome | None = None,
|
||||
agent_instance_id: UUID | None = None,
|
||||
**kwargs: Any,
|
||||
) -> RetrievalResult[Episode]:
|
||||
"""Retrieve episodes by outcome."""
|
||||
start_time = time.perf_counter()
|
||||
|
||||
query = (
|
||||
select(EpisodeModel)
|
||||
.where(EpisodeModel.project_id == project_id)
|
||||
.order_by(desc(EpisodeModel.occurred_at))
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
if outcome is not None:
|
||||
db_outcome = EpisodeOutcome(outcome.value)
|
||||
query = query.where(EpisodeModel.outcome == db_outcome)
|
||||
|
||||
if agent_instance_id is not None:
|
||||
query = query.where(EpisodeModel.agent_instance_id == agent_instance_id)
|
||||
|
||||
result = await session.execute(query)
|
||||
models = list(result.scalars().all())
|
||||
|
||||
# Get total count
|
||||
count_query = select(EpisodeModel).where(EpisodeModel.project_id == project_id)
|
||||
if outcome is not None:
|
||||
count_query = count_query.where(
|
||||
EpisodeModel.outcome == EpisodeOutcome(outcome.value)
|
||||
)
|
||||
count_result = await session.execute(count_query)
|
||||
total_count = len(list(count_result.scalars().all()))
|
||||
|
||||
latency_ms = (time.perf_counter() - start_time) * 1000
|
||||
|
||||
return RetrievalResult(
|
||||
items=[_model_to_episode(m) for m in models],
|
||||
total_count=total_count,
|
||||
query=f"outcome:{outcome.value if outcome else 'all'}",
|
||||
retrieval_type=RetrievalStrategy.OUTCOME.value,
|
||||
latency_ms=latency_ms,
|
||||
metadata={"outcome": outcome.value if outcome else None},
|
||||
)
|
||||
|
||||
|
||||
class TaskTypeRetriever(BaseRetriever):
|
||||
"""Retrieves episodes filtered by task type."""
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
project_id: UUID,
|
||||
limit: int = 10,
|
||||
*,
|
||||
task_type: str | None = None,
|
||||
agent_instance_id: UUID | None = None,
|
||||
**kwargs: Any,
|
||||
) -> RetrievalResult[Episode]:
|
||||
"""Retrieve episodes by task type."""
|
||||
start_time = time.perf_counter()
|
||||
|
||||
query = (
|
||||
select(EpisodeModel)
|
||||
.where(EpisodeModel.project_id == project_id)
|
||||
.order_by(desc(EpisodeModel.occurred_at))
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
if task_type is not None:
|
||||
query = query.where(EpisodeModel.task_type == task_type)
|
||||
|
||||
if agent_instance_id is not None:
|
||||
query = query.where(EpisodeModel.agent_instance_id == agent_instance_id)
|
||||
|
||||
result = await session.execute(query)
|
||||
models = list(result.scalars().all())
|
||||
|
||||
# Get total count
|
||||
count_query = select(EpisodeModel).where(EpisodeModel.project_id == project_id)
|
||||
if task_type is not None:
|
||||
count_query = count_query.where(EpisodeModel.task_type == task_type)
|
||||
count_result = await session.execute(count_query)
|
||||
total_count = len(list(count_result.scalars().all()))
|
||||
|
||||
latency_ms = (time.perf_counter() - start_time) * 1000
|
||||
|
||||
return RetrievalResult(
|
||||
items=[_model_to_episode(m) for m in models],
|
||||
total_count=total_count,
|
||||
query=f"task_type:{task_type or 'all'}",
|
||||
retrieval_type="task_type",
|
||||
latency_ms=latency_ms,
|
||||
metadata={"task_type": task_type},
|
||||
)
|
||||
|
||||
|
||||
class ImportanceRetriever(BaseRetriever):
|
||||
"""Retrieves episodes ranked by importance score."""
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
project_id: UUID,
|
||||
limit: int = 10,
|
||||
*,
|
||||
min_importance: float = 0.0,
|
||||
agent_instance_id: UUID | None = None,
|
||||
**kwargs: Any,
|
||||
) -> RetrievalResult[Episode]:
|
||||
"""Retrieve episodes by importance."""
|
||||
start_time = time.perf_counter()
|
||||
|
||||
query = (
|
||||
select(EpisodeModel)
|
||||
.where(
|
||||
and_(
|
||||
EpisodeModel.project_id == project_id,
|
||||
EpisodeModel.importance_score >= min_importance,
|
||||
)
|
||||
)
|
||||
.order_by(desc(EpisodeModel.importance_score))
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
if agent_instance_id is not None:
|
||||
query = query.where(EpisodeModel.agent_instance_id == agent_instance_id)
|
||||
|
||||
result = await session.execute(query)
|
||||
models = list(result.scalars().all())
|
||||
|
||||
# Get total count
|
||||
count_query = select(EpisodeModel).where(
|
||||
and_(
|
||||
EpisodeModel.project_id == project_id,
|
||||
EpisodeModel.importance_score >= min_importance,
|
||||
)
|
||||
)
|
||||
count_result = await session.execute(count_query)
|
||||
total_count = len(list(count_result.scalars().all()))
|
||||
|
||||
latency_ms = (time.perf_counter() - start_time) * 1000
|
||||
|
||||
return RetrievalResult(
|
||||
items=[_model_to_episode(m) for m in models],
|
||||
total_count=total_count,
|
||||
query=f"importance>={min_importance}",
|
||||
retrieval_type=RetrievalStrategy.IMPORTANCE.value,
|
||||
latency_ms=latency_ms,
|
||||
metadata={"min_importance": min_importance},
|
||||
)
|
||||
|
||||
|
||||
class SemanticRetriever(BaseRetriever):
|
||||
"""Retrieves episodes by semantic similarity using vector search."""
|
||||
|
||||
def __init__(self, embedding_generator: Any | None = None) -> None:
|
||||
"""Initialize with optional embedding generator."""
|
||||
self._embedding_generator = embedding_generator
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
project_id: UUID,
|
||||
limit: int = 10,
|
||||
*,
|
||||
query_text: str | None = None,
|
||||
query_embedding: list[float] | None = None,
|
||||
agent_instance_id: UUID | None = None,
|
||||
**kwargs: Any,
|
||||
) -> RetrievalResult[Episode]:
|
||||
"""Retrieve episodes by semantic similarity."""
|
||||
start_time = time.perf_counter()
|
||||
|
||||
# If no embedding provided, fall back to recency
|
||||
if query_embedding is None and query_text is None:
|
||||
logger.warning(
|
||||
"No query provided for semantic search, falling back to recency"
|
||||
)
|
||||
recency = RecencyRetriever()
|
||||
fallback_result = await recency.retrieve(
|
||||
session, project_id, limit, agent_instance_id=agent_instance_id
|
||||
)
|
||||
latency_ms = (time.perf_counter() - start_time) * 1000
|
||||
return RetrievalResult(
|
||||
items=fallback_result.items,
|
||||
total_count=fallback_result.total_count,
|
||||
query="no_query",
|
||||
retrieval_type=RetrievalStrategy.SEMANTIC.value,
|
||||
latency_ms=latency_ms,
|
||||
metadata={"fallback": "recency", "reason": "no_query"},
|
||||
)
|
||||
|
||||
# Generate embedding if needed
|
||||
embedding = query_embedding
|
||||
if embedding is None and query_text is not None:
|
||||
if self._embedding_generator is not None:
|
||||
embedding = await self._embedding_generator.generate(query_text)
|
||||
else:
|
||||
logger.warning("No embedding generator, falling back to recency")
|
||||
recency = RecencyRetriever()
|
||||
fallback_result = await recency.retrieve(
|
||||
session, project_id, limit, agent_instance_id=agent_instance_id
|
||||
)
|
||||
latency_ms = (time.perf_counter() - start_time) * 1000
|
||||
return RetrievalResult(
|
||||
items=fallback_result.items,
|
||||
total_count=fallback_result.total_count,
|
||||
query=query_text,
|
||||
retrieval_type=RetrievalStrategy.SEMANTIC.value,
|
||||
latency_ms=latency_ms,
|
||||
metadata={
|
||||
"fallback": "recency",
|
||||
"reason": "no_embedding_generator",
|
||||
},
|
||||
)
|
||||
|
||||
# For now, use recency if vector search not available
|
||||
# TODO: Implement proper pgvector similarity search when integrated
|
||||
logger.debug("Vector search not yet implemented, using recency fallback")
|
||||
recency = RecencyRetriever()
|
||||
result = await recency.retrieve(
|
||||
session, project_id, limit, agent_instance_id=agent_instance_id
|
||||
)
|
||||
|
||||
latency_ms = (time.perf_counter() - start_time) * 1000
|
||||
|
||||
return RetrievalResult(
|
||||
items=result.items,
|
||||
total_count=result.total_count,
|
||||
query=query_text or "embedding",
|
||||
retrieval_type=RetrievalStrategy.SEMANTIC.value,
|
||||
latency_ms=latency_ms,
|
||||
metadata={"fallback": "recency"},
|
||||
)
|
||||
|
||||
|
||||
class EpisodeRetriever:
|
||||
"""
|
||||
Unified episode retrieval service.
|
||||
|
||||
Provides a single interface for all retrieval strategies.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
embedding_generator: Any | None = None,
|
||||
) -> None:
|
||||
"""Initialize retriever with database session."""
|
||||
self._session = session
|
||||
self._retrievers: dict[RetrievalStrategy, BaseRetriever] = {
|
||||
RetrievalStrategy.RECENCY: RecencyRetriever(),
|
||||
RetrievalStrategy.OUTCOME: OutcomeRetriever(),
|
||||
RetrievalStrategy.IMPORTANCE: ImportanceRetriever(),
|
||||
RetrievalStrategy.SEMANTIC: SemanticRetriever(embedding_generator),
|
||||
}
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
project_id: UUID,
|
||||
strategy: RetrievalStrategy = RetrievalStrategy.RECENCY,
|
||||
limit: int = 10,
|
||||
**kwargs: Any,
|
||||
) -> RetrievalResult[Episode]:
|
||||
"""
|
||||
Retrieve episodes using the specified strategy.
|
||||
|
||||
Args:
|
||||
project_id: Project to search within
|
||||
strategy: Retrieval strategy to use
|
||||
limit: Maximum number of episodes to return
|
||||
**kwargs: Strategy-specific parameters
|
||||
|
||||
Returns:
|
||||
RetrievalResult containing matching episodes
|
||||
"""
|
||||
retriever = self._retrievers.get(strategy)
|
||||
if retriever is None:
|
||||
raise ValueError(f"Unknown retrieval strategy: {strategy}")
|
||||
|
||||
return await retriever.retrieve(self._session, project_id, limit, **kwargs)
|
||||
|
||||
async def get_recent(
|
||||
self,
|
||||
project_id: UUID,
|
||||
limit: int = 10,
|
||||
since: datetime | None = None,
|
||||
agent_instance_id: UUID | None = None,
|
||||
) -> RetrievalResult[Episode]:
|
||||
"""Get recent episodes."""
|
||||
return await self.retrieve(
|
||||
project_id,
|
||||
RetrievalStrategy.RECENCY,
|
||||
limit,
|
||||
since=since,
|
||||
agent_instance_id=agent_instance_id,
|
||||
)
|
||||
|
||||
async def get_by_outcome(
|
||||
self,
|
||||
project_id: UUID,
|
||||
outcome: Outcome,
|
||||
limit: int = 10,
|
||||
agent_instance_id: UUID | None = None,
|
||||
) -> RetrievalResult[Episode]:
|
||||
"""Get episodes by outcome."""
|
||||
return await self.retrieve(
|
||||
project_id,
|
||||
RetrievalStrategy.OUTCOME,
|
||||
limit,
|
||||
outcome=outcome,
|
||||
agent_instance_id=agent_instance_id,
|
||||
)
|
||||
|
||||
async def get_by_task_type(
|
||||
self,
|
||||
project_id: UUID,
|
||||
task_type: str,
|
||||
limit: int = 10,
|
||||
agent_instance_id: UUID | None = None,
|
||||
) -> RetrievalResult[Episode]:
|
||||
"""Get episodes by task type."""
|
||||
retriever = TaskTypeRetriever()
|
||||
return await retriever.retrieve(
|
||||
self._session,
|
||||
project_id,
|
||||
limit,
|
||||
task_type=task_type,
|
||||
agent_instance_id=agent_instance_id,
|
||||
)
|
||||
|
||||
async def get_important(
|
||||
self,
|
||||
project_id: UUID,
|
||||
limit: int = 10,
|
||||
min_importance: float = 0.7,
|
||||
agent_instance_id: UUID | None = None,
|
||||
) -> RetrievalResult[Episode]:
|
||||
"""Get high-importance episodes."""
|
||||
return await self.retrieve(
|
||||
project_id,
|
||||
RetrievalStrategy.IMPORTANCE,
|
||||
limit,
|
||||
min_importance=min_importance,
|
||||
agent_instance_id=agent_instance_id,
|
||||
)
|
||||
|
||||
async def search_similar(
|
||||
self,
|
||||
project_id: UUID,
|
||||
query: str,
|
||||
limit: int = 10,
|
||||
agent_instance_id: UUID | None = None,
|
||||
) -> RetrievalResult[Episode]:
|
||||
"""Search for semantically similar episodes."""
|
||||
return await self.retrieve(
|
||||
project_id,
|
||||
RetrievalStrategy.SEMANTIC,
|
||||
limit,
|
||||
query_text=query,
|
||||
agent_instance_id=agent_instance_id,
|
||||
)
|
||||
222
backend/app/services/memory/exceptions.py
Normal file
222
backend/app/services/memory/exceptions.py
Normal file
@@ -0,0 +1,222 @@
|
||||
"""
|
||||
Memory System Exceptions
|
||||
|
||||
Custom exception classes for the Agent Memory System.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
|
||||
class MemoryError(Exception):
|
||||
"""Base exception for all memory-related errors."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
*,
|
||||
memory_type: str | None = None,
|
||||
scope_type: str | None = None,
|
||||
scope_id: str | None = None,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.memory_type = memory_type
|
||||
self.scope_type = scope_type
|
||||
self.scope_id = scope_id
|
||||
self.details = details or {}
|
||||
|
||||
|
||||
class MemoryNotFoundError(MemoryError):
|
||||
"""Raised when a memory item is not found."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Memory not found",
|
||||
*,
|
||||
memory_id: UUID | str | None = None,
|
||||
key: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.memory_id = memory_id
|
||||
self.key = key
|
||||
|
||||
|
||||
class MemoryCapacityError(MemoryError):
|
||||
"""Raised when memory capacity limits are exceeded."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Memory capacity exceeded",
|
||||
*,
|
||||
current_size: int = 0,
|
||||
max_size: int = 0,
|
||||
item_count: int = 0,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.current_size = current_size
|
||||
self.max_size = max_size
|
||||
self.item_count = item_count
|
||||
|
||||
|
||||
class MemoryExpiredError(MemoryError):
|
||||
"""Raised when attempting to access expired memory."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Memory has expired",
|
||||
*,
|
||||
key: str | None = None,
|
||||
expired_at: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.key = key
|
||||
self.expired_at = expired_at
|
||||
|
||||
|
||||
class MemoryStorageError(MemoryError):
|
||||
"""Raised when memory storage operations fail."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Memory storage operation failed",
|
||||
*,
|
||||
operation: str | None = None,
|
||||
backend: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.operation = operation
|
||||
self.backend = backend
|
||||
|
||||
|
||||
class MemoryConnectionError(MemoryError):
|
||||
"""Raised when memory storage connection fails."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Memory connection failed",
|
||||
*,
|
||||
backend: str | None = None,
|
||||
host: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.backend = backend
|
||||
self.host = host
|
||||
|
||||
|
||||
class MemorySerializationError(MemoryError):
|
||||
"""Raised when memory serialization/deserialization fails."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Memory serialization failed",
|
||||
*,
|
||||
content_type: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.content_type = content_type
|
||||
|
||||
|
||||
class MemoryScopeError(MemoryError):
|
||||
"""Raised when memory scope operations fail."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Memory scope error",
|
||||
*,
|
||||
requested_scope: str | None = None,
|
||||
allowed_scopes: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.requested_scope = requested_scope
|
||||
self.allowed_scopes = allowed_scopes or []
|
||||
|
||||
|
||||
class MemoryConsolidationError(MemoryError):
|
||||
"""Raised when memory consolidation fails."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Memory consolidation failed",
|
||||
*,
|
||||
source_type: str | None = None,
|
||||
target_type: str | None = None,
|
||||
items_processed: int = 0,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.source_type = source_type
|
||||
self.target_type = target_type
|
||||
self.items_processed = items_processed
|
||||
|
||||
|
||||
class MemoryRetrievalError(MemoryError):
|
||||
"""Raised when memory retrieval fails."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Memory retrieval failed",
|
||||
*,
|
||||
query: str | None = None,
|
||||
retrieval_type: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.query = query
|
||||
self.retrieval_type = retrieval_type
|
||||
|
||||
|
||||
class EmbeddingError(MemoryError):
|
||||
"""Raised when embedding generation fails."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Embedding generation failed",
|
||||
*,
|
||||
content_length: int = 0,
|
||||
model: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.content_length = content_length
|
||||
self.model = model
|
||||
|
||||
|
||||
class CheckpointError(MemoryError):
|
||||
"""Raised when checkpoint operations fail."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Checkpoint operation failed",
|
||||
*,
|
||||
checkpoint_id: str | None = None,
|
||||
operation: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.checkpoint_id = checkpoint_id
|
||||
self.operation = operation
|
||||
|
||||
|
||||
class MemoryConflictError(MemoryError):
|
||||
"""Raised when there's a conflict in memory (e.g., contradictory facts)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Memory conflict detected",
|
||||
*,
|
||||
conflicting_ids: list[str | UUID] | None = None,
|
||||
conflict_type: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.conflicting_ids = conflicting_ids or []
|
||||
self.conflict_type = conflict_type
|
||||
56
backend/app/services/memory/indexing/__init__.py
Normal file
56
backend/app/services/memory/indexing/__init__.py
Normal file
@@ -0,0 +1,56 @@
|
||||
# app/services/memory/indexing/__init__.py
|
||||
"""
|
||||
Memory Indexing & Retrieval.
|
||||
|
||||
Provides vector embeddings and multiple index types for efficient memory search:
|
||||
- Vector index for semantic similarity
|
||||
- Temporal index for time-based queries
|
||||
- Entity index for entity lookups
|
||||
- Outcome index for success/failure filtering
|
||||
"""
|
||||
|
||||
from .index import (
|
||||
EntityIndex,
|
||||
EntityIndexEntry,
|
||||
IndexEntry,
|
||||
MemoryIndex,
|
||||
MemoryIndexer,
|
||||
OutcomeIndex,
|
||||
OutcomeIndexEntry,
|
||||
TemporalIndex,
|
||||
TemporalIndexEntry,
|
||||
VectorIndex,
|
||||
VectorIndexEntry,
|
||||
get_memory_indexer,
|
||||
)
|
||||
from .retrieval import (
|
||||
CacheEntry,
|
||||
RelevanceScorer,
|
||||
RetrievalCache,
|
||||
RetrievalEngine,
|
||||
RetrievalQuery,
|
||||
ScoredResult,
|
||||
get_retrieval_engine,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"CacheEntry",
|
||||
"EntityIndex",
|
||||
"EntityIndexEntry",
|
||||
"IndexEntry",
|
||||
"MemoryIndex",
|
||||
"MemoryIndexer",
|
||||
"OutcomeIndex",
|
||||
"OutcomeIndexEntry",
|
||||
"RelevanceScorer",
|
||||
"RetrievalCache",
|
||||
"RetrievalEngine",
|
||||
"RetrievalQuery",
|
||||
"ScoredResult",
|
||||
"TemporalIndex",
|
||||
"TemporalIndexEntry",
|
||||
"VectorIndex",
|
||||
"VectorIndexEntry",
|
||||
"get_memory_indexer",
|
||||
"get_retrieval_engine",
|
||||
]
|
||||
858
backend/app/services/memory/indexing/index.py
Normal file
858
backend/app/services/memory/indexing/index.py
Normal file
@@ -0,0 +1,858 @@
|
||||
# app/services/memory/indexing/index.py
|
||||
"""
|
||||
Memory Indexing.
|
||||
|
||||
Provides multiple indexing strategies for efficient memory retrieval:
|
||||
- Vector embeddings for semantic search
|
||||
- Temporal index for time-based queries
|
||||
- Entity index for entity-based lookups
|
||||
- Outcome index for success/failure filtering
|
||||
"""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any, TypeVar
|
||||
from uuid import UUID
|
||||
|
||||
from app.services.memory.types import Episode, Fact, MemoryType, Outcome, Procedure
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T", Episode, Fact, Procedure)
|
||||
|
||||
|
||||
def _utcnow() -> datetime:
|
||||
"""Get current UTC time as timezone-aware datetime."""
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexEntry:
|
||||
"""A single entry in an index."""
|
||||
|
||||
memory_id: UUID
|
||||
memory_type: MemoryType
|
||||
indexed_at: datetime = field(default_factory=_utcnow)
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VectorIndexEntry(IndexEntry):
|
||||
"""An entry with vector embedding."""
|
||||
|
||||
embedding: list[float] = field(default_factory=list)
|
||||
dimension: int = 0
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Set dimension from embedding."""
|
||||
if self.embedding:
|
||||
self.dimension = len(self.embedding)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TemporalIndexEntry(IndexEntry):
|
||||
"""An entry indexed by time."""
|
||||
|
||||
timestamp: datetime = field(default_factory=_utcnow)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EntityIndexEntry(IndexEntry):
|
||||
"""An entry indexed by entity."""
|
||||
|
||||
entity_type: str = ""
|
||||
entity_value: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutcomeIndexEntry(IndexEntry):
|
||||
"""An entry indexed by outcome."""
|
||||
|
||||
outcome: Outcome = Outcome.SUCCESS
|
||||
|
||||
|
||||
class MemoryIndex[T](ABC):
|
||||
"""Abstract base class for memory indices."""
|
||||
|
||||
@abstractmethod
|
||||
async def add(self, item: T) -> IndexEntry:
|
||||
"""Add an item to the index."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def remove(self, memory_id: UUID) -> bool:
|
||||
"""Remove an item from the index."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def search(
|
||||
self,
|
||||
query: Any,
|
||||
limit: int = 10,
|
||||
**kwargs: Any,
|
||||
) -> list[IndexEntry]:
|
||||
"""Search the index."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def clear(self) -> int:
|
||||
"""Clear all entries from the index."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def count(self) -> int:
|
||||
"""Get the number of entries in the index."""
|
||||
...
|
||||
|
||||
|
||||
class VectorIndex(MemoryIndex[T]):
|
||||
"""
|
||||
Vector-based index using embeddings for semantic similarity search.
|
||||
|
||||
Uses cosine similarity for matching.
|
||||
"""
|
||||
|
||||
def __init__(self, dimension: int = 1536) -> None:
|
||||
"""
|
||||
Initialize the vector index.
|
||||
|
||||
Args:
|
||||
dimension: Embedding dimension (default 1536 for OpenAI)
|
||||
"""
|
||||
self._dimension = dimension
|
||||
self._entries: dict[UUID, VectorIndexEntry] = {}
|
||||
logger.info(f"Initialized VectorIndex with dimension={dimension}")
|
||||
|
||||
async def add(self, item: T) -> VectorIndexEntry:
|
||||
"""
|
||||
Add an item to the vector index.
|
||||
|
||||
Args:
|
||||
item: Memory item with embedding
|
||||
|
||||
Returns:
|
||||
The created index entry
|
||||
"""
|
||||
embedding = getattr(item, "embedding", None) or []
|
||||
|
||||
entry = VectorIndexEntry(
|
||||
memory_id=item.id,
|
||||
memory_type=self._get_memory_type(item),
|
||||
embedding=embedding,
|
||||
dimension=len(embedding),
|
||||
)
|
||||
|
||||
self._entries[item.id] = entry
|
||||
logger.debug(f"Added {item.id} to vector index")
|
||||
return entry
|
||||
|
||||
async def remove(self, memory_id: UUID) -> bool:
|
||||
"""Remove an item from the vector index."""
|
||||
if memory_id in self._entries:
|
||||
del self._entries[memory_id]
|
||||
logger.debug(f"Removed {memory_id} from vector index")
|
||||
return True
|
||||
return False
|
||||
|
||||
async def search( # type: ignore[override]
|
||||
self,
|
||||
query: Any,
|
||||
limit: int = 10,
|
||||
min_similarity: float = 0.0,
|
||||
**kwargs: Any,
|
||||
) -> list[VectorIndexEntry]:
|
||||
"""
|
||||
Search for similar items using vector similarity.
|
||||
|
||||
Args:
|
||||
query: Query embedding vector
|
||||
limit: Maximum results to return
|
||||
min_similarity: Minimum similarity threshold (0-1)
|
||||
**kwargs: Additional filter parameters
|
||||
|
||||
Returns:
|
||||
List of matching entries sorted by similarity
|
||||
"""
|
||||
if not isinstance(query, list) or not query:
|
||||
return []
|
||||
|
||||
results: list[tuple[float, VectorIndexEntry]] = []
|
||||
|
||||
for entry in self._entries.values():
|
||||
if not entry.embedding:
|
||||
continue
|
||||
|
||||
similarity = self._cosine_similarity(query, entry.embedding)
|
||||
if similarity >= min_similarity:
|
||||
results.append((similarity, entry))
|
||||
|
||||
# Sort by similarity descending
|
||||
results.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
# Apply memory type filter if provided
|
||||
memory_type = kwargs.get("memory_type")
|
||||
if memory_type:
|
||||
results = [(s, e) for s, e in results if e.memory_type == memory_type]
|
||||
|
||||
# Store similarity in metadata for the returned entries
|
||||
# Use a copy of metadata to avoid mutating cached entries
|
||||
output = []
|
||||
for similarity, entry in results[:limit]:
|
||||
# Create a shallow copy of the entry with updated metadata
|
||||
entry_with_score = VectorIndexEntry(
|
||||
memory_id=entry.memory_id,
|
||||
memory_type=entry.memory_type,
|
||||
embedding=entry.embedding,
|
||||
metadata={**entry.metadata, "similarity": similarity},
|
||||
)
|
||||
output.append(entry_with_score)
|
||||
|
||||
logger.debug(f"Vector search returned {len(output)} results")
|
||||
return output
|
||||
|
||||
async def clear(self) -> int:
|
||||
"""Clear all entries from the index."""
|
||||
count = len(self._entries)
|
||||
self._entries.clear()
|
||||
logger.info(f"Cleared {count} entries from vector index")
|
||||
return count
|
||||
|
||||
async def count(self) -> int:
|
||||
"""Get the number of entries in the index."""
|
||||
return len(self._entries)
|
||||
|
||||
def _cosine_similarity(self, a: list[float], b: list[float]) -> float:
|
||||
"""Calculate cosine similarity between two vectors."""
|
||||
if len(a) != len(b) or len(a) == 0:
|
||||
return 0.0
|
||||
|
||||
dot_product = sum(x * y for x, y in zip(a, b, strict=True))
|
||||
norm_a = sum(x * x for x in a) ** 0.5
|
||||
norm_b = sum(x * x for x in b) ** 0.5
|
||||
|
||||
if norm_a == 0 or norm_b == 0:
|
||||
return 0.0
|
||||
|
||||
return dot_product / (norm_a * norm_b)
|
||||
|
||||
def _get_memory_type(self, item: T) -> MemoryType:
|
||||
"""Get the memory type for an item."""
|
||||
if isinstance(item, Episode):
|
||||
return MemoryType.EPISODIC
|
||||
elif isinstance(item, Fact):
|
||||
return MemoryType.SEMANTIC
|
||||
elif isinstance(item, Procedure):
|
||||
return MemoryType.PROCEDURAL
|
||||
return MemoryType.WORKING
|
||||
|
||||
|
||||
class TemporalIndex(MemoryIndex[T]):
|
||||
"""
|
||||
Time-based index for efficient temporal queries.
|
||||
|
||||
Supports:
|
||||
- Range queries (between timestamps)
|
||||
- Recent items (within last N seconds/hours/days)
|
||||
- Oldest/newest sorting
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the temporal index."""
|
||||
self._entries: dict[UUID, TemporalIndexEntry] = {}
|
||||
# Sorted list for efficient range queries
|
||||
self._sorted_entries: list[tuple[datetime, UUID]] = []
|
||||
logger.info("Initialized TemporalIndex")
|
||||
|
||||
async def add(self, item: T) -> TemporalIndexEntry:
|
||||
"""
|
||||
Add an item to the temporal index.
|
||||
|
||||
Args:
|
||||
item: Memory item with timestamp
|
||||
|
||||
Returns:
|
||||
The created index entry
|
||||
"""
|
||||
# Get timestamp from various possible fields
|
||||
timestamp = self._get_timestamp(item)
|
||||
|
||||
entry = TemporalIndexEntry(
|
||||
memory_id=item.id,
|
||||
memory_type=self._get_memory_type(item),
|
||||
timestamp=timestamp,
|
||||
)
|
||||
|
||||
self._entries[item.id] = entry
|
||||
self._insert_sorted(timestamp, item.id)
|
||||
|
||||
logger.debug(f"Added {item.id} to temporal index at {timestamp}")
|
||||
return entry
|
||||
|
||||
async def remove(self, memory_id: UUID) -> bool:
|
||||
"""Remove an item from the temporal index."""
|
||||
if memory_id not in self._entries:
|
||||
return False
|
||||
|
||||
self._entries.pop(memory_id)
|
||||
self._sorted_entries = [
|
||||
(ts, mid) for ts, mid in self._sorted_entries if mid != memory_id
|
||||
]
|
||||
|
||||
logger.debug(f"Removed {memory_id} from temporal index")
|
||||
return True
|
||||
|
||||
async def search( # type: ignore[override]
|
||||
self,
|
||||
query: Any,
|
||||
limit: int = 10,
|
||||
start_time: datetime | None = None,
|
||||
end_time: datetime | None = None,
|
||||
recent_seconds: float | None = None,
|
||||
order: str = "desc",
|
||||
**kwargs: Any,
|
||||
) -> list[TemporalIndexEntry]:
|
||||
"""
|
||||
Search for items by time.
|
||||
|
||||
Args:
|
||||
query: Ignored for temporal search
|
||||
limit: Maximum results to return
|
||||
start_time: Start of time range
|
||||
end_time: End of time range
|
||||
recent_seconds: Get items from last N seconds
|
||||
order: Sort order ("asc" or "desc")
|
||||
**kwargs: Additional filter parameters
|
||||
|
||||
Returns:
|
||||
List of matching entries sorted by time
|
||||
"""
|
||||
if recent_seconds is not None:
|
||||
start_time = _utcnow() - timedelta(seconds=recent_seconds)
|
||||
end_time = _utcnow()
|
||||
|
||||
# Filter by time range
|
||||
results: list[TemporalIndexEntry] = []
|
||||
for entry in self._entries.values():
|
||||
if start_time and entry.timestamp < start_time:
|
||||
continue
|
||||
if end_time and entry.timestamp > end_time:
|
||||
continue
|
||||
results.append(entry)
|
||||
|
||||
# Apply memory type filter if provided
|
||||
memory_type = kwargs.get("memory_type")
|
||||
if memory_type:
|
||||
results = [e for e in results if e.memory_type == memory_type]
|
||||
|
||||
# Sort by timestamp
|
||||
results.sort(key=lambda e: e.timestamp, reverse=(order == "desc"))
|
||||
|
||||
logger.debug(f"Temporal search returned {min(len(results), limit)} results")
|
||||
return results[:limit]
|
||||
|
||||
async def clear(self) -> int:
|
||||
"""Clear all entries from the index."""
|
||||
count = len(self._entries)
|
||||
self._entries.clear()
|
||||
self._sorted_entries.clear()
|
||||
logger.info(f"Cleared {count} entries from temporal index")
|
||||
return count
|
||||
|
||||
async def count(self) -> int:
|
||||
"""Get the number of entries in the index."""
|
||||
return len(self._entries)
|
||||
|
||||
def _insert_sorted(self, timestamp: datetime, memory_id: UUID) -> None:
|
||||
"""Insert entry maintaining sorted order."""
|
||||
# Binary search insert for efficiency
|
||||
low, high = 0, len(self._sorted_entries)
|
||||
while low < high:
|
||||
mid = (low + high) // 2
|
||||
if self._sorted_entries[mid][0] < timestamp:
|
||||
low = mid + 1
|
||||
else:
|
||||
high = mid
|
||||
self._sorted_entries.insert(low, (timestamp, memory_id))
|
||||
|
||||
def _get_timestamp(self, item: T) -> datetime:
|
||||
"""Get the relevant timestamp for an item."""
|
||||
if hasattr(item, "occurred_at"):
|
||||
return item.occurred_at
|
||||
if hasattr(item, "first_learned"):
|
||||
return item.first_learned
|
||||
if hasattr(item, "last_used") and item.last_used:
|
||||
return item.last_used
|
||||
if hasattr(item, "created_at"):
|
||||
return item.created_at
|
||||
return _utcnow()
|
||||
|
||||
def _get_memory_type(self, item: T) -> MemoryType:
|
||||
"""Get the memory type for an item."""
|
||||
if isinstance(item, Episode):
|
||||
return MemoryType.EPISODIC
|
||||
elif isinstance(item, Fact):
|
||||
return MemoryType.SEMANTIC
|
||||
elif isinstance(item, Procedure):
|
||||
return MemoryType.PROCEDURAL
|
||||
return MemoryType.WORKING
|
||||
|
||||
|
||||
class EntityIndex(MemoryIndex[T]):
|
||||
"""
|
||||
Entity-based index for lookups by entities mentioned in memories.
|
||||
|
||||
Supports:
|
||||
- Single entity lookup
|
||||
- Multi-entity intersection
|
||||
- Entity type filtering
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the entity index."""
|
||||
# Main storage
|
||||
self._entries: dict[UUID, EntityIndexEntry] = {}
|
||||
# Inverted index: entity -> set of memory IDs
|
||||
self._entity_to_memories: dict[str, set[UUID]] = {}
|
||||
# Memory to entities mapping
|
||||
self._memory_to_entities: dict[UUID, set[str]] = {}
|
||||
logger.info("Initialized EntityIndex")
|
||||
|
||||
async def add(self, item: T) -> EntityIndexEntry:
|
||||
"""
|
||||
Add an item to the entity index.
|
||||
|
||||
Args:
|
||||
item: Memory item with entity information
|
||||
|
||||
Returns:
|
||||
The created index entry
|
||||
"""
|
||||
entities = self._extract_entities(item)
|
||||
|
||||
# Create entry for the primary entity (or first one)
|
||||
primary_entity = entities[0] if entities else ("unknown", "unknown")
|
||||
|
||||
entry = EntityIndexEntry(
|
||||
memory_id=item.id,
|
||||
memory_type=self._get_memory_type(item),
|
||||
entity_type=primary_entity[0],
|
||||
entity_value=primary_entity[1],
|
||||
)
|
||||
|
||||
self._entries[item.id] = entry
|
||||
|
||||
# Update inverted indices
|
||||
entity_keys = {f"{etype}:{evalue}" for etype, evalue in entities}
|
||||
self._memory_to_entities[item.id] = entity_keys
|
||||
|
||||
for entity_key in entity_keys:
|
||||
if entity_key not in self._entity_to_memories:
|
||||
self._entity_to_memories[entity_key] = set()
|
||||
self._entity_to_memories[entity_key].add(item.id)
|
||||
|
||||
logger.debug(f"Added {item.id} to entity index with {len(entities)} entities")
|
||||
return entry
|
||||
|
||||
async def remove(self, memory_id: UUID) -> bool:
|
||||
"""Remove an item from the entity index."""
|
||||
if memory_id not in self._entries:
|
||||
return False
|
||||
|
||||
# Remove from inverted index
|
||||
if memory_id in self._memory_to_entities:
|
||||
for entity_key in self._memory_to_entities[memory_id]:
|
||||
if entity_key in self._entity_to_memories:
|
||||
self._entity_to_memories[entity_key].discard(memory_id)
|
||||
if not self._entity_to_memories[entity_key]:
|
||||
del self._entity_to_memories[entity_key]
|
||||
del self._memory_to_entities[memory_id]
|
||||
|
||||
del self._entries[memory_id]
|
||||
logger.debug(f"Removed {memory_id} from entity index")
|
||||
return True
|
||||
|
||||
async def search( # type: ignore[override]
|
||||
self,
|
||||
query: Any,
|
||||
limit: int = 10,
|
||||
entity_type: str | None = None,
|
||||
entity_value: str | None = None,
|
||||
entities: list[tuple[str, str]] | None = None,
|
||||
match_all: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> list[EntityIndexEntry]:
|
||||
"""
|
||||
Search for items by entity.
|
||||
|
||||
Args:
|
||||
query: Entity value to search (if entity_type not specified)
|
||||
limit: Maximum results to return
|
||||
entity_type: Type of entity to filter
|
||||
entity_value: Specific entity value
|
||||
entities: List of (type, value) tuples to match
|
||||
match_all: If True, require all entities to match
|
||||
**kwargs: Additional filter parameters
|
||||
|
||||
Returns:
|
||||
List of matching entries
|
||||
"""
|
||||
matching_ids: set[UUID] | None = None
|
||||
|
||||
# Handle single entity query
|
||||
if entity_type and entity_value:
|
||||
entities = [(entity_type, entity_value)]
|
||||
elif entity_value is None and isinstance(query, str):
|
||||
# Search across all entity types
|
||||
entity_value = query
|
||||
|
||||
if entities:
|
||||
for etype, evalue in entities:
|
||||
entity_key = f"{etype}:{evalue}"
|
||||
if entity_key in self._entity_to_memories:
|
||||
ids = self._entity_to_memories[entity_key]
|
||||
if matching_ids is None:
|
||||
matching_ids = ids.copy()
|
||||
elif match_all:
|
||||
matching_ids &= ids
|
||||
else:
|
||||
matching_ids |= ids
|
||||
elif match_all:
|
||||
# Required entity not found
|
||||
matching_ids = set()
|
||||
break
|
||||
elif entity_value:
|
||||
# Search for value across all types
|
||||
matching_ids = set()
|
||||
for entity_key, ids in self._entity_to_memories.items():
|
||||
if entity_value.lower() in entity_key.lower():
|
||||
matching_ids |= ids
|
||||
|
||||
if matching_ids is None:
|
||||
matching_ids = set(self._entries.keys())
|
||||
|
||||
# Apply memory type filter if provided
|
||||
memory_type = kwargs.get("memory_type")
|
||||
results = []
|
||||
for mid in matching_ids:
|
||||
if mid in self._entries:
|
||||
entry = self._entries[mid]
|
||||
if memory_type and entry.memory_type != memory_type:
|
||||
continue
|
||||
results.append(entry)
|
||||
|
||||
logger.debug(f"Entity search returned {min(len(results), limit)} results")
|
||||
return results[:limit]
|
||||
|
||||
async def clear(self) -> int:
|
||||
"""Clear all entries from the index."""
|
||||
count = len(self._entries)
|
||||
self._entries.clear()
|
||||
self._entity_to_memories.clear()
|
||||
self._memory_to_entities.clear()
|
||||
logger.info(f"Cleared {count} entries from entity index")
|
||||
return count
|
||||
|
||||
async def count(self) -> int:
|
||||
"""Get the number of entries in the index."""
|
||||
return len(self._entries)
|
||||
|
||||
async def get_entities(self, memory_id: UUID) -> list[tuple[str, str]]:
|
||||
"""Get all entities for a memory item."""
|
||||
if memory_id not in self._memory_to_entities:
|
||||
return []
|
||||
|
||||
entities = []
|
||||
for entity_key in self._memory_to_entities[memory_id]:
|
||||
if ":" in entity_key:
|
||||
etype, evalue = entity_key.split(":", 1)
|
||||
entities.append((etype, evalue))
|
||||
return entities
|
||||
|
||||
def _extract_entities(self, item: T) -> list[tuple[str, str]]:
|
||||
"""Extract entities from a memory item."""
|
||||
entities: list[tuple[str, str]] = []
|
||||
|
||||
if isinstance(item, Episode):
|
||||
# Extract from task type and context
|
||||
entities.append(("task_type", item.task_type))
|
||||
if item.project_id:
|
||||
entities.append(("project", str(item.project_id)))
|
||||
if item.agent_instance_id:
|
||||
entities.append(("agent_instance", str(item.agent_instance_id)))
|
||||
if item.agent_type_id:
|
||||
entities.append(("agent_type", str(item.agent_type_id)))
|
||||
|
||||
elif isinstance(item, Fact):
|
||||
# Subject and object are entities
|
||||
entities.append(("subject", item.subject))
|
||||
entities.append(("object", item.object))
|
||||
if item.project_id:
|
||||
entities.append(("project", str(item.project_id)))
|
||||
|
||||
elif isinstance(item, Procedure):
|
||||
entities.append(("procedure", item.name))
|
||||
if item.project_id:
|
||||
entities.append(("project", str(item.project_id)))
|
||||
if item.agent_type_id:
|
||||
entities.append(("agent_type", str(item.agent_type_id)))
|
||||
|
||||
return entities
|
||||
|
||||
def _get_memory_type(self, item: T) -> MemoryType:
|
||||
"""Get the memory type for an item."""
|
||||
if isinstance(item, Episode):
|
||||
return MemoryType.EPISODIC
|
||||
elif isinstance(item, Fact):
|
||||
return MemoryType.SEMANTIC
|
||||
elif isinstance(item, Procedure):
|
||||
return MemoryType.PROCEDURAL
|
||||
return MemoryType.WORKING
|
||||
|
||||
|
||||
class OutcomeIndex(MemoryIndex[T]):
|
||||
"""
|
||||
Outcome-based index for filtering by success/failure.
|
||||
|
||||
Primarily used for episodes and procedures.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the outcome index."""
|
||||
self._entries: dict[UUID, OutcomeIndexEntry] = {}
|
||||
# Inverted index by outcome
|
||||
self._outcome_to_memories: dict[Outcome, set[UUID]] = {
|
||||
Outcome.SUCCESS: set(),
|
||||
Outcome.FAILURE: set(),
|
||||
Outcome.PARTIAL: set(),
|
||||
}
|
||||
logger.info("Initialized OutcomeIndex")
|
||||
|
||||
async def add(self, item: T) -> OutcomeIndexEntry:
|
||||
"""
|
||||
Add an item to the outcome index.
|
||||
|
||||
Args:
|
||||
item: Memory item with outcome information
|
||||
|
||||
Returns:
|
||||
The created index entry
|
||||
"""
|
||||
outcome = self._get_outcome(item)
|
||||
|
||||
entry = OutcomeIndexEntry(
|
||||
memory_id=item.id,
|
||||
memory_type=self._get_memory_type(item),
|
||||
outcome=outcome,
|
||||
)
|
||||
|
||||
self._entries[item.id] = entry
|
||||
self._outcome_to_memories[outcome].add(item.id)
|
||||
|
||||
logger.debug(f"Added {item.id} to outcome index with {outcome.value}")
|
||||
return entry
|
||||
|
||||
async def remove(self, memory_id: UUID) -> bool:
|
||||
"""Remove an item from the outcome index."""
|
||||
if memory_id not in self._entries:
|
||||
return False
|
||||
|
||||
entry = self._entries.pop(memory_id)
|
||||
self._outcome_to_memories[entry.outcome].discard(memory_id)
|
||||
|
||||
logger.debug(f"Removed {memory_id} from outcome index")
|
||||
return True
|
||||
|
||||
async def search( # type: ignore[override]
|
||||
self,
|
||||
query: Any,
|
||||
limit: int = 10,
|
||||
outcome: Outcome | None = None,
|
||||
outcomes: list[Outcome] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> list[OutcomeIndexEntry]:
|
||||
"""
|
||||
Search for items by outcome.
|
||||
|
||||
Args:
|
||||
query: Ignored for outcome search
|
||||
limit: Maximum results to return
|
||||
outcome: Single outcome to filter
|
||||
outcomes: Multiple outcomes to filter (OR)
|
||||
**kwargs: Additional filter parameters
|
||||
|
||||
Returns:
|
||||
List of matching entries
|
||||
"""
|
||||
if outcome:
|
||||
outcomes = [outcome]
|
||||
|
||||
if outcomes:
|
||||
matching_ids: set[UUID] = set()
|
||||
for o in outcomes:
|
||||
matching_ids |= self._outcome_to_memories.get(o, set())
|
||||
else:
|
||||
matching_ids = set(self._entries.keys())
|
||||
|
||||
# Apply memory type filter if provided
|
||||
memory_type = kwargs.get("memory_type")
|
||||
results = []
|
||||
for mid in matching_ids:
|
||||
if mid in self._entries:
|
||||
entry = self._entries[mid]
|
||||
if memory_type and entry.memory_type != memory_type:
|
||||
continue
|
||||
results.append(entry)
|
||||
|
||||
logger.debug(f"Outcome search returned {min(len(results), limit)} results")
|
||||
return results[:limit]
|
||||
|
||||
async def clear(self) -> int:
|
||||
"""Clear all entries from the index."""
|
||||
count = len(self._entries)
|
||||
self._entries.clear()
|
||||
for outcome in self._outcome_to_memories:
|
||||
self._outcome_to_memories[outcome].clear()
|
||||
logger.info(f"Cleared {count} entries from outcome index")
|
||||
return count
|
||||
|
||||
async def count(self) -> int:
|
||||
"""Get the number of entries in the index."""
|
||||
return len(self._entries)
|
||||
|
||||
async def get_outcome_stats(self) -> dict[Outcome, int]:
|
||||
"""Get statistics on outcomes."""
|
||||
return {outcome: len(ids) for outcome, ids in self._outcome_to_memories.items()}
|
||||
|
||||
def _get_outcome(self, item: T) -> Outcome:
|
||||
"""Get the outcome for an item."""
|
||||
if isinstance(item, Episode):
|
||||
return item.outcome
|
||||
elif isinstance(item, Procedure):
|
||||
# Derive from success rate
|
||||
if item.success_rate >= 0.8:
|
||||
return Outcome.SUCCESS
|
||||
elif item.success_rate <= 0.2:
|
||||
return Outcome.FAILURE
|
||||
return Outcome.PARTIAL
|
||||
return Outcome.SUCCESS
|
||||
|
||||
def _get_memory_type(self, item: T) -> MemoryType:
|
||||
"""Get the memory type for an item."""
|
||||
if isinstance(item, Episode):
|
||||
return MemoryType.EPISODIC
|
||||
elif isinstance(item, Fact):
|
||||
return MemoryType.SEMANTIC
|
||||
elif isinstance(item, Procedure):
|
||||
return MemoryType.PROCEDURAL
|
||||
return MemoryType.WORKING
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryIndexer:
|
||||
"""
|
||||
Unified indexer that manages all index types.
|
||||
|
||||
Provides a single interface for indexing and searching across
|
||||
multiple index types.
|
||||
"""
|
||||
|
||||
vector_index: VectorIndex[Any] = field(default_factory=VectorIndex)
|
||||
temporal_index: TemporalIndex[Any] = field(default_factory=TemporalIndex)
|
||||
entity_index: EntityIndex[Any] = field(default_factory=EntityIndex)
|
||||
outcome_index: OutcomeIndex[Any] = field(default_factory=OutcomeIndex)
|
||||
|
||||
async def index(self, item: Episode | Fact | Procedure) -> dict[str, IndexEntry]:
|
||||
"""
|
||||
Index an item across all applicable indices.
|
||||
|
||||
Args:
|
||||
item: Memory item to index
|
||||
|
||||
Returns:
|
||||
Dictionary of index type to entry
|
||||
"""
|
||||
results: dict[str, IndexEntry] = {}
|
||||
|
||||
# Vector index (if embedding present)
|
||||
if getattr(item, "embedding", None):
|
||||
results["vector"] = await self.vector_index.add(item)
|
||||
|
||||
# Temporal index
|
||||
results["temporal"] = await self.temporal_index.add(item)
|
||||
|
||||
# Entity index
|
||||
results["entity"] = await self.entity_index.add(item)
|
||||
|
||||
# Outcome index (for episodes and procedures)
|
||||
if isinstance(item, (Episode, Procedure)):
|
||||
results["outcome"] = await self.outcome_index.add(item)
|
||||
|
||||
logger.info(
|
||||
f"Indexed {item.id} across {len(results)} indices: {list(results.keys())}"
|
||||
)
|
||||
return results
|
||||
|
||||
async def remove(self, memory_id: UUID) -> dict[str, bool]:
|
||||
"""
|
||||
Remove an item from all indices.
|
||||
|
||||
Args:
|
||||
memory_id: ID of the memory to remove
|
||||
|
||||
Returns:
|
||||
Dictionary of index type to removal success
|
||||
"""
|
||||
results = {
|
||||
"vector": await self.vector_index.remove(memory_id),
|
||||
"temporal": await self.temporal_index.remove(memory_id),
|
||||
"entity": await self.entity_index.remove(memory_id),
|
||||
"outcome": await self.outcome_index.remove(memory_id),
|
||||
}
|
||||
|
||||
removed_from = [k for k, v in results.items() if v]
|
||||
if removed_from:
|
||||
logger.info(f"Removed {memory_id} from indices: {removed_from}")
|
||||
|
||||
return results
|
||||
|
||||
async def clear_all(self) -> dict[str, int]:
|
||||
"""
|
||||
Clear all indices.
|
||||
|
||||
Returns:
|
||||
Dictionary of index type to count cleared
|
||||
"""
|
||||
return {
|
||||
"vector": await self.vector_index.clear(),
|
||||
"temporal": await self.temporal_index.clear(),
|
||||
"entity": await self.entity_index.clear(),
|
||||
"outcome": await self.outcome_index.clear(),
|
||||
}
|
||||
|
||||
async def get_stats(self) -> dict[str, int]:
|
||||
"""
|
||||
Get statistics for all indices.
|
||||
|
||||
Returns:
|
||||
Dictionary of index type to entry count
|
||||
"""
|
||||
return {
|
||||
"vector": await self.vector_index.count(),
|
||||
"temporal": await self.temporal_index.count(),
|
||||
"entity": await self.entity_index.count(),
|
||||
"outcome": await self.outcome_index.count(),
|
||||
}
|
||||
|
||||
|
||||
# Singleton indexer instance
|
||||
_indexer: MemoryIndexer | None = None
|
||||
|
||||
|
||||
def get_memory_indexer() -> MemoryIndexer:
|
||||
"""Get the singleton memory indexer instance."""
|
||||
global _indexer
|
||||
if _indexer is None:
|
||||
_indexer = MemoryIndexer()
|
||||
return _indexer
|
||||
742
backend/app/services/memory/indexing/retrieval.py
Normal file
742
backend/app/services/memory/indexing/retrieval.py
Normal file
@@ -0,0 +1,742 @@
|
||||
# app/services/memory/indexing/retrieval.py
|
||||
"""
|
||||
Memory Retrieval Engine.
|
||||
|
||||
Provides hybrid retrieval capabilities combining:
|
||||
- Vector similarity search
|
||||
- Temporal filtering
|
||||
- Entity filtering
|
||||
- Outcome filtering
|
||||
- Relevance scoring
|
||||
- Result caching
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, TypeVar
|
||||
from uuid import UUID
|
||||
|
||||
from app.services.memory.types import (
|
||||
Episode,
|
||||
Fact,
|
||||
MemoryType,
|
||||
Outcome,
|
||||
Procedure,
|
||||
RetrievalResult,
|
||||
)
|
||||
|
||||
from .index import (
|
||||
MemoryIndexer,
|
||||
get_memory_indexer,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T", Episode, Fact, Procedure)
|
||||
|
||||
|
||||
def _utcnow() -> datetime:
|
||||
"""Get current UTC time as timezone-aware datetime."""
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrievalQuery:
|
||||
"""Query parameters for memory retrieval."""
|
||||
|
||||
# Text/semantic query
|
||||
query_text: str | None = None
|
||||
query_embedding: list[float] | None = None
|
||||
|
||||
# Temporal filters
|
||||
start_time: datetime | None = None
|
||||
end_time: datetime | None = None
|
||||
recent_seconds: float | None = None
|
||||
|
||||
# Entity filters
|
||||
entities: list[tuple[str, str]] | None = None
|
||||
entity_match_all: bool = False
|
||||
|
||||
# Outcome filters
|
||||
outcomes: list[Outcome] | None = None
|
||||
|
||||
# Memory type filter
|
||||
memory_types: list[MemoryType] | None = None
|
||||
|
||||
# Result options
|
||||
limit: int = 10
|
||||
min_relevance: float = 0.0
|
||||
|
||||
# Retrieval mode
|
||||
use_vector: bool = True
|
||||
use_temporal: bool = True
|
||||
use_entity: bool = True
|
||||
use_outcome: bool = True
|
||||
|
||||
def to_cache_key(self) -> str:
|
||||
"""Generate a cache key for this query."""
|
||||
key_parts = [
|
||||
self.query_text or "",
|
||||
str(self.start_time),
|
||||
str(self.end_time),
|
||||
str(self.recent_seconds),
|
||||
str(self.entities),
|
||||
str(self.outcomes),
|
||||
str(self.memory_types),
|
||||
str(self.limit),
|
||||
str(self.min_relevance),
|
||||
]
|
||||
key_string = "|".join(key_parts)
|
||||
return hashlib.sha256(key_string.encode()).hexdigest()[:32]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScoredResult:
|
||||
"""A retrieval result with relevance score."""
|
||||
|
||||
memory_id: UUID
|
||||
memory_type: MemoryType
|
||||
relevance_score: float
|
||||
score_breakdown: dict[str, float] = field(default_factory=dict)
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheEntry:
|
||||
"""A cached retrieval result."""
|
||||
|
||||
results: list[ScoredResult]
|
||||
created_at: datetime
|
||||
ttl_seconds: float
|
||||
query_key: str
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if this cache entry has expired."""
|
||||
age = (_utcnow() - self.created_at).total_seconds()
|
||||
return age > self.ttl_seconds
|
||||
|
||||
|
||||
class RelevanceScorer:
|
||||
"""
|
||||
Calculates relevance scores for retrieved memories.
|
||||
|
||||
Combines multiple signals:
|
||||
- Vector similarity (if available)
|
||||
- Temporal recency
|
||||
- Entity match count
|
||||
- Outcome preference
|
||||
- Importance/confidence
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vector_weight: float = 0.4,
|
||||
recency_weight: float = 0.2,
|
||||
entity_weight: float = 0.2,
|
||||
outcome_weight: float = 0.1,
|
||||
importance_weight: float = 0.1,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the relevance scorer.
|
||||
|
||||
Args:
|
||||
vector_weight: Weight for vector similarity (0-1)
|
||||
recency_weight: Weight for temporal recency (0-1)
|
||||
entity_weight: Weight for entity matches (0-1)
|
||||
outcome_weight: Weight for outcome preference (0-1)
|
||||
importance_weight: Weight for importance score (0-1)
|
||||
"""
|
||||
total = (
|
||||
vector_weight
|
||||
+ recency_weight
|
||||
+ entity_weight
|
||||
+ outcome_weight
|
||||
+ importance_weight
|
||||
)
|
||||
# Normalize weights
|
||||
self.vector_weight = vector_weight / total
|
||||
self.recency_weight = recency_weight / total
|
||||
self.entity_weight = entity_weight / total
|
||||
self.outcome_weight = outcome_weight / total
|
||||
self.importance_weight = importance_weight / total
|
||||
|
||||
def score(
|
||||
self,
|
||||
memory_id: UUID,
|
||||
memory_type: MemoryType,
|
||||
vector_similarity: float | None = None,
|
||||
timestamp: datetime | None = None,
|
||||
entity_match_count: int = 0,
|
||||
entity_total: int = 1,
|
||||
outcome: Outcome | None = None,
|
||||
importance: float = 0.5,
|
||||
preferred_outcomes: list[Outcome] | None = None,
|
||||
) -> ScoredResult:
|
||||
"""
|
||||
Calculate a relevance score for a memory.
|
||||
|
||||
Args:
|
||||
memory_id: ID of the memory
|
||||
memory_type: Type of memory
|
||||
vector_similarity: Similarity score from vector search (0-1)
|
||||
timestamp: Timestamp of the memory
|
||||
entity_match_count: Number of matching entities
|
||||
entity_total: Total entities in query
|
||||
outcome: Outcome of the memory
|
||||
importance: Importance score of the memory (0-1)
|
||||
preferred_outcomes: Outcomes to prefer
|
||||
|
||||
Returns:
|
||||
Scored result with breakdown
|
||||
"""
|
||||
breakdown: dict[str, float] = {}
|
||||
|
||||
# Vector similarity score
|
||||
if vector_similarity is not None:
|
||||
breakdown["vector"] = vector_similarity
|
||||
else:
|
||||
breakdown["vector"] = 0.5 # Neutral if no vector
|
||||
|
||||
# Recency score (exponential decay)
|
||||
if timestamp:
|
||||
age_hours = (_utcnow() - timestamp).total_seconds() / 3600
|
||||
# Decay with half-life of 24 hours
|
||||
breakdown["recency"] = 2 ** (-age_hours / 24)
|
||||
else:
|
||||
breakdown["recency"] = 0.5
|
||||
|
||||
# Entity match score
|
||||
if entity_total > 0:
|
||||
breakdown["entity"] = entity_match_count / entity_total
|
||||
else:
|
||||
breakdown["entity"] = 1.0 # No entity filter = full score
|
||||
|
||||
# Outcome score
|
||||
if preferred_outcomes and outcome:
|
||||
breakdown["outcome"] = 1.0 if outcome in preferred_outcomes else 0.0
|
||||
else:
|
||||
breakdown["outcome"] = 0.5 # Neutral if no preference
|
||||
|
||||
# Importance score
|
||||
breakdown["importance"] = importance
|
||||
|
||||
# Calculate weighted sum
|
||||
total_score = (
|
||||
breakdown["vector"] * self.vector_weight
|
||||
+ breakdown["recency"] * self.recency_weight
|
||||
+ breakdown["entity"] * self.entity_weight
|
||||
+ breakdown["outcome"] * self.outcome_weight
|
||||
+ breakdown["importance"] * self.importance_weight
|
||||
)
|
||||
|
||||
return ScoredResult(
|
||||
memory_id=memory_id,
|
||||
memory_type=memory_type,
|
||||
relevance_score=total_score,
|
||||
score_breakdown=breakdown,
|
||||
)
|
||||
|
||||
|
||||
class RetrievalCache:
|
||||
"""
|
||||
In-memory cache for retrieval results.
|
||||
|
||||
Supports TTL-based expiration and LRU eviction with O(1) operations.
|
||||
Uses OrderedDict for efficient LRU tracking.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_entries: int = 1000,
|
||||
default_ttl_seconds: float = 300,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the cache.
|
||||
|
||||
Args:
|
||||
max_entries: Maximum cache entries
|
||||
default_ttl_seconds: Default TTL for entries
|
||||
"""
|
||||
# OrderedDict maintains insertion order; we use move_to_end for O(1) LRU
|
||||
self._cache: OrderedDict[str, CacheEntry] = OrderedDict()
|
||||
self._max_entries = max_entries
|
||||
self._default_ttl = default_ttl_seconds
|
||||
logger.info(
|
||||
f"Initialized RetrievalCache with max_entries={max_entries}, "
|
||||
f"ttl={default_ttl_seconds}s"
|
||||
)
|
||||
|
||||
def get(self, query_key: str) -> list[ScoredResult] | None:
|
||||
"""
|
||||
Get cached results for a query.
|
||||
|
||||
Args:
|
||||
query_key: Cache key for the query
|
||||
|
||||
Returns:
|
||||
Cached results or None if not found/expired
|
||||
"""
|
||||
if query_key not in self._cache:
|
||||
return None
|
||||
|
||||
entry = self._cache[query_key]
|
||||
if entry.is_expired():
|
||||
del self._cache[query_key]
|
||||
return None
|
||||
|
||||
# Update access order (LRU) - O(1) with OrderedDict
|
||||
self._cache.move_to_end(query_key)
|
||||
|
||||
logger.debug(f"Cache hit for {query_key}")
|
||||
return entry.results
|
||||
|
||||
def put(
|
||||
self,
|
||||
query_key: str,
|
||||
results: list[ScoredResult],
|
||||
ttl_seconds: float | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Cache results for a query.
|
||||
|
||||
Args:
|
||||
query_key: Cache key for the query
|
||||
results: Results to cache
|
||||
ttl_seconds: TTL for this entry (or default)
|
||||
"""
|
||||
# Evict oldest entries if at capacity - O(1) with popitem(last=False)
|
||||
while len(self._cache) >= self._max_entries:
|
||||
self._cache.popitem(last=False)
|
||||
|
||||
entry = CacheEntry(
|
||||
results=results,
|
||||
created_at=_utcnow(),
|
||||
ttl_seconds=ttl_seconds or self._default_ttl,
|
||||
query_key=query_key,
|
||||
)
|
||||
|
||||
self._cache[query_key] = entry
|
||||
logger.debug(f"Cached {len(results)} results for {query_key}")
|
||||
|
||||
def invalidate(self, query_key: str) -> bool:
|
||||
"""
|
||||
Invalidate a specific cache entry.
|
||||
|
||||
Args:
|
||||
query_key: Cache key to invalidate
|
||||
|
||||
Returns:
|
||||
True if entry was found and removed
|
||||
"""
|
||||
if query_key in self._cache:
|
||||
del self._cache[query_key]
|
||||
return True
|
||||
return False
|
||||
|
||||
def invalidate_by_memory(self, memory_id: UUID) -> int:
|
||||
"""
|
||||
Invalidate all cache entries containing a specific memory.
|
||||
|
||||
Args:
|
||||
memory_id: Memory ID to invalidate
|
||||
|
||||
Returns:
|
||||
Number of entries invalidated
|
||||
"""
|
||||
keys_to_remove = []
|
||||
for key, entry in self._cache.items():
|
||||
if any(r.memory_id == memory_id for r in entry.results):
|
||||
keys_to_remove.append(key)
|
||||
|
||||
for key in keys_to_remove:
|
||||
self.invalidate(key)
|
||||
|
||||
if keys_to_remove:
|
||||
logger.debug(
|
||||
f"Invalidated {len(keys_to_remove)} cache entries for {memory_id}"
|
||||
)
|
||||
return len(keys_to_remove)
|
||||
|
||||
def clear(self) -> int:
|
||||
"""
|
||||
Clear all cache entries.
|
||||
|
||||
Returns:
|
||||
Number of entries cleared
|
||||
"""
|
||||
count = len(self._cache)
|
||||
self._cache.clear()
|
||||
logger.info(f"Cleared {count} cache entries")
|
||||
return count
|
||||
|
||||
def get_stats(self) -> dict[str, Any]:
|
||||
"""Get cache statistics."""
|
||||
expired_count = sum(1 for e in self._cache.values() if e.is_expired())
|
||||
return {
|
||||
"total_entries": len(self._cache),
|
||||
"expired_entries": expired_count,
|
||||
"max_entries": self._max_entries,
|
||||
"default_ttl_seconds": self._default_ttl,
|
||||
}
|
||||
|
||||
|
||||
class RetrievalEngine:
|
||||
"""
|
||||
Hybrid retrieval engine for memory search.
|
||||
|
||||
Combines multiple index types for comprehensive retrieval:
|
||||
- Vector search for semantic similarity
|
||||
- Temporal index for time-based filtering
|
||||
- Entity index for entity-based lookups
|
||||
- Outcome index for success/failure filtering
|
||||
|
||||
Results are scored and ranked using relevance scoring.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
indexer: MemoryIndexer | None = None,
|
||||
scorer: RelevanceScorer | None = None,
|
||||
cache: RetrievalCache | None = None,
|
||||
enable_cache: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the retrieval engine.
|
||||
|
||||
Args:
|
||||
indexer: Memory indexer (defaults to singleton)
|
||||
scorer: Relevance scorer (defaults to new instance)
|
||||
cache: Retrieval cache (defaults to new instance)
|
||||
enable_cache: Whether to enable result caching
|
||||
"""
|
||||
self._indexer = indexer or get_memory_indexer()
|
||||
self._scorer = scorer or RelevanceScorer()
|
||||
self._cache = cache or RetrievalCache() if enable_cache else None
|
||||
self._enable_cache = enable_cache
|
||||
logger.info(f"Initialized RetrievalEngine with cache={enable_cache}")
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
query: RetrievalQuery,
|
||||
use_cache: bool = True,
|
||||
) -> RetrievalResult[ScoredResult]:
|
||||
"""
|
||||
Retrieve relevant memories using hybrid search.
|
||||
|
||||
Args:
|
||||
query: Retrieval query parameters
|
||||
use_cache: Whether to use cached results
|
||||
|
||||
Returns:
|
||||
Retrieval result with scored items
|
||||
"""
|
||||
start_time = _utcnow()
|
||||
|
||||
# Check cache
|
||||
cache_key = query.to_cache_key()
|
||||
if use_cache and self._cache:
|
||||
cached = self._cache.get(cache_key)
|
||||
if cached:
|
||||
latency = (_utcnow() - start_time).total_seconds() * 1000
|
||||
return RetrievalResult(
|
||||
items=cached,
|
||||
total_count=len(cached),
|
||||
query=query.query_text or "",
|
||||
retrieval_type="cached",
|
||||
latency_ms=latency,
|
||||
metadata={"cache_hit": True},
|
||||
)
|
||||
|
||||
# Collect candidates from each index
|
||||
candidates: dict[UUID, dict[str, Any]] = {}
|
||||
|
||||
# Vector search
|
||||
if query.use_vector and query.query_embedding:
|
||||
vector_results = await self._indexer.vector_index.search(
|
||||
query=query.query_embedding,
|
||||
limit=query.limit * 3, # Get more for filtering
|
||||
min_similarity=query.min_relevance,
|
||||
memory_type=query.memory_types[0] if query.memory_types else None,
|
||||
)
|
||||
for entry in vector_results:
|
||||
if entry.memory_id not in candidates:
|
||||
candidates[entry.memory_id] = {
|
||||
"memory_type": entry.memory_type,
|
||||
"sources": [],
|
||||
}
|
||||
candidates[entry.memory_id]["vector_similarity"] = entry.metadata.get(
|
||||
"similarity", 0.5
|
||||
)
|
||||
candidates[entry.memory_id]["sources"].append("vector")
|
||||
|
||||
# Temporal search
|
||||
if query.use_temporal and (
|
||||
query.start_time or query.end_time or query.recent_seconds
|
||||
):
|
||||
temporal_results = await self._indexer.temporal_index.search(
|
||||
query=None,
|
||||
limit=query.limit * 3,
|
||||
start_time=query.start_time,
|
||||
end_time=query.end_time,
|
||||
recent_seconds=query.recent_seconds,
|
||||
memory_type=query.memory_types[0] if query.memory_types else None,
|
||||
)
|
||||
for temporal_entry in temporal_results:
|
||||
if temporal_entry.memory_id not in candidates:
|
||||
candidates[temporal_entry.memory_id] = {
|
||||
"memory_type": temporal_entry.memory_type,
|
||||
"sources": [],
|
||||
}
|
||||
candidates[temporal_entry.memory_id]["timestamp"] = (
|
||||
temporal_entry.timestamp
|
||||
)
|
||||
candidates[temporal_entry.memory_id]["sources"].append("temporal")
|
||||
|
||||
# Entity search
|
||||
if query.use_entity and query.entities:
|
||||
entity_results = await self._indexer.entity_index.search(
|
||||
query=None,
|
||||
limit=query.limit * 3,
|
||||
entities=query.entities,
|
||||
match_all=query.entity_match_all,
|
||||
memory_type=query.memory_types[0] if query.memory_types else None,
|
||||
)
|
||||
for entity_entry in entity_results:
|
||||
if entity_entry.memory_id not in candidates:
|
||||
candidates[entity_entry.memory_id] = {
|
||||
"memory_type": entity_entry.memory_type,
|
||||
"sources": [],
|
||||
}
|
||||
# Count entity matches
|
||||
entity_count = candidates[entity_entry.memory_id].get(
|
||||
"entity_match_count", 0
|
||||
)
|
||||
candidates[entity_entry.memory_id]["entity_match_count"] = (
|
||||
entity_count + 1
|
||||
)
|
||||
candidates[entity_entry.memory_id]["sources"].append("entity")
|
||||
|
||||
# Outcome search
|
||||
if query.use_outcome and query.outcomes:
|
||||
outcome_results = await self._indexer.outcome_index.search(
|
||||
query=None,
|
||||
limit=query.limit * 3,
|
||||
outcomes=query.outcomes,
|
||||
memory_type=query.memory_types[0] if query.memory_types else None,
|
||||
)
|
||||
for outcome_entry in outcome_results:
|
||||
if outcome_entry.memory_id not in candidates:
|
||||
candidates[outcome_entry.memory_id] = {
|
||||
"memory_type": outcome_entry.memory_type,
|
||||
"sources": [],
|
||||
}
|
||||
candidates[outcome_entry.memory_id]["outcome"] = outcome_entry.outcome
|
||||
candidates[outcome_entry.memory_id]["sources"].append("outcome")
|
||||
|
||||
# Score and rank candidates
|
||||
scored_results: list[ScoredResult] = []
|
||||
entity_total = len(query.entities) if query.entities else 1
|
||||
|
||||
for memory_id, data in candidates.items():
|
||||
scored = self._scorer.score(
|
||||
memory_id=memory_id,
|
||||
memory_type=data["memory_type"],
|
||||
vector_similarity=data.get("vector_similarity"),
|
||||
timestamp=data.get("timestamp"),
|
||||
entity_match_count=data.get("entity_match_count", 0),
|
||||
entity_total=entity_total,
|
||||
outcome=data.get("outcome"),
|
||||
preferred_outcomes=query.outcomes,
|
||||
)
|
||||
scored.metadata["sources"] = data.get("sources", [])
|
||||
|
||||
# Filter by minimum relevance
|
||||
if scored.relevance_score >= query.min_relevance:
|
||||
scored_results.append(scored)
|
||||
|
||||
# Sort by relevance score
|
||||
scored_results.sort(key=lambda x: x.relevance_score, reverse=True)
|
||||
|
||||
# Apply limit
|
||||
final_results = scored_results[: query.limit]
|
||||
|
||||
# Cache results
|
||||
if use_cache and self._cache and final_results:
|
||||
self._cache.put(cache_key, final_results)
|
||||
|
||||
latency = (_utcnow() - start_time).total_seconds() * 1000
|
||||
|
||||
logger.info(
|
||||
f"Retrieved {len(final_results)} results from {len(candidates)} candidates "
|
||||
f"in {latency:.2f}ms"
|
||||
)
|
||||
|
||||
return RetrievalResult(
|
||||
items=final_results,
|
||||
total_count=len(candidates),
|
||||
query=query.query_text or "",
|
||||
retrieval_type="hybrid",
|
||||
latency_ms=latency,
|
||||
metadata={
|
||||
"cache_hit": False,
|
||||
"candidates_count": len(candidates),
|
||||
"filtered_count": len(scored_results),
|
||||
},
|
||||
)
|
||||
|
||||
async def retrieve_similar(
|
||||
self,
|
||||
embedding: list[float],
|
||||
limit: int = 10,
|
||||
min_similarity: float = 0.5,
|
||||
memory_types: list[MemoryType] | None = None,
|
||||
) -> RetrievalResult[ScoredResult]:
|
||||
"""
|
||||
Retrieve memories similar to a given embedding.
|
||||
|
||||
Args:
|
||||
embedding: Query embedding
|
||||
limit: Maximum results
|
||||
min_similarity: Minimum similarity threshold
|
||||
memory_types: Filter by memory types
|
||||
|
||||
Returns:
|
||||
Retrieval result with scored items
|
||||
"""
|
||||
query = RetrievalQuery(
|
||||
query_embedding=embedding,
|
||||
limit=limit,
|
||||
min_relevance=min_similarity,
|
||||
memory_types=memory_types,
|
||||
use_temporal=False,
|
||||
use_entity=False,
|
||||
use_outcome=False,
|
||||
)
|
||||
return await self.retrieve(query)
|
||||
|
||||
async def retrieve_recent(
|
||||
self,
|
||||
hours: float = 24,
|
||||
limit: int = 10,
|
||||
memory_types: list[MemoryType] | None = None,
|
||||
) -> RetrievalResult[ScoredResult]:
|
||||
"""
|
||||
Retrieve recent memories.
|
||||
|
||||
Args:
|
||||
hours: Number of hours to look back
|
||||
limit: Maximum results
|
||||
memory_types: Filter by memory types
|
||||
|
||||
Returns:
|
||||
Retrieval result with scored items
|
||||
"""
|
||||
query = RetrievalQuery(
|
||||
recent_seconds=hours * 3600,
|
||||
limit=limit,
|
||||
memory_types=memory_types,
|
||||
use_vector=False,
|
||||
use_entity=False,
|
||||
use_outcome=False,
|
||||
)
|
||||
return await self.retrieve(query)
|
||||
|
||||
async def retrieve_by_entity(
|
||||
self,
|
||||
entity_type: str,
|
||||
entity_value: str,
|
||||
limit: int = 10,
|
||||
memory_types: list[MemoryType] | None = None,
|
||||
) -> RetrievalResult[ScoredResult]:
|
||||
"""
|
||||
Retrieve memories by entity.
|
||||
|
||||
Args:
|
||||
entity_type: Type of entity
|
||||
entity_value: Entity value
|
||||
limit: Maximum results
|
||||
memory_types: Filter by memory types
|
||||
|
||||
Returns:
|
||||
Retrieval result with scored items
|
||||
"""
|
||||
query = RetrievalQuery(
|
||||
entities=[(entity_type, entity_value)],
|
||||
limit=limit,
|
||||
memory_types=memory_types,
|
||||
use_vector=False,
|
||||
use_temporal=False,
|
||||
use_outcome=False,
|
||||
)
|
||||
return await self.retrieve(query)
|
||||
|
||||
async def retrieve_successful(
|
||||
self,
|
||||
limit: int = 10,
|
||||
memory_types: list[MemoryType] | None = None,
|
||||
) -> RetrievalResult[ScoredResult]:
|
||||
"""
|
||||
Retrieve successful memories.
|
||||
|
||||
Args:
|
||||
limit: Maximum results
|
||||
memory_types: Filter by memory types
|
||||
|
||||
Returns:
|
||||
Retrieval result with scored items
|
||||
"""
|
||||
query = RetrievalQuery(
|
||||
outcomes=[Outcome.SUCCESS],
|
||||
limit=limit,
|
||||
memory_types=memory_types,
|
||||
use_vector=False,
|
||||
use_temporal=False,
|
||||
use_entity=False,
|
||||
)
|
||||
return await self.retrieve(query)
|
||||
|
||||
def invalidate_cache(self) -> int:
|
||||
"""
|
||||
Invalidate all cached results.
|
||||
|
||||
Returns:
|
||||
Number of entries invalidated
|
||||
"""
|
||||
if self._cache:
|
||||
return self._cache.clear()
|
||||
return 0
|
||||
|
||||
def invalidate_cache_for_memory(self, memory_id: UUID) -> int:
|
||||
"""
|
||||
Invalidate cache entries containing a specific memory.
|
||||
|
||||
Args:
|
||||
memory_id: Memory ID to invalidate
|
||||
|
||||
Returns:
|
||||
Number of entries invalidated
|
||||
"""
|
||||
if self._cache:
|
||||
return self._cache.invalidate_by_memory(memory_id)
|
||||
return 0
|
||||
|
||||
def get_cache_stats(self) -> dict[str, Any]:
|
||||
"""Get cache statistics."""
|
||||
if self._cache:
|
||||
return self._cache.get_stats()
|
||||
return {"enabled": False}
|
||||
|
||||
|
||||
# Singleton retrieval engine instance
|
||||
_engine: RetrievalEngine | None = None
|
||||
|
||||
|
||||
def get_retrieval_engine() -> RetrievalEngine:
|
||||
"""Get the singleton retrieval engine instance."""
|
||||
global _engine
|
||||
if _engine is None:
|
||||
_engine = RetrievalEngine()
|
||||
return _engine
|
||||
19
backend/app/services/memory/integration/__init__.py
Normal file
19
backend/app/services/memory/integration/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
# app/services/memory/integration/__init__.py
|
||||
"""
|
||||
Memory Integration Module.
|
||||
|
||||
Provides integration between the agent memory system and other Syndarix components:
|
||||
- Context Engine: Memory as context source
|
||||
- Agent Lifecycle: Spawn, pause, resume, terminate hooks
|
||||
"""
|
||||
|
||||
from .context_source import MemoryContextSource, get_memory_context_source
|
||||
from .lifecycle import AgentLifecycleManager, LifecycleHooks, get_lifecycle_manager
|
||||
|
||||
__all__ = [
|
||||
"AgentLifecycleManager",
|
||||
"LifecycleHooks",
|
||||
"MemoryContextSource",
|
||||
"get_lifecycle_manager",
|
||||
"get_memory_context_source",
|
||||
]
|
||||
399
backend/app/services/memory/integration/context_source.py
Normal file
399
backend/app/services/memory/integration/context_source.py
Normal file
@@ -0,0 +1,399 @@
|
||||
# app/services/memory/integration/context_source.py
|
||||
"""
|
||||
Memory Context Source.
|
||||
|
||||
Provides agent memory as a context source for the Context Engine.
|
||||
Retrieves relevant memories based on query and converts them to MemoryContext objects.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.context.types.memory import MemoryContext
|
||||
from app.services.memory.episodic import EpisodicMemory
|
||||
from app.services.memory.procedural import ProceduralMemory
|
||||
from app.services.memory.semantic import SemanticMemory
|
||||
from app.services.memory.working import WorkingMemory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryFetchConfig:
|
||||
"""Configuration for memory fetching."""
|
||||
|
||||
# Limits per memory type
|
||||
working_limit: int = 10
|
||||
episodic_limit: int = 10
|
||||
semantic_limit: int = 15
|
||||
procedural_limit: int = 5
|
||||
|
||||
# Time ranges
|
||||
episodic_days_back: int = 30
|
||||
min_relevance: float = 0.3
|
||||
|
||||
# Which memory types to include
|
||||
include_working: bool = True
|
||||
include_episodic: bool = True
|
||||
include_semantic: bool = True
|
||||
include_procedural: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryFetchResult:
|
||||
"""Result of memory fetch operation."""
|
||||
|
||||
contexts: list[MemoryContext]
|
||||
by_type: dict[str, int]
|
||||
fetch_time_ms: float
|
||||
query: str
|
||||
|
||||
|
||||
class MemoryContextSource:
|
||||
"""
|
||||
Source for memory context in the Context Engine.
|
||||
|
||||
This service retrieves relevant memories based on a query and
|
||||
converts them to MemoryContext objects for context assembly.
|
||||
It coordinates between all memory types (working, episodic,
|
||||
semantic, procedural) to provide a comprehensive memory context.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
embedding_generator: Any | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the memory context source.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
embedding_generator: Optional embedding generator for semantic search
|
||||
"""
|
||||
self._session = session
|
||||
self._embedding_generator = embedding_generator
|
||||
|
||||
# Lazy-initialized memory services
|
||||
self._episodic: EpisodicMemory | None = None
|
||||
self._semantic: SemanticMemory | None = None
|
||||
self._procedural: ProceduralMemory | None = None
|
||||
|
||||
async def _get_episodic(self) -> EpisodicMemory:
|
||||
"""Get or create episodic memory service."""
|
||||
if self._episodic is None:
|
||||
self._episodic = await EpisodicMemory.create(
|
||||
self._session,
|
||||
self._embedding_generator,
|
||||
)
|
||||
return self._episodic
|
||||
|
||||
async def _get_semantic(self) -> SemanticMemory:
|
||||
"""Get or create semantic memory service."""
|
||||
if self._semantic is None:
|
||||
self._semantic = await SemanticMemory.create(
|
||||
self._session,
|
||||
self._embedding_generator,
|
||||
)
|
||||
return self._semantic
|
||||
|
||||
async def _get_procedural(self) -> ProceduralMemory:
|
||||
"""Get or create procedural memory service."""
|
||||
if self._procedural is None:
|
||||
self._procedural = await ProceduralMemory.create(
|
||||
self._session,
|
||||
self._embedding_generator,
|
||||
)
|
||||
return self._procedural
|
||||
|
||||
async def fetch_context(
|
||||
self,
|
||||
query: str,
|
||||
project_id: UUID,
|
||||
agent_instance_id: UUID | None = None,
|
||||
agent_type_id: UUID | None = None,
|
||||
session_id: str | None = None,
|
||||
config: MemoryFetchConfig | None = None,
|
||||
) -> MemoryFetchResult:
|
||||
"""
|
||||
Fetch relevant memories as context.
|
||||
|
||||
This is the main entry point for the Context Engine integration.
|
||||
It searches across all memory types and returns relevant memories
|
||||
as MemoryContext objects.
|
||||
|
||||
Args:
|
||||
query: Search query for finding relevant memories
|
||||
project_id: Project scope
|
||||
agent_instance_id: Optional agent instance scope
|
||||
agent_type_id: Optional agent type scope (for procedural)
|
||||
session_id: Optional session ID (for working memory)
|
||||
config: Optional fetch configuration
|
||||
|
||||
Returns:
|
||||
MemoryFetchResult with contexts and metadata
|
||||
"""
|
||||
config = config or MemoryFetchConfig()
|
||||
start_time = datetime.now(UTC)
|
||||
|
||||
contexts: list[MemoryContext] = []
|
||||
by_type: dict[str, int] = {
|
||||
"working": 0,
|
||||
"episodic": 0,
|
||||
"semantic": 0,
|
||||
"procedural": 0,
|
||||
}
|
||||
|
||||
# Fetch from working memory (session-scoped)
|
||||
if config.include_working and session_id:
|
||||
try:
|
||||
working_contexts = await self._fetch_working(
|
||||
query=query,
|
||||
session_id=session_id,
|
||||
project_id=project_id,
|
||||
agent_instance_id=agent_instance_id,
|
||||
limit=config.working_limit,
|
||||
)
|
||||
contexts.extend(working_contexts)
|
||||
by_type["working"] = len(working_contexts)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch working memory: {e}")
|
||||
|
||||
# Fetch from episodic memory
|
||||
if config.include_episodic:
|
||||
try:
|
||||
episodic_contexts = await self._fetch_episodic(
|
||||
query=query,
|
||||
project_id=project_id,
|
||||
agent_instance_id=agent_instance_id,
|
||||
limit=config.episodic_limit,
|
||||
days_back=config.episodic_days_back,
|
||||
)
|
||||
contexts.extend(episodic_contexts)
|
||||
by_type["episodic"] = len(episodic_contexts)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch episodic memory: {e}")
|
||||
|
||||
# Fetch from semantic memory
|
||||
if config.include_semantic:
|
||||
try:
|
||||
semantic_contexts = await self._fetch_semantic(
|
||||
query=query,
|
||||
project_id=project_id,
|
||||
limit=config.semantic_limit,
|
||||
min_relevance=config.min_relevance,
|
||||
)
|
||||
contexts.extend(semantic_contexts)
|
||||
by_type["semantic"] = len(semantic_contexts)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch semantic memory: {e}")
|
||||
|
||||
# Fetch from procedural memory
|
||||
if config.include_procedural:
|
||||
try:
|
||||
procedural_contexts = await self._fetch_procedural(
|
||||
query=query,
|
||||
project_id=project_id,
|
||||
agent_type_id=agent_type_id,
|
||||
limit=config.procedural_limit,
|
||||
)
|
||||
contexts.extend(procedural_contexts)
|
||||
by_type["procedural"] = len(procedural_contexts)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch procedural memory: {e}")
|
||||
|
||||
# Sort by relevance
|
||||
contexts.sort(key=lambda c: c.relevance_score, reverse=True)
|
||||
|
||||
fetch_time = (datetime.now(UTC) - start_time).total_seconds() * 1000
|
||||
|
||||
logger.debug(
|
||||
f"Fetched {len(contexts)} memory contexts for query '{query[:50]}...' "
|
||||
f"in {fetch_time:.1f}ms"
|
||||
)
|
||||
|
||||
return MemoryFetchResult(
|
||||
contexts=contexts,
|
||||
by_type=by_type,
|
||||
fetch_time_ms=fetch_time,
|
||||
query=query,
|
||||
)
|
||||
|
||||
async def _fetch_working(
|
||||
self,
|
||||
query: str,
|
||||
session_id: str,
|
||||
project_id: UUID,
|
||||
agent_instance_id: UUID | None,
|
||||
limit: int,
|
||||
) -> list[MemoryContext]:
|
||||
"""Fetch from working memory."""
|
||||
working = await WorkingMemory.for_session(
|
||||
session_id=session_id,
|
||||
project_id=str(project_id),
|
||||
agent_instance_id=str(agent_instance_id) if agent_instance_id else None,
|
||||
)
|
||||
|
||||
contexts: list[MemoryContext] = []
|
||||
all_keys = await working.list_keys()
|
||||
|
||||
# Filter keys by query (simple substring match)
|
||||
query_lower = query.lower()
|
||||
matched_keys = [k for k in all_keys if query_lower in k.lower()]
|
||||
|
||||
# If no query match, include all keys (working memory is always relevant)
|
||||
if not matched_keys and query:
|
||||
matched_keys = all_keys
|
||||
|
||||
for key in matched_keys[:limit]:
|
||||
value = await working.get(key)
|
||||
if value is not None:
|
||||
contexts.append(
|
||||
MemoryContext.from_working_memory(
|
||||
key=key,
|
||||
value=value,
|
||||
source=f"working:{session_id}",
|
||||
query=query,
|
||||
)
|
||||
)
|
||||
|
||||
return contexts
|
||||
|
||||
async def _fetch_episodic(
|
||||
self,
|
||||
query: str,
|
||||
project_id: UUID,
|
||||
agent_instance_id: UUID | None,
|
||||
limit: int,
|
||||
days_back: int,
|
||||
) -> list[MemoryContext]:
|
||||
"""Fetch from episodic memory."""
|
||||
episodic = await self._get_episodic()
|
||||
|
||||
# Search for similar episodes
|
||||
episodes = await episodic.search_similar(
|
||||
project_id=project_id,
|
||||
query=query,
|
||||
limit=limit,
|
||||
agent_instance_id=agent_instance_id,
|
||||
)
|
||||
|
||||
# Also get recent episodes if we didn't find enough
|
||||
if len(episodes) < limit // 2:
|
||||
since = datetime.now(UTC) - timedelta(days=days_back)
|
||||
recent = await episodic.get_recent(
|
||||
project_id=project_id,
|
||||
limit=limit,
|
||||
since=since,
|
||||
)
|
||||
# Deduplicate by ID
|
||||
existing_ids = {e.id for e in episodes}
|
||||
for ep in recent:
|
||||
if ep.id not in existing_ids:
|
||||
episodes.append(ep)
|
||||
if len(episodes) >= limit:
|
||||
break
|
||||
|
||||
return [
|
||||
MemoryContext.from_episodic_memory(ep, query=query)
|
||||
for ep in episodes[:limit]
|
||||
]
|
||||
|
||||
async def _fetch_semantic(
|
||||
self,
|
||||
query: str,
|
||||
project_id: UUID,
|
||||
limit: int,
|
||||
min_relevance: float,
|
||||
) -> list[MemoryContext]:
|
||||
"""Fetch from semantic memory."""
|
||||
semantic = await self._get_semantic()
|
||||
|
||||
facts = await semantic.search_facts(
|
||||
query=query,
|
||||
project_id=project_id,
|
||||
limit=limit,
|
||||
min_confidence=min_relevance,
|
||||
)
|
||||
|
||||
return [MemoryContext.from_semantic_memory(fact, query=query) for fact in facts]
|
||||
|
||||
async def _fetch_procedural(
|
||||
self,
|
||||
query: str,
|
||||
project_id: UUID,
|
||||
agent_type_id: UUID | None,
|
||||
limit: int,
|
||||
) -> list[MemoryContext]:
|
||||
"""Fetch from procedural memory."""
|
||||
procedural = await self._get_procedural()
|
||||
|
||||
procedures = await procedural.find_matching(
|
||||
context=query,
|
||||
project_id=project_id,
|
||||
agent_type_id=agent_type_id,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
return [
|
||||
MemoryContext.from_procedural_memory(proc, query=query)
|
||||
for proc in procedures
|
||||
]
|
||||
|
||||
async def fetch_all_working(
|
||||
self,
|
||||
session_id: str,
|
||||
project_id: UUID,
|
||||
agent_instance_id: UUID | None = None,
|
||||
) -> list[MemoryContext]:
|
||||
"""
|
||||
Fetch all working memory for a session.
|
||||
|
||||
Useful for including entire session state in context.
|
||||
|
||||
Args:
|
||||
session_id: Session ID
|
||||
project_id: Project scope
|
||||
agent_instance_id: Optional agent instance scope
|
||||
|
||||
Returns:
|
||||
List of MemoryContext for all working memory items
|
||||
"""
|
||||
working = await WorkingMemory.for_session(
|
||||
session_id=session_id,
|
||||
project_id=str(project_id),
|
||||
agent_instance_id=str(agent_instance_id) if agent_instance_id else None,
|
||||
)
|
||||
|
||||
contexts: list[MemoryContext] = []
|
||||
all_keys = await working.list_keys()
|
||||
|
||||
for key in all_keys:
|
||||
value = await working.get(key)
|
||||
if value is not None:
|
||||
contexts.append(
|
||||
MemoryContext.from_working_memory(
|
||||
key=key,
|
||||
value=value,
|
||||
source=f"working:{session_id}",
|
||||
)
|
||||
)
|
||||
|
||||
return contexts
|
||||
|
||||
|
||||
# Factory function
|
||||
async def get_memory_context_source(
|
||||
session: AsyncSession,
|
||||
embedding_generator: Any | None = None,
|
||||
) -> MemoryContextSource:
|
||||
"""Create a memory context source instance."""
|
||||
return MemoryContextSource(
|
||||
session=session,
|
||||
embedding_generator=embedding_generator,
|
||||
)
|
||||
635
backend/app/services/memory/integration/lifecycle.py
Normal file
635
backend/app/services/memory/integration/lifecycle.py
Normal file
@@ -0,0 +1,635 @@
|
||||
# app/services/memory/integration/lifecycle.py
|
||||
"""
|
||||
Agent Lifecycle Hooks for Memory System.
|
||||
|
||||
Provides memory management hooks for agent lifecycle events:
|
||||
- spawn: Initialize working memory for new agent instance
|
||||
- pause: Checkpoint working memory state
|
||||
- resume: Restore working memory from checkpoint
|
||||
- terminate: Consolidate session to episodic memory
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable, Coroutine
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.memory.episodic import EpisodicMemory
|
||||
from app.services.memory.types import EpisodeCreate, Outcome
|
||||
from app.services.memory.working import WorkingMemory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LifecycleEvent:
|
||||
"""Event data for lifecycle hooks."""
|
||||
|
||||
event_type: str # spawn, pause, resume, terminate
|
||||
project_id: UUID
|
||||
agent_instance_id: UUID
|
||||
agent_type_id: UUID | None = None
|
||||
session_id: str | None = None
|
||||
timestamp: datetime = field(default_factory=lambda: datetime.now(UTC))
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LifecycleResult:
|
||||
"""Result of a lifecycle operation."""
|
||||
|
||||
success: bool
|
||||
event_type: str
|
||||
message: str | None = None
|
||||
data: dict[str, Any] = field(default_factory=dict)
|
||||
duration_ms: float = 0.0
|
||||
|
||||
|
||||
# Type alias for lifecycle hooks
|
||||
LifecycleHook = Callable[[LifecycleEvent], Coroutine[Any, Any, None]]
|
||||
|
||||
|
||||
class LifecycleHooks:
|
||||
"""
|
||||
Collection of lifecycle hooks.
|
||||
|
||||
Allows registration of custom hooks for lifecycle events.
|
||||
Hooks are called after the core memory operations.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize lifecycle hooks."""
|
||||
self._spawn_hooks: list[LifecycleHook] = []
|
||||
self._pause_hooks: list[LifecycleHook] = []
|
||||
self._resume_hooks: list[LifecycleHook] = []
|
||||
self._terminate_hooks: list[LifecycleHook] = []
|
||||
|
||||
def on_spawn(self, hook: LifecycleHook) -> LifecycleHook:
|
||||
"""Register a spawn hook."""
|
||||
self._spawn_hooks.append(hook)
|
||||
return hook
|
||||
|
||||
def on_pause(self, hook: LifecycleHook) -> LifecycleHook:
|
||||
"""Register a pause hook."""
|
||||
self._pause_hooks.append(hook)
|
||||
return hook
|
||||
|
||||
def on_resume(self, hook: LifecycleHook) -> LifecycleHook:
|
||||
"""Register a resume hook."""
|
||||
self._resume_hooks.append(hook)
|
||||
return hook
|
||||
|
||||
def on_terminate(self, hook: LifecycleHook) -> LifecycleHook:
|
||||
"""Register a terminate hook."""
|
||||
self._terminate_hooks.append(hook)
|
||||
return hook
|
||||
|
||||
async def run_spawn_hooks(self, event: LifecycleEvent) -> None:
|
||||
"""Run all spawn hooks."""
|
||||
for hook in self._spawn_hooks:
|
||||
try:
|
||||
await hook(event)
|
||||
except Exception as e:
|
||||
logger.warning(f"Spawn hook failed: {e}")
|
||||
|
||||
async def run_pause_hooks(self, event: LifecycleEvent) -> None:
|
||||
"""Run all pause hooks."""
|
||||
for hook in self._pause_hooks:
|
||||
try:
|
||||
await hook(event)
|
||||
except Exception as e:
|
||||
logger.warning(f"Pause hook failed: {e}")
|
||||
|
||||
async def run_resume_hooks(self, event: LifecycleEvent) -> None:
|
||||
"""Run all resume hooks."""
|
||||
for hook in self._resume_hooks:
|
||||
try:
|
||||
await hook(event)
|
||||
except Exception as e:
|
||||
logger.warning(f"Resume hook failed: {e}")
|
||||
|
||||
async def run_terminate_hooks(self, event: LifecycleEvent) -> None:
|
||||
"""Run all terminate hooks."""
|
||||
for hook in self._terminate_hooks:
|
||||
try:
|
||||
await hook(event)
|
||||
except Exception as e:
|
||||
logger.warning(f"Terminate hook failed: {e}")
|
||||
|
||||
|
||||
class AgentLifecycleManager:
|
||||
"""
|
||||
Manager for agent lifecycle and memory integration.
|
||||
|
||||
Handles memory operations during agent lifecycle events:
|
||||
- spawn: Creates new working memory for the session
|
||||
- pause: Saves working memory state to checkpoint
|
||||
- resume: Restores working memory from checkpoint
|
||||
- terminate: Consolidates working memory to episodic memory
|
||||
"""
|
||||
|
||||
# Key prefix for checkpoint storage
|
||||
CHECKPOINT_PREFIX = "__checkpoint__"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
embedding_generator: Any | None = None,
|
||||
hooks: LifecycleHooks | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the lifecycle manager.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
embedding_generator: Optional embedding generator
|
||||
hooks: Optional lifecycle hooks
|
||||
"""
|
||||
self._session = session
|
||||
self._embedding_generator = embedding_generator
|
||||
self._hooks = hooks or LifecycleHooks()
|
||||
|
||||
# Lazy-initialized services
|
||||
self._episodic: EpisodicMemory | None = None
|
||||
|
||||
async def _get_episodic(self) -> EpisodicMemory:
|
||||
"""Get or create episodic memory service."""
|
||||
if self._episodic is None:
|
||||
self._episodic = await EpisodicMemory.create(
|
||||
self._session,
|
||||
self._embedding_generator,
|
||||
)
|
||||
return self._episodic
|
||||
|
||||
@property
|
||||
def hooks(self) -> LifecycleHooks:
|
||||
"""Get the lifecycle hooks."""
|
||||
return self._hooks
|
||||
|
||||
async def spawn(
|
||||
self,
|
||||
project_id: UUID,
|
||||
agent_instance_id: UUID,
|
||||
session_id: str,
|
||||
agent_type_id: UUID | None = None,
|
||||
initial_state: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> LifecycleResult:
|
||||
"""
|
||||
Handle agent spawn - initialize working memory.
|
||||
|
||||
Creates a new working memory instance for the agent session
|
||||
and optionally populates it with initial state.
|
||||
|
||||
Args:
|
||||
project_id: Project scope
|
||||
agent_instance_id: Agent instance ID
|
||||
session_id: Session ID for working memory
|
||||
agent_type_id: Optional agent type ID
|
||||
initial_state: Optional initial state to populate
|
||||
metadata: Optional metadata for the event
|
||||
|
||||
Returns:
|
||||
LifecycleResult with spawn outcome
|
||||
"""
|
||||
start_time = datetime.now(UTC)
|
||||
|
||||
try:
|
||||
# Create working memory for the session
|
||||
working = await WorkingMemory.for_session(
|
||||
session_id=session_id,
|
||||
project_id=str(project_id),
|
||||
agent_instance_id=str(agent_instance_id),
|
||||
)
|
||||
|
||||
# Populate initial state if provided
|
||||
items_set = 0
|
||||
if initial_state:
|
||||
for key, value in initial_state.items():
|
||||
await working.set(key, value)
|
||||
items_set += 1
|
||||
|
||||
# Create and run event hooks
|
||||
event = LifecycleEvent(
|
||||
event_type="spawn",
|
||||
project_id=project_id,
|
||||
agent_instance_id=agent_instance_id,
|
||||
agent_type_id=agent_type_id,
|
||||
session_id=session_id,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
await self._hooks.run_spawn_hooks(event)
|
||||
|
||||
duration_ms = (datetime.now(UTC) - start_time).total_seconds() * 1000
|
||||
|
||||
logger.info(
|
||||
f"Agent {agent_instance_id} spawned with session {session_id}, "
|
||||
f"initial state: {items_set} items"
|
||||
)
|
||||
|
||||
return LifecycleResult(
|
||||
success=True,
|
||||
event_type="spawn",
|
||||
message="Agent spawned successfully",
|
||||
data={
|
||||
"session_id": session_id,
|
||||
"initial_items": items_set,
|
||||
},
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Spawn failed for agent {agent_instance_id}: {e}")
|
||||
return LifecycleResult(
|
||||
success=False,
|
||||
event_type="spawn",
|
||||
message=f"Spawn failed: {e}",
|
||||
)
|
||||
|
||||
async def pause(
|
||||
self,
|
||||
project_id: UUID,
|
||||
agent_instance_id: UUID,
|
||||
session_id: str,
|
||||
checkpoint_id: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> LifecycleResult:
|
||||
"""
|
||||
Handle agent pause - checkpoint working memory.
|
||||
|
||||
Saves the current working memory state to a checkpoint
|
||||
that can be restored later with resume().
|
||||
|
||||
Args:
|
||||
project_id: Project scope
|
||||
agent_instance_id: Agent instance ID
|
||||
session_id: Session ID
|
||||
checkpoint_id: Optional checkpoint identifier
|
||||
metadata: Optional metadata for the event
|
||||
|
||||
Returns:
|
||||
LifecycleResult with checkpoint data
|
||||
"""
|
||||
start_time = datetime.now(UTC)
|
||||
checkpoint_id = checkpoint_id or f"checkpoint_{int(start_time.timestamp())}"
|
||||
|
||||
try:
|
||||
working = await WorkingMemory.for_session(
|
||||
session_id=session_id,
|
||||
project_id=str(project_id),
|
||||
agent_instance_id=str(agent_instance_id),
|
||||
)
|
||||
|
||||
# Get all current state
|
||||
all_keys = await working.list_keys()
|
||||
# Filter out checkpoint keys
|
||||
state_keys = [
|
||||
k for k in all_keys if not k.startswith(self.CHECKPOINT_PREFIX)
|
||||
]
|
||||
|
||||
state: dict[str, Any] = {}
|
||||
for key in state_keys:
|
||||
value = await working.get(key)
|
||||
if value is not None:
|
||||
state[key] = value
|
||||
|
||||
# Store checkpoint
|
||||
checkpoint_key = f"{self.CHECKPOINT_PREFIX}{checkpoint_id}"
|
||||
await working.set(
|
||||
checkpoint_key,
|
||||
{
|
||||
"state": state,
|
||||
"timestamp": start_time.isoformat(),
|
||||
"keys_count": len(state),
|
||||
},
|
||||
ttl_seconds=86400 * 7, # Keep checkpoint for 7 days
|
||||
)
|
||||
|
||||
# Run hooks
|
||||
event = LifecycleEvent(
|
||||
event_type="pause",
|
||||
project_id=project_id,
|
||||
agent_instance_id=agent_instance_id,
|
||||
session_id=session_id,
|
||||
metadata={**(metadata or {}), "checkpoint_id": checkpoint_id},
|
||||
)
|
||||
await self._hooks.run_pause_hooks(event)
|
||||
|
||||
duration_ms = (datetime.now(UTC) - start_time).total_seconds() * 1000
|
||||
|
||||
logger.info(
|
||||
f"Agent {agent_instance_id} paused, checkpoint {checkpoint_id} "
|
||||
f"saved with {len(state)} items"
|
||||
)
|
||||
|
||||
return LifecycleResult(
|
||||
success=True,
|
||||
event_type="pause",
|
||||
message="Agent paused successfully",
|
||||
data={
|
||||
"checkpoint_id": checkpoint_id,
|
||||
"items_saved": len(state),
|
||||
"timestamp": start_time.isoformat(),
|
||||
},
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Pause failed for agent {agent_instance_id}: {e}")
|
||||
return LifecycleResult(
|
||||
success=False,
|
||||
event_type="pause",
|
||||
message=f"Pause failed: {e}",
|
||||
)
|
||||
|
||||
async def resume(
|
||||
self,
|
||||
project_id: UUID,
|
||||
agent_instance_id: UUID,
|
||||
session_id: str,
|
||||
checkpoint_id: str,
|
||||
clear_current: bool = True,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> LifecycleResult:
|
||||
"""
|
||||
Handle agent resume - restore from checkpoint.
|
||||
|
||||
Restores working memory state from a previously saved checkpoint.
|
||||
|
||||
Args:
|
||||
project_id: Project scope
|
||||
agent_instance_id: Agent instance ID
|
||||
session_id: Session ID
|
||||
checkpoint_id: Checkpoint to restore from
|
||||
clear_current: Whether to clear current state before restoring
|
||||
metadata: Optional metadata for the event
|
||||
|
||||
Returns:
|
||||
LifecycleResult with restore outcome
|
||||
"""
|
||||
start_time = datetime.now(UTC)
|
||||
|
||||
try:
|
||||
working = await WorkingMemory.for_session(
|
||||
session_id=session_id,
|
||||
project_id=str(project_id),
|
||||
agent_instance_id=str(agent_instance_id),
|
||||
)
|
||||
|
||||
# Get checkpoint
|
||||
checkpoint_key = f"{self.CHECKPOINT_PREFIX}{checkpoint_id}"
|
||||
checkpoint = await working.get(checkpoint_key)
|
||||
|
||||
if checkpoint is None:
|
||||
return LifecycleResult(
|
||||
success=False,
|
||||
event_type="resume",
|
||||
message=f"Checkpoint '{checkpoint_id}' not found",
|
||||
)
|
||||
|
||||
# Clear current state if requested
|
||||
if clear_current:
|
||||
all_keys = await working.list_keys()
|
||||
for key in all_keys:
|
||||
if not key.startswith(self.CHECKPOINT_PREFIX):
|
||||
await working.delete(key)
|
||||
|
||||
# Restore state from checkpoint
|
||||
state = checkpoint.get("state", {})
|
||||
items_restored = 0
|
||||
for key, value in state.items():
|
||||
await working.set(key, value)
|
||||
items_restored += 1
|
||||
|
||||
# Run hooks
|
||||
event = LifecycleEvent(
|
||||
event_type="resume",
|
||||
project_id=project_id,
|
||||
agent_instance_id=agent_instance_id,
|
||||
session_id=session_id,
|
||||
metadata={**(metadata or {}), "checkpoint_id": checkpoint_id},
|
||||
)
|
||||
await self._hooks.run_resume_hooks(event)
|
||||
|
||||
duration_ms = (datetime.now(UTC) - start_time).total_seconds() * 1000
|
||||
|
||||
logger.info(
|
||||
f"Agent {agent_instance_id} resumed from checkpoint {checkpoint_id}, "
|
||||
f"restored {items_restored} items"
|
||||
)
|
||||
|
||||
return LifecycleResult(
|
||||
success=True,
|
||||
event_type="resume",
|
||||
message="Agent resumed successfully",
|
||||
data={
|
||||
"checkpoint_id": checkpoint_id,
|
||||
"items_restored": items_restored,
|
||||
"checkpoint_timestamp": checkpoint.get("timestamp"),
|
||||
},
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Resume failed for agent {agent_instance_id}: {e}")
|
||||
return LifecycleResult(
|
||||
success=False,
|
||||
event_type="resume",
|
||||
message=f"Resume failed: {e}",
|
||||
)
|
||||
|
||||
async def terminate(
|
||||
self,
|
||||
project_id: UUID,
|
||||
agent_instance_id: UUID,
|
||||
session_id: str,
|
||||
task_description: str | None = None,
|
||||
outcome: Outcome = Outcome.SUCCESS,
|
||||
lessons_learned: list[str] | None = None,
|
||||
consolidate_to_episodic: bool = True,
|
||||
cleanup_working: bool = True,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> LifecycleResult:
|
||||
"""
|
||||
Handle agent termination - consolidate to episodic memory.
|
||||
|
||||
Consolidates the session's working memory into an episodic memory
|
||||
entry, then optionally cleans up the working memory.
|
||||
|
||||
Args:
|
||||
project_id: Project scope
|
||||
agent_instance_id: Agent instance ID
|
||||
session_id: Session ID
|
||||
task_description: Description of what was accomplished
|
||||
outcome: Task outcome (SUCCESS, FAILURE, PARTIAL)
|
||||
lessons_learned: Optional list of lessons learned
|
||||
consolidate_to_episodic: Whether to create episodic entry
|
||||
cleanup_working: Whether to clear working memory
|
||||
metadata: Optional metadata for the event
|
||||
|
||||
Returns:
|
||||
LifecycleResult with termination outcome
|
||||
"""
|
||||
start_time = datetime.now(UTC)
|
||||
|
||||
try:
|
||||
working = await WorkingMemory.for_session(
|
||||
session_id=session_id,
|
||||
project_id=str(project_id),
|
||||
agent_instance_id=str(agent_instance_id),
|
||||
)
|
||||
|
||||
# Gather session state for consolidation
|
||||
all_keys = await working.list_keys()
|
||||
state_keys = [
|
||||
k for k in all_keys if not k.startswith(self.CHECKPOINT_PREFIX)
|
||||
]
|
||||
|
||||
session_state: dict[str, Any] = {}
|
||||
for key in state_keys:
|
||||
value = await working.get(key)
|
||||
if value is not None:
|
||||
session_state[key] = value
|
||||
|
||||
episode_id: str | None = None
|
||||
|
||||
# Consolidate to episodic memory
|
||||
if consolidate_to_episodic:
|
||||
episodic = await self._get_episodic()
|
||||
|
||||
description = task_description or f"Session {session_id} completed"
|
||||
|
||||
episode_data = EpisodeCreate(
|
||||
project_id=project_id,
|
||||
agent_instance_id=agent_instance_id,
|
||||
session_id=session_id,
|
||||
task_type="session_completion",
|
||||
task_description=description[:500],
|
||||
outcome=outcome,
|
||||
outcome_details=f"Session terminated with {len(session_state)} state items",
|
||||
actions=[
|
||||
{
|
||||
"type": "session_terminate",
|
||||
"state_keys": list(session_state.keys()),
|
||||
"outcome": outcome.value,
|
||||
}
|
||||
],
|
||||
context_summary=str(session_state)[:1000] if session_state else "",
|
||||
lessons_learned=lessons_learned or [],
|
||||
duration_seconds=0.0, # Unknown at this point
|
||||
tokens_used=0,
|
||||
importance_score=0.6, # Moderate importance for session ends
|
||||
)
|
||||
|
||||
episode = await episodic.record_episode(episode_data)
|
||||
episode_id = str(episode.id)
|
||||
|
||||
# Clean up working memory
|
||||
items_cleared = 0
|
||||
if cleanup_working:
|
||||
for key in all_keys:
|
||||
await working.delete(key)
|
||||
items_cleared += 1
|
||||
|
||||
# Run hooks
|
||||
event = LifecycleEvent(
|
||||
event_type="terminate",
|
||||
project_id=project_id,
|
||||
agent_instance_id=agent_instance_id,
|
||||
session_id=session_id,
|
||||
metadata={**(metadata or {}), "episode_id": episode_id},
|
||||
)
|
||||
await self._hooks.run_terminate_hooks(event)
|
||||
|
||||
duration_ms = (datetime.now(UTC) - start_time).total_seconds() * 1000
|
||||
|
||||
logger.info(
|
||||
f"Agent {agent_instance_id} terminated, session {session_id} "
|
||||
f"consolidated to episode {episode_id}"
|
||||
)
|
||||
|
||||
return LifecycleResult(
|
||||
success=True,
|
||||
event_type="terminate",
|
||||
message="Agent terminated successfully",
|
||||
data={
|
||||
"episode_id": episode_id,
|
||||
"state_items_consolidated": len(session_state),
|
||||
"items_cleared": items_cleared,
|
||||
"outcome": outcome.value,
|
||||
},
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Terminate failed for agent {agent_instance_id}: {e}")
|
||||
return LifecycleResult(
|
||||
success=False,
|
||||
event_type="terminate",
|
||||
message=f"Terminate failed: {e}",
|
||||
)
|
||||
|
||||
async def list_checkpoints(
|
||||
self,
|
||||
project_id: UUID,
|
||||
agent_instance_id: UUID,
|
||||
session_id: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
List available checkpoints for a session.
|
||||
|
||||
Args:
|
||||
project_id: Project scope
|
||||
agent_instance_id: Agent instance ID
|
||||
session_id: Session ID
|
||||
|
||||
Returns:
|
||||
List of checkpoint metadata dicts
|
||||
"""
|
||||
working = await WorkingMemory.for_session(
|
||||
session_id=session_id,
|
||||
project_id=str(project_id),
|
||||
agent_instance_id=str(agent_instance_id),
|
||||
)
|
||||
|
||||
all_keys = await working.list_keys()
|
||||
checkpoints: list[dict[str, Any]] = []
|
||||
|
||||
for key in all_keys:
|
||||
if key.startswith(self.CHECKPOINT_PREFIX):
|
||||
checkpoint_id = key[len(self.CHECKPOINT_PREFIX) :]
|
||||
checkpoint = await working.get(key)
|
||||
if checkpoint:
|
||||
checkpoints.append(
|
||||
{
|
||||
"checkpoint_id": checkpoint_id,
|
||||
"timestamp": checkpoint.get("timestamp"),
|
||||
"keys_count": checkpoint.get("keys_count", 0),
|
||||
}
|
||||
)
|
||||
|
||||
# Sort by timestamp (newest first)
|
||||
checkpoints.sort(
|
||||
key=lambda c: c.get("timestamp", ""),
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
return checkpoints
|
||||
|
||||
|
||||
# Factory function
|
||||
async def get_lifecycle_manager(
|
||||
session: AsyncSession,
|
||||
embedding_generator: Any | None = None,
|
||||
hooks: LifecycleHooks | None = None,
|
||||
) -> AgentLifecycleManager:
|
||||
"""Create a lifecycle manager instance."""
|
||||
return AgentLifecycleManager(
|
||||
session=session,
|
||||
embedding_generator=embedding_generator,
|
||||
hooks=hooks,
|
||||
)
|
||||
606
backend/app/services/memory/manager.py
Normal file
606
backend/app/services/memory/manager.py
Normal file
@@ -0,0 +1,606 @@
|
||||
"""
|
||||
Memory Manager
|
||||
|
||||
Facade for the Agent Memory System providing unified access
|
||||
to all memory types and operations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from .config import MemorySettings, get_memory_settings
|
||||
from .types import (
|
||||
Episode,
|
||||
EpisodeCreate,
|
||||
Fact,
|
||||
FactCreate,
|
||||
MemoryStats,
|
||||
MemoryType,
|
||||
Outcome,
|
||||
Procedure,
|
||||
ProcedureCreate,
|
||||
RetrievalResult,
|
||||
ScopeContext,
|
||||
ScopeLevel,
|
||||
TaskState,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MemoryManager:
|
||||
"""
|
||||
Unified facade for the Agent Memory System.
|
||||
|
||||
Provides a single entry point for all memory operations across
|
||||
working, episodic, semantic, and procedural memory types.
|
||||
|
||||
Usage:
|
||||
manager = MemoryManager.create()
|
||||
|
||||
# Working memory
|
||||
await manager.set_working("key", {"data": "value"})
|
||||
value = await manager.get_working("key")
|
||||
|
||||
# Episodic memory
|
||||
episode = await manager.record_episode(episode_data)
|
||||
similar = await manager.search_episodes("query")
|
||||
|
||||
# Semantic memory
|
||||
fact = await manager.store_fact(fact_data)
|
||||
facts = await manager.search_facts("query")
|
||||
|
||||
# Procedural memory
|
||||
procedure = await manager.record_procedure(procedure_data)
|
||||
procedures = await manager.find_procedures("context")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings: MemorySettings,
|
||||
scope: ScopeContext,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the MemoryManager.
|
||||
|
||||
Args:
|
||||
settings: Memory configuration settings
|
||||
scope: The scope context for this manager instance
|
||||
"""
|
||||
self._settings = settings
|
||||
self._scope = scope
|
||||
self._initialized = False
|
||||
|
||||
# These will be initialized when the respective sub-modules are implemented
|
||||
self._working_memory: Any | None = None
|
||||
self._episodic_memory: Any | None = None
|
||||
self._semantic_memory: Any | None = None
|
||||
self._procedural_memory: Any | None = None
|
||||
|
||||
logger.debug(
|
||||
"MemoryManager created for scope %s:%s",
|
||||
scope.scope_type.value,
|
||||
scope.scope_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
scope_type: ScopeLevel = ScopeLevel.SESSION,
|
||||
scope_id: str = "default",
|
||||
parent_scope: ScopeContext | None = None,
|
||||
settings: MemorySettings | None = None,
|
||||
) -> "MemoryManager":
|
||||
"""
|
||||
Create a new MemoryManager instance.
|
||||
|
||||
Args:
|
||||
scope_type: The scope level for this manager
|
||||
scope_id: The scope identifier
|
||||
parent_scope: Optional parent scope for inheritance
|
||||
settings: Optional custom settings (uses global if not provided)
|
||||
|
||||
Returns:
|
||||
A new MemoryManager instance
|
||||
"""
|
||||
if settings is None:
|
||||
settings = get_memory_settings()
|
||||
|
||||
scope = ScopeContext(
|
||||
scope_type=scope_type,
|
||||
scope_id=scope_id,
|
||||
parent=parent_scope,
|
||||
)
|
||||
|
||||
return cls(settings=settings, scope=scope)
|
||||
|
||||
@classmethod
|
||||
def for_session(
|
||||
cls,
|
||||
session_id: str,
|
||||
agent_instance_id: UUID | None = None,
|
||||
project_id: UUID | None = None,
|
||||
) -> "MemoryManager":
|
||||
"""
|
||||
Create a MemoryManager for a specific session.
|
||||
|
||||
Builds the appropriate scope hierarchy based on provided IDs.
|
||||
|
||||
Args:
|
||||
session_id: The session identifier
|
||||
agent_instance_id: Optional agent instance ID
|
||||
project_id: Optional project ID
|
||||
|
||||
Returns:
|
||||
A MemoryManager configured for the session scope
|
||||
"""
|
||||
settings = get_memory_settings()
|
||||
|
||||
# Build scope hierarchy
|
||||
parent: ScopeContext | None = None
|
||||
|
||||
if project_id:
|
||||
parent = ScopeContext(
|
||||
scope_type=ScopeLevel.PROJECT,
|
||||
scope_id=str(project_id),
|
||||
parent=ScopeContext(
|
||||
scope_type=ScopeLevel.GLOBAL,
|
||||
scope_id="global",
|
||||
),
|
||||
)
|
||||
|
||||
if agent_instance_id:
|
||||
parent = ScopeContext(
|
||||
scope_type=ScopeLevel.AGENT_INSTANCE,
|
||||
scope_id=str(agent_instance_id),
|
||||
parent=parent,
|
||||
)
|
||||
|
||||
scope = ScopeContext(
|
||||
scope_type=ScopeLevel.SESSION,
|
||||
scope_id=session_id,
|
||||
parent=parent,
|
||||
)
|
||||
|
||||
return cls(settings=settings, scope=scope)
|
||||
|
||||
@property
|
||||
def scope(self) -> ScopeContext:
|
||||
"""Get the current scope context."""
|
||||
return self._scope
|
||||
|
||||
@property
|
||||
def settings(self) -> MemorySettings:
|
||||
"""Get the memory settings."""
|
||||
return self._settings
|
||||
|
||||
# =========================================================================
|
||||
# Working Memory Operations
|
||||
# =========================================================================
|
||||
|
||||
async def set_working(
|
||||
self,
|
||||
key: str,
|
||||
value: Any,
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Set a value in working memory.
|
||||
|
||||
Args:
|
||||
key: The key to store the value under
|
||||
value: The value to store (must be JSON serializable)
|
||||
ttl_seconds: Optional TTL (uses default if not provided)
|
||||
"""
|
||||
# Placeholder - will be implemented in #89
|
||||
logger.debug("set_working called for key=%s (not yet implemented)", key)
|
||||
raise NotImplementedError("Working memory not yet implemented")
|
||||
|
||||
async def get_working(
|
||||
self,
|
||||
key: str,
|
||||
default: Any = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Get a value from working memory.
|
||||
|
||||
Args:
|
||||
key: The key to retrieve
|
||||
default: Default value if key not found
|
||||
|
||||
Returns:
|
||||
The stored value or default
|
||||
"""
|
||||
# Placeholder - will be implemented in #89
|
||||
logger.debug("get_working called for key=%s (not yet implemented)", key)
|
||||
raise NotImplementedError("Working memory not yet implemented")
|
||||
|
||||
async def delete_working(self, key: str) -> bool:
|
||||
"""
|
||||
Delete a value from working memory.
|
||||
|
||||
Args:
|
||||
key: The key to delete
|
||||
|
||||
Returns:
|
||||
True if the key was deleted, False if not found
|
||||
"""
|
||||
# Placeholder - will be implemented in #89
|
||||
logger.debug("delete_working called for key=%s (not yet implemented)", key)
|
||||
raise NotImplementedError("Working memory not yet implemented")
|
||||
|
||||
async def set_task_state(self, state: TaskState) -> None:
|
||||
"""
|
||||
Set the current task state in working memory.
|
||||
|
||||
Args:
|
||||
state: The task state to store
|
||||
"""
|
||||
# Placeholder - will be implemented in #89
|
||||
logger.debug(
|
||||
"set_task_state called for task=%s (not yet implemented)",
|
||||
state.task_id,
|
||||
)
|
||||
raise NotImplementedError("Working memory not yet implemented")
|
||||
|
||||
async def get_task_state(self) -> TaskState | None:
|
||||
"""
|
||||
Get the current task state from working memory.
|
||||
|
||||
Returns:
|
||||
The current task state or None
|
||||
"""
|
||||
# Placeholder - will be implemented in #89
|
||||
logger.debug("get_task_state called (not yet implemented)")
|
||||
raise NotImplementedError("Working memory not yet implemented")
|
||||
|
||||
async def create_checkpoint(self) -> str:
|
||||
"""
|
||||
Create a checkpoint of the current working memory state.
|
||||
|
||||
Returns:
|
||||
The checkpoint ID
|
||||
"""
|
||||
# Placeholder - will be implemented in #89
|
||||
logger.debug("create_checkpoint called (not yet implemented)")
|
||||
raise NotImplementedError("Working memory not yet implemented")
|
||||
|
||||
async def restore_checkpoint(self, checkpoint_id: str) -> None:
|
||||
"""
|
||||
Restore working memory from a checkpoint.
|
||||
|
||||
Args:
|
||||
checkpoint_id: The checkpoint to restore from
|
||||
"""
|
||||
# Placeholder - will be implemented in #89
|
||||
logger.debug(
|
||||
"restore_checkpoint called for id=%s (not yet implemented)",
|
||||
checkpoint_id,
|
||||
)
|
||||
raise NotImplementedError("Working memory not yet implemented")
|
||||
|
||||
# =========================================================================
|
||||
# Episodic Memory Operations
|
||||
# =========================================================================
|
||||
|
||||
async def record_episode(self, episode: EpisodeCreate) -> Episode:
|
||||
"""
|
||||
Record a new episode in episodic memory.
|
||||
|
||||
Args:
|
||||
episode: The episode data to record
|
||||
|
||||
Returns:
|
||||
The created episode with ID
|
||||
"""
|
||||
# Placeholder - will be implemented in #90
|
||||
logger.debug(
|
||||
"record_episode called for task=%s (not yet implemented)",
|
||||
episode.task_type,
|
||||
)
|
||||
raise NotImplementedError("Episodic memory not yet implemented")
|
||||
|
||||
async def search_episodes(
|
||||
self,
|
||||
query: str,
|
||||
limit: int | None = None,
|
||||
) -> RetrievalResult[Episode]:
|
||||
"""
|
||||
Search for similar episodes.
|
||||
|
||||
Args:
|
||||
query: The search query
|
||||
limit: Maximum results to return
|
||||
|
||||
Returns:
|
||||
Retrieval result with matching episodes
|
||||
"""
|
||||
# Placeholder - will be implemented in #90
|
||||
logger.debug(
|
||||
"search_episodes called for query=%s (not yet implemented)",
|
||||
query[:50],
|
||||
)
|
||||
raise NotImplementedError("Episodic memory not yet implemented")
|
||||
|
||||
async def get_recent_episodes(
|
||||
self,
|
||||
limit: int = 10,
|
||||
) -> list[Episode]:
|
||||
"""
|
||||
Get the most recent episodes.
|
||||
|
||||
Args:
|
||||
limit: Maximum episodes to return
|
||||
|
||||
Returns:
|
||||
List of recent episodes
|
||||
"""
|
||||
# Placeholder - will be implemented in #90
|
||||
logger.debug("get_recent_episodes called (not yet implemented)")
|
||||
raise NotImplementedError("Episodic memory not yet implemented")
|
||||
|
||||
async def get_episodes_by_outcome(
|
||||
self,
|
||||
outcome: Outcome,
|
||||
limit: int = 10,
|
||||
) -> list[Episode]:
|
||||
"""
|
||||
Get episodes by outcome.
|
||||
|
||||
Args:
|
||||
outcome: The outcome to filter by
|
||||
limit: Maximum episodes to return
|
||||
|
||||
Returns:
|
||||
List of episodes with the specified outcome
|
||||
"""
|
||||
# Placeholder - will be implemented in #90
|
||||
logger.debug(
|
||||
"get_episodes_by_outcome called for outcome=%s (not yet implemented)",
|
||||
outcome.value,
|
||||
)
|
||||
raise NotImplementedError("Episodic memory not yet implemented")
|
||||
|
||||
# =========================================================================
|
||||
# Semantic Memory Operations
|
||||
# =========================================================================
|
||||
|
||||
async def store_fact(self, fact: FactCreate) -> Fact:
|
||||
"""
|
||||
Store a new fact in semantic memory.
|
||||
|
||||
Args:
|
||||
fact: The fact data to store
|
||||
|
||||
Returns:
|
||||
The created fact with ID
|
||||
"""
|
||||
# Placeholder - will be implemented in #91
|
||||
logger.debug(
|
||||
"store_fact called for %s %s %s (not yet implemented)",
|
||||
fact.subject,
|
||||
fact.predicate,
|
||||
fact.object,
|
||||
)
|
||||
raise NotImplementedError("Semantic memory not yet implemented")
|
||||
|
||||
async def search_facts(
|
||||
self,
|
||||
query: str,
|
||||
limit: int | None = None,
|
||||
) -> RetrievalResult[Fact]:
|
||||
"""
|
||||
Search for facts matching a query.
|
||||
|
||||
Args:
|
||||
query: The search query
|
||||
limit: Maximum results to return
|
||||
|
||||
Returns:
|
||||
Retrieval result with matching facts
|
||||
"""
|
||||
# Placeholder - will be implemented in #91
|
||||
logger.debug(
|
||||
"search_facts called for query=%s (not yet implemented)",
|
||||
query[:50],
|
||||
)
|
||||
raise NotImplementedError("Semantic memory not yet implemented")
|
||||
|
||||
async def get_facts_by_entity(
|
||||
self,
|
||||
entity: str,
|
||||
limit: int = 20,
|
||||
) -> list[Fact]:
|
||||
"""
|
||||
Get facts related to an entity.
|
||||
|
||||
Args:
|
||||
entity: The entity to search for
|
||||
limit: Maximum facts to return
|
||||
|
||||
Returns:
|
||||
List of facts mentioning the entity
|
||||
"""
|
||||
# Placeholder - will be implemented in #91
|
||||
logger.debug(
|
||||
"get_facts_by_entity called for entity=%s (not yet implemented)",
|
||||
entity,
|
||||
)
|
||||
raise NotImplementedError("Semantic memory not yet implemented")
|
||||
|
||||
async def reinforce_fact(self, fact_id: UUID) -> Fact:
|
||||
"""
|
||||
Reinforce a fact (increase confidence from repeated learning).
|
||||
|
||||
Args:
|
||||
fact_id: The fact to reinforce
|
||||
|
||||
Returns:
|
||||
The updated fact
|
||||
"""
|
||||
# Placeholder - will be implemented in #91
|
||||
logger.debug(
|
||||
"reinforce_fact called for id=%s (not yet implemented)",
|
||||
fact_id,
|
||||
)
|
||||
raise NotImplementedError("Semantic memory not yet implemented")
|
||||
|
||||
# =========================================================================
|
||||
# Procedural Memory Operations
|
||||
# =========================================================================
|
||||
|
||||
async def record_procedure(self, procedure: ProcedureCreate) -> Procedure:
|
||||
"""
|
||||
Record a new procedure.
|
||||
|
||||
Args:
|
||||
procedure: The procedure data to record
|
||||
|
||||
Returns:
|
||||
The created procedure with ID
|
||||
"""
|
||||
# Placeholder - will be implemented in #92
|
||||
logger.debug(
|
||||
"record_procedure called for name=%s (not yet implemented)",
|
||||
procedure.name,
|
||||
)
|
||||
raise NotImplementedError("Procedural memory not yet implemented")
|
||||
|
||||
async def find_procedures(
|
||||
self,
|
||||
context: str,
|
||||
limit: int = 5,
|
||||
) -> list[Procedure]:
|
||||
"""
|
||||
Find procedures matching the current context.
|
||||
|
||||
Args:
|
||||
context: The context to match against
|
||||
limit: Maximum procedures to return
|
||||
|
||||
Returns:
|
||||
List of matching procedures sorted by success rate
|
||||
"""
|
||||
# Placeholder - will be implemented in #92
|
||||
logger.debug(
|
||||
"find_procedures called for context=%s (not yet implemented)",
|
||||
context[:50],
|
||||
)
|
||||
raise NotImplementedError("Procedural memory not yet implemented")
|
||||
|
||||
async def record_procedure_outcome(
|
||||
self,
|
||||
procedure_id: UUID,
|
||||
success: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Record the outcome of using a procedure.
|
||||
|
||||
Args:
|
||||
procedure_id: The procedure that was used
|
||||
success: Whether the procedure succeeded
|
||||
"""
|
||||
# Placeholder - will be implemented in #92
|
||||
logger.debug(
|
||||
"record_procedure_outcome called for id=%s success=%s (not yet implemented)",
|
||||
procedure_id,
|
||||
success,
|
||||
)
|
||||
raise NotImplementedError("Procedural memory not yet implemented")
|
||||
|
||||
# =========================================================================
|
||||
# Cross-Memory Operations
|
||||
# =========================================================================
|
||||
|
||||
async def recall(
|
||||
self,
|
||||
query: str,
|
||||
memory_types: list[MemoryType] | None = None,
|
||||
limit: int = 10,
|
||||
) -> dict[MemoryType, list[Any]]:
|
||||
"""
|
||||
Recall memories across multiple memory types.
|
||||
|
||||
Args:
|
||||
query: The search query
|
||||
memory_types: Memory types to search (all if not specified)
|
||||
limit: Maximum results per type
|
||||
|
||||
Returns:
|
||||
Dictionary mapping memory types to results
|
||||
"""
|
||||
# Placeholder - will be implemented in #97 (Component Integration)
|
||||
logger.debug("recall called for query=%s (not yet implemented)", query[:50])
|
||||
raise NotImplementedError("Cross-memory recall not yet implemented")
|
||||
|
||||
async def get_stats(
|
||||
self,
|
||||
memory_type: MemoryType | None = None,
|
||||
) -> list[MemoryStats]:
|
||||
"""
|
||||
Get memory statistics.
|
||||
|
||||
Args:
|
||||
memory_type: Specific type or all if not specified
|
||||
|
||||
Returns:
|
||||
List of statistics for requested memory types
|
||||
"""
|
||||
# Placeholder - will be implemented in #100 (Metrics & Observability)
|
||||
logger.debug("get_stats called (not yet implemented)")
|
||||
raise NotImplementedError("Memory stats not yet implemented")
|
||||
|
||||
# =========================================================================
|
||||
# Lifecycle Operations
|
||||
# =========================================================================
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""
|
||||
Initialize the memory manager and its backends.
|
||||
|
||||
Should be called before using the manager.
|
||||
"""
|
||||
if self._initialized:
|
||||
logger.debug("MemoryManager already initialized")
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"Initializing MemoryManager for scope %s:%s",
|
||||
self._scope.scope_type.value,
|
||||
self._scope.scope_id,
|
||||
)
|
||||
|
||||
# TODO: Initialize backends when implemented
|
||||
|
||||
self._initialized = True
|
||||
logger.info("MemoryManager initialized successfully")
|
||||
|
||||
async def close(self) -> None:
|
||||
"""
|
||||
Close the memory manager and release resources.
|
||||
|
||||
Should be called when done using the manager.
|
||||
"""
|
||||
if not self._initialized:
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"Closing MemoryManager for scope %s:%s",
|
||||
self._scope.scope_type.value,
|
||||
self._scope.scope_id,
|
||||
)
|
||||
|
||||
# TODO: Close backends when implemented
|
||||
|
||||
self._initialized = False
|
||||
logger.info("MemoryManager closed successfully")
|
||||
|
||||
async def __aenter__(self) -> "MemoryManager":
|
||||
"""Async context manager entry."""
|
||||
await self.initialize()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
"""Async context manager exit."""
|
||||
await self.close()
|
||||
40
backend/app/services/memory/mcp/__init__.py
Normal file
40
backend/app/services/memory/mcp/__init__.py
Normal file
@@ -0,0 +1,40 @@
|
||||
# app/services/memory/mcp/__init__.py
|
||||
"""
|
||||
MCP Tools for Agent Memory System.
|
||||
|
||||
Exposes memory operations as MCP-compatible tools that agents can invoke:
|
||||
- remember: Store data in memory
|
||||
- recall: Retrieve from memory
|
||||
- forget: Remove from memory
|
||||
- reflect: Analyze patterns
|
||||
- get_memory_stats: Usage statistics
|
||||
- search_procedures: Find relevant procedures
|
||||
- record_outcome: Record task success/failure
|
||||
"""
|
||||
|
||||
from .service import MemoryToolService, get_memory_tool_service
|
||||
from .tools import (
|
||||
MEMORY_TOOL_DEFINITIONS,
|
||||
ForgetArgs,
|
||||
GetMemoryStatsArgs,
|
||||
MemoryToolDefinition,
|
||||
RecallArgs,
|
||||
RecordOutcomeArgs,
|
||||
ReflectArgs,
|
||||
RememberArgs,
|
||||
SearchProceduresArgs,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"MEMORY_TOOL_DEFINITIONS",
|
||||
"ForgetArgs",
|
||||
"GetMemoryStatsArgs",
|
||||
"MemoryToolDefinition",
|
||||
"MemoryToolService",
|
||||
"RecallArgs",
|
||||
"RecordOutcomeArgs",
|
||||
"ReflectArgs",
|
||||
"RememberArgs",
|
||||
"SearchProceduresArgs",
|
||||
"get_memory_tool_service",
|
||||
]
|
||||
1086
backend/app/services/memory/mcp/service.py
Normal file
1086
backend/app/services/memory/mcp/service.py
Normal file
File diff suppressed because it is too large
Load Diff
485
backend/app/services/memory/mcp/tools.py
Normal file
485
backend/app/services/memory/mcp/tools.py
Normal file
@@ -0,0 +1,485 @@
|
||||
# app/services/memory/mcp/tools.py
|
||||
"""
|
||||
MCP Tool Definitions for Agent Memory System.
|
||||
|
||||
Defines the schema and metadata for memory-related MCP tools.
|
||||
These tools are invoked by AI agents to interact with the memory system.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# OutcomeType alias - uses core Outcome enum from types module for consistency
|
||||
from app.services.memory.types import Outcome as OutcomeType
|
||||
|
||||
|
||||
class MemoryType(str, Enum):
|
||||
"""Types of memory for storage operations."""
|
||||
|
||||
WORKING = "working"
|
||||
EPISODIC = "episodic"
|
||||
SEMANTIC = "semantic"
|
||||
PROCEDURAL = "procedural"
|
||||
|
||||
|
||||
class AnalysisType(str, Enum):
|
||||
"""Types of pattern analysis for the reflect tool."""
|
||||
|
||||
RECENT_PATTERNS = "recent_patterns"
|
||||
SUCCESS_FACTORS = "success_factors"
|
||||
FAILURE_PATTERNS = "failure_patterns"
|
||||
COMMON_PROCEDURES = "common_procedures"
|
||||
LEARNING_PROGRESS = "learning_progress"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tool Argument Schemas (Pydantic models for validation)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class RememberArgs(BaseModel):
|
||||
"""Arguments for the 'remember' tool."""
|
||||
|
||||
memory_type: MemoryType = Field(
|
||||
...,
|
||||
description="Type of memory to store in: working, episodic, semantic, or procedural",
|
||||
)
|
||||
content: str = Field(
|
||||
...,
|
||||
description="The content to remember. Can be text, facts, or procedure steps.",
|
||||
min_length=1,
|
||||
max_length=10000,
|
||||
)
|
||||
key: str | None = Field(
|
||||
None,
|
||||
description="Optional key for working memory entries. Required for working memory type.",
|
||||
max_length=256,
|
||||
)
|
||||
importance: float = Field(
|
||||
0.5,
|
||||
description="Importance score from 0.0 (low) to 1.0 (critical)",
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
)
|
||||
ttl_seconds: int | None = Field(
|
||||
None,
|
||||
description="Time-to-live in seconds for working memory. None for permanent storage.",
|
||||
ge=1,
|
||||
le=86400 * 30, # Max 30 days
|
||||
)
|
||||
metadata: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Additional metadata to store with the memory",
|
||||
)
|
||||
# For semantic memory (facts)
|
||||
subject: str | None = Field(
|
||||
None,
|
||||
description="Subject of the fact (for semantic memory)",
|
||||
max_length=256,
|
||||
)
|
||||
predicate: str | None = Field(
|
||||
None,
|
||||
description="Predicate/relationship (for semantic memory)",
|
||||
max_length=256,
|
||||
)
|
||||
object_value: str | None = Field(
|
||||
None,
|
||||
description="Object of the fact (for semantic memory)",
|
||||
max_length=1000,
|
||||
)
|
||||
# For procedural memory
|
||||
trigger: str | None = Field(
|
||||
None,
|
||||
description="Trigger condition for the procedure (for procedural memory)",
|
||||
max_length=500,
|
||||
)
|
||||
steps: list[dict[str, Any]] | None = Field(
|
||||
None,
|
||||
description="Procedure steps as a list of action dictionaries",
|
||||
)
|
||||
|
||||
|
||||
class RecallArgs(BaseModel):
|
||||
"""Arguments for the 'recall' tool."""
|
||||
|
||||
query: str = Field(
|
||||
...,
|
||||
description="Search query to find relevant memories",
|
||||
min_length=1,
|
||||
max_length=1000,
|
||||
)
|
||||
memory_types: list[MemoryType] = Field(
|
||||
default_factory=lambda: [MemoryType.EPISODIC, MemoryType.SEMANTIC],
|
||||
description="Types of memory to search in",
|
||||
)
|
||||
limit: int = Field(
|
||||
10,
|
||||
description="Maximum number of results to return",
|
||||
ge=1,
|
||||
le=100,
|
||||
)
|
||||
min_relevance: float = Field(
|
||||
0.0,
|
||||
description="Minimum relevance score (0.0-1.0) for results",
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
)
|
||||
filters: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Additional filters (e.g., outcome, task_type, date range)",
|
||||
)
|
||||
include_context: bool = Field(
|
||||
True,
|
||||
description="Whether to include surrounding context in results",
|
||||
)
|
||||
|
||||
|
||||
class ForgetArgs(BaseModel):
|
||||
"""Arguments for the 'forget' tool."""
|
||||
|
||||
memory_type: MemoryType = Field(
|
||||
...,
|
||||
description="Type of memory to remove from",
|
||||
)
|
||||
key: str | None = Field(
|
||||
None,
|
||||
description="Key to remove (for working memory)",
|
||||
max_length=256,
|
||||
)
|
||||
memory_id: str | None = Field(
|
||||
None,
|
||||
description="Specific memory ID to remove (for episodic/semantic/procedural)",
|
||||
)
|
||||
pattern: str | None = Field(
|
||||
None,
|
||||
description="Pattern to match for bulk removal (use with caution)",
|
||||
max_length=500,
|
||||
)
|
||||
confirm_bulk: bool = Field(
|
||||
False,
|
||||
description="Must be True to confirm bulk deletion when using pattern",
|
||||
)
|
||||
|
||||
|
||||
class ReflectArgs(BaseModel):
|
||||
"""Arguments for the 'reflect' tool."""
|
||||
|
||||
analysis_type: AnalysisType = Field(
|
||||
...,
|
||||
description="Type of pattern analysis to perform",
|
||||
)
|
||||
scope: str | None = Field(
|
||||
None,
|
||||
description="Optional scope to limit analysis (e.g., task_type, time range)",
|
||||
max_length=500,
|
||||
)
|
||||
depth: int = Field(
|
||||
3,
|
||||
description="Depth of analysis (1=surface, 5=deep)",
|
||||
ge=1,
|
||||
le=5,
|
||||
)
|
||||
include_examples: bool = Field(
|
||||
True,
|
||||
description="Whether to include example memories in the analysis",
|
||||
)
|
||||
max_items: int = Field(
|
||||
10,
|
||||
description="Maximum number of patterns/examples to analyze",
|
||||
ge=1,
|
||||
le=50,
|
||||
)
|
||||
|
||||
|
||||
class GetMemoryStatsArgs(BaseModel):
|
||||
"""Arguments for the 'get_memory_stats' tool."""
|
||||
|
||||
include_breakdown: bool = Field(
|
||||
True,
|
||||
description="Include breakdown by memory type",
|
||||
)
|
||||
include_recent_activity: bool = Field(
|
||||
True,
|
||||
description="Include recent memory activity summary",
|
||||
)
|
||||
time_range_days: int = Field(
|
||||
7,
|
||||
description="Time range for activity analysis in days",
|
||||
ge=1,
|
||||
le=90,
|
||||
)
|
||||
|
||||
|
||||
class SearchProceduresArgs(BaseModel):
|
||||
"""Arguments for the 'search_procedures' tool."""
|
||||
|
||||
trigger: str = Field(
|
||||
...,
|
||||
description="Trigger or situation to find procedures for",
|
||||
min_length=1,
|
||||
max_length=500,
|
||||
)
|
||||
task_type: str | None = Field(
|
||||
None,
|
||||
description="Optional task type to filter procedures",
|
||||
max_length=100,
|
||||
)
|
||||
min_success_rate: float = Field(
|
||||
0.5,
|
||||
description="Minimum success rate (0.0-1.0) for returned procedures",
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
)
|
||||
limit: int = Field(
|
||||
5,
|
||||
description="Maximum number of procedures to return",
|
||||
ge=1,
|
||||
le=20,
|
||||
)
|
||||
include_steps: bool = Field(
|
||||
True,
|
||||
description="Whether to include detailed steps in the response",
|
||||
)
|
||||
|
||||
|
||||
class RecordOutcomeArgs(BaseModel):
|
||||
"""Arguments for the 'record_outcome' tool."""
|
||||
|
||||
task_type: str = Field(
|
||||
...,
|
||||
description="Type of task that was executed",
|
||||
min_length=1,
|
||||
max_length=100,
|
||||
)
|
||||
outcome: OutcomeType = Field(
|
||||
...,
|
||||
description="Outcome of the task execution",
|
||||
)
|
||||
procedure_id: str | None = Field(
|
||||
None,
|
||||
description="ID of the procedure that was followed (if any)",
|
||||
)
|
||||
context: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Context in which the task was executed",
|
||||
)
|
||||
lessons_learned: str | None = Field(
|
||||
None,
|
||||
description="What was learned from this execution",
|
||||
max_length=2000,
|
||||
)
|
||||
duration_seconds: float | None = Field(
|
||||
None,
|
||||
description="How long the task took to execute",
|
||||
ge=0.0,
|
||||
)
|
||||
error_details: str | None = Field(
|
||||
None,
|
||||
description="Details about any errors encountered (for failures)",
|
||||
max_length=2000,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tool Definition Structure
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryToolDefinition:
|
||||
"""Definition of an MCP tool for the memory system."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
args_schema: type[BaseModel]
|
||||
input_schema: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Generate input schema from Pydantic model."""
|
||||
if not self.input_schema:
|
||||
self.input_schema = self.args_schema.model_json_schema()
|
||||
|
||||
def to_mcp_format(self) -> dict[str, Any]:
|
||||
"""Convert to MCP tool format."""
|
||||
return {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"inputSchema": self.input_schema,
|
||||
}
|
||||
|
||||
def validate_args(self, args: dict[str, Any]) -> BaseModel:
|
||||
"""Validate and parse arguments."""
|
||||
return self.args_schema.model_validate(args)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tool Definitions
|
||||
# ============================================================================
|
||||
|
||||
|
||||
REMEMBER_TOOL = MemoryToolDefinition(
|
||||
name="remember",
|
||||
description="""Store information in the agent's memory system.
|
||||
|
||||
Use this tool to:
|
||||
- Store temporary data in working memory (key-value with optional TTL)
|
||||
- Record important events in episodic memory (automatically done on session end)
|
||||
- Store facts/knowledge in semantic memory (subject-predicate-object triples)
|
||||
- Save procedures in procedural memory (trigger conditions and steps)
|
||||
|
||||
Examples:
|
||||
- Working memory: {"memory_type": "working", "key": "current_task", "content": "Implementing auth", "ttl_seconds": 3600}
|
||||
- Semantic fact: {"memory_type": "semantic", "subject": "User", "predicate": "prefers", "object_value": "dark mode", "content": "User preference noted"}
|
||||
- Procedure: {"memory_type": "procedural", "trigger": "When creating a new file", "steps": [{"action": "check_exists"}, {"action": "create"}], "content": "File creation procedure"}
|
||||
""",
|
||||
args_schema=RememberArgs,
|
||||
)
|
||||
|
||||
|
||||
RECALL_TOOL = MemoryToolDefinition(
|
||||
name="recall",
|
||||
description="""Retrieve information from the agent's memory system.
|
||||
|
||||
Use this tool to:
|
||||
- Search for relevant past experiences (episodic)
|
||||
- Look up known facts and knowledge (semantic)
|
||||
- Find applicable procedures for current task (procedural)
|
||||
- Get current session state (working)
|
||||
|
||||
The query supports semantic search - describe what you're looking for in natural language.
|
||||
|
||||
Examples:
|
||||
- {"query": "How did I handle authentication errors before?", "memory_types": ["episodic"]}
|
||||
- {"query": "What are the user's preferences?", "memory_types": ["semantic"], "limit": 5}
|
||||
- {"query": "database connection", "memory_types": ["episodic", "semantic", "procedural"], "filters": {"outcome": "success"}}
|
||||
""",
|
||||
args_schema=RecallArgs,
|
||||
)
|
||||
|
||||
|
||||
FORGET_TOOL = MemoryToolDefinition(
|
||||
name="forget",
|
||||
description="""Remove information from the agent's memory system.
|
||||
|
||||
Use this tool to:
|
||||
- Clear temporary working memory entries
|
||||
- Remove specific memories by ID
|
||||
- Bulk remove memories matching a pattern (requires confirmation)
|
||||
|
||||
WARNING: Deletion is permanent. Use with caution.
|
||||
|
||||
Examples:
|
||||
- Working memory: {"memory_type": "working", "key": "temp_calculation"}
|
||||
- Specific memory: {"memory_type": "episodic", "memory_id": "ep-123"}
|
||||
- Bulk (requires confirm): {"memory_type": "working", "pattern": "cache_*", "confirm_bulk": true}
|
||||
""",
|
||||
args_schema=ForgetArgs,
|
||||
)
|
||||
|
||||
|
||||
REFLECT_TOOL = MemoryToolDefinition(
|
||||
name="reflect",
|
||||
description="""Analyze patterns in the agent's memory to gain insights.
|
||||
|
||||
Use this tool to:
|
||||
- Identify patterns in recent work
|
||||
- Understand what leads to success/failure
|
||||
- Learn from past experiences
|
||||
- Track learning progress over time
|
||||
|
||||
Analysis types:
|
||||
- recent_patterns: What patterns appear in recent work
|
||||
- success_factors: What conditions lead to success
|
||||
- failure_patterns: What causes failures and how to avoid them
|
||||
- common_procedures: Most frequently used procedures
|
||||
- learning_progress: How knowledge has grown over time
|
||||
|
||||
Examples:
|
||||
- {"analysis_type": "success_factors", "scope": "code_review", "depth": 3}
|
||||
- {"analysis_type": "failure_patterns", "include_examples": true, "max_items": 5}
|
||||
""",
|
||||
args_schema=ReflectArgs,
|
||||
)
|
||||
|
||||
|
||||
GET_MEMORY_STATS_TOOL = MemoryToolDefinition(
|
||||
name="get_memory_stats",
|
||||
description="""Get statistics about the agent's memory usage.
|
||||
|
||||
Returns information about:
|
||||
- Total memories stored by type
|
||||
- Storage utilization
|
||||
- Recent activity summary
|
||||
- Memory health indicators
|
||||
|
||||
Use this to understand memory capacity and usage patterns.
|
||||
|
||||
Examples:
|
||||
- {"include_breakdown": true, "include_recent_activity": true}
|
||||
- {"time_range_days": 30, "include_breakdown": true}
|
||||
""",
|
||||
args_schema=GetMemoryStatsArgs,
|
||||
)
|
||||
|
||||
|
||||
SEARCH_PROCEDURES_TOOL = MemoryToolDefinition(
|
||||
name="search_procedures",
|
||||
description="""Find relevant procedures for a given situation.
|
||||
|
||||
Use this tool when you need to:
|
||||
- Find the best way to handle a situation
|
||||
- Look up proven approaches to problems
|
||||
- Get step-by-step guidance for tasks
|
||||
|
||||
Returns procedures ranked by relevance and success rate.
|
||||
|
||||
Examples:
|
||||
- {"trigger": "Deploying to production", "min_success_rate": 0.8}
|
||||
- {"trigger": "Handling merge conflicts", "task_type": "git_operations", "limit": 3}
|
||||
""",
|
||||
args_schema=SearchProceduresArgs,
|
||||
)
|
||||
|
||||
|
||||
RECORD_OUTCOME_TOOL = MemoryToolDefinition(
|
||||
name="record_outcome",
|
||||
description="""Record the outcome of a task execution.
|
||||
|
||||
Use this tool after completing a task to:
|
||||
- Update procedure success/failure rates
|
||||
- Store lessons learned for future reference
|
||||
- Improve procedure recommendations
|
||||
|
||||
This helps the memory system learn from experience.
|
||||
|
||||
Examples:
|
||||
- {"task_type": "code_review", "outcome": "success", "lessons_learned": "Breaking changes caught early"}
|
||||
- {"task_type": "deployment", "outcome": "failure", "error_details": "Database migration timeout", "lessons_learned": "Need to test migrations locally first"}
|
||||
""",
|
||||
args_schema=RecordOutcomeArgs,
|
||||
)
|
||||
|
||||
|
||||
# All tool definitions in a dictionary for easy lookup
|
||||
MEMORY_TOOL_DEFINITIONS: dict[str, MemoryToolDefinition] = {
|
||||
"remember": REMEMBER_TOOL,
|
||||
"recall": RECALL_TOOL,
|
||||
"forget": FORGET_TOOL,
|
||||
"reflect": REFLECT_TOOL,
|
||||
"get_memory_stats": GET_MEMORY_STATS_TOOL,
|
||||
"search_procedures": SEARCH_PROCEDURES_TOOL,
|
||||
"record_outcome": RECORD_OUTCOME_TOOL,
|
||||
}
|
||||
|
||||
|
||||
def get_all_tool_schemas() -> list[dict[str, Any]]:
|
||||
"""Get MCP-formatted schemas for all memory tools."""
|
||||
return [tool.to_mcp_format() for tool in MEMORY_TOOL_DEFINITIONS.values()]
|
||||
|
||||
|
||||
def get_tool_definition(name: str) -> MemoryToolDefinition | None:
|
||||
"""Get a specific tool definition by name."""
|
||||
return MEMORY_TOOL_DEFINITIONS.get(name)
|
||||
18
backend/app/services/memory/metrics/__init__.py
Normal file
18
backend/app/services/memory/metrics/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# app/services/memory/metrics/__init__.py
|
||||
"""Memory Metrics module."""
|
||||
|
||||
from .collector import (
|
||||
MemoryMetrics,
|
||||
get_memory_metrics,
|
||||
record_memory_operation,
|
||||
record_retrieval,
|
||||
reset_memory_metrics,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"MemoryMetrics",
|
||||
"get_memory_metrics",
|
||||
"record_memory_operation",
|
||||
"record_retrieval",
|
||||
"reset_memory_metrics",
|
||||
]
|
||||
542
backend/app/services/memory/metrics/collector.py
Normal file
542
backend/app/services/memory/metrics/collector.py
Normal file
@@ -0,0 +1,542 @@
|
||||
# app/services/memory/metrics/collector.py
|
||||
"""
|
||||
Memory Metrics Collector
|
||||
|
||||
Collects and exposes metrics for the memory system.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections import Counter, defaultdict, deque
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MetricType(str, Enum):
|
||||
"""Types of metrics."""
|
||||
|
||||
COUNTER = "counter"
|
||||
GAUGE = "gauge"
|
||||
HISTOGRAM = "histogram"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MetricValue:
|
||||
"""A single metric value."""
|
||||
|
||||
name: str
|
||||
metric_type: MetricType
|
||||
value: float
|
||||
labels: dict[str, str] = field(default_factory=dict)
|
||||
timestamp: datetime = field(default_factory=lambda: datetime.now(UTC))
|
||||
|
||||
|
||||
@dataclass
|
||||
class HistogramBucket:
|
||||
"""Histogram bucket for distribution metrics."""
|
||||
|
||||
le: float # Less than or equal
|
||||
count: int = 0
|
||||
|
||||
|
||||
class MemoryMetrics:
|
||||
"""
|
||||
Collects memory system metrics.
|
||||
|
||||
Metrics tracked:
|
||||
- Memory operations (get/set/delete by type and scope)
|
||||
- Retrieval operations and latencies
|
||||
- Memory item counts by type
|
||||
- Consolidation operations and durations
|
||||
- Cache hit/miss rates
|
||||
- Procedure success rates
|
||||
- Embedding operations
|
||||
"""
|
||||
|
||||
# Maximum samples to keep in histogram (circular buffer)
|
||||
MAX_HISTOGRAM_SAMPLES = 10000
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize MemoryMetrics."""
|
||||
self._counters: dict[str, Counter[str]] = defaultdict(Counter)
|
||||
self._gauges: dict[str, dict[str, float]] = defaultdict(dict)
|
||||
# Use deque with maxlen for bounded memory (circular buffer)
|
||||
self._histograms: dict[str, deque[float]] = defaultdict(
|
||||
lambda: deque(maxlen=self.MAX_HISTOGRAM_SAMPLES)
|
||||
)
|
||||
self._histogram_buckets: dict[str, list[HistogramBucket]] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# Initialize histogram buckets
|
||||
self._init_histogram_buckets()
|
||||
|
||||
def _init_histogram_buckets(self) -> None:
|
||||
"""Initialize histogram buckets for latency metrics."""
|
||||
# Fast operations (working memory)
|
||||
fast_buckets = [0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, float("inf")]
|
||||
|
||||
# Normal operations (retrieval)
|
||||
normal_buckets = [0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, float("inf")]
|
||||
|
||||
# Slow operations (consolidation)
|
||||
slow_buckets = [0.1, 0.5, 1.0, 2.5, 5.0, 10.0, 30.0, 60.0, float("inf")]
|
||||
|
||||
self._histogram_buckets["memory_working_latency_seconds"] = [
|
||||
HistogramBucket(le=b) for b in fast_buckets
|
||||
]
|
||||
self._histogram_buckets["memory_retrieval_latency_seconds"] = [
|
||||
HistogramBucket(le=b) for b in normal_buckets
|
||||
]
|
||||
self._histogram_buckets["memory_consolidation_duration_seconds"] = [
|
||||
HistogramBucket(le=b) for b in slow_buckets
|
||||
]
|
||||
self._histogram_buckets["memory_embedding_latency_seconds"] = [
|
||||
HistogramBucket(le=b) for b in normal_buckets
|
||||
]
|
||||
|
||||
# Counter methods - Operations
|
||||
|
||||
async def inc_operations(
|
||||
self,
|
||||
operation: str,
|
||||
memory_type: str,
|
||||
scope: str | None = None,
|
||||
success: bool = True,
|
||||
) -> None:
|
||||
"""Increment memory operation counter."""
|
||||
async with self._lock:
|
||||
labels = f"operation={operation},memory_type={memory_type}"
|
||||
if scope:
|
||||
labels += f",scope={scope}"
|
||||
labels += f",success={str(success).lower()}"
|
||||
self._counters["memory_operations_total"][labels] += 1
|
||||
|
||||
async def inc_retrieval(
|
||||
self,
|
||||
memory_type: str,
|
||||
strategy: str,
|
||||
results_count: int,
|
||||
) -> None:
|
||||
"""Increment retrieval counter."""
|
||||
async with self._lock:
|
||||
labels = f"memory_type={memory_type},strategy={strategy}"
|
||||
self._counters["memory_retrievals_total"][labels] += 1
|
||||
|
||||
# Track result counts as a separate metric
|
||||
self._counters["memory_retrieval_results_total"][labels] += results_count
|
||||
|
||||
async def inc_cache_hit(self, cache_type: str) -> None:
|
||||
"""Increment cache hit counter."""
|
||||
async with self._lock:
|
||||
labels = f"cache_type={cache_type}"
|
||||
self._counters["memory_cache_hits_total"][labels] += 1
|
||||
|
||||
async def inc_cache_miss(self, cache_type: str) -> None:
|
||||
"""Increment cache miss counter."""
|
||||
async with self._lock:
|
||||
labels = f"cache_type={cache_type}"
|
||||
self._counters["memory_cache_misses_total"][labels] += 1
|
||||
|
||||
async def inc_consolidation(
|
||||
self,
|
||||
consolidation_type: str,
|
||||
success: bool = True,
|
||||
) -> None:
|
||||
"""Increment consolidation counter."""
|
||||
async with self._lock:
|
||||
labels = f"type={consolidation_type},success={str(success).lower()}"
|
||||
self._counters["memory_consolidations_total"][labels] += 1
|
||||
|
||||
async def inc_procedure_execution(
|
||||
self,
|
||||
procedure_id: str | None = None,
|
||||
success: bool = True,
|
||||
) -> None:
|
||||
"""Increment procedure execution counter."""
|
||||
async with self._lock:
|
||||
labels = f"success={str(success).lower()}"
|
||||
self._counters["memory_procedure_executions_total"][labels] += 1
|
||||
|
||||
async def inc_embeddings_generated(self, memory_type: str) -> None:
|
||||
"""Increment embeddings generated counter."""
|
||||
async with self._lock:
|
||||
labels = f"memory_type={memory_type}"
|
||||
self._counters["memory_embeddings_generated_total"][labels] += 1
|
||||
|
||||
async def inc_fact_reinforcements(self) -> None:
|
||||
"""Increment fact reinforcement counter."""
|
||||
async with self._lock:
|
||||
self._counters["memory_fact_reinforcements_total"][""] += 1
|
||||
|
||||
async def inc_episodes_recorded(self, outcome: str) -> None:
|
||||
"""Increment episodes recorded counter."""
|
||||
async with self._lock:
|
||||
labels = f"outcome={outcome}"
|
||||
self._counters["memory_episodes_recorded_total"][labels] += 1
|
||||
|
||||
async def inc_anomalies_detected(self, anomaly_type: str) -> None:
|
||||
"""Increment anomaly detection counter."""
|
||||
async with self._lock:
|
||||
labels = f"anomaly_type={anomaly_type}"
|
||||
self._counters["memory_anomalies_detected_total"][labels] += 1
|
||||
|
||||
async def inc_patterns_detected(self, pattern_type: str) -> None:
|
||||
"""Increment pattern detection counter."""
|
||||
async with self._lock:
|
||||
labels = f"pattern_type={pattern_type}"
|
||||
self._counters["memory_patterns_detected_total"][labels] += 1
|
||||
|
||||
async def inc_insights_generated(self, insight_type: str) -> None:
|
||||
"""Increment insight generation counter."""
|
||||
async with self._lock:
|
||||
labels = f"insight_type={insight_type}"
|
||||
self._counters["memory_insights_generated_total"][labels] += 1
|
||||
|
||||
# Gauge methods
|
||||
|
||||
async def set_memory_items_count(
|
||||
self,
|
||||
memory_type: str,
|
||||
scope: str,
|
||||
count: int,
|
||||
) -> None:
|
||||
"""Set memory item count gauge."""
|
||||
async with self._lock:
|
||||
labels = f"memory_type={memory_type},scope={scope}"
|
||||
self._gauges["memory_items_count"][labels] = float(count)
|
||||
|
||||
async def set_memory_size_bytes(
|
||||
self,
|
||||
memory_type: str,
|
||||
scope: str,
|
||||
size_bytes: int,
|
||||
) -> None:
|
||||
"""Set memory size gauge in bytes."""
|
||||
async with self._lock:
|
||||
labels = f"memory_type={memory_type},scope={scope}"
|
||||
self._gauges["memory_size_bytes"][labels] = float(size_bytes)
|
||||
|
||||
async def set_cache_size(self, cache_type: str, size: int) -> None:
|
||||
"""Set cache size gauge."""
|
||||
async with self._lock:
|
||||
labels = f"cache_type={cache_type}"
|
||||
self._gauges["memory_cache_size"][labels] = float(size)
|
||||
|
||||
async def set_procedure_success_rate(
|
||||
self,
|
||||
procedure_name: str,
|
||||
rate: float,
|
||||
) -> None:
|
||||
"""Set procedure success rate gauge (0-1)."""
|
||||
async with self._lock:
|
||||
labels = f"procedure_name={procedure_name}"
|
||||
self._gauges["memory_procedure_success_rate"][labels] = rate
|
||||
|
||||
async def set_active_sessions(self, count: int) -> None:
|
||||
"""Set active working memory sessions gauge."""
|
||||
async with self._lock:
|
||||
self._gauges["memory_active_sessions"][""] = float(count)
|
||||
|
||||
async def set_pending_consolidations(self, count: int) -> None:
|
||||
"""Set pending consolidations gauge."""
|
||||
async with self._lock:
|
||||
self._gauges["memory_pending_consolidations"][""] = float(count)
|
||||
|
||||
# Histogram methods
|
||||
|
||||
async def observe_working_latency(self, latency_seconds: float) -> None:
|
||||
"""Observe working memory operation latency."""
|
||||
async with self._lock:
|
||||
self._observe_histogram("memory_working_latency_seconds", latency_seconds)
|
||||
|
||||
async def observe_retrieval_latency(self, latency_seconds: float) -> None:
|
||||
"""Observe retrieval latency."""
|
||||
async with self._lock:
|
||||
self._observe_histogram("memory_retrieval_latency_seconds", latency_seconds)
|
||||
|
||||
async def observe_consolidation_duration(self, duration_seconds: float) -> None:
|
||||
"""Observe consolidation duration."""
|
||||
async with self._lock:
|
||||
self._observe_histogram(
|
||||
"memory_consolidation_duration_seconds", duration_seconds
|
||||
)
|
||||
|
||||
async def observe_embedding_latency(self, latency_seconds: float) -> None:
|
||||
"""Observe embedding generation latency."""
|
||||
async with self._lock:
|
||||
self._observe_histogram("memory_embedding_latency_seconds", latency_seconds)
|
||||
|
||||
def _observe_histogram(self, name: str, value: float) -> None:
|
||||
"""Record a value in a histogram."""
|
||||
self._histograms[name].append(value)
|
||||
|
||||
# Update buckets
|
||||
if name in self._histogram_buckets:
|
||||
for bucket in self._histogram_buckets[name]:
|
||||
if value <= bucket.le:
|
||||
bucket.count += 1
|
||||
|
||||
# Export methods
|
||||
|
||||
async def get_all_metrics(self) -> list[MetricValue]:
|
||||
"""Get all metrics as MetricValue objects."""
|
||||
metrics: list[MetricValue] = []
|
||||
|
||||
async with self._lock:
|
||||
# Export counters
|
||||
for name, counter in self._counters.items():
|
||||
for labels_str, value in counter.items():
|
||||
labels = self._parse_labels(labels_str)
|
||||
metrics.append(
|
||||
MetricValue(
|
||||
name=name,
|
||||
metric_type=MetricType.COUNTER,
|
||||
value=float(value),
|
||||
labels=labels,
|
||||
)
|
||||
)
|
||||
|
||||
# Export gauges
|
||||
for name, gauge_dict in self._gauges.items():
|
||||
for labels_str, gauge_value in gauge_dict.items():
|
||||
gauge_labels = self._parse_labels(labels_str)
|
||||
metrics.append(
|
||||
MetricValue(
|
||||
name=name,
|
||||
metric_type=MetricType.GAUGE,
|
||||
value=gauge_value,
|
||||
labels=gauge_labels,
|
||||
)
|
||||
)
|
||||
|
||||
# Export histogram summaries
|
||||
for name, values in self._histograms.items():
|
||||
if values:
|
||||
metrics.append(
|
||||
MetricValue(
|
||||
name=f"{name}_count",
|
||||
metric_type=MetricType.COUNTER,
|
||||
value=float(len(values)),
|
||||
)
|
||||
)
|
||||
metrics.append(
|
||||
MetricValue(
|
||||
name=f"{name}_sum",
|
||||
metric_type=MetricType.COUNTER,
|
||||
value=sum(values),
|
||||
)
|
||||
)
|
||||
|
||||
return metrics
|
||||
|
||||
async def get_prometheus_format(self) -> str:
|
||||
"""Export metrics in Prometheus text format."""
|
||||
lines: list[str] = []
|
||||
|
||||
async with self._lock:
|
||||
# Export counters
|
||||
for name, counter in self._counters.items():
|
||||
lines.append(f"# TYPE {name} counter")
|
||||
for labels_str, value in counter.items():
|
||||
if labels_str:
|
||||
lines.append(f"{name}{{{labels_str}}} {value}")
|
||||
else:
|
||||
lines.append(f"{name} {value}")
|
||||
|
||||
# Export gauges
|
||||
for name, gauge_dict in self._gauges.items():
|
||||
lines.append(f"# TYPE {name} gauge")
|
||||
for labels_str, gauge_value in gauge_dict.items():
|
||||
if labels_str:
|
||||
lines.append(f"{name}{{{labels_str}}} {gauge_value}")
|
||||
else:
|
||||
lines.append(f"{name} {gauge_value}")
|
||||
|
||||
# Export histograms
|
||||
for name, buckets in self._histogram_buckets.items():
|
||||
lines.append(f"# TYPE {name} histogram")
|
||||
for bucket in buckets:
|
||||
le_str = "+Inf" if bucket.le == float("inf") else str(bucket.le)
|
||||
lines.append(f'{name}_bucket{{le="{le_str}"}} {bucket.count}')
|
||||
|
||||
if name in self._histograms:
|
||||
values = self._histograms[name]
|
||||
lines.append(f"{name}_count {len(values)}")
|
||||
lines.append(f"{name}_sum {sum(values)}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
async def get_summary(self) -> dict[str, Any]:
|
||||
"""Get a summary of key metrics."""
|
||||
async with self._lock:
|
||||
total_operations = sum(self._counters["memory_operations_total"].values())
|
||||
successful_operations = sum(
|
||||
v
|
||||
for k, v in self._counters["memory_operations_total"].items()
|
||||
if "success=true" in k
|
||||
)
|
||||
|
||||
total_retrievals = sum(self._counters["memory_retrievals_total"].values())
|
||||
|
||||
total_cache_hits = sum(self._counters["memory_cache_hits_total"].values())
|
||||
total_cache_misses = sum(
|
||||
self._counters["memory_cache_misses_total"].values()
|
||||
)
|
||||
cache_hit_rate = (
|
||||
total_cache_hits / (total_cache_hits + total_cache_misses)
|
||||
if (total_cache_hits + total_cache_misses) > 0
|
||||
else 0.0
|
||||
)
|
||||
|
||||
total_consolidations = sum(
|
||||
self._counters["memory_consolidations_total"].values()
|
||||
)
|
||||
|
||||
total_episodes = sum(
|
||||
self._counters["memory_episodes_recorded_total"].values()
|
||||
)
|
||||
|
||||
# Calculate average latencies
|
||||
retrieval_latencies = list(
|
||||
self._histograms.get("memory_retrieval_latency_seconds", deque())
|
||||
)
|
||||
avg_retrieval_latency = (
|
||||
sum(retrieval_latencies) / len(retrieval_latencies)
|
||||
if retrieval_latencies
|
||||
else 0.0
|
||||
)
|
||||
|
||||
return {
|
||||
"total_operations": total_operations,
|
||||
"successful_operations": successful_operations,
|
||||
"operation_success_rate": (
|
||||
successful_operations / total_operations
|
||||
if total_operations > 0
|
||||
else 1.0
|
||||
),
|
||||
"total_retrievals": total_retrievals,
|
||||
"cache_hit_rate": cache_hit_rate,
|
||||
"total_consolidations": total_consolidations,
|
||||
"total_episodes_recorded": total_episodes,
|
||||
"avg_retrieval_latency_ms": avg_retrieval_latency * 1000,
|
||||
"patterns_detected": sum(
|
||||
self._counters["memory_patterns_detected_total"].values()
|
||||
),
|
||||
"insights_generated": sum(
|
||||
self._counters["memory_insights_generated_total"].values()
|
||||
),
|
||||
"anomalies_detected": sum(
|
||||
self._counters["memory_anomalies_detected_total"].values()
|
||||
),
|
||||
"active_sessions": self._gauges.get("memory_active_sessions", {}).get(
|
||||
"", 0
|
||||
),
|
||||
"pending_consolidations": self._gauges.get(
|
||||
"memory_pending_consolidations", {}
|
||||
).get("", 0),
|
||||
}
|
||||
|
||||
async def get_cache_stats(self) -> dict[str, Any]:
|
||||
"""Get detailed cache statistics."""
|
||||
async with self._lock:
|
||||
stats: dict[str, Any] = {}
|
||||
|
||||
# Get hits/misses by cache type
|
||||
for labels_str, hits in self._counters["memory_cache_hits_total"].items():
|
||||
cache_type = self._parse_labels(labels_str).get("cache_type", "unknown")
|
||||
if cache_type not in stats:
|
||||
stats[cache_type] = {"hits": 0, "misses": 0}
|
||||
stats[cache_type]["hits"] = hits
|
||||
|
||||
for labels_str, misses in self._counters[
|
||||
"memory_cache_misses_total"
|
||||
].items():
|
||||
cache_type = self._parse_labels(labels_str).get("cache_type", "unknown")
|
||||
if cache_type not in stats:
|
||||
stats[cache_type] = {"hits": 0, "misses": 0}
|
||||
stats[cache_type]["misses"] = misses
|
||||
|
||||
# Calculate hit rates
|
||||
for data in stats.values():
|
||||
total = data["hits"] + data["misses"]
|
||||
data["hit_rate"] = data["hits"] / total if total > 0 else 0.0
|
||||
data["total"] = total
|
||||
|
||||
return stats
|
||||
|
||||
async def reset(self) -> None:
|
||||
"""Reset all metrics."""
|
||||
async with self._lock:
|
||||
self._counters.clear()
|
||||
self._gauges.clear()
|
||||
self._histograms.clear()
|
||||
self._init_histogram_buckets()
|
||||
|
||||
def _parse_labels(self, labels_str: str) -> dict[str, str]:
|
||||
"""Parse labels string into dictionary."""
|
||||
if not labels_str:
|
||||
return {}
|
||||
|
||||
labels = {}
|
||||
for pair in labels_str.split(","):
|
||||
if "=" in pair:
|
||||
key, value = pair.split("=", 1)
|
||||
labels[key.strip()] = value.strip()
|
||||
|
||||
return labels
|
||||
|
||||
|
||||
# Singleton instance
|
||||
_metrics: MemoryMetrics | None = None
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def get_memory_metrics() -> MemoryMetrics:
|
||||
"""Get the singleton MemoryMetrics instance."""
|
||||
global _metrics
|
||||
|
||||
async with _lock:
|
||||
if _metrics is None:
|
||||
_metrics = MemoryMetrics()
|
||||
return _metrics
|
||||
|
||||
|
||||
async def reset_memory_metrics() -> None:
|
||||
"""Reset the singleton instance (for testing)."""
|
||||
global _metrics
|
||||
async with _lock:
|
||||
_metrics = None
|
||||
|
||||
|
||||
# Convenience functions
|
||||
|
||||
|
||||
async def record_memory_operation(
|
||||
operation: str,
|
||||
memory_type: str,
|
||||
scope: str | None = None,
|
||||
success: bool = True,
|
||||
latency_ms: float | None = None,
|
||||
) -> None:
|
||||
"""Record a memory operation."""
|
||||
metrics = await get_memory_metrics()
|
||||
await metrics.inc_operations(operation, memory_type, scope, success)
|
||||
|
||||
if latency_ms is not None and memory_type == "working":
|
||||
await metrics.observe_working_latency(latency_ms / 1000)
|
||||
|
||||
|
||||
async def record_retrieval(
|
||||
memory_type: str,
|
||||
strategy: str,
|
||||
results_count: int,
|
||||
latency_ms: float,
|
||||
) -> None:
|
||||
"""Record a retrieval operation."""
|
||||
metrics = await get_memory_metrics()
|
||||
await metrics.inc_retrieval(memory_type, strategy, results_count)
|
||||
await metrics.observe_retrieval_latency(latency_ms / 1000)
|
||||
22
backend/app/services/memory/procedural/__init__.py
Normal file
22
backend/app/services/memory/procedural/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
# app/services/memory/procedural/__init__.py
|
||||
"""
|
||||
Procedural Memory
|
||||
|
||||
Learned skills and procedures from successful task patterns.
|
||||
"""
|
||||
|
||||
from .matching import (
|
||||
MatchContext,
|
||||
MatchResult,
|
||||
ProcedureMatcher,
|
||||
get_procedure_matcher,
|
||||
)
|
||||
from .memory import ProceduralMemory
|
||||
|
||||
__all__ = [
|
||||
"MatchContext",
|
||||
"MatchResult",
|
||||
"ProceduralMemory",
|
||||
"ProcedureMatcher",
|
||||
"get_procedure_matcher",
|
||||
]
|
||||
291
backend/app/services/memory/procedural/matching.py
Normal file
291
backend/app/services/memory/procedural/matching.py
Normal file
@@ -0,0 +1,291 @@
|
||||
# app/services/memory/procedural/matching.py
|
||||
"""
|
||||
Procedure Matching.
|
||||
|
||||
Provides utilities for matching procedures to contexts,
|
||||
ranking procedures by relevance, and suggesting procedures.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from app.services.memory.types import Procedure
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MatchResult:
|
||||
"""Result of a procedure match."""
|
||||
|
||||
procedure: Procedure
|
||||
score: float
|
||||
matched_terms: list[str] = field(default_factory=list)
|
||||
match_type: str = "keyword" # keyword, semantic, pattern
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"procedure_id": str(self.procedure.id),
|
||||
"procedure_name": self.procedure.name,
|
||||
"score": self.score,
|
||||
"matched_terms": self.matched_terms,
|
||||
"match_type": self.match_type,
|
||||
"success_rate": self.procedure.success_rate,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class MatchContext:
|
||||
"""Context for procedure matching."""
|
||||
|
||||
query: str
|
||||
task_type: str | None = None
|
||||
project_id: Any | None = None
|
||||
agent_type_id: Any | None = None
|
||||
max_results: int = 5
|
||||
min_score: float = 0.3
|
||||
require_success_rate: float | None = None
|
||||
|
||||
|
||||
class ProcedureMatcher:
|
||||
"""
|
||||
Matches procedures to contexts using multiple strategies.
|
||||
|
||||
Matching strategies:
|
||||
- Keyword matching on trigger pattern and name
|
||||
- Pattern-based matching using regex
|
||||
- Success rate weighting
|
||||
|
||||
In production, this would be augmented with vector similarity search.
|
||||
"""
|
||||
|
||||
# Common task-related keywords for boosting
|
||||
TASK_KEYWORDS: ClassVar[set[str]] = {
|
||||
"create",
|
||||
"update",
|
||||
"delete",
|
||||
"fix",
|
||||
"implement",
|
||||
"add",
|
||||
"remove",
|
||||
"refactor",
|
||||
"test",
|
||||
"deploy",
|
||||
"configure",
|
||||
"setup",
|
||||
"build",
|
||||
"debug",
|
||||
"optimize",
|
||||
}
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the matcher."""
|
||||
self._compiled_patterns: dict[str, re.Pattern[str]] = {}
|
||||
|
||||
def match(
|
||||
self,
|
||||
procedures: list[Procedure],
|
||||
context: MatchContext,
|
||||
) -> list[MatchResult]:
|
||||
"""
|
||||
Match procedures against a context.
|
||||
|
||||
Args:
|
||||
procedures: List of procedures to match
|
||||
context: Matching context
|
||||
|
||||
Returns:
|
||||
List of match results, sorted by score (highest first)
|
||||
"""
|
||||
results: list[MatchResult] = []
|
||||
|
||||
query_terms = self._extract_terms(context.query)
|
||||
query_lower = context.query.lower()
|
||||
|
||||
for procedure in procedures:
|
||||
score, matched = self._calculate_match_score(
|
||||
procedure=procedure,
|
||||
query_terms=query_terms,
|
||||
query_lower=query_lower,
|
||||
context=context,
|
||||
)
|
||||
|
||||
if score >= context.min_score:
|
||||
# Apply success rate boost
|
||||
if context.require_success_rate is not None:
|
||||
if procedure.success_rate < context.require_success_rate:
|
||||
continue
|
||||
|
||||
# Boost score based on success rate
|
||||
success_boost = procedure.success_rate * 0.2
|
||||
final_score = min(1.0, score + success_boost)
|
||||
|
||||
results.append(
|
||||
MatchResult(
|
||||
procedure=procedure,
|
||||
score=final_score,
|
||||
matched_terms=matched,
|
||||
match_type="keyword",
|
||||
)
|
||||
)
|
||||
|
||||
# Sort by score descending
|
||||
results.sort(key=lambda r: r.score, reverse=True)
|
||||
|
||||
return results[: context.max_results]
|
||||
|
||||
def _extract_terms(self, text: str) -> list[str]:
|
||||
"""Extract searchable terms from text."""
|
||||
# Remove special characters and split
|
||||
clean = re.sub(r"[^\w\s-]", " ", text.lower())
|
||||
terms = clean.split()
|
||||
|
||||
# Filter out very short terms
|
||||
return [t for t in terms if len(t) >= 2]
|
||||
|
||||
def _calculate_match_score(
|
||||
self,
|
||||
procedure: Procedure,
|
||||
query_terms: list[str],
|
||||
query_lower: str,
|
||||
context: MatchContext,
|
||||
) -> tuple[float, list[str]]:
|
||||
"""
|
||||
Calculate match score between procedure and query.
|
||||
|
||||
Returns:
|
||||
Tuple of (score, matched_terms)
|
||||
"""
|
||||
score = 0.0
|
||||
matched: list[str] = []
|
||||
|
||||
trigger_lower = procedure.trigger_pattern.lower()
|
||||
name_lower = procedure.name.lower()
|
||||
|
||||
# Exact name match - high score
|
||||
if name_lower in query_lower or query_lower in name_lower:
|
||||
score += 0.5
|
||||
matched.append(f"name:{procedure.name}")
|
||||
|
||||
# Trigger pattern match
|
||||
if trigger_lower in query_lower or query_lower in trigger_lower:
|
||||
score += 0.4
|
||||
matched.append(f"trigger:{procedure.trigger_pattern[:30]}")
|
||||
|
||||
# Term-by-term matching
|
||||
for term in query_terms:
|
||||
if term in trigger_lower:
|
||||
score += 0.1
|
||||
matched.append(term)
|
||||
elif term in name_lower:
|
||||
score += 0.08
|
||||
matched.append(term)
|
||||
|
||||
# Boost for task keywords
|
||||
if term in self.TASK_KEYWORDS:
|
||||
if term in trigger_lower or term in name_lower:
|
||||
score += 0.05
|
||||
|
||||
# Task type match if provided
|
||||
if context.task_type:
|
||||
task_type_lower = context.task_type.lower()
|
||||
if task_type_lower in trigger_lower or task_type_lower in name_lower:
|
||||
score += 0.3
|
||||
matched.append(f"task_type:{context.task_type}")
|
||||
|
||||
# Regex pattern matching on trigger
|
||||
try:
|
||||
pattern = self._get_or_compile_pattern(trigger_lower)
|
||||
if pattern and pattern.search(query_lower):
|
||||
score += 0.25
|
||||
matched.append("pattern_match")
|
||||
except re.error:
|
||||
pass # Invalid regex, skip pattern matching
|
||||
|
||||
return min(1.0, score), matched
|
||||
|
||||
def _get_or_compile_pattern(self, pattern: str) -> re.Pattern[str] | None:
|
||||
"""Get or compile a regex pattern with caching."""
|
||||
if pattern in self._compiled_patterns:
|
||||
return self._compiled_patterns[pattern]
|
||||
|
||||
# Only compile if it looks like a regex pattern
|
||||
if not any(c in pattern for c in r"\.*+?[]{}|()^$"):
|
||||
return None
|
||||
|
||||
try:
|
||||
compiled = re.compile(pattern, re.IGNORECASE)
|
||||
self._compiled_patterns[pattern] = compiled
|
||||
return compiled
|
||||
except re.error:
|
||||
return None
|
||||
|
||||
def rank_by_relevance(
|
||||
self,
|
||||
procedures: list[Procedure],
|
||||
task_type: str,
|
||||
) -> list[Procedure]:
|
||||
"""
|
||||
Rank procedures by relevance to a task type.
|
||||
|
||||
Args:
|
||||
procedures: Procedures to rank
|
||||
task_type: Task type for relevance
|
||||
|
||||
Returns:
|
||||
Procedures sorted by relevance
|
||||
"""
|
||||
context = MatchContext(
|
||||
query=task_type,
|
||||
task_type=task_type,
|
||||
min_score=0.0,
|
||||
max_results=len(procedures),
|
||||
)
|
||||
|
||||
results = self.match(procedures, context)
|
||||
return [r.procedure for r in results]
|
||||
|
||||
def suggest_procedures(
|
||||
self,
|
||||
procedures: list[Procedure],
|
||||
query: str,
|
||||
min_success_rate: float = 0.5,
|
||||
max_suggestions: int = 3,
|
||||
) -> list[MatchResult]:
|
||||
"""
|
||||
Suggest the best procedures for a query.
|
||||
|
||||
Only suggests procedures with sufficient success rate.
|
||||
|
||||
Args:
|
||||
procedures: Available procedures
|
||||
query: Query/context
|
||||
min_success_rate: Minimum success rate to suggest
|
||||
max_suggestions: Maximum suggestions
|
||||
|
||||
Returns:
|
||||
List of procedure suggestions
|
||||
"""
|
||||
context = MatchContext(
|
||||
query=query,
|
||||
max_results=max_suggestions,
|
||||
min_score=0.2,
|
||||
require_success_rate=min_success_rate,
|
||||
)
|
||||
|
||||
return self.match(procedures, context)
|
||||
|
||||
|
||||
# Singleton matcher instance
|
||||
_matcher: ProcedureMatcher | None = None
|
||||
|
||||
|
||||
def get_procedure_matcher() -> ProcedureMatcher:
|
||||
"""Get the singleton procedure matcher instance."""
|
||||
global _matcher
|
||||
if _matcher is None:
|
||||
_matcher = ProcedureMatcher()
|
||||
return _matcher
|
||||
749
backend/app/services/memory/procedural/memory.py
Normal file
749
backend/app/services/memory/procedural/memory.py
Normal file
@@ -0,0 +1,749 @@
|
||||
# app/services/memory/procedural/memory.py
|
||||
"""
|
||||
Procedural Memory Implementation.
|
||||
|
||||
Provides storage and retrieval for learned procedures (skills)
|
||||
derived from successful task execution patterns.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_, desc, or_, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.memory.procedure import Procedure as ProcedureModel
|
||||
from app.services.memory.config import get_memory_settings
|
||||
from app.services.memory.types import Procedure, ProcedureCreate, RetrievalResult, Step
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _escape_like_pattern(pattern: str) -> str:
|
||||
"""
|
||||
Escape SQL LIKE/ILIKE special characters to prevent pattern injection.
|
||||
|
||||
Characters escaped:
|
||||
- % (matches zero or more characters)
|
||||
- _ (matches exactly one character)
|
||||
- \\ (escape character itself)
|
||||
|
||||
Args:
|
||||
pattern: Raw search pattern from user input
|
||||
|
||||
Returns:
|
||||
Escaped pattern safe for use in LIKE/ILIKE queries
|
||||
"""
|
||||
# Escape backslash first, then the wildcards
|
||||
return pattern.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
|
||||
|
||||
|
||||
def _model_to_procedure(model: ProcedureModel) -> Procedure:
|
||||
"""Convert SQLAlchemy model to Procedure dataclass."""
|
||||
return Procedure(
|
||||
id=model.id, # type: ignore[arg-type]
|
||||
project_id=model.project_id, # type: ignore[arg-type]
|
||||
agent_type_id=model.agent_type_id, # type: ignore[arg-type]
|
||||
name=model.name, # type: ignore[arg-type]
|
||||
trigger_pattern=model.trigger_pattern, # type: ignore[arg-type]
|
||||
steps=model.steps or [], # type: ignore[arg-type]
|
||||
success_count=model.success_count, # type: ignore[arg-type]
|
||||
failure_count=model.failure_count, # type: ignore[arg-type]
|
||||
last_used=model.last_used, # type: ignore[arg-type]
|
||||
embedding=None, # Don't expose raw embedding
|
||||
created_at=model.created_at, # type: ignore[arg-type]
|
||||
updated_at=model.updated_at, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
|
||||
class ProceduralMemory:
|
||||
"""
|
||||
Procedural Memory Service.
|
||||
|
||||
Provides procedure storage and retrieval:
|
||||
- Record procedures from successful task patterns
|
||||
- Find matching procedures by trigger pattern
|
||||
- Track success/failure rates
|
||||
- Get best procedure for a task type
|
||||
- Update procedure steps
|
||||
|
||||
Performance target: <50ms P95 for matching
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
embedding_generator: Any | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize procedural memory.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
embedding_generator: Optional embedding generator for semantic matching
|
||||
"""
|
||||
self._session = session
|
||||
self._embedding_generator = embedding_generator
|
||||
self._settings = get_memory_settings()
|
||||
|
||||
@classmethod
|
||||
async def create(
|
||||
cls,
|
||||
session: AsyncSession,
|
||||
embedding_generator: Any | None = None,
|
||||
) -> "ProceduralMemory":
|
||||
"""
|
||||
Factory method to create ProceduralMemory.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
embedding_generator: Optional embedding generator
|
||||
|
||||
Returns:
|
||||
Configured ProceduralMemory instance
|
||||
"""
|
||||
return cls(session=session, embedding_generator=embedding_generator)
|
||||
|
||||
# =========================================================================
|
||||
# Procedure Recording
|
||||
# =========================================================================
|
||||
|
||||
async def record_procedure(self, procedure: ProcedureCreate) -> Procedure:
|
||||
"""
|
||||
Record a new procedure or update an existing one.
|
||||
|
||||
If a procedure with the same name exists in the same scope,
|
||||
its steps will be updated and success count incremented.
|
||||
|
||||
Args:
|
||||
procedure: Procedure data to record
|
||||
|
||||
Returns:
|
||||
The created or updated procedure
|
||||
"""
|
||||
# Check for existing procedure with same name
|
||||
existing = await self._find_existing_procedure(
|
||||
project_id=procedure.project_id,
|
||||
agent_type_id=procedure.agent_type_id,
|
||||
name=procedure.name,
|
||||
)
|
||||
|
||||
if existing is not None:
|
||||
# Update existing procedure
|
||||
return await self._update_existing_procedure(
|
||||
existing=existing,
|
||||
new_steps=procedure.steps,
|
||||
new_trigger=procedure.trigger_pattern,
|
||||
)
|
||||
|
||||
# Create new procedure
|
||||
now = datetime.now(UTC)
|
||||
|
||||
# Generate embedding if possible
|
||||
embedding = None
|
||||
if self._embedding_generator is not None:
|
||||
embedding_text = self._create_embedding_text(procedure)
|
||||
embedding = await self._embedding_generator.generate(embedding_text)
|
||||
|
||||
model = ProcedureModel(
|
||||
project_id=procedure.project_id,
|
||||
agent_type_id=procedure.agent_type_id,
|
||||
name=procedure.name,
|
||||
trigger_pattern=procedure.trigger_pattern,
|
||||
steps=procedure.steps,
|
||||
success_count=1, # New procedures start with 1 success (they worked)
|
||||
failure_count=0,
|
||||
last_used=now,
|
||||
embedding=embedding,
|
||||
)
|
||||
|
||||
self._session.add(model)
|
||||
await self._session.flush()
|
||||
await self._session.refresh(model)
|
||||
|
||||
logger.info(
|
||||
f"Recorded new procedure: {procedure.name} with {len(procedure.steps)} steps"
|
||||
)
|
||||
|
||||
return _model_to_procedure(model)
|
||||
|
||||
async def _find_existing_procedure(
|
||||
self,
|
||||
project_id: UUID | None,
|
||||
agent_type_id: UUID | None,
|
||||
name: str,
|
||||
) -> ProcedureModel | None:
|
||||
"""Find an existing procedure with the same name in the same scope."""
|
||||
query = select(ProcedureModel).where(ProcedureModel.name == name)
|
||||
|
||||
if project_id is not None:
|
||||
query = query.where(ProcedureModel.project_id == project_id)
|
||||
else:
|
||||
query = query.where(ProcedureModel.project_id.is_(None))
|
||||
|
||||
if agent_type_id is not None:
|
||||
query = query.where(ProcedureModel.agent_type_id == agent_type_id)
|
||||
else:
|
||||
query = query.where(ProcedureModel.agent_type_id.is_(None))
|
||||
|
||||
result = await self._session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def _update_existing_procedure(
|
||||
self,
|
||||
existing: ProcedureModel,
|
||||
new_steps: list[dict[str, Any]],
|
||||
new_trigger: str,
|
||||
) -> Procedure:
|
||||
"""Update an existing procedure with new steps."""
|
||||
now = datetime.now(UTC)
|
||||
|
||||
# Merge steps intelligently - keep existing order, add new steps
|
||||
merged_steps = self._merge_steps(
|
||||
existing.steps or [], # type: ignore[arg-type]
|
||||
new_steps,
|
||||
)
|
||||
|
||||
stmt = (
|
||||
update(ProcedureModel)
|
||||
.where(ProcedureModel.id == existing.id)
|
||||
.values(
|
||||
steps=merged_steps,
|
||||
trigger_pattern=new_trigger,
|
||||
success_count=ProcedureModel.success_count + 1,
|
||||
last_used=now,
|
||||
updated_at=now,
|
||||
)
|
||||
.returning(ProcedureModel)
|
||||
)
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
updated_model = result.scalar_one()
|
||||
await self._session.flush()
|
||||
|
||||
logger.info(f"Updated existing procedure: {existing.name}")
|
||||
|
||||
return _model_to_procedure(updated_model)
|
||||
|
||||
def _merge_steps(
|
||||
self,
|
||||
existing_steps: list[dict[str, Any]],
|
||||
new_steps: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Merge steps from a new execution with existing steps."""
|
||||
if not existing_steps:
|
||||
return new_steps
|
||||
if not new_steps:
|
||||
return existing_steps
|
||||
|
||||
# For now, use the new steps if they differ significantly
|
||||
# In production, this could use more sophisticated merging
|
||||
if len(new_steps) != len(existing_steps):
|
||||
# If structure changed, prefer newer steps
|
||||
return new_steps
|
||||
|
||||
# Merge step-by-step, preferring new data where available
|
||||
merged = []
|
||||
for i, new_step in enumerate(new_steps):
|
||||
if i < len(existing_steps):
|
||||
# Merge with existing step
|
||||
step = {**existing_steps[i], **new_step}
|
||||
else:
|
||||
step = new_step
|
||||
merged.append(step)
|
||||
|
||||
return merged
|
||||
|
||||
def _create_embedding_text(self, procedure: ProcedureCreate) -> str:
|
||||
"""Create text for embedding from procedure data."""
|
||||
steps_text = " ".join(step.get("action", "") for step in procedure.steps)
|
||||
return f"{procedure.name} {procedure.trigger_pattern} {steps_text}"
|
||||
|
||||
# =========================================================================
|
||||
# Procedure Retrieval
|
||||
# =========================================================================
|
||||
|
||||
async def find_matching(
|
||||
self,
|
||||
context: str,
|
||||
project_id: UUID | None = None,
|
||||
agent_type_id: UUID | None = None,
|
||||
limit: int = 5,
|
||||
) -> list[Procedure]:
|
||||
"""
|
||||
Find procedures matching the given context.
|
||||
|
||||
Args:
|
||||
context: Context/trigger to match against
|
||||
project_id: Optional project to search within
|
||||
agent_type_id: Optional agent type filter
|
||||
limit: Maximum results
|
||||
|
||||
Returns:
|
||||
List of matching procedures
|
||||
"""
|
||||
result = await self._find_matching_with_metadata(
|
||||
context=context,
|
||||
project_id=project_id,
|
||||
agent_type_id=agent_type_id,
|
||||
limit=limit,
|
||||
)
|
||||
return result.items
|
||||
|
||||
async def _find_matching_with_metadata(
|
||||
self,
|
||||
context: str,
|
||||
project_id: UUID | None = None,
|
||||
agent_type_id: UUID | None = None,
|
||||
limit: int = 5,
|
||||
) -> RetrievalResult[Procedure]:
|
||||
"""Find matching procedures with full result metadata."""
|
||||
start_time = time.perf_counter()
|
||||
|
||||
# Build base query - prioritize by success rate
|
||||
stmt = (
|
||||
select(ProcedureModel)
|
||||
.order_by(
|
||||
desc(
|
||||
ProcedureModel.success_count
|
||||
/ (ProcedureModel.success_count + ProcedureModel.failure_count + 1)
|
||||
),
|
||||
desc(ProcedureModel.last_used),
|
||||
)
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
# Apply scope filters
|
||||
if project_id is not None:
|
||||
stmt = stmt.where(
|
||||
or_(
|
||||
ProcedureModel.project_id == project_id,
|
||||
ProcedureModel.project_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
if agent_type_id is not None:
|
||||
stmt = stmt.where(
|
||||
or_(
|
||||
ProcedureModel.agent_type_id == agent_type_id,
|
||||
ProcedureModel.agent_type_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
# Text-based matching on trigger pattern and name
|
||||
# TODO: Implement proper vector similarity search when pgvector is integrated
|
||||
search_terms = context.lower().split()[:5] # Limit to 5 terms
|
||||
if search_terms:
|
||||
conditions = []
|
||||
for term in search_terms:
|
||||
# Escape SQL wildcards to prevent pattern injection
|
||||
escaped_term = _escape_like_pattern(term)
|
||||
term_pattern = f"%{escaped_term}%"
|
||||
conditions.append(
|
||||
or_(
|
||||
ProcedureModel.trigger_pattern.ilike(term_pattern),
|
||||
ProcedureModel.name.ilike(term_pattern),
|
||||
)
|
||||
)
|
||||
if conditions:
|
||||
stmt = stmt.where(or_(*conditions))
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
models = list(result.scalars().all())
|
||||
|
||||
latency_ms = (time.perf_counter() - start_time) * 1000
|
||||
|
||||
return RetrievalResult(
|
||||
items=[_model_to_procedure(m) for m in models],
|
||||
total_count=len(models),
|
||||
query=context,
|
||||
retrieval_type="procedural",
|
||||
latency_ms=latency_ms,
|
||||
metadata={"project_id": str(project_id) if project_id else None},
|
||||
)
|
||||
|
||||
async def get_best_procedure(
|
||||
self,
|
||||
task_type: str,
|
||||
project_id: UUID | None = None,
|
||||
agent_type_id: UUID | None = None,
|
||||
min_success_rate: float = 0.5,
|
||||
min_uses: int = 1,
|
||||
) -> Procedure | None:
|
||||
"""
|
||||
Get the best procedure for a given task type.
|
||||
|
||||
Returns the procedure with the highest success rate that
|
||||
meets the minimum thresholds.
|
||||
|
||||
Args:
|
||||
task_type: Task type to find procedure for
|
||||
project_id: Optional project scope
|
||||
agent_type_id: Optional agent type scope
|
||||
min_success_rate: Minimum required success rate
|
||||
min_uses: Minimum number of uses required
|
||||
|
||||
Returns:
|
||||
Best matching procedure or None
|
||||
"""
|
||||
# Escape SQL wildcards to prevent pattern injection
|
||||
escaped_task_type = _escape_like_pattern(task_type)
|
||||
task_type_pattern = f"%{escaped_task_type}%"
|
||||
|
||||
# Build query for procedures matching task type
|
||||
stmt = (
|
||||
select(ProcedureModel)
|
||||
.where(
|
||||
and_(
|
||||
(ProcedureModel.success_count + ProcedureModel.failure_count)
|
||||
>= min_uses,
|
||||
or_(
|
||||
ProcedureModel.trigger_pattern.ilike(task_type_pattern),
|
||||
ProcedureModel.name.ilike(task_type_pattern),
|
||||
),
|
||||
)
|
||||
)
|
||||
.order_by(
|
||||
desc(
|
||||
ProcedureModel.success_count
|
||||
/ (ProcedureModel.success_count + ProcedureModel.failure_count + 1)
|
||||
),
|
||||
desc(ProcedureModel.last_used),
|
||||
)
|
||||
.limit(10)
|
||||
)
|
||||
|
||||
# Apply scope filters
|
||||
if project_id is not None:
|
||||
stmt = stmt.where(
|
||||
or_(
|
||||
ProcedureModel.project_id == project_id,
|
||||
ProcedureModel.project_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
if agent_type_id is not None:
|
||||
stmt = stmt.where(
|
||||
or_(
|
||||
ProcedureModel.agent_type_id == agent_type_id,
|
||||
ProcedureModel.agent_type_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
models = list(result.scalars().all())
|
||||
|
||||
# Filter by success rate in Python (SQLAlchemy division in WHERE is complex)
|
||||
for model in models:
|
||||
success = float(model.success_count)
|
||||
failure = float(model.failure_count)
|
||||
total = success + failure
|
||||
if total > 0 and (success / total) >= min_success_rate:
|
||||
logger.debug(
|
||||
f"Found best procedure for '{task_type}': {model.name} "
|
||||
f"(success_rate={success / total:.2%})"
|
||||
)
|
||||
return _model_to_procedure(model)
|
||||
|
||||
return None
|
||||
|
||||
async def get_by_id(self, procedure_id: UUID) -> Procedure | None:
|
||||
"""Get a procedure by ID."""
|
||||
query = select(ProcedureModel).where(ProcedureModel.id == procedure_id)
|
||||
result = await self._session.execute(query)
|
||||
model = result.scalar_one_or_none()
|
||||
return _model_to_procedure(model) if model else None
|
||||
|
||||
# =========================================================================
|
||||
# Outcome Recording
|
||||
# =========================================================================
|
||||
|
||||
async def record_outcome(
|
||||
self,
|
||||
procedure_id: UUID,
|
||||
success: bool,
|
||||
) -> Procedure:
|
||||
"""
|
||||
Record the outcome of using a procedure.
|
||||
|
||||
Updates the success or failure count and last_used timestamp.
|
||||
|
||||
Args:
|
||||
procedure_id: Procedure that was used
|
||||
success: Whether the procedure succeeded
|
||||
|
||||
Returns:
|
||||
Updated procedure
|
||||
|
||||
Raises:
|
||||
ValueError: If procedure not found
|
||||
"""
|
||||
query = select(ProcedureModel).where(ProcedureModel.id == procedure_id)
|
||||
result = await self._session.execute(query)
|
||||
model = result.scalar_one_or_none()
|
||||
|
||||
if model is None:
|
||||
raise ValueError(f"Procedure not found: {procedure_id}")
|
||||
|
||||
now = datetime.now(UTC)
|
||||
|
||||
if success:
|
||||
stmt = (
|
||||
update(ProcedureModel)
|
||||
.where(ProcedureModel.id == procedure_id)
|
||||
.values(
|
||||
success_count=ProcedureModel.success_count + 1,
|
||||
last_used=now,
|
||||
updated_at=now,
|
||||
)
|
||||
.returning(ProcedureModel)
|
||||
)
|
||||
else:
|
||||
stmt = (
|
||||
update(ProcedureModel)
|
||||
.where(ProcedureModel.id == procedure_id)
|
||||
.values(
|
||||
failure_count=ProcedureModel.failure_count + 1,
|
||||
last_used=now,
|
||||
updated_at=now,
|
||||
)
|
||||
.returning(ProcedureModel)
|
||||
)
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
updated_model = result.scalar_one()
|
||||
await self._session.flush()
|
||||
|
||||
outcome = "success" if success else "failure"
|
||||
logger.info(
|
||||
f"Recorded {outcome} for procedure {procedure_id}: "
|
||||
f"success_rate={updated_model.success_rate:.2%}"
|
||||
)
|
||||
|
||||
return _model_to_procedure(updated_model)
|
||||
|
||||
# =========================================================================
|
||||
# Step Management
|
||||
# =========================================================================
|
||||
|
||||
async def update_steps(
|
||||
self,
|
||||
procedure_id: UUID,
|
||||
steps: list[Step],
|
||||
) -> Procedure:
|
||||
"""
|
||||
Update the steps of a procedure.
|
||||
|
||||
Args:
|
||||
procedure_id: Procedure to update
|
||||
steps: New steps
|
||||
|
||||
Returns:
|
||||
Updated procedure
|
||||
|
||||
Raises:
|
||||
ValueError: If procedure not found
|
||||
"""
|
||||
query = select(ProcedureModel).where(ProcedureModel.id == procedure_id)
|
||||
result = await self._session.execute(query)
|
||||
model = result.scalar_one_or_none()
|
||||
|
||||
if model is None:
|
||||
raise ValueError(f"Procedure not found: {procedure_id}")
|
||||
|
||||
# Convert Step objects to dictionaries
|
||||
steps_dict = [
|
||||
{
|
||||
"order": step.order,
|
||||
"action": step.action,
|
||||
"parameters": step.parameters,
|
||||
"expected_outcome": step.expected_outcome,
|
||||
"fallback_action": step.fallback_action,
|
||||
}
|
||||
for step in steps
|
||||
]
|
||||
|
||||
now = datetime.now(UTC)
|
||||
stmt = (
|
||||
update(ProcedureModel)
|
||||
.where(ProcedureModel.id == procedure_id)
|
||||
.values(
|
||||
steps=steps_dict,
|
||||
updated_at=now,
|
||||
)
|
||||
.returning(ProcedureModel)
|
||||
)
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
updated_model = result.scalar_one()
|
||||
await self._session.flush()
|
||||
|
||||
logger.info(f"Updated steps for procedure {procedure_id}: {len(steps)} steps")
|
||||
|
||||
return _model_to_procedure(updated_model)
|
||||
|
||||
# =========================================================================
|
||||
# Statistics & Management
|
||||
# =========================================================================
|
||||
|
||||
async def get_stats(
|
||||
self,
|
||||
project_id: UUID | None = None,
|
||||
agent_type_id: UUID | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Get statistics about procedural memory.
|
||||
|
||||
Args:
|
||||
project_id: Optional project to get stats for
|
||||
agent_type_id: Optional agent type filter
|
||||
|
||||
Returns:
|
||||
Dictionary with statistics
|
||||
"""
|
||||
query = select(ProcedureModel)
|
||||
|
||||
if project_id is not None:
|
||||
query = query.where(
|
||||
or_(
|
||||
ProcedureModel.project_id == project_id,
|
||||
ProcedureModel.project_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
if agent_type_id is not None:
|
||||
query = query.where(
|
||||
or_(
|
||||
ProcedureModel.agent_type_id == agent_type_id,
|
||||
ProcedureModel.agent_type_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
result = await self._session.execute(query)
|
||||
models = list(result.scalars().all())
|
||||
|
||||
if not models:
|
||||
return {
|
||||
"total_procedures": 0,
|
||||
"avg_success_rate": 0.0,
|
||||
"avg_steps_count": 0.0,
|
||||
"total_uses": 0,
|
||||
"high_success_count": 0,
|
||||
"low_success_count": 0,
|
||||
}
|
||||
|
||||
success_rates = [m.success_rate for m in models]
|
||||
step_counts = [len(m.steps or []) for m in models]
|
||||
total_uses = sum(m.total_uses for m in models)
|
||||
|
||||
return {
|
||||
"total_procedures": len(models),
|
||||
"avg_success_rate": sum(success_rates) / len(success_rates),
|
||||
"avg_steps_count": sum(step_counts) / len(step_counts),
|
||||
"total_uses": total_uses,
|
||||
"high_success_count": sum(1 for r in success_rates if r >= 0.8),
|
||||
"low_success_count": sum(1 for r in success_rates if r < 0.5),
|
||||
}
|
||||
|
||||
async def count(
|
||||
self,
|
||||
project_id: UUID | None = None,
|
||||
agent_type_id: UUID | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Count procedures in scope.
|
||||
|
||||
Args:
|
||||
project_id: Optional project to count for
|
||||
agent_type_id: Optional agent type filter
|
||||
|
||||
Returns:
|
||||
Number of procedures
|
||||
"""
|
||||
query = select(ProcedureModel)
|
||||
|
||||
if project_id is not None:
|
||||
query = query.where(
|
||||
or_(
|
||||
ProcedureModel.project_id == project_id,
|
||||
ProcedureModel.project_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
if agent_type_id is not None:
|
||||
query = query.where(
|
||||
or_(
|
||||
ProcedureModel.agent_type_id == agent_type_id,
|
||||
ProcedureModel.agent_type_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
result = await self._session.execute(query)
|
||||
return len(list(result.scalars().all()))
|
||||
|
||||
async def delete(self, procedure_id: UUID) -> bool:
|
||||
"""
|
||||
Delete a procedure.
|
||||
|
||||
Args:
|
||||
procedure_id: Procedure to delete
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
query = select(ProcedureModel).where(ProcedureModel.id == procedure_id)
|
||||
result = await self._session.execute(query)
|
||||
model = result.scalar_one_or_none()
|
||||
|
||||
if model is None:
|
||||
return False
|
||||
|
||||
await self._session.delete(model)
|
||||
await self._session.flush()
|
||||
|
||||
logger.info(f"Deleted procedure {procedure_id}")
|
||||
return True
|
||||
|
||||
async def get_procedures_by_success_rate(
|
||||
self,
|
||||
min_rate: float = 0.0,
|
||||
max_rate: float = 1.0,
|
||||
project_id: UUID | None = None,
|
||||
limit: int = 20,
|
||||
) -> list[Procedure]:
|
||||
"""
|
||||
Get procedures within a success rate range.
|
||||
|
||||
Args:
|
||||
min_rate: Minimum success rate
|
||||
max_rate: Maximum success rate
|
||||
project_id: Optional project scope
|
||||
limit: Maximum results
|
||||
|
||||
Returns:
|
||||
List of procedures
|
||||
"""
|
||||
query = (
|
||||
select(ProcedureModel)
|
||||
.order_by(desc(ProcedureModel.last_used))
|
||||
.limit(limit * 2) # Fetch more since we filter in Python
|
||||
)
|
||||
|
||||
if project_id is not None:
|
||||
query = query.where(
|
||||
or_(
|
||||
ProcedureModel.project_id == project_id,
|
||||
ProcedureModel.project_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
result = await self._session.execute(query)
|
||||
models = list(result.scalars().all())
|
||||
|
||||
# Filter by success rate in Python
|
||||
filtered = [m for m in models if min_rate <= m.success_rate <= max_rate][:limit]
|
||||
|
||||
return [_model_to_procedure(m) for m in filtered]
|
||||
38
backend/app/services/memory/reflection/__init__.py
Normal file
38
backend/app/services/memory/reflection/__init__.py
Normal file
@@ -0,0 +1,38 @@
|
||||
# app/services/memory/reflection/__init__.py
|
||||
"""
|
||||
Memory Reflection Layer.
|
||||
|
||||
Analyzes patterns in agent experiences to generate actionable insights.
|
||||
"""
|
||||
|
||||
from .service import (
|
||||
MemoryReflection,
|
||||
ReflectionConfig,
|
||||
get_memory_reflection,
|
||||
)
|
||||
from .types import (
|
||||
Anomaly,
|
||||
AnomalyType,
|
||||
Factor,
|
||||
FactorType,
|
||||
Insight,
|
||||
InsightType,
|
||||
Pattern,
|
||||
PatternType,
|
||||
TimeRange,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Anomaly",
|
||||
"AnomalyType",
|
||||
"Factor",
|
||||
"FactorType",
|
||||
"Insight",
|
||||
"InsightType",
|
||||
"MemoryReflection",
|
||||
"Pattern",
|
||||
"PatternType",
|
||||
"ReflectionConfig",
|
||||
"TimeRange",
|
||||
"get_memory_reflection",
|
||||
]
|
||||
1451
backend/app/services/memory/reflection/service.py
Normal file
1451
backend/app/services/memory/reflection/service.py
Normal file
File diff suppressed because it is too large
Load Diff
304
backend/app/services/memory/reflection/types.py
Normal file
304
backend/app/services/memory/reflection/types.py
Normal file
@@ -0,0 +1,304 @@
|
||||
# app/services/memory/reflection/types.py
|
||||
"""
|
||||
Memory Reflection Types.
|
||||
|
||||
Type definitions for pattern detection, anomaly detection, and insights.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
|
||||
def _utcnow() -> datetime:
|
||||
"""Get current UTC time as timezone-aware datetime."""
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
class PatternType(str, Enum):
|
||||
"""Types of patterns detected in episodic memory."""
|
||||
|
||||
RECURRING_SUCCESS = "recurring_success"
|
||||
RECURRING_FAILURE = "recurring_failure"
|
||||
ACTION_SEQUENCE = "action_sequence"
|
||||
CONTEXT_CORRELATION = "context_correlation"
|
||||
TEMPORAL = "temporal"
|
||||
EFFICIENCY = "efficiency"
|
||||
|
||||
|
||||
class FactorType(str, Enum):
|
||||
"""Types of factors contributing to outcomes."""
|
||||
|
||||
ACTION = "action"
|
||||
CONTEXT = "context"
|
||||
TIMING = "timing"
|
||||
RESOURCE = "resource"
|
||||
PRECEDING_STATE = "preceding_state"
|
||||
|
||||
|
||||
class AnomalyType(str, Enum):
|
||||
"""Types of anomalies detected."""
|
||||
|
||||
UNUSUAL_DURATION = "unusual_duration"
|
||||
UNEXPECTED_OUTCOME = "unexpected_outcome"
|
||||
UNUSUAL_TOKEN_USAGE = "unusual_token_usage"
|
||||
UNUSUAL_FAILURE_RATE = "unusual_failure_rate"
|
||||
UNUSUAL_ACTION_PATTERN = "unusual_action_pattern"
|
||||
|
||||
|
||||
class InsightType(str, Enum):
|
||||
"""Types of insights generated."""
|
||||
|
||||
OPTIMIZATION = "optimization"
|
||||
WARNING = "warning"
|
||||
LEARNING = "learning"
|
||||
RECOMMENDATION = "recommendation"
|
||||
TREND = "trend"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimeRange:
|
||||
"""Time range for reflection analysis."""
|
||||
|
||||
start: datetime
|
||||
end: datetime
|
||||
|
||||
@classmethod
|
||||
def last_hours(cls, hours: int = 24) -> "TimeRange":
|
||||
"""Create time range for last N hours."""
|
||||
end = _utcnow()
|
||||
start = datetime(
|
||||
end.year, end.month, end.day, end.hour, end.minute, end.second, tzinfo=UTC
|
||||
) - __import__("datetime").timedelta(hours=hours)
|
||||
return cls(start=start, end=end)
|
||||
|
||||
@classmethod
|
||||
def last_days(cls, days: int = 7) -> "TimeRange":
|
||||
"""Create time range for last N days."""
|
||||
from datetime import timedelta
|
||||
|
||||
end = _utcnow()
|
||||
start = end - timedelta(days=days)
|
||||
return cls(start=start, end=end)
|
||||
|
||||
@property
|
||||
def duration_hours(self) -> float:
|
||||
"""Get duration in hours."""
|
||||
return (self.end - self.start).total_seconds() / 3600
|
||||
|
||||
@property
|
||||
def duration_days(self) -> float:
|
||||
"""Get duration in days."""
|
||||
return (self.end - self.start).total_seconds() / 86400
|
||||
|
||||
|
||||
@dataclass
|
||||
class Pattern:
|
||||
"""A detected pattern in episodic memory."""
|
||||
|
||||
id: UUID
|
||||
pattern_type: PatternType
|
||||
name: str
|
||||
description: str
|
||||
confidence: float
|
||||
occurrence_count: int
|
||||
episode_ids: list[UUID]
|
||||
first_seen: datetime
|
||||
last_seen: datetime
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def frequency(self) -> float:
|
||||
"""Calculate pattern frequency per day."""
|
||||
duration_days = (self.last_seen - self.first_seen).total_seconds() / 86400
|
||||
if duration_days < 1:
|
||||
duration_days = 1
|
||||
return self.occurrence_count / duration_days
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"id": str(self.id),
|
||||
"pattern_type": self.pattern_type.value,
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"confidence": self.confidence,
|
||||
"occurrence_count": self.occurrence_count,
|
||||
"episode_ids": [str(eid) for eid in self.episode_ids],
|
||||
"first_seen": self.first_seen.isoformat(),
|
||||
"last_seen": self.last_seen.isoformat(),
|
||||
"frequency": self.frequency,
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Factor:
|
||||
"""A factor contributing to success or failure."""
|
||||
|
||||
id: UUID
|
||||
factor_type: FactorType
|
||||
name: str
|
||||
description: str
|
||||
impact_score: float
|
||||
correlation: float
|
||||
sample_size: int
|
||||
positive_examples: list[UUID]
|
||||
negative_examples: list[UUID]
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def net_impact(self) -> float:
|
||||
"""Calculate net impact considering sample size."""
|
||||
# Weight impact by sample confidence
|
||||
confidence_weight = min(1.0, self.sample_size / 20)
|
||||
return self.impact_score * self.correlation * confidence_weight
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"id": str(self.id),
|
||||
"factor_type": self.factor_type.value,
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"impact_score": self.impact_score,
|
||||
"correlation": self.correlation,
|
||||
"sample_size": self.sample_size,
|
||||
"positive_examples": [str(eid) for eid in self.positive_examples],
|
||||
"negative_examples": [str(eid) for eid in self.negative_examples],
|
||||
"net_impact": self.net_impact,
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Anomaly:
|
||||
"""An anomaly detected in memory patterns."""
|
||||
|
||||
id: UUID
|
||||
anomaly_type: AnomalyType
|
||||
description: str
|
||||
severity: float
|
||||
episode_ids: list[UUID]
|
||||
detected_at: datetime
|
||||
baseline_value: float
|
||||
observed_value: float
|
||||
deviation_factor: float
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def is_critical(self) -> bool:
|
||||
"""Check if anomaly is critical (severity > 0.8)."""
|
||||
return self.severity > 0.8
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"id": str(self.id),
|
||||
"anomaly_type": self.anomaly_type.value,
|
||||
"description": self.description,
|
||||
"severity": self.severity,
|
||||
"episode_ids": [str(eid) for eid in self.episode_ids],
|
||||
"detected_at": self.detected_at.isoformat(),
|
||||
"baseline_value": self.baseline_value,
|
||||
"observed_value": self.observed_value,
|
||||
"deviation_factor": self.deviation_factor,
|
||||
"is_critical": self.is_critical,
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Insight:
|
||||
"""An actionable insight generated from reflection."""
|
||||
|
||||
id: UUID
|
||||
insight_type: InsightType
|
||||
title: str
|
||||
description: str
|
||||
priority: float
|
||||
confidence: float
|
||||
source_patterns: list[UUID]
|
||||
source_factors: list[UUID]
|
||||
source_anomalies: list[UUID]
|
||||
recommended_actions: list[str]
|
||||
generated_at: datetime
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def actionable_score(self) -> float:
|
||||
"""Calculate how actionable this insight is."""
|
||||
action_weight = min(1.0, len(self.recommended_actions) / 3)
|
||||
return self.priority * self.confidence * action_weight
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"id": str(self.id),
|
||||
"insight_type": self.insight_type.value,
|
||||
"title": self.title,
|
||||
"description": self.description,
|
||||
"priority": self.priority,
|
||||
"confidence": self.confidence,
|
||||
"source_patterns": [str(pid) for pid in self.source_patterns],
|
||||
"source_factors": [str(fid) for fid in self.source_factors],
|
||||
"source_anomalies": [str(aid) for aid in self.source_anomalies],
|
||||
"recommended_actions": self.recommended_actions,
|
||||
"generated_at": self.generated_at.isoformat(),
|
||||
"actionable_score": self.actionable_score,
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReflectionResult:
|
||||
"""Result of a reflection operation."""
|
||||
|
||||
patterns: list[Pattern]
|
||||
factors: list[Factor]
|
||||
anomalies: list[Anomaly]
|
||||
insights: list[Insight]
|
||||
time_range: TimeRange
|
||||
episodes_analyzed: int
|
||||
analysis_duration_seconds: float
|
||||
generated_at: datetime = field(default_factory=_utcnow)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"patterns": [p.to_dict() for p in self.patterns],
|
||||
"factors": [f.to_dict() for f in self.factors],
|
||||
"anomalies": [a.to_dict() for a in self.anomalies],
|
||||
"insights": [i.to_dict() for i in self.insights],
|
||||
"time_range": {
|
||||
"start": self.time_range.start.isoformat(),
|
||||
"end": self.time_range.end.isoformat(),
|
||||
"duration_hours": self.time_range.duration_hours,
|
||||
},
|
||||
"episodes_analyzed": self.episodes_analyzed,
|
||||
"analysis_duration_seconds": self.analysis_duration_seconds,
|
||||
"generated_at": self.generated_at.isoformat(),
|
||||
}
|
||||
|
||||
@property
|
||||
def summary(self) -> str:
|
||||
"""Generate a summary of the reflection results."""
|
||||
lines = [
|
||||
f"Reflection Analysis ({self.time_range.duration_days:.1f} days)",
|
||||
f"Episodes analyzed: {self.episodes_analyzed}",
|
||||
"",
|
||||
f"Patterns detected: {len(self.patterns)}",
|
||||
f"Success/failure factors: {len(self.factors)}",
|
||||
f"Anomalies found: {len(self.anomalies)}",
|
||||
f"Insights generated: {len(self.insights)}",
|
||||
]
|
||||
|
||||
if self.insights:
|
||||
lines.append("")
|
||||
lines.append("Top insights:")
|
||||
for insight in sorted(self.insights, key=lambda i: -i.priority)[:3]:
|
||||
lines.append(f" - [{insight.insight_type.value}] {insight.title}")
|
||||
|
||||
return "\n".join(lines)
|
||||
33
backend/app/services/memory/scoping/__init__.py
Normal file
33
backend/app/services/memory/scoping/__init__.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# app/services/memory/scoping/__init__.py
|
||||
"""
|
||||
Memory Scoping
|
||||
|
||||
Hierarchical scoping for memory with inheritance:
|
||||
Global -> Project -> Agent Type -> Agent Instance -> Session
|
||||
"""
|
||||
|
||||
from .resolver import (
|
||||
ResolutionOptions,
|
||||
ResolutionResult,
|
||||
ScopeFilter,
|
||||
ScopeResolver,
|
||||
get_scope_resolver,
|
||||
)
|
||||
from .scope import (
|
||||
ScopeInfo,
|
||||
ScopeManager,
|
||||
ScopePolicy,
|
||||
get_scope_manager,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ResolutionOptions",
|
||||
"ResolutionResult",
|
||||
"ScopeFilter",
|
||||
"ScopeInfo",
|
||||
"ScopeManager",
|
||||
"ScopePolicy",
|
||||
"ScopeResolver",
|
||||
"get_scope_manager",
|
||||
"get_scope_resolver",
|
||||
]
|
||||
390
backend/app/services/memory/scoping/resolver.py
Normal file
390
backend/app/services/memory/scoping/resolver.py
Normal file
@@ -0,0 +1,390 @@
|
||||
# app/services/memory/scoping/resolver.py
|
||||
"""
|
||||
Scope Resolution.
|
||||
|
||||
Provides utilities for resolving memory queries across scope hierarchies,
|
||||
implementing inheritance and aggregation of memories from parent scopes.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from app.services.memory.types import ScopeContext, ScopeLevel
|
||||
|
||||
from .scope import ScopeManager, get_scope_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResolutionResult[T]:
|
||||
"""Result of a scope resolution."""
|
||||
|
||||
items: list[T]
|
||||
sources: list[ScopeContext]
|
||||
total_from_each: dict[str, int] = field(default_factory=dict)
|
||||
inherited_count: int = 0
|
||||
own_count: int = 0
|
||||
|
||||
@property
|
||||
def total_count(self) -> int:
|
||||
"""Get total items from all sources."""
|
||||
return len(self.items)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResolutionOptions:
|
||||
"""Options for scope resolution."""
|
||||
|
||||
include_inherited: bool = True
|
||||
max_inheritance_depth: int = 5
|
||||
limit_per_scope: int = 100
|
||||
total_limit: int = 500
|
||||
deduplicate: bool = True
|
||||
deduplicate_key: str | None = None # Field to use for deduplication
|
||||
|
||||
|
||||
class ScopeResolver:
|
||||
"""
|
||||
Resolves memory queries across scope hierarchies.
|
||||
|
||||
Features:
|
||||
- Traverse scope hierarchy for inherited memories
|
||||
- Aggregate results from multiple scope levels
|
||||
- Apply access control policies
|
||||
- Support deduplication across scopes
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
manager: ScopeManager | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the resolver.
|
||||
|
||||
Args:
|
||||
manager: Scope manager to use (defaults to singleton)
|
||||
"""
|
||||
self._manager = manager or get_scope_manager()
|
||||
|
||||
def resolve(
|
||||
self,
|
||||
scope: ScopeContext,
|
||||
fetcher: Callable[[ScopeContext, int], list[T]],
|
||||
options: ResolutionOptions | None = None,
|
||||
) -> ResolutionResult[T]:
|
||||
"""
|
||||
Resolve memories for a scope, including inherited memories.
|
||||
|
||||
Args:
|
||||
scope: Starting scope
|
||||
fetcher: Function to fetch items for a scope (scope, limit) -> items
|
||||
options: Resolution options
|
||||
|
||||
Returns:
|
||||
Resolution result with items from all scopes
|
||||
"""
|
||||
opts = options or ResolutionOptions()
|
||||
|
||||
all_items: list[T] = []
|
||||
sources: list[ScopeContext] = []
|
||||
counts: dict[str, int] = {}
|
||||
seen_keys: set[Any] = set()
|
||||
|
||||
# Collect scopes to query (starting from current, going up to ancestors)
|
||||
scopes_to_query = self._collect_queryable_scopes(
|
||||
scope=scope,
|
||||
max_depth=opts.max_inheritance_depth if opts.include_inherited else 0,
|
||||
)
|
||||
|
||||
own_count = 0
|
||||
inherited_count = 0
|
||||
remaining_limit = opts.total_limit
|
||||
|
||||
for i, query_scope in enumerate(scopes_to_query):
|
||||
if remaining_limit <= 0:
|
||||
break
|
||||
|
||||
# Check access policy
|
||||
policy = self._manager.get_policy(query_scope)
|
||||
if not policy.allows_read():
|
||||
continue
|
||||
if i > 0 and not policy.allows_inherit():
|
||||
continue
|
||||
|
||||
# Fetch items for this scope
|
||||
scope_limit = min(opts.limit_per_scope, remaining_limit)
|
||||
items = fetcher(query_scope, scope_limit)
|
||||
|
||||
# Apply deduplication
|
||||
if opts.deduplicate and opts.deduplicate_key:
|
||||
items = self._deduplicate(items, opts.deduplicate_key, seen_keys)
|
||||
|
||||
if items:
|
||||
all_items.extend(items)
|
||||
sources.append(query_scope)
|
||||
key = f"{query_scope.scope_type.value}:{query_scope.scope_id}"
|
||||
counts[key] = len(items)
|
||||
|
||||
if i == 0:
|
||||
own_count = len(items)
|
||||
else:
|
||||
inherited_count += len(items)
|
||||
|
||||
remaining_limit -= len(items)
|
||||
|
||||
logger.debug(
|
||||
f"Resolved {len(all_items)} items from {len(sources)} scopes "
|
||||
f"(own={own_count}, inherited={inherited_count})"
|
||||
)
|
||||
|
||||
return ResolutionResult(
|
||||
items=all_items[: opts.total_limit],
|
||||
sources=sources,
|
||||
total_from_each=counts,
|
||||
own_count=own_count,
|
||||
inherited_count=inherited_count,
|
||||
)
|
||||
|
||||
def _collect_queryable_scopes(
|
||||
self,
|
||||
scope: ScopeContext,
|
||||
max_depth: int,
|
||||
) -> list[ScopeContext]:
|
||||
"""Collect scopes to query, from current to ancestors."""
|
||||
scopes: list[ScopeContext] = [scope]
|
||||
|
||||
if max_depth <= 0:
|
||||
return scopes
|
||||
|
||||
current = scope.parent
|
||||
depth = 0
|
||||
|
||||
while current is not None and depth < max_depth:
|
||||
scopes.append(current)
|
||||
current = current.parent
|
||||
depth += 1
|
||||
|
||||
return scopes
|
||||
|
||||
def _deduplicate(
|
||||
self,
|
||||
items: list[T],
|
||||
key_field: str,
|
||||
seen_keys: set[Any],
|
||||
) -> list[T]:
|
||||
"""Remove duplicate items based on a key field."""
|
||||
unique: list[T] = []
|
||||
|
||||
for item in items:
|
||||
key = getattr(item, key_field, None)
|
||||
if key is None:
|
||||
# If no key, include the item
|
||||
unique.append(item)
|
||||
elif key not in seen_keys:
|
||||
seen_keys.add(key)
|
||||
unique.append(item)
|
||||
|
||||
return unique
|
||||
|
||||
def get_visible_scopes(
|
||||
self,
|
||||
scope: ScopeContext,
|
||||
) -> list[ScopeContext]:
|
||||
"""
|
||||
Get all scopes visible from a given scope.
|
||||
|
||||
A scope can see itself and all its ancestors (if inheritance allowed).
|
||||
|
||||
Args:
|
||||
scope: Starting scope
|
||||
|
||||
Returns:
|
||||
List of visible scopes (from most specific to most general)
|
||||
"""
|
||||
visible = [scope]
|
||||
|
||||
current = scope.parent
|
||||
while current is not None:
|
||||
policy = self._manager.get_policy(current)
|
||||
if policy.allows_inherit():
|
||||
visible.append(current)
|
||||
else:
|
||||
break # Stop at first non-inheritable scope
|
||||
current = current.parent
|
||||
|
||||
return visible
|
||||
|
||||
def find_write_scope(
|
||||
self,
|
||||
target_level: ScopeLevel,
|
||||
scope: ScopeContext,
|
||||
) -> ScopeContext | None:
|
||||
"""
|
||||
Find the appropriate scope for writing at a target level.
|
||||
|
||||
Walks up the hierarchy to find a scope at the target level
|
||||
that allows writing.
|
||||
|
||||
Args:
|
||||
target_level: Desired scope level
|
||||
scope: Starting scope
|
||||
|
||||
Returns:
|
||||
Scope to write to, or None if not found/not allowed
|
||||
"""
|
||||
# First check if current scope is at target level
|
||||
if scope.scope_type == target_level:
|
||||
policy = self._manager.get_policy(scope)
|
||||
return scope if policy.allows_write() else None
|
||||
|
||||
# Check ancestors
|
||||
current = scope.parent
|
||||
while current is not None:
|
||||
if current.scope_type == target_level:
|
||||
policy = self._manager.get_policy(current)
|
||||
return current if policy.allows_write() else None
|
||||
current = current.parent
|
||||
|
||||
return None
|
||||
|
||||
def resolve_scope_from_memory(
|
||||
self,
|
||||
memory_type: str,
|
||||
project_id: str | None = None,
|
||||
agent_type_id: str | None = None,
|
||||
agent_instance_id: str | None = None,
|
||||
session_id: str | None = None,
|
||||
) -> tuple[ScopeContext, ScopeLevel]:
|
||||
"""
|
||||
Resolve the appropriate scope for a memory operation.
|
||||
|
||||
Different memory types have different scope requirements:
|
||||
- working: Session or Agent Instance
|
||||
- episodic: Agent Instance or Project
|
||||
- semantic: Project or Global
|
||||
- procedural: Agent Type or Project
|
||||
|
||||
Args:
|
||||
memory_type: Type of memory
|
||||
project_id: Project ID
|
||||
agent_type_id: Agent type ID
|
||||
agent_instance_id: Agent instance ID
|
||||
session_id: Session ID
|
||||
|
||||
Returns:
|
||||
Tuple of (scope context, recommended level)
|
||||
"""
|
||||
# Build full scope chain
|
||||
scope = self._manager.create_scope_from_ids(
|
||||
project_id=project_id if project_id else None, # type: ignore[arg-type]
|
||||
agent_type_id=agent_type_id if agent_type_id else None, # type: ignore[arg-type]
|
||||
agent_instance_id=agent_instance_id if agent_instance_id else None, # type: ignore[arg-type]
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Determine recommended level based on memory type
|
||||
recommended = self._get_recommended_level(memory_type)
|
||||
|
||||
return scope, recommended
|
||||
|
||||
def _get_recommended_level(self, memory_type: str) -> ScopeLevel:
|
||||
"""Get recommended scope level for a memory type."""
|
||||
recommendations = {
|
||||
"working": ScopeLevel.SESSION,
|
||||
"episodic": ScopeLevel.AGENT_INSTANCE,
|
||||
"semantic": ScopeLevel.PROJECT,
|
||||
"procedural": ScopeLevel.AGENT_TYPE,
|
||||
}
|
||||
return recommendations.get(memory_type, ScopeLevel.PROJECT)
|
||||
|
||||
def validate_write_access(
|
||||
self,
|
||||
scope: ScopeContext,
|
||||
memory_type: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Validate that writing is allowed for the given scope and memory type.
|
||||
|
||||
Args:
|
||||
scope: Scope to validate
|
||||
memory_type: Type of memory to write
|
||||
|
||||
Returns:
|
||||
True if write is allowed
|
||||
"""
|
||||
policy = self._manager.get_policy(scope)
|
||||
|
||||
if not policy.allows_write():
|
||||
return False
|
||||
|
||||
if not policy.allows_memory_type(memory_type):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_scope_chain(
|
||||
self,
|
||||
scope: ScopeContext,
|
||||
) -> list[tuple[ScopeLevel, str]]:
|
||||
"""
|
||||
Get the scope chain as a list of (level, id) tuples.
|
||||
|
||||
Args:
|
||||
scope: Scope to get chain for
|
||||
|
||||
Returns:
|
||||
List of (level, id) tuples from root to leaf
|
||||
"""
|
||||
chain: list[tuple[ScopeLevel, str]] = []
|
||||
|
||||
# Get full hierarchy
|
||||
hierarchy = scope.get_hierarchy()
|
||||
for ctx in hierarchy:
|
||||
chain.append((ctx.scope_type, ctx.scope_id))
|
||||
|
||||
return chain
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScopeFilter:
|
||||
"""Filter for querying across scopes."""
|
||||
|
||||
scope_types: list[ScopeLevel] | None = None
|
||||
project_ids: list[str] | None = None
|
||||
agent_type_ids: list[str] | None = None
|
||||
include_global: bool = True
|
||||
|
||||
def matches(self, scope: ScopeContext) -> bool:
|
||||
"""Check if a scope matches this filter."""
|
||||
if self.scope_types and scope.scope_type not in self.scope_types:
|
||||
return False
|
||||
|
||||
if scope.scope_type == ScopeLevel.GLOBAL:
|
||||
return self.include_global
|
||||
|
||||
if scope.scope_type == ScopeLevel.PROJECT:
|
||||
if self.project_ids and scope.scope_id not in self.project_ids:
|
||||
return False
|
||||
|
||||
if scope.scope_type == ScopeLevel.AGENT_TYPE:
|
||||
if self.agent_type_ids and scope.scope_id not in self.agent_type_ids:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
# Singleton resolver instance
|
||||
_resolver: ScopeResolver | None = None
|
||||
|
||||
|
||||
def get_scope_resolver() -> ScopeResolver:
|
||||
"""Get the singleton scope resolver instance."""
|
||||
global _resolver
|
||||
if _resolver is None:
|
||||
_resolver = ScopeResolver()
|
||||
return _resolver
|
||||
472
backend/app/services/memory/scoping/scope.py
Normal file
472
backend/app/services/memory/scoping/scope.py
Normal file
@@ -0,0 +1,472 @@
|
||||
# app/services/memory/scoping/scope.py
|
||||
"""
|
||||
Scope Management.
|
||||
|
||||
Provides utilities for managing memory scopes with hierarchical inheritance:
|
||||
Global -> Project -> Agent Type -> Agent Instance -> Session
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, ClassVar
|
||||
from uuid import UUID
|
||||
|
||||
from app.services.memory.types import ScopeContext, ScopeLevel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScopePolicy:
|
||||
"""Access control policy for a scope."""
|
||||
|
||||
scope_type: ScopeLevel
|
||||
scope_id: str
|
||||
can_read: bool = True
|
||||
can_write: bool = True
|
||||
can_inherit: bool = True
|
||||
allowed_memory_types: list[str] = field(default_factory=lambda: ["all"])
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def allows_read(self) -> bool:
|
||||
"""Check if reading is allowed."""
|
||||
return self.can_read
|
||||
|
||||
def allows_write(self) -> bool:
|
||||
"""Check if writing is allowed."""
|
||||
return self.can_write
|
||||
|
||||
def allows_inherit(self) -> bool:
|
||||
"""Check if inheritance from parent is allowed."""
|
||||
return self.can_inherit
|
||||
|
||||
def allows_memory_type(self, memory_type: str) -> bool:
|
||||
"""Check if a specific memory type is allowed."""
|
||||
return (
|
||||
"all" in self.allowed_memory_types
|
||||
or memory_type in self.allowed_memory_types
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScopeInfo:
|
||||
"""Information about a scope including its hierarchy."""
|
||||
|
||||
context: ScopeContext
|
||||
policy: ScopePolicy
|
||||
parent_info: "ScopeInfo | None" = None
|
||||
child_count: int = 0
|
||||
memory_count: int = 0
|
||||
|
||||
@property
|
||||
def depth(self) -> int:
|
||||
"""Get the depth of this scope in the hierarchy."""
|
||||
count = 0
|
||||
current = self.parent_info
|
||||
while current is not None:
|
||||
count += 1
|
||||
current = current.parent_info
|
||||
return count
|
||||
|
||||
|
||||
class ScopeManager:
|
||||
"""
|
||||
Manages memory scopes and their hierarchies.
|
||||
|
||||
Provides:
|
||||
- Scope creation and validation
|
||||
- Hierarchy management
|
||||
- Access control policy management
|
||||
- Scope inheritance rules
|
||||
"""
|
||||
|
||||
# Order of scope levels from root to leaf
|
||||
SCOPE_ORDER: ClassVar[list[ScopeLevel]] = [
|
||||
ScopeLevel.GLOBAL,
|
||||
ScopeLevel.PROJECT,
|
||||
ScopeLevel.AGENT_TYPE,
|
||||
ScopeLevel.AGENT_INSTANCE,
|
||||
ScopeLevel.SESSION,
|
||||
]
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the scope manager."""
|
||||
# In-memory policy cache (would be backed by database in production)
|
||||
self._policies: dict[str, ScopePolicy] = {}
|
||||
self._default_policies = self._create_default_policies()
|
||||
|
||||
def _create_default_policies(self) -> dict[ScopeLevel, ScopePolicy]:
|
||||
"""Create default policies for each scope level."""
|
||||
return {
|
||||
ScopeLevel.GLOBAL: ScopePolicy(
|
||||
scope_type=ScopeLevel.GLOBAL,
|
||||
scope_id="global",
|
||||
can_read=True,
|
||||
can_write=False, # Global writes require special permission
|
||||
can_inherit=True,
|
||||
),
|
||||
ScopeLevel.PROJECT: ScopePolicy(
|
||||
scope_type=ScopeLevel.PROJECT,
|
||||
scope_id="default",
|
||||
can_read=True,
|
||||
can_write=True,
|
||||
can_inherit=True,
|
||||
),
|
||||
ScopeLevel.AGENT_TYPE: ScopePolicy(
|
||||
scope_type=ScopeLevel.AGENT_TYPE,
|
||||
scope_id="default",
|
||||
can_read=True,
|
||||
can_write=True,
|
||||
can_inherit=True,
|
||||
),
|
||||
ScopeLevel.AGENT_INSTANCE: ScopePolicy(
|
||||
scope_type=ScopeLevel.AGENT_INSTANCE,
|
||||
scope_id="default",
|
||||
can_read=True,
|
||||
can_write=True,
|
||||
can_inherit=True,
|
||||
),
|
||||
ScopeLevel.SESSION: ScopePolicy(
|
||||
scope_type=ScopeLevel.SESSION,
|
||||
scope_id="default",
|
||||
can_read=True,
|
||||
can_write=True,
|
||||
can_inherit=True,
|
||||
allowed_memory_types=["working"], # Sessions only allow working memory
|
||||
),
|
||||
}
|
||||
|
||||
def create_scope(
|
||||
self,
|
||||
scope_type: ScopeLevel,
|
||||
scope_id: str,
|
||||
parent: ScopeContext | None = None,
|
||||
) -> ScopeContext:
|
||||
"""
|
||||
Create a new scope context.
|
||||
|
||||
Args:
|
||||
scope_type: Level of the scope
|
||||
scope_id: Unique identifier within the level
|
||||
parent: Optional parent scope
|
||||
|
||||
Returns:
|
||||
Created scope context
|
||||
|
||||
Raises:
|
||||
ValueError: If scope hierarchy is invalid
|
||||
"""
|
||||
# Validate hierarchy
|
||||
if parent is not None:
|
||||
self._validate_parent_child(parent.scope_type, scope_type)
|
||||
|
||||
# For non-global scopes without parent, auto-create parent chain
|
||||
if parent is None and scope_type != ScopeLevel.GLOBAL:
|
||||
parent = self._create_parent_chain(scope_type, scope_id)
|
||||
|
||||
context = ScopeContext(
|
||||
scope_type=scope_type,
|
||||
scope_id=scope_id,
|
||||
parent=parent,
|
||||
)
|
||||
|
||||
logger.debug(f"Created scope: {scope_type.value}:{scope_id}")
|
||||
return context
|
||||
|
||||
def _validate_parent_child(
|
||||
self,
|
||||
parent_type: ScopeLevel,
|
||||
child_type: ScopeLevel,
|
||||
) -> None:
|
||||
"""Validate that parent-child relationship is valid."""
|
||||
parent_idx = self.SCOPE_ORDER.index(parent_type)
|
||||
child_idx = self.SCOPE_ORDER.index(child_type)
|
||||
|
||||
if child_idx <= parent_idx:
|
||||
raise ValueError(
|
||||
f"Invalid scope hierarchy: {child_type.value} cannot be child of {parent_type.value}"
|
||||
)
|
||||
|
||||
# Allow skipping levels (e.g., PROJECT -> SESSION is valid)
|
||||
# This enables flexible scope structures
|
||||
|
||||
def _create_parent_chain(
|
||||
self,
|
||||
target_type: ScopeLevel,
|
||||
scope_id: str,
|
||||
) -> ScopeContext:
|
||||
"""Create parent scope chain up to target type."""
|
||||
target_idx = self.SCOPE_ORDER.index(target_type)
|
||||
|
||||
# Start from global and build chain
|
||||
current: ScopeContext | None = None
|
||||
|
||||
for i in range(target_idx):
|
||||
level = self.SCOPE_ORDER[i]
|
||||
if level == ScopeLevel.GLOBAL:
|
||||
level_id = "global"
|
||||
else:
|
||||
# Use a default ID for intermediate levels
|
||||
level_id = f"default_{level.value}"
|
||||
|
||||
current = ScopeContext(
|
||||
scope_type=level,
|
||||
scope_id=level_id,
|
||||
parent=current,
|
||||
)
|
||||
|
||||
return current # type: ignore[return-value]
|
||||
|
||||
def create_scope_from_ids(
|
||||
self,
|
||||
project_id: UUID | None = None,
|
||||
agent_type_id: UUID | None = None,
|
||||
agent_instance_id: UUID | None = None,
|
||||
session_id: str | None = None,
|
||||
) -> ScopeContext:
|
||||
"""
|
||||
Create a scope context from individual IDs.
|
||||
|
||||
Automatically determines the most specific scope level
|
||||
based on provided IDs.
|
||||
|
||||
Args:
|
||||
project_id: Project UUID
|
||||
agent_type_id: Agent type UUID
|
||||
agent_instance_id: Agent instance UUID
|
||||
session_id: Session identifier
|
||||
|
||||
Returns:
|
||||
Scope context for the most specific level
|
||||
"""
|
||||
# Build scope chain from most general to most specific
|
||||
current: ScopeContext = ScopeContext(
|
||||
scope_type=ScopeLevel.GLOBAL,
|
||||
scope_id="global",
|
||||
parent=None,
|
||||
)
|
||||
|
||||
if project_id is not None:
|
||||
current = ScopeContext(
|
||||
scope_type=ScopeLevel.PROJECT,
|
||||
scope_id=str(project_id),
|
||||
parent=current,
|
||||
)
|
||||
|
||||
if agent_type_id is not None:
|
||||
current = ScopeContext(
|
||||
scope_type=ScopeLevel.AGENT_TYPE,
|
||||
scope_id=str(agent_type_id),
|
||||
parent=current,
|
||||
)
|
||||
|
||||
if agent_instance_id is not None:
|
||||
current = ScopeContext(
|
||||
scope_type=ScopeLevel.AGENT_INSTANCE,
|
||||
scope_id=str(agent_instance_id),
|
||||
parent=current,
|
||||
)
|
||||
|
||||
if session_id is not None:
|
||||
current = ScopeContext(
|
||||
scope_type=ScopeLevel.SESSION,
|
||||
scope_id=session_id,
|
||||
parent=current,
|
||||
)
|
||||
|
||||
return current
|
||||
|
||||
def get_policy(
|
||||
self,
|
||||
scope: ScopeContext,
|
||||
) -> ScopePolicy:
|
||||
"""
|
||||
Get the access policy for a scope.
|
||||
|
||||
Args:
|
||||
scope: Scope to get policy for
|
||||
|
||||
Returns:
|
||||
Policy for the scope
|
||||
"""
|
||||
key = self._scope_key(scope)
|
||||
|
||||
if key in self._policies:
|
||||
return self._policies[key]
|
||||
|
||||
# Return default policy for the scope level
|
||||
return self._default_policies.get(
|
||||
scope.scope_type,
|
||||
ScopePolicy(
|
||||
scope_type=scope.scope_type,
|
||||
scope_id=scope.scope_id,
|
||||
),
|
||||
)
|
||||
|
||||
def set_policy(
|
||||
self,
|
||||
scope: ScopeContext,
|
||||
policy: ScopePolicy,
|
||||
) -> None:
|
||||
"""
|
||||
Set the access policy for a scope.
|
||||
|
||||
Args:
|
||||
scope: Scope to set policy for
|
||||
policy: Policy to apply
|
||||
"""
|
||||
key = self._scope_key(scope)
|
||||
self._policies[key] = policy
|
||||
logger.info(f"Set policy for scope {key}")
|
||||
|
||||
def _scope_key(self, scope: ScopeContext) -> str:
|
||||
"""Generate a unique key for a scope."""
|
||||
return f"{scope.scope_type.value}:{scope.scope_id}"
|
||||
|
||||
def get_scope_depth(self, scope_type: ScopeLevel) -> int:
|
||||
"""Get the depth of a scope level in the hierarchy."""
|
||||
return self.SCOPE_ORDER.index(scope_type)
|
||||
|
||||
def get_parent_level(self, scope_type: ScopeLevel) -> ScopeLevel | None:
|
||||
"""Get the parent scope level for a given level."""
|
||||
idx = self.SCOPE_ORDER.index(scope_type)
|
||||
if idx == 0:
|
||||
return None
|
||||
return self.SCOPE_ORDER[idx - 1]
|
||||
|
||||
def get_child_level(self, scope_type: ScopeLevel) -> ScopeLevel | None:
|
||||
"""Get the child scope level for a given level."""
|
||||
idx = self.SCOPE_ORDER.index(scope_type)
|
||||
if idx >= len(self.SCOPE_ORDER) - 1:
|
||||
return None
|
||||
return self.SCOPE_ORDER[idx + 1]
|
||||
|
||||
def is_ancestor(
|
||||
self,
|
||||
potential_ancestor: ScopeContext,
|
||||
descendant: ScopeContext,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if one scope is an ancestor of another.
|
||||
|
||||
Args:
|
||||
potential_ancestor: Scope to check as ancestor
|
||||
descendant: Scope to check as descendant
|
||||
|
||||
Returns:
|
||||
True if ancestor relationship exists
|
||||
"""
|
||||
current = descendant.parent
|
||||
while current is not None:
|
||||
if (
|
||||
current.scope_type == potential_ancestor.scope_type
|
||||
and current.scope_id == potential_ancestor.scope_id
|
||||
):
|
||||
return True
|
||||
current = current.parent
|
||||
return False
|
||||
|
||||
def get_common_ancestor(
|
||||
self,
|
||||
scope_a: ScopeContext,
|
||||
scope_b: ScopeContext,
|
||||
) -> ScopeContext | None:
|
||||
"""
|
||||
Find the nearest common ancestor of two scopes.
|
||||
|
||||
Args:
|
||||
scope_a: First scope
|
||||
scope_b: Second scope
|
||||
|
||||
Returns:
|
||||
Common ancestor or None if none exists
|
||||
"""
|
||||
# Get ancestors of scope_a
|
||||
ancestors_a: set[str] = set()
|
||||
current: ScopeContext | None = scope_a
|
||||
while current is not None:
|
||||
ancestors_a.add(self._scope_key(current))
|
||||
current = current.parent
|
||||
|
||||
# Find first ancestor of scope_b that's in ancestors_a
|
||||
current = scope_b
|
||||
while current is not None:
|
||||
if self._scope_key(current) in ancestors_a:
|
||||
return current
|
||||
current = current.parent
|
||||
|
||||
return None
|
||||
|
||||
def can_access(
|
||||
self,
|
||||
accessor_scope: ScopeContext,
|
||||
target_scope: ScopeContext,
|
||||
operation: str = "read",
|
||||
) -> bool:
|
||||
"""
|
||||
Check if accessor scope can access target scope.
|
||||
|
||||
Access rules:
|
||||
- A scope can always access itself
|
||||
- A scope can access ancestors (if inheritance allowed)
|
||||
- A scope CANNOT access descendants (privacy)
|
||||
- Sibling scopes cannot access each other
|
||||
|
||||
Args:
|
||||
accessor_scope: Scope attempting access
|
||||
target_scope: Scope being accessed
|
||||
operation: Type of operation (read/write)
|
||||
|
||||
Returns:
|
||||
True if access is allowed
|
||||
"""
|
||||
# Same scope - always allowed
|
||||
if (
|
||||
accessor_scope.scope_type == target_scope.scope_type
|
||||
and accessor_scope.scope_id == target_scope.scope_id
|
||||
):
|
||||
policy = self.get_policy(target_scope)
|
||||
if operation == "write":
|
||||
return policy.allows_write()
|
||||
return policy.allows_read()
|
||||
|
||||
# Check if target is ancestor (inheritance)
|
||||
if self.is_ancestor(target_scope, accessor_scope):
|
||||
policy = self.get_policy(target_scope)
|
||||
if not policy.allows_inherit():
|
||||
return False
|
||||
if operation == "write":
|
||||
return policy.allows_write()
|
||||
return policy.allows_read()
|
||||
|
||||
# Check if accessor is ancestor of target (downward access)
|
||||
# This is NOT allowed - parents cannot access children's memories
|
||||
if self.is_ancestor(accessor_scope, target_scope):
|
||||
return False
|
||||
|
||||
# Sibling scopes cannot access each other
|
||||
return False
|
||||
|
||||
|
||||
# Singleton manager instance with thread-safe initialization
|
||||
_manager: ScopeManager | None = None
|
||||
_manager_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_scope_manager() -> ScopeManager:
|
||||
"""Get the singleton scope manager instance (thread-safe)."""
|
||||
global _manager
|
||||
if _manager is None:
|
||||
with _manager_lock:
|
||||
# Double-check locking pattern
|
||||
if _manager is None:
|
||||
_manager = ScopeManager()
|
||||
return _manager
|
||||
|
||||
|
||||
def reset_scope_manager() -> None:
|
||||
"""Reset the scope manager singleton (for testing)."""
|
||||
global _manager
|
||||
with _manager_lock:
|
||||
_manager = None
|
||||
27
backend/app/services/memory/semantic/__init__.py
Normal file
27
backend/app/services/memory/semantic/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
||||
# app/services/memory/semantic/__init__.py
|
||||
"""
|
||||
Semantic Memory
|
||||
|
||||
Fact storage with triple format (subject, predicate, object)
|
||||
and semantic search capabilities.
|
||||
"""
|
||||
|
||||
from .extraction import (
|
||||
ExtractedFact,
|
||||
ExtractionContext,
|
||||
FactExtractor,
|
||||
get_fact_extractor,
|
||||
)
|
||||
from .memory import SemanticMemory
|
||||
from .verification import FactConflict, FactVerifier, VerificationResult
|
||||
|
||||
__all__ = [
|
||||
"ExtractedFact",
|
||||
"ExtractionContext",
|
||||
"FactConflict",
|
||||
"FactExtractor",
|
||||
"FactVerifier",
|
||||
"SemanticMemory",
|
||||
"VerificationResult",
|
||||
"get_fact_extractor",
|
||||
]
|
||||
313
backend/app/services/memory/semantic/extraction.py
Normal file
313
backend/app/services/memory/semantic/extraction.py
Normal file
@@ -0,0 +1,313 @@
|
||||
# app/services/memory/semantic/extraction.py
|
||||
"""
|
||||
Fact Extraction from Episodes.
|
||||
|
||||
Provides utilities for extracting semantic facts (subject-predicate-object triples)
|
||||
from episodic memories and other text sources.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from app.services.memory.types import Episode, FactCreate, Outcome
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractionContext:
|
||||
"""Context for fact extraction."""
|
||||
|
||||
project_id: Any | None = None
|
||||
source_episode_id: Any | None = None
|
||||
min_confidence: float = 0.5
|
||||
max_facts_per_source: int = 10
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractedFact:
|
||||
"""A fact extracted from text before storage."""
|
||||
|
||||
subject: str
|
||||
predicate: str
|
||||
object: str
|
||||
confidence: float
|
||||
source_text: str = ""
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_fact_create(
|
||||
self,
|
||||
project_id: Any | None = None,
|
||||
source_episode_ids: list[Any] | None = None,
|
||||
) -> FactCreate:
|
||||
"""Convert to FactCreate for storage."""
|
||||
return FactCreate(
|
||||
subject=self.subject,
|
||||
predicate=self.predicate,
|
||||
object=self.object,
|
||||
confidence=self.confidence,
|
||||
project_id=project_id,
|
||||
source_episode_ids=source_episode_ids or [],
|
||||
)
|
||||
|
||||
|
||||
class FactExtractor:
|
||||
"""
|
||||
Extracts facts from episodes and text.
|
||||
|
||||
This is a rule-based extractor. In production, this would be
|
||||
replaced or augmented with LLM-based extraction for better accuracy.
|
||||
"""
|
||||
|
||||
# Common predicates we can detect
|
||||
PREDICATE_PATTERNS: ClassVar[dict[str, str]] = {
|
||||
"uses": r"(?:uses?|using|utilizes?)",
|
||||
"requires": r"(?:requires?|needs?|depends?\s+on)",
|
||||
"is_a": r"(?:is\s+a|is\s+an|are\s+a|are)",
|
||||
"has": r"(?:has|have|contains?)",
|
||||
"part_of": r"(?:part\s+of|belongs?\s+to|member\s+of)",
|
||||
"causes": r"(?:causes?|leads?\s+to|results?\s+in)",
|
||||
"prevents": r"(?:prevents?|avoids?|stops?)",
|
||||
"solves": r"(?:solves?|fixes?|resolves?)",
|
||||
}
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize extractor."""
|
||||
self._compiled_patterns = {
|
||||
pred: re.compile(pattern, re.IGNORECASE)
|
||||
for pred, pattern in self.PREDICATE_PATTERNS.items()
|
||||
}
|
||||
|
||||
def extract_from_episode(
|
||||
self,
|
||||
episode: Episode,
|
||||
context: ExtractionContext | None = None,
|
||||
) -> list[ExtractedFact]:
|
||||
"""
|
||||
Extract facts from an episode.
|
||||
|
||||
Args:
|
||||
episode: Episode to extract from
|
||||
context: Optional extraction context
|
||||
|
||||
Returns:
|
||||
List of extracted facts
|
||||
"""
|
||||
ctx = context or ExtractionContext()
|
||||
facts: list[ExtractedFact] = []
|
||||
|
||||
# Extract from task description
|
||||
task_facts = self._extract_from_text(
|
||||
episode.task_description,
|
||||
source_prefix=episode.task_type,
|
||||
)
|
||||
facts.extend(task_facts)
|
||||
|
||||
# Extract from lessons learned
|
||||
for lesson in episode.lessons_learned:
|
||||
lesson_facts = self._extract_from_lesson(lesson, episode)
|
||||
facts.extend(lesson_facts)
|
||||
|
||||
# Extract outcome-based facts
|
||||
outcome_facts = self._extract_outcome_facts(episode)
|
||||
facts.extend(outcome_facts)
|
||||
|
||||
# Limit and filter
|
||||
facts = [f for f in facts if f.confidence >= ctx.min_confidence]
|
||||
facts = facts[: ctx.max_facts_per_source]
|
||||
|
||||
logger.debug(f"Extracted {len(facts)} facts from episode {episode.id}")
|
||||
|
||||
return facts
|
||||
|
||||
def _extract_from_text(
|
||||
self,
|
||||
text: str,
|
||||
source_prefix: str = "",
|
||||
) -> list[ExtractedFact]:
|
||||
"""Extract facts from free-form text using pattern matching."""
|
||||
facts: list[ExtractedFact] = []
|
||||
|
||||
if not text or len(text) < 10:
|
||||
return facts
|
||||
|
||||
# Split into sentences
|
||||
sentences = re.split(r"[.!?]+", text)
|
||||
|
||||
for sentence in sentences:
|
||||
sentence = sentence.strip()
|
||||
if len(sentence) < 10:
|
||||
continue
|
||||
|
||||
# Try to match predicate patterns
|
||||
for predicate, pattern in self._compiled_patterns.items():
|
||||
match = pattern.search(sentence)
|
||||
if match:
|
||||
# Extract subject (text before predicate)
|
||||
subject = sentence[: match.start()].strip()
|
||||
# Extract object (text after predicate)
|
||||
obj = sentence[match.end() :].strip()
|
||||
|
||||
if len(subject) > 2 and len(obj) > 2:
|
||||
facts.append(
|
||||
ExtractedFact(
|
||||
subject=subject[:200], # Limit length
|
||||
predicate=predicate,
|
||||
object=obj[:500],
|
||||
confidence=0.6, # Medium confidence for pattern matching
|
||||
source_text=sentence,
|
||||
)
|
||||
)
|
||||
break # One fact per sentence
|
||||
|
||||
return facts
|
||||
|
||||
def _extract_from_lesson(
|
||||
self,
|
||||
lesson: str,
|
||||
episode: Episode,
|
||||
) -> list[ExtractedFact]:
|
||||
"""Extract facts from a lesson learned."""
|
||||
facts: list[ExtractedFact] = []
|
||||
|
||||
if not lesson or len(lesson) < 10:
|
||||
return facts
|
||||
|
||||
# Lessons are typically in the form "Always do X" or "Never do Y"
|
||||
# or "When X, do Y"
|
||||
|
||||
# Direct lesson fact
|
||||
facts.append(
|
||||
ExtractedFact(
|
||||
subject=episode.task_type,
|
||||
predicate="lesson_learned",
|
||||
object=lesson,
|
||||
confidence=0.8, # High confidence for explicit lessons
|
||||
source_text=lesson,
|
||||
metadata={"outcome": episode.outcome.value},
|
||||
)
|
||||
)
|
||||
|
||||
# Extract conditional patterns
|
||||
conditional_match = re.match(
|
||||
r"(?:when|if)\s+(.+?),\s*(.+)",
|
||||
lesson,
|
||||
re.IGNORECASE,
|
||||
)
|
||||
if conditional_match:
|
||||
condition, action = conditional_match.groups()
|
||||
facts.append(
|
||||
ExtractedFact(
|
||||
subject=condition.strip(),
|
||||
predicate="requires_action",
|
||||
object=action.strip(),
|
||||
confidence=0.7,
|
||||
source_text=lesson,
|
||||
)
|
||||
)
|
||||
|
||||
# Extract "always/never" patterns
|
||||
always_match = re.match(
|
||||
r"(?:always)\s+(.+)",
|
||||
lesson,
|
||||
re.IGNORECASE,
|
||||
)
|
||||
if always_match:
|
||||
facts.append(
|
||||
ExtractedFact(
|
||||
subject=episode.task_type,
|
||||
predicate="best_practice",
|
||||
object=always_match.group(1).strip(),
|
||||
confidence=0.85,
|
||||
source_text=lesson,
|
||||
)
|
||||
)
|
||||
|
||||
never_match = re.match(
|
||||
r"(?:never|avoid)\s+(.+)",
|
||||
lesson,
|
||||
re.IGNORECASE,
|
||||
)
|
||||
if never_match:
|
||||
facts.append(
|
||||
ExtractedFact(
|
||||
subject=episode.task_type,
|
||||
predicate="anti_pattern",
|
||||
object=never_match.group(1).strip(),
|
||||
confidence=0.85,
|
||||
source_text=lesson,
|
||||
)
|
||||
)
|
||||
|
||||
return facts
|
||||
|
||||
def _extract_outcome_facts(
|
||||
self,
|
||||
episode: Episode,
|
||||
) -> list[ExtractedFact]:
|
||||
"""Extract facts based on episode outcome."""
|
||||
facts: list[ExtractedFact] = []
|
||||
|
||||
# Create fact based on outcome
|
||||
if episode.outcome == Outcome.SUCCESS:
|
||||
if episode.outcome_details:
|
||||
facts.append(
|
||||
ExtractedFact(
|
||||
subject=episode.task_type,
|
||||
predicate="successful_approach",
|
||||
object=episode.outcome_details[:500],
|
||||
confidence=0.75,
|
||||
source_text=episode.outcome_details,
|
||||
)
|
||||
)
|
||||
elif episode.outcome == Outcome.FAILURE:
|
||||
if episode.outcome_details:
|
||||
facts.append(
|
||||
ExtractedFact(
|
||||
subject=episode.task_type,
|
||||
predicate="known_failure_mode",
|
||||
object=episode.outcome_details[:500],
|
||||
confidence=0.8, # High confidence for failures
|
||||
source_text=episode.outcome_details,
|
||||
)
|
||||
)
|
||||
|
||||
return facts
|
||||
|
||||
def extract_from_text(
|
||||
self,
|
||||
text: str,
|
||||
context: ExtractionContext | None = None,
|
||||
) -> list[ExtractedFact]:
|
||||
"""
|
||||
Extract facts from arbitrary text.
|
||||
|
||||
Args:
|
||||
text: Text to extract from
|
||||
context: Optional extraction context
|
||||
|
||||
Returns:
|
||||
List of extracted facts
|
||||
"""
|
||||
ctx = context or ExtractionContext()
|
||||
|
||||
facts = self._extract_from_text(text)
|
||||
|
||||
# Filter by confidence
|
||||
facts = [f for f in facts if f.confidence >= ctx.min_confidence]
|
||||
|
||||
return facts[: ctx.max_facts_per_source]
|
||||
|
||||
|
||||
# Singleton extractor instance
|
||||
_extractor: FactExtractor | None = None
|
||||
|
||||
|
||||
def get_fact_extractor() -> FactExtractor:
|
||||
"""Get the singleton fact extractor instance."""
|
||||
global _extractor
|
||||
if _extractor is None:
|
||||
_extractor = FactExtractor()
|
||||
return _extractor
|
||||
767
backend/app/services/memory/semantic/memory.py
Normal file
767
backend/app/services/memory/semantic/memory.py
Normal file
@@ -0,0 +1,767 @@
|
||||
# app/services/memory/semantic/memory.py
|
||||
"""
|
||||
Semantic Memory Implementation.
|
||||
|
||||
Provides fact storage and retrieval using subject-predicate-object triples.
|
||||
Supports semantic search, confidence scoring, and fact reinforcement.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_, desc, or_, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.memory.fact import Fact as FactModel
|
||||
from app.services.memory.config import get_memory_settings
|
||||
from app.services.memory.types import Episode, Fact, FactCreate, RetrievalResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _escape_like_pattern(pattern: str) -> str:
|
||||
"""
|
||||
Escape SQL LIKE/ILIKE special characters to prevent pattern injection.
|
||||
|
||||
Characters escaped:
|
||||
- % (matches zero or more characters)
|
||||
- _ (matches exactly one character)
|
||||
- \\ (escape character itself)
|
||||
|
||||
Args:
|
||||
pattern: Raw search pattern from user input
|
||||
|
||||
Returns:
|
||||
Escaped pattern safe for use in LIKE/ILIKE queries
|
||||
"""
|
||||
# Escape backslash first, then the wildcards
|
||||
return pattern.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
|
||||
|
||||
|
||||
def _model_to_fact(model: FactModel) -> Fact:
|
||||
"""Convert SQLAlchemy model to Fact dataclass."""
|
||||
# SQLAlchemy Column types are inferred as Column[T] by mypy, but at runtime
|
||||
# they return actual values. We use type: ignore to handle this mismatch.
|
||||
return Fact(
|
||||
id=model.id, # type: ignore[arg-type]
|
||||
project_id=model.project_id, # type: ignore[arg-type]
|
||||
subject=model.subject, # type: ignore[arg-type]
|
||||
predicate=model.predicate, # type: ignore[arg-type]
|
||||
object=model.object, # type: ignore[arg-type]
|
||||
confidence=model.confidence, # type: ignore[arg-type]
|
||||
source_episode_ids=model.source_episode_ids or [], # type: ignore[arg-type]
|
||||
first_learned=model.first_learned, # type: ignore[arg-type]
|
||||
last_reinforced=model.last_reinforced, # type: ignore[arg-type]
|
||||
reinforcement_count=model.reinforcement_count, # type: ignore[arg-type]
|
||||
embedding=None, # Don't expose raw embedding
|
||||
created_at=model.created_at, # type: ignore[arg-type]
|
||||
updated_at=model.updated_at, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
|
||||
class SemanticMemory:
|
||||
"""
|
||||
Semantic Memory Service.
|
||||
|
||||
Provides fact storage and retrieval:
|
||||
- Store facts as subject-predicate-object triples
|
||||
- Semantic search over facts
|
||||
- Entity-based retrieval
|
||||
- Confidence scoring and decay
|
||||
- Fact reinforcement on repeated learning
|
||||
- Conflict resolution
|
||||
|
||||
Performance target: <100ms P95 for retrieval
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
embedding_generator: Any | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize semantic memory.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
embedding_generator: Optional embedding generator for semantic search
|
||||
"""
|
||||
self._session = session
|
||||
self._embedding_generator = embedding_generator
|
||||
self._settings = get_memory_settings()
|
||||
|
||||
@classmethod
|
||||
async def create(
|
||||
cls,
|
||||
session: AsyncSession,
|
||||
embedding_generator: Any | None = None,
|
||||
) -> "SemanticMemory":
|
||||
"""
|
||||
Factory method to create SemanticMemory.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
embedding_generator: Optional embedding generator
|
||||
|
||||
Returns:
|
||||
Configured SemanticMemory instance
|
||||
"""
|
||||
return cls(session=session, embedding_generator=embedding_generator)
|
||||
|
||||
# =========================================================================
|
||||
# Fact Storage
|
||||
# =========================================================================
|
||||
|
||||
async def store_fact(self, fact: FactCreate) -> Fact:
|
||||
"""
|
||||
Store a new fact or reinforce an existing one.
|
||||
|
||||
If a fact with the same triple (subject, predicate, object) exists
|
||||
in the same scope, it will be reinforced instead of duplicated.
|
||||
|
||||
Args:
|
||||
fact: Fact data to store
|
||||
|
||||
Returns:
|
||||
The created or reinforced fact
|
||||
"""
|
||||
# Check for existing fact with same triple
|
||||
existing = await self._find_existing_fact(
|
||||
project_id=fact.project_id,
|
||||
subject=fact.subject,
|
||||
predicate=fact.predicate,
|
||||
object=fact.object,
|
||||
)
|
||||
|
||||
if existing is not None:
|
||||
# Reinforce existing fact
|
||||
return await self.reinforce_fact(
|
||||
existing.id, # type: ignore[arg-type]
|
||||
source_episode_ids=fact.source_episode_ids,
|
||||
)
|
||||
|
||||
# Create new fact
|
||||
now = datetime.now(UTC)
|
||||
|
||||
# Generate embedding if possible
|
||||
embedding = None
|
||||
if self._embedding_generator is not None:
|
||||
embedding_text = self._create_embedding_text(fact)
|
||||
embedding = await self._embedding_generator.generate(embedding_text)
|
||||
|
||||
model = FactModel(
|
||||
project_id=fact.project_id,
|
||||
subject=fact.subject,
|
||||
predicate=fact.predicate,
|
||||
object=fact.object,
|
||||
confidence=fact.confidence,
|
||||
source_episode_ids=fact.source_episode_ids,
|
||||
first_learned=now,
|
||||
last_reinforced=now,
|
||||
reinforcement_count=1,
|
||||
embedding=embedding,
|
||||
)
|
||||
|
||||
self._session.add(model)
|
||||
await self._session.flush()
|
||||
await self._session.refresh(model)
|
||||
|
||||
logger.info(
|
||||
f"Stored new fact: {fact.subject} - {fact.predicate} - {fact.object[:50]}..."
|
||||
)
|
||||
|
||||
return _model_to_fact(model)
|
||||
|
||||
async def _find_existing_fact(
|
||||
self,
|
||||
project_id: UUID | None,
|
||||
subject: str,
|
||||
predicate: str,
|
||||
object: str,
|
||||
) -> FactModel | None:
|
||||
"""Find an existing fact with the same triple in the same scope."""
|
||||
query = select(FactModel).where(
|
||||
and_(
|
||||
FactModel.subject == subject,
|
||||
FactModel.predicate == predicate,
|
||||
FactModel.object == object,
|
||||
)
|
||||
)
|
||||
|
||||
if project_id is not None:
|
||||
query = query.where(FactModel.project_id == project_id)
|
||||
else:
|
||||
query = query.where(FactModel.project_id.is_(None))
|
||||
|
||||
result = await self._session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
def _create_embedding_text(self, fact: FactCreate) -> str:
|
||||
"""Create text for embedding from fact data."""
|
||||
return f"{fact.subject} {fact.predicate} {fact.object}"
|
||||
|
||||
# =========================================================================
|
||||
# Fact Retrieval
|
||||
# =========================================================================
|
||||
|
||||
async def search_facts(
|
||||
self,
|
||||
query: str,
|
||||
project_id: UUID | None = None,
|
||||
limit: int = 10,
|
||||
min_confidence: float | None = None,
|
||||
) -> list[Fact]:
|
||||
"""
|
||||
Search for facts semantically similar to the query.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
project_id: Optional project to search within
|
||||
limit: Maximum results
|
||||
min_confidence: Optional minimum confidence filter
|
||||
|
||||
Returns:
|
||||
List of matching facts
|
||||
"""
|
||||
result = await self._search_facts_with_metadata(
|
||||
query=query,
|
||||
project_id=project_id,
|
||||
limit=limit,
|
||||
min_confidence=min_confidence,
|
||||
)
|
||||
return result.items
|
||||
|
||||
async def _search_facts_with_metadata(
|
||||
self,
|
||||
query: str,
|
||||
project_id: UUID | None = None,
|
||||
limit: int = 10,
|
||||
min_confidence: float | None = None,
|
||||
) -> RetrievalResult[Fact]:
|
||||
"""Search facts with full result metadata."""
|
||||
start_time = time.perf_counter()
|
||||
|
||||
min_conf = min_confidence or self._settings.semantic_min_confidence
|
||||
|
||||
# Build base query
|
||||
stmt = (
|
||||
select(FactModel)
|
||||
.where(FactModel.confidence >= min_conf)
|
||||
.order_by(desc(FactModel.confidence), desc(FactModel.last_reinforced))
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
# Apply project filter
|
||||
if project_id is not None:
|
||||
# Include both project-specific and global facts
|
||||
stmt = stmt.where(
|
||||
or_(
|
||||
FactModel.project_id == project_id,
|
||||
FactModel.project_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
# TODO: Implement proper vector similarity search when pgvector is integrated
|
||||
# For now, do text-based search on subject/predicate/object
|
||||
search_terms = query.lower().split()
|
||||
if search_terms:
|
||||
conditions = []
|
||||
for term in search_terms[:5]: # Limit to 5 terms
|
||||
# Escape SQL wildcards to prevent pattern injection
|
||||
escaped_term = _escape_like_pattern(term)
|
||||
term_pattern = f"%{escaped_term}%"
|
||||
conditions.append(
|
||||
or_(
|
||||
FactModel.subject.ilike(term_pattern),
|
||||
FactModel.predicate.ilike(term_pattern),
|
||||
FactModel.object.ilike(term_pattern),
|
||||
)
|
||||
)
|
||||
if conditions:
|
||||
stmt = stmt.where(or_(*conditions))
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
models = list(result.scalars().all())
|
||||
|
||||
latency_ms = (time.perf_counter() - start_time) * 1000
|
||||
|
||||
return RetrievalResult(
|
||||
items=[_model_to_fact(m) for m in models],
|
||||
total_count=len(models),
|
||||
query=query,
|
||||
retrieval_type="semantic",
|
||||
latency_ms=latency_ms,
|
||||
metadata={"min_confidence": min_conf},
|
||||
)
|
||||
|
||||
async def get_by_entity(
|
||||
self,
|
||||
entity: str,
|
||||
project_id: UUID | None = None,
|
||||
limit: int = 20,
|
||||
) -> list[Fact]:
|
||||
"""
|
||||
Get facts related to an entity (as subject or object).
|
||||
|
||||
Args:
|
||||
entity: Entity to search for
|
||||
project_id: Optional project to search within
|
||||
limit: Maximum results
|
||||
|
||||
Returns:
|
||||
List of facts mentioning the entity
|
||||
"""
|
||||
start_time = time.perf_counter()
|
||||
|
||||
# Escape SQL wildcards to prevent pattern injection
|
||||
escaped_entity = _escape_like_pattern(entity)
|
||||
entity_pattern = f"%{escaped_entity}%"
|
||||
|
||||
stmt = (
|
||||
select(FactModel)
|
||||
.where(
|
||||
or_(
|
||||
FactModel.subject.ilike(entity_pattern),
|
||||
FactModel.object.ilike(entity_pattern),
|
||||
)
|
||||
)
|
||||
.order_by(desc(FactModel.confidence), desc(FactModel.last_reinforced))
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
if project_id is not None:
|
||||
stmt = stmt.where(
|
||||
or_(
|
||||
FactModel.project_id == project_id,
|
||||
FactModel.project_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
models = list(result.scalars().all())
|
||||
|
||||
latency_ms = (time.perf_counter() - start_time) * 1000
|
||||
logger.debug(
|
||||
f"get_by_entity({entity}) returned {len(models)} facts in {latency_ms:.1f}ms"
|
||||
)
|
||||
|
||||
return [_model_to_fact(m) for m in models]
|
||||
|
||||
async def get_by_subject(
|
||||
self,
|
||||
subject: str,
|
||||
project_id: UUID | None = None,
|
||||
predicate: str | None = None,
|
||||
limit: int = 20,
|
||||
) -> list[Fact]:
|
||||
"""
|
||||
Get facts with a specific subject.
|
||||
|
||||
Args:
|
||||
subject: Subject to search for
|
||||
project_id: Optional project to search within
|
||||
predicate: Optional predicate filter
|
||||
limit: Maximum results
|
||||
|
||||
Returns:
|
||||
List of facts with matching subject
|
||||
"""
|
||||
stmt = (
|
||||
select(FactModel)
|
||||
.where(FactModel.subject == subject)
|
||||
.order_by(desc(FactModel.confidence))
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
if predicate is not None:
|
||||
stmt = stmt.where(FactModel.predicate == predicate)
|
||||
|
||||
if project_id is not None:
|
||||
stmt = stmt.where(
|
||||
or_(
|
||||
FactModel.project_id == project_id,
|
||||
FactModel.project_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
models = list(result.scalars().all())
|
||||
|
||||
return [_model_to_fact(m) for m in models]
|
||||
|
||||
async def get_by_id(self, fact_id: UUID) -> Fact | None:
|
||||
"""Get a fact by ID."""
|
||||
query = select(FactModel).where(FactModel.id == fact_id)
|
||||
result = await self._session.execute(query)
|
||||
model = result.scalar_one_or_none()
|
||||
return _model_to_fact(model) if model else None
|
||||
|
||||
# =========================================================================
|
||||
# Fact Reinforcement
|
||||
# =========================================================================
|
||||
|
||||
async def reinforce_fact(
|
||||
self,
|
||||
fact_id: UUID,
|
||||
confidence_boost: float = 0.1,
|
||||
source_episode_ids: list[UUID] | None = None,
|
||||
) -> Fact:
|
||||
"""
|
||||
Reinforce a fact, increasing its confidence.
|
||||
|
||||
Args:
|
||||
fact_id: Fact to reinforce
|
||||
confidence_boost: Amount to increase confidence (default 0.1)
|
||||
source_episode_ids: Additional source episodes
|
||||
|
||||
Returns:
|
||||
Updated fact
|
||||
|
||||
Raises:
|
||||
ValueError: If fact not found
|
||||
"""
|
||||
query = select(FactModel).where(FactModel.id == fact_id)
|
||||
result = await self._session.execute(query)
|
||||
model = result.scalar_one_or_none()
|
||||
|
||||
if model is None:
|
||||
raise ValueError(f"Fact not found: {fact_id}")
|
||||
|
||||
# Calculate new confidence (max 1.0)
|
||||
current_confidence: float = model.confidence # type: ignore[assignment]
|
||||
new_confidence = min(1.0, current_confidence + confidence_boost)
|
||||
|
||||
# Merge source episode IDs
|
||||
current_sources: list[UUID] = model.source_episode_ids or [] # type: ignore[assignment]
|
||||
if source_episode_ids:
|
||||
# Add new sources, avoiding duplicates
|
||||
new_sources = list(set(current_sources + source_episode_ids))
|
||||
else:
|
||||
new_sources = current_sources
|
||||
|
||||
now = datetime.now(UTC)
|
||||
stmt = (
|
||||
update(FactModel)
|
||||
.where(FactModel.id == fact_id)
|
||||
.values(
|
||||
confidence=new_confidence,
|
||||
source_episode_ids=new_sources,
|
||||
last_reinforced=now,
|
||||
reinforcement_count=FactModel.reinforcement_count + 1,
|
||||
updated_at=now,
|
||||
)
|
||||
.returning(FactModel)
|
||||
)
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
updated_model = result.scalar_one()
|
||||
await self._session.flush()
|
||||
|
||||
logger.info(
|
||||
f"Reinforced fact {fact_id}: confidence {current_confidence:.2f} -> {new_confidence:.2f}"
|
||||
)
|
||||
|
||||
return _model_to_fact(updated_model)
|
||||
|
||||
async def deprecate_fact(
|
||||
self,
|
||||
fact_id: UUID,
|
||||
reason: str,
|
||||
new_confidence: float = 0.0,
|
||||
) -> Fact | None:
|
||||
"""
|
||||
Deprecate a fact by lowering its confidence.
|
||||
|
||||
Args:
|
||||
fact_id: Fact to deprecate
|
||||
reason: Reason for deprecation
|
||||
new_confidence: New confidence level (default 0.0)
|
||||
|
||||
Returns:
|
||||
Updated fact or None if not found
|
||||
"""
|
||||
query = select(FactModel).where(FactModel.id == fact_id)
|
||||
result = await self._session.execute(query)
|
||||
model = result.scalar_one_or_none()
|
||||
|
||||
if model is None:
|
||||
return None
|
||||
|
||||
now = datetime.now(UTC)
|
||||
stmt = (
|
||||
update(FactModel)
|
||||
.where(FactModel.id == fact_id)
|
||||
.values(
|
||||
confidence=max(0.0, new_confidence),
|
||||
updated_at=now,
|
||||
)
|
||||
.returning(FactModel)
|
||||
)
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
updated_model = result.scalar_one_or_none()
|
||||
await self._session.flush()
|
||||
|
||||
logger.info(f"Deprecated fact {fact_id}: {reason}")
|
||||
|
||||
return _model_to_fact(updated_model) if updated_model else None
|
||||
|
||||
# =========================================================================
|
||||
# Fact Extraction from Episodes
|
||||
# =========================================================================
|
||||
|
||||
async def extract_facts_from_episode(
|
||||
self,
|
||||
episode: Episode,
|
||||
) -> list[Fact]:
|
||||
"""
|
||||
Extract facts from an episode.
|
||||
|
||||
This is a placeholder for LLM-based fact extraction.
|
||||
In production, this would call an LLM to analyze the episode
|
||||
and extract subject-predicate-object triples.
|
||||
|
||||
Args:
|
||||
episode: Episode to extract facts from
|
||||
|
||||
Returns:
|
||||
List of extracted facts
|
||||
"""
|
||||
# For now, extract basic facts from lessons learned
|
||||
extracted_facts: list[Fact] = []
|
||||
|
||||
for lesson in episode.lessons_learned:
|
||||
if len(lesson) > 10: # Skip very short lessons
|
||||
fact_create = FactCreate(
|
||||
subject=episode.task_type,
|
||||
predicate="lesson_learned",
|
||||
object=lesson,
|
||||
confidence=0.7, # Lessons start with moderate confidence
|
||||
project_id=episode.project_id,
|
||||
source_episode_ids=[episode.id],
|
||||
)
|
||||
fact = await self.store_fact(fact_create)
|
||||
extracted_facts.append(fact)
|
||||
|
||||
logger.debug(
|
||||
f"Extracted {len(extracted_facts)} facts from episode {episode.id}"
|
||||
)
|
||||
|
||||
return extracted_facts
|
||||
|
||||
# =========================================================================
|
||||
# Conflict Resolution
|
||||
# =========================================================================
|
||||
|
||||
async def resolve_conflict(
|
||||
self,
|
||||
fact_ids: list[UUID],
|
||||
keep_fact_id: UUID | None = None,
|
||||
) -> Fact | None:
|
||||
"""
|
||||
Resolve a conflict between multiple facts.
|
||||
|
||||
If keep_fact_id is specified, that fact is kept and others are deprecated.
|
||||
Otherwise, the fact with highest confidence is kept.
|
||||
|
||||
Args:
|
||||
fact_ids: IDs of conflicting facts
|
||||
keep_fact_id: Optional ID of fact to keep
|
||||
|
||||
Returns:
|
||||
The winning fact, or None if no facts found
|
||||
"""
|
||||
if not fact_ids:
|
||||
return None
|
||||
|
||||
# Load all facts
|
||||
query = select(FactModel).where(FactModel.id.in_(fact_ids))
|
||||
result = await self._session.execute(query)
|
||||
models = list(result.scalars().all())
|
||||
|
||||
if not models:
|
||||
return None
|
||||
|
||||
# Determine winner
|
||||
if keep_fact_id is not None:
|
||||
winner = next((m for m in models if m.id == keep_fact_id), None)
|
||||
if winner is None:
|
||||
# Fallback to highest confidence
|
||||
winner = max(models, key=lambda m: m.confidence)
|
||||
else:
|
||||
# Keep the fact with highest confidence
|
||||
winner = max(models, key=lambda m: m.confidence)
|
||||
|
||||
# Deprecate losers
|
||||
for model in models:
|
||||
if model.id != winner.id:
|
||||
await self.deprecate_fact(
|
||||
model.id, # type: ignore[arg-type]
|
||||
reason=f"Conflict resolution: superseded by {winner.id}",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Resolved conflict between {len(fact_ids)} facts, keeping {winner.id}"
|
||||
)
|
||||
|
||||
return _model_to_fact(winner)
|
||||
|
||||
# =========================================================================
|
||||
# Confidence Decay
|
||||
# =========================================================================
|
||||
|
||||
async def apply_confidence_decay(
|
||||
self,
|
||||
project_id: UUID | None = None,
|
||||
decay_factor: float = 0.01,
|
||||
) -> int:
|
||||
"""
|
||||
Apply confidence decay to facts that haven't been reinforced recently.
|
||||
|
||||
Args:
|
||||
project_id: Optional project to apply decay to
|
||||
decay_factor: Decay factor per day (default 0.01)
|
||||
|
||||
Returns:
|
||||
Number of facts affected
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
decay_days = self._settings.semantic_confidence_decay_days
|
||||
min_conf = self._settings.semantic_min_confidence
|
||||
|
||||
# Calculate cutoff date
|
||||
from datetime import timedelta
|
||||
|
||||
cutoff = now - timedelta(days=decay_days)
|
||||
|
||||
# Find facts needing decay
|
||||
query = select(FactModel).where(
|
||||
and_(
|
||||
FactModel.last_reinforced < cutoff,
|
||||
FactModel.confidence > min_conf,
|
||||
)
|
||||
)
|
||||
|
||||
if project_id is not None:
|
||||
query = query.where(FactModel.project_id == project_id)
|
||||
|
||||
result = await self._session.execute(query)
|
||||
models = list(result.scalars().all())
|
||||
|
||||
# Apply decay
|
||||
updated_count = 0
|
||||
for model in models:
|
||||
# Calculate days since last reinforcement
|
||||
days_since: float = (now - model.last_reinforced).days
|
||||
|
||||
# Calculate decay: exponential decay based on days
|
||||
decay = decay_factor * (days_since - decay_days)
|
||||
new_confidence = max(min_conf, model.confidence - decay)
|
||||
|
||||
if new_confidence != model.confidence:
|
||||
await self._session.execute(
|
||||
update(FactModel)
|
||||
.where(FactModel.id == model.id)
|
||||
.values(confidence=new_confidence, updated_at=now)
|
||||
)
|
||||
updated_count += 1
|
||||
|
||||
await self._session.flush()
|
||||
logger.info(f"Applied confidence decay to {updated_count} facts")
|
||||
|
||||
return updated_count
|
||||
|
||||
# =========================================================================
|
||||
# Statistics
|
||||
# =========================================================================
|
||||
|
||||
async def get_stats(self, project_id: UUID | None = None) -> dict[str, Any]:
|
||||
"""
|
||||
Get statistics about semantic memory.
|
||||
|
||||
Args:
|
||||
project_id: Optional project to get stats for
|
||||
|
||||
Returns:
|
||||
Dictionary with statistics
|
||||
"""
|
||||
# Get all facts for this scope
|
||||
query = select(FactModel)
|
||||
if project_id is not None:
|
||||
query = query.where(
|
||||
or_(
|
||||
FactModel.project_id == project_id,
|
||||
FactModel.project_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
result = await self._session.execute(query)
|
||||
models = list(result.scalars().all())
|
||||
|
||||
if not models:
|
||||
return {
|
||||
"total_facts": 0,
|
||||
"avg_confidence": 0.0,
|
||||
"avg_reinforcement_count": 0.0,
|
||||
"high_confidence_count": 0,
|
||||
"low_confidence_count": 0,
|
||||
}
|
||||
|
||||
confidences = [m.confidence for m in models]
|
||||
reinforcements = [m.reinforcement_count for m in models]
|
||||
|
||||
return {
|
||||
"total_facts": len(models),
|
||||
"avg_confidence": sum(confidences) / len(confidences),
|
||||
"avg_reinforcement_count": sum(reinforcements) / len(reinforcements),
|
||||
"high_confidence_count": sum(1 for c in confidences if c >= 0.8),
|
||||
"low_confidence_count": sum(1 for c in confidences if c < 0.5),
|
||||
}
|
||||
|
||||
async def count(self, project_id: UUID | None = None) -> int:
|
||||
"""
|
||||
Count facts in scope.
|
||||
|
||||
Args:
|
||||
project_id: Optional project to count for
|
||||
|
||||
Returns:
|
||||
Number of facts
|
||||
"""
|
||||
query = select(FactModel)
|
||||
if project_id is not None:
|
||||
query = query.where(
|
||||
or_(
|
||||
FactModel.project_id == project_id,
|
||||
FactModel.project_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
result = await self._session.execute(query)
|
||||
return len(list(result.scalars().all()))
|
||||
|
||||
async def delete(self, fact_id: UUID) -> bool:
|
||||
"""
|
||||
Delete a fact.
|
||||
|
||||
Args:
|
||||
fact_id: Fact to delete
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
query = select(FactModel).where(FactModel.id == fact_id)
|
||||
result = await self._session.execute(query)
|
||||
model = result.scalar_one_or_none()
|
||||
|
||||
if model is None:
|
||||
return False
|
||||
|
||||
await self._session.delete(model)
|
||||
await self._session.flush()
|
||||
|
||||
logger.info(f"Deleted fact {fact_id}")
|
||||
return True
|
||||
363
backend/app/services/memory/semantic/verification.py
Normal file
363
backend/app/services/memory/semantic/verification.py
Normal file
@@ -0,0 +1,363 @@
|
||||
# app/services/memory/semantic/verification.py
|
||||
"""
|
||||
Fact Verification.
|
||||
|
||||
Provides utilities for verifying facts, detecting conflicts,
|
||||
and managing fact consistency.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, ClassVar
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_, or_, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.memory.fact import Fact as FactModel
|
||||
from app.services.memory.types import Fact
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VerificationResult:
|
||||
"""Result of fact verification."""
|
||||
|
||||
is_valid: bool
|
||||
confidence_adjustment: float = 0.0
|
||||
conflicts: list["FactConflict"] = field(default_factory=list)
|
||||
supporting_facts: list[Fact] = field(default_factory=list)
|
||||
messages: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FactConflict:
|
||||
"""Represents a conflict between two facts."""
|
||||
|
||||
fact_a_id: UUID
|
||||
fact_b_id: UUID
|
||||
conflict_type: str # "contradiction", "superseded", "partial_overlap"
|
||||
description: str
|
||||
suggested_resolution: str | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"fact_a_id": str(self.fact_a_id),
|
||||
"fact_b_id": str(self.fact_b_id),
|
||||
"conflict_type": self.conflict_type,
|
||||
"description": self.description,
|
||||
"suggested_resolution": self.suggested_resolution,
|
||||
}
|
||||
|
||||
|
||||
class FactVerifier:
|
||||
"""
|
||||
Verifies facts and detects conflicts.
|
||||
|
||||
Provides methods to:
|
||||
- Check if a fact conflicts with existing facts
|
||||
- Find supporting evidence for a fact
|
||||
- Detect contradictions in the fact base
|
||||
"""
|
||||
|
||||
# Predicates that are opposites/contradictions
|
||||
CONTRADICTORY_PREDICATES: ClassVar[set[tuple[str, str]]] = {
|
||||
("uses", "does_not_use"),
|
||||
("requires", "does_not_require"),
|
||||
("is_a", "is_not_a"),
|
||||
("causes", "prevents"),
|
||||
("allows", "prevents"),
|
||||
("supports", "does_not_support"),
|
||||
("best_practice", "anti_pattern"),
|
||||
}
|
||||
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
"""Initialize verifier with database session."""
|
||||
self._session = session
|
||||
|
||||
async def verify_fact(
|
||||
self,
|
||||
subject: str,
|
||||
predicate: str,
|
||||
obj: str,
|
||||
project_id: UUID | None = None,
|
||||
) -> VerificationResult:
|
||||
"""
|
||||
Verify a fact against existing facts.
|
||||
|
||||
Args:
|
||||
subject: Fact subject
|
||||
predicate: Fact predicate
|
||||
obj: Fact object
|
||||
project_id: Optional project scope
|
||||
|
||||
Returns:
|
||||
VerificationResult with verification details
|
||||
"""
|
||||
result = VerificationResult(is_valid=True)
|
||||
|
||||
# Check for direct contradictions
|
||||
conflicts = await self._find_contradictions(
|
||||
subject=subject,
|
||||
predicate=predicate,
|
||||
obj=obj,
|
||||
project_id=project_id,
|
||||
)
|
||||
result.conflicts = conflicts
|
||||
|
||||
if conflicts:
|
||||
result.is_valid = False
|
||||
result.messages.append(f"Found {len(conflicts)} conflicting fact(s)")
|
||||
# Reduce confidence based on conflicts
|
||||
result.confidence_adjustment = -0.1 * len(conflicts)
|
||||
|
||||
# Find supporting facts
|
||||
supporting = await self._find_supporting_facts(
|
||||
subject=subject,
|
||||
predicate=predicate,
|
||||
project_id=project_id,
|
||||
)
|
||||
result.supporting_facts = supporting
|
||||
|
||||
if supporting:
|
||||
result.messages.append(f"Found {len(supporting)} supporting fact(s)")
|
||||
# Boost confidence based on support
|
||||
result.confidence_adjustment += 0.05 * min(len(supporting), 3)
|
||||
|
||||
return result
|
||||
|
||||
async def _find_contradictions(
|
||||
self,
|
||||
subject: str,
|
||||
predicate: str,
|
||||
obj: str,
|
||||
project_id: UUID | None = None,
|
||||
) -> list[FactConflict]:
|
||||
"""Find facts that contradict the given fact."""
|
||||
conflicts: list[FactConflict] = []
|
||||
|
||||
# Find opposite predicates
|
||||
opposite_predicates = self._get_opposite_predicates(predicate)
|
||||
|
||||
if not opposite_predicates:
|
||||
return conflicts
|
||||
|
||||
# Search for contradicting facts
|
||||
query = select(FactModel).where(
|
||||
and_(
|
||||
FactModel.subject == subject,
|
||||
FactModel.predicate.in_(opposite_predicates),
|
||||
)
|
||||
)
|
||||
|
||||
if project_id is not None:
|
||||
query = query.where(
|
||||
or_(
|
||||
FactModel.project_id == project_id,
|
||||
FactModel.project_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
result = await self._session.execute(query)
|
||||
models = list(result.scalars().all())
|
||||
|
||||
for model in models:
|
||||
conflicts.append(
|
||||
FactConflict(
|
||||
fact_a_id=model.id, # type: ignore[arg-type]
|
||||
fact_b_id=UUID(
|
||||
"00000000-0000-0000-0000-000000000000"
|
||||
), # Placeholder for new fact
|
||||
conflict_type="contradiction",
|
||||
description=(
|
||||
f"'{subject} {predicate} {obj}' contradicts "
|
||||
f"'{model.subject} {model.predicate} {model.object}'"
|
||||
),
|
||||
suggested_resolution="Keep fact with higher confidence",
|
||||
)
|
||||
)
|
||||
|
||||
return conflicts
|
||||
|
||||
def _get_opposite_predicates(self, predicate: str) -> list[str]:
|
||||
"""Get predicates that are opposite to the given predicate."""
|
||||
opposites: list[str] = []
|
||||
|
||||
for pair in self.CONTRADICTORY_PREDICATES:
|
||||
if predicate in pair:
|
||||
opposites.extend(p for p in pair if p != predicate)
|
||||
|
||||
return opposites
|
||||
|
||||
async def _find_supporting_facts(
|
||||
self,
|
||||
subject: str,
|
||||
predicate: str,
|
||||
project_id: UUID | None = None,
|
||||
) -> list[Fact]:
|
||||
"""Find facts that support the given fact."""
|
||||
# Find facts with same subject and predicate
|
||||
query = (
|
||||
select(FactModel)
|
||||
.where(
|
||||
and_(
|
||||
FactModel.subject == subject,
|
||||
FactModel.predicate == predicate,
|
||||
FactModel.confidence >= 0.5,
|
||||
)
|
||||
)
|
||||
.limit(10)
|
||||
)
|
||||
|
||||
if project_id is not None:
|
||||
query = query.where(
|
||||
or_(
|
||||
FactModel.project_id == project_id,
|
||||
FactModel.project_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
result = await self._session.execute(query)
|
||||
models = list(result.scalars().all())
|
||||
|
||||
return [self._model_to_fact(m) for m in models]
|
||||
|
||||
async def find_all_conflicts(
|
||||
self,
|
||||
project_id: UUID | None = None,
|
||||
) -> list[FactConflict]:
|
||||
"""
|
||||
Find all conflicts in the fact base.
|
||||
|
||||
Args:
|
||||
project_id: Optional project scope
|
||||
|
||||
Returns:
|
||||
List of all detected conflicts
|
||||
"""
|
||||
conflicts: list[FactConflict] = []
|
||||
|
||||
# Get all facts
|
||||
query = select(FactModel)
|
||||
if project_id is not None:
|
||||
query = query.where(
|
||||
or_(
|
||||
FactModel.project_id == project_id,
|
||||
FactModel.project_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
result = await self._session.execute(query)
|
||||
models = list(result.scalars().all())
|
||||
|
||||
# Check each pair for conflicts
|
||||
for i, fact_a in enumerate(models):
|
||||
for fact_b in models[i + 1 :]:
|
||||
conflict = self._check_pair_conflict(fact_a, fact_b)
|
||||
if conflict:
|
||||
conflicts.append(conflict)
|
||||
|
||||
logger.info(f"Found {len(conflicts)} conflicts in fact base")
|
||||
|
||||
return conflicts
|
||||
|
||||
def _check_pair_conflict(
|
||||
self,
|
||||
fact_a: FactModel,
|
||||
fact_b: FactModel,
|
||||
) -> FactConflict | None:
|
||||
"""Check if two facts conflict."""
|
||||
# Same subject?
|
||||
if fact_a.subject != fact_b.subject:
|
||||
return None
|
||||
|
||||
# Contradictory predicates?
|
||||
opposite = self._get_opposite_predicates(fact_a.predicate) # type: ignore[arg-type]
|
||||
if fact_b.predicate not in opposite:
|
||||
return None
|
||||
|
||||
return FactConflict(
|
||||
fact_a_id=fact_a.id, # type: ignore[arg-type]
|
||||
fact_b_id=fact_b.id, # type: ignore[arg-type]
|
||||
conflict_type="contradiction",
|
||||
description=(
|
||||
f"'{fact_a.subject} {fact_a.predicate} {fact_a.object}' "
|
||||
f"contradicts '{fact_b.subject} {fact_b.predicate} {fact_b.object}'"
|
||||
),
|
||||
suggested_resolution="Deprecate fact with lower confidence",
|
||||
)
|
||||
|
||||
async def get_fact_reliability_score(
|
||||
self,
|
||||
fact_id: UUID,
|
||||
) -> float:
|
||||
"""
|
||||
Calculate a reliability score for a fact.
|
||||
|
||||
Based on:
|
||||
- Confidence score
|
||||
- Number of reinforcements
|
||||
- Number of supporting facts
|
||||
- Absence of conflicts
|
||||
|
||||
Args:
|
||||
fact_id: Fact to score
|
||||
|
||||
Returns:
|
||||
Reliability score (0.0 to 1.0)
|
||||
"""
|
||||
query = select(FactModel).where(FactModel.id == fact_id)
|
||||
result = await self._session.execute(query)
|
||||
model = result.scalar_one_or_none()
|
||||
|
||||
if model is None:
|
||||
return 0.0
|
||||
|
||||
# Base score from confidence - explicitly typed to avoid Column type issues
|
||||
score: float = float(model.confidence)
|
||||
|
||||
# Boost for reinforcements (diminishing returns)
|
||||
reinforcement_boost = min(0.2, float(model.reinforcement_count) * 0.02)
|
||||
score += reinforcement_boost
|
||||
|
||||
# Find supporting facts
|
||||
supporting = await self._find_supporting_facts(
|
||||
subject=model.subject, # type: ignore[arg-type]
|
||||
predicate=model.predicate, # type: ignore[arg-type]
|
||||
project_id=model.project_id, # type: ignore[arg-type]
|
||||
)
|
||||
support_boost = min(0.1, len(supporting) * 0.02)
|
||||
score += support_boost
|
||||
|
||||
# Check for conflicts
|
||||
conflicts = await self._find_contradictions(
|
||||
subject=model.subject, # type: ignore[arg-type]
|
||||
predicate=model.predicate, # type: ignore[arg-type]
|
||||
obj=model.object, # type: ignore[arg-type]
|
||||
project_id=model.project_id, # type: ignore[arg-type]
|
||||
)
|
||||
conflict_penalty = min(0.3, len(conflicts) * 0.1)
|
||||
score -= conflict_penalty
|
||||
|
||||
# Clamp to valid range
|
||||
return max(0.0, min(1.0, score))
|
||||
|
||||
def _model_to_fact(self, model: FactModel) -> Fact:
|
||||
"""Convert SQLAlchemy model to Fact dataclass."""
|
||||
return Fact(
|
||||
id=model.id, # type: ignore[arg-type]
|
||||
project_id=model.project_id, # type: ignore[arg-type]
|
||||
subject=model.subject, # type: ignore[arg-type]
|
||||
predicate=model.predicate, # type: ignore[arg-type]
|
||||
object=model.object, # type: ignore[arg-type]
|
||||
confidence=model.confidence, # type: ignore[arg-type]
|
||||
source_episode_ids=model.source_episode_ids or [], # type: ignore[arg-type]
|
||||
first_learned=model.first_learned, # type: ignore[arg-type]
|
||||
last_reinforced=model.last_reinforced, # type: ignore[arg-type]
|
||||
reinforcement_count=model.reinforcement_count, # type: ignore[arg-type]
|
||||
embedding=None,
|
||||
created_at=model.created_at, # type: ignore[arg-type]
|
||||
updated_at=model.updated_at, # type: ignore[arg-type]
|
||||
)
|
||||
328
backend/app/services/memory/types.py
Normal file
328
backend/app/services/memory/types.py
Normal file
@@ -0,0 +1,328 @@
|
||||
"""
|
||||
Memory System Types
|
||||
|
||||
Core type definitions and interfaces for the Agent Memory System.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
|
||||
def _utcnow() -> datetime:
|
||||
"""Get current UTC time as timezone-aware datetime."""
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
class MemoryType(str, Enum):
|
||||
"""Types of memory in the agent memory system."""
|
||||
|
||||
WORKING = "working"
|
||||
EPISODIC = "episodic"
|
||||
SEMANTIC = "semantic"
|
||||
PROCEDURAL = "procedural"
|
||||
|
||||
|
||||
class ScopeLevel(str, Enum):
|
||||
"""Hierarchical scoping levels for memory."""
|
||||
|
||||
GLOBAL = "global"
|
||||
PROJECT = "project"
|
||||
AGENT_TYPE = "agent_type"
|
||||
AGENT_INSTANCE = "agent_instance"
|
||||
SESSION = "session"
|
||||
|
||||
|
||||
class Outcome(str, Enum):
|
||||
"""Outcome of a task or episode."""
|
||||
|
||||
SUCCESS = "success"
|
||||
FAILURE = "failure"
|
||||
PARTIAL = "partial"
|
||||
ABANDONED = "abandoned"
|
||||
|
||||
|
||||
class ConsolidationStatus(str, Enum):
|
||||
"""Status of a memory consolidation job."""
|
||||
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class ConsolidationType(str, Enum):
|
||||
"""Types of memory consolidation."""
|
||||
|
||||
WORKING_TO_EPISODIC = "working_to_episodic"
|
||||
EPISODIC_TO_SEMANTIC = "episodic_to_semantic"
|
||||
EPISODIC_TO_PROCEDURAL = "episodic_to_procedural"
|
||||
PRUNING = "pruning"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScopeContext:
|
||||
"""Represents a memory scope with its hierarchy."""
|
||||
|
||||
scope_type: ScopeLevel
|
||||
scope_id: str
|
||||
parent: "ScopeContext | None" = None
|
||||
|
||||
def get_hierarchy(self) -> list["ScopeContext"]:
|
||||
"""Get the full scope hierarchy from root to this scope."""
|
||||
hierarchy: list[ScopeContext] = []
|
||||
current: ScopeContext | None = self
|
||||
while current is not None:
|
||||
hierarchy.insert(0, current)
|
||||
current = current.parent
|
||||
return hierarchy
|
||||
|
||||
def to_key_prefix(self) -> str:
|
||||
"""Convert scope to a key prefix for storage."""
|
||||
return f"{self.scope_type.value}:{self.scope_id}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryItem:
|
||||
"""Base class for all memory items."""
|
||||
|
||||
id: UUID
|
||||
memory_type: MemoryType
|
||||
scope_type: ScopeLevel
|
||||
scope_id: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def get_age_seconds(self) -> float:
|
||||
"""Get the age of this memory item in seconds."""
|
||||
return (_utcnow() - self.created_at).total_seconds()
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkingMemoryItem:
|
||||
"""A key-value item in working memory."""
|
||||
|
||||
id: UUID
|
||||
scope_type: ScopeLevel
|
||||
scope_id: str
|
||||
key: str
|
||||
value: Any
|
||||
expires_at: datetime | None = None
|
||||
created_at: datetime = field(default_factory=_utcnow)
|
||||
updated_at: datetime = field(default_factory=_utcnow)
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if this item has expired."""
|
||||
if self.expires_at is None:
|
||||
return False
|
||||
return _utcnow() > self.expires_at
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskState:
|
||||
"""Current state of a task in working memory."""
|
||||
|
||||
task_id: str
|
||||
task_type: str
|
||||
description: str
|
||||
status: str = "in_progress"
|
||||
current_step: int = 0
|
||||
total_steps: int = 0
|
||||
progress_percent: float = 0.0
|
||||
context: dict[str, Any] = field(default_factory=dict)
|
||||
started_at: datetime = field(default_factory=_utcnow)
|
||||
updated_at: datetime = field(default_factory=_utcnow)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Episode:
|
||||
"""An episodic memory - a recorded experience."""
|
||||
|
||||
id: UUID
|
||||
project_id: UUID
|
||||
agent_instance_id: UUID | None
|
||||
agent_type_id: UUID | None
|
||||
session_id: str
|
||||
task_type: str
|
||||
task_description: str
|
||||
actions: list[dict[str, Any]]
|
||||
context_summary: str
|
||||
outcome: Outcome
|
||||
outcome_details: str
|
||||
duration_seconds: float
|
||||
tokens_used: int
|
||||
lessons_learned: list[str]
|
||||
importance_score: float
|
||||
embedding: list[float] | None
|
||||
occurred_at: datetime
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
@dataclass
|
||||
class EpisodeCreate:
|
||||
"""Data required to create a new episode."""
|
||||
|
||||
project_id: UUID
|
||||
session_id: str
|
||||
task_type: str
|
||||
task_description: str
|
||||
actions: list[dict[str, Any]]
|
||||
context_summary: str
|
||||
outcome: Outcome
|
||||
outcome_details: str
|
||||
duration_seconds: float
|
||||
tokens_used: int
|
||||
lessons_learned: list[str] = field(default_factory=list)
|
||||
importance_score: float = 0.5
|
||||
agent_instance_id: UUID | None = None
|
||||
agent_type_id: UUID | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Fact:
|
||||
"""A semantic memory fact - a piece of knowledge."""
|
||||
|
||||
id: UUID
|
||||
project_id: UUID | None # None for global facts
|
||||
subject: str
|
||||
predicate: str
|
||||
object: str
|
||||
confidence: float
|
||||
source_episode_ids: list[UUID]
|
||||
first_learned: datetime
|
||||
last_reinforced: datetime
|
||||
reinforcement_count: int
|
||||
embedding: list[float] | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
@dataclass
|
||||
class FactCreate:
|
||||
"""Data required to create a new fact."""
|
||||
|
||||
subject: str
|
||||
predicate: str
|
||||
object: str
|
||||
confidence: float = 0.8
|
||||
project_id: UUID | None = None
|
||||
source_episode_ids: list[UUID] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Procedure:
|
||||
"""A procedural memory - a learned skill or procedure."""
|
||||
|
||||
id: UUID
|
||||
project_id: UUID | None
|
||||
agent_type_id: UUID | None
|
||||
name: str
|
||||
trigger_pattern: str
|
||||
steps: list[dict[str, Any]]
|
||||
success_count: int
|
||||
failure_count: int
|
||||
last_used: datetime | None
|
||||
embedding: list[float] | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
@property
|
||||
def success_rate(self) -> float:
|
||||
"""Calculate the success rate of this procedure."""
|
||||
total = self.success_count + self.failure_count
|
||||
if total == 0:
|
||||
return 0.0
|
||||
return self.success_count / total
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcedureCreate:
|
||||
"""Data required to create a new procedure."""
|
||||
|
||||
name: str
|
||||
trigger_pattern: str
|
||||
steps: list[dict[str, Any]]
|
||||
project_id: UUID | None = None
|
||||
agent_type_id: UUID | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Step:
|
||||
"""A single step in a procedure."""
|
||||
|
||||
order: int
|
||||
action: str
|
||||
parameters: dict[str, Any] = field(default_factory=dict)
|
||||
expected_outcome: str = ""
|
||||
fallback_action: str | None = None
|
||||
|
||||
|
||||
class MemoryStore[T: MemoryItem](ABC):
|
||||
"""Abstract base class for memory storage backends."""
|
||||
|
||||
@abstractmethod
|
||||
async def store(self, item: T) -> T:
|
||||
"""Store a memory item."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get(self, item_id: UUID) -> T | None:
|
||||
"""Get a memory item by ID."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def delete(self, item_id: UUID) -> bool:
|
||||
"""Delete a memory item."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def list(
|
||||
self,
|
||||
scope_type: ScopeLevel | None = None,
|
||||
scope_id: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> list[T]:
|
||||
"""List memory items with optional scope filtering."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def count(
|
||||
self,
|
||||
scope_type: ScopeLevel | None = None,
|
||||
scope_id: str | None = None,
|
||||
) -> int:
|
||||
"""Count memory items with optional scope filtering."""
|
||||
...
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrievalResult[T]:
|
||||
"""Result of a memory retrieval operation."""
|
||||
|
||||
items: list[T]
|
||||
total_count: int
|
||||
query: str
|
||||
retrieval_type: str
|
||||
latency_ms: float
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryStats:
|
||||
"""Statistics about memory usage."""
|
||||
|
||||
memory_type: MemoryType
|
||||
scope_type: ScopeLevel | None
|
||||
scope_id: str | None
|
||||
item_count: int
|
||||
total_size_bytes: int
|
||||
oldest_item_age_seconds: float
|
||||
newest_item_age_seconds: float
|
||||
avg_item_size_bytes: float
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
16
backend/app/services/memory/working/__init__.py
Normal file
16
backend/app/services/memory/working/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# app/services/memory/working/__init__.py
|
||||
"""
|
||||
Working Memory Implementation.
|
||||
|
||||
Provides short-term memory storage with Redis primary and in-memory fallback.
|
||||
"""
|
||||
|
||||
from .memory import WorkingMemory
|
||||
from .storage import InMemoryStorage, RedisStorage, WorkingMemoryStorage
|
||||
|
||||
__all__ = [
|
||||
"InMemoryStorage",
|
||||
"RedisStorage",
|
||||
"WorkingMemory",
|
||||
"WorkingMemoryStorage",
|
||||
]
|
||||
544
backend/app/services/memory/working/memory.py
Normal file
544
backend/app/services/memory/working/memory.py
Normal file
@@ -0,0 +1,544 @@
|
||||
# app/services/memory/working/memory.py
|
||||
"""
|
||||
Working Memory Implementation.
|
||||
|
||||
Provides session-scoped ephemeral memory with:
|
||||
- Key-value storage with TTL
|
||||
- Task state tracking
|
||||
- Scratchpad for reasoning steps
|
||||
- Checkpoint/snapshot support
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from dataclasses import asdict
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from app.services.memory.config import get_memory_settings
|
||||
from app.services.memory.exceptions import (
|
||||
MemoryConnectionError,
|
||||
MemoryNotFoundError,
|
||||
)
|
||||
from app.services.memory.types import ScopeContext, ScopeLevel, TaskState
|
||||
|
||||
from .storage import InMemoryStorage, RedisStorage, WorkingMemoryStorage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Reserved key prefixes for internal use
|
||||
_TASK_STATE_KEY = "_task_state"
|
||||
_SCRATCHPAD_KEY = "_scratchpad"
|
||||
_CHECKPOINT_PREFIX = "_checkpoint:"
|
||||
_METADATA_KEY = "_metadata"
|
||||
|
||||
|
||||
class WorkingMemory:
|
||||
"""
|
||||
Session-scoped working memory.
|
||||
|
||||
Provides ephemeral storage for agent's current task context:
|
||||
- Variables and intermediate data
|
||||
- Task state (current step, status, progress)
|
||||
- Scratchpad for reasoning steps
|
||||
- Checkpoints for recovery
|
||||
|
||||
Uses Redis as primary storage with in-memory fallback.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scope: ScopeContext,
|
||||
storage: WorkingMemoryStorage,
|
||||
default_ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize working memory for a scope.
|
||||
|
||||
Args:
|
||||
scope: The scope context (session, agent instance, etc.)
|
||||
storage: Storage backend (use create() factory for auto-configuration)
|
||||
default_ttl_seconds: Default TTL for keys (None = no expiration)
|
||||
"""
|
||||
self._scope = scope
|
||||
self._storage: WorkingMemoryStorage = storage
|
||||
self._default_ttl = default_ttl_seconds
|
||||
self._using_fallback = False
|
||||
self._initialized = False
|
||||
|
||||
@classmethod
|
||||
async def create(
|
||||
cls,
|
||||
scope: ScopeContext,
|
||||
default_ttl_seconds: int | None = None,
|
||||
) -> "WorkingMemory":
|
||||
"""
|
||||
Factory method to create WorkingMemory with auto-configured storage.
|
||||
|
||||
Attempts Redis first, falls back to in-memory if unavailable.
|
||||
"""
|
||||
settings = get_memory_settings()
|
||||
key_prefix = f"wm:{scope.to_key_prefix()}:"
|
||||
storage: WorkingMemoryStorage
|
||||
|
||||
# Try Redis first
|
||||
if settings.working_memory_backend == "redis":
|
||||
redis_storage = RedisStorage(key_prefix=key_prefix)
|
||||
try:
|
||||
if await redis_storage.is_healthy():
|
||||
logger.debug(f"Using Redis storage for scope {scope.scope_id}")
|
||||
instance = cls(
|
||||
scope=scope,
|
||||
storage=redis_storage,
|
||||
default_ttl_seconds=default_ttl_seconds
|
||||
or settings.working_memory_default_ttl_seconds,
|
||||
)
|
||||
await instance._initialize()
|
||||
return instance
|
||||
except MemoryConnectionError:
|
||||
logger.warning("Redis unavailable, falling back to in-memory storage")
|
||||
await redis_storage.close()
|
||||
|
||||
# Fall back to in-memory
|
||||
storage = InMemoryStorage(
|
||||
max_keys=settings.working_memory_max_items_per_session
|
||||
)
|
||||
instance = cls(
|
||||
scope=scope,
|
||||
storage=storage,
|
||||
default_ttl_seconds=default_ttl_seconds
|
||||
or settings.working_memory_default_ttl_seconds,
|
||||
)
|
||||
instance._using_fallback = True
|
||||
await instance._initialize()
|
||||
return instance
|
||||
|
||||
@classmethod
|
||||
async def for_session(
|
||||
cls,
|
||||
session_id: str,
|
||||
project_id: str | None = None,
|
||||
agent_instance_id: str | None = None,
|
||||
) -> "WorkingMemory":
|
||||
"""
|
||||
Convenience factory for session-scoped working memory.
|
||||
|
||||
Args:
|
||||
session_id: Unique session identifier
|
||||
project_id: Optional project context
|
||||
agent_instance_id: Optional agent instance context
|
||||
"""
|
||||
# Build scope hierarchy
|
||||
parent = None
|
||||
if project_id:
|
||||
parent = ScopeContext(
|
||||
scope_type=ScopeLevel.PROJECT,
|
||||
scope_id=project_id,
|
||||
)
|
||||
if agent_instance_id:
|
||||
parent = ScopeContext(
|
||||
scope_type=ScopeLevel.AGENT_INSTANCE,
|
||||
scope_id=agent_instance_id,
|
||||
parent=parent,
|
||||
)
|
||||
|
||||
scope = ScopeContext(
|
||||
scope_type=ScopeLevel.SESSION,
|
||||
scope_id=session_id,
|
||||
parent=parent,
|
||||
)
|
||||
|
||||
return await cls.create(scope=scope)
|
||||
|
||||
async def _initialize(self) -> None:
|
||||
"""Initialize working memory metadata."""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
metadata = {
|
||||
"scope_type": self._scope.scope_type.value,
|
||||
"scope_id": self._scope.scope_id,
|
||||
"created_at": datetime.now(UTC).isoformat(),
|
||||
"using_fallback": self._using_fallback,
|
||||
}
|
||||
await self._storage.set(_METADATA_KEY, metadata)
|
||||
self._initialized = True
|
||||
|
||||
@property
|
||||
def scope(self) -> ScopeContext:
|
||||
"""Get the scope context."""
|
||||
return self._scope
|
||||
|
||||
@property
|
||||
def is_using_fallback(self) -> bool:
|
||||
"""Check if using fallback in-memory storage."""
|
||||
return self._using_fallback
|
||||
|
||||
# =========================================================================
|
||||
# Basic Key-Value Operations
|
||||
# =========================================================================
|
||||
|
||||
async def set(
|
||||
self,
|
||||
key: str,
|
||||
value: Any,
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Store a value.
|
||||
|
||||
Args:
|
||||
key: The key to store under
|
||||
value: The value to store (must be JSON-serializable)
|
||||
ttl_seconds: Optional TTL (uses default if not specified)
|
||||
"""
|
||||
if key.startswith("_"):
|
||||
raise ValueError("Keys starting with '_' are reserved for internal use")
|
||||
|
||||
ttl = ttl_seconds if ttl_seconds is not None else self._default_ttl
|
||||
await self._storage.set(key, value, ttl)
|
||||
|
||||
async def get(self, key: str, default: Any = None) -> Any:
|
||||
"""
|
||||
Get a value.
|
||||
|
||||
Args:
|
||||
key: The key to retrieve
|
||||
default: Default value if key not found
|
||||
|
||||
Returns:
|
||||
The stored value or default
|
||||
"""
|
||||
result = await self._storage.get(key)
|
||||
return result if result is not None else default
|
||||
|
||||
async def delete(self, key: str) -> bool:
|
||||
"""
|
||||
Delete a key.
|
||||
|
||||
Args:
|
||||
key: The key to delete
|
||||
|
||||
Returns:
|
||||
True if the key existed
|
||||
"""
|
||||
if key.startswith("_"):
|
||||
raise ValueError("Cannot delete internal keys directly")
|
||||
return await self._storage.delete(key)
|
||||
|
||||
async def exists(self, key: str) -> bool:
|
||||
"""
|
||||
Check if a key exists.
|
||||
|
||||
Args:
|
||||
key: The key to check
|
||||
|
||||
Returns:
|
||||
True if the key exists
|
||||
"""
|
||||
return await self._storage.exists(key)
|
||||
|
||||
async def list_keys(self, pattern: str = "*") -> list[str]:
|
||||
"""
|
||||
List keys matching a pattern.
|
||||
|
||||
Args:
|
||||
pattern: Glob-style pattern (default "*" for all)
|
||||
|
||||
Returns:
|
||||
List of matching keys (excludes internal keys)
|
||||
"""
|
||||
all_keys = await self._storage.list_keys(pattern)
|
||||
return [k for k in all_keys if not k.startswith("_")]
|
||||
|
||||
async def get_all(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get all user key-value pairs.
|
||||
|
||||
Returns:
|
||||
Dictionary of all key-value pairs (excludes internal keys)
|
||||
"""
|
||||
all_data = await self._storage.get_all()
|
||||
return {k: v for k, v in all_data.items() if not k.startswith("_")}
|
||||
|
||||
async def clear(self) -> int:
|
||||
"""
|
||||
Clear all user keys (preserves internal state).
|
||||
|
||||
Returns:
|
||||
Number of keys deleted
|
||||
"""
|
||||
# Save internal state
|
||||
task_state = await self._storage.get(_TASK_STATE_KEY)
|
||||
scratchpad = await self._storage.get(_SCRATCHPAD_KEY)
|
||||
metadata = await self._storage.get(_METADATA_KEY)
|
||||
|
||||
count = await self._storage.clear()
|
||||
|
||||
# Restore internal state
|
||||
if metadata is not None:
|
||||
await self._storage.set(_METADATA_KEY, metadata)
|
||||
if task_state is not None:
|
||||
await self._storage.set(_TASK_STATE_KEY, task_state)
|
||||
if scratchpad is not None:
|
||||
await self._storage.set(_SCRATCHPAD_KEY, scratchpad)
|
||||
|
||||
# Adjust count for preserved keys
|
||||
preserved = sum(1 for x in [task_state, scratchpad, metadata] if x is not None)
|
||||
return max(0, count - preserved)
|
||||
|
||||
# =========================================================================
|
||||
# Task State Operations
|
||||
# =========================================================================
|
||||
|
||||
async def set_task_state(self, state: TaskState) -> None:
|
||||
"""
|
||||
Set the current task state.
|
||||
|
||||
Args:
|
||||
state: The task state to store
|
||||
"""
|
||||
state.updated_at = datetime.now(UTC)
|
||||
await self._storage.set(_TASK_STATE_KEY, asdict(state))
|
||||
|
||||
async def get_task_state(self) -> TaskState | None:
|
||||
"""
|
||||
Get the current task state.
|
||||
|
||||
Returns:
|
||||
The current TaskState or None if not set
|
||||
"""
|
||||
data = await self._storage.get(_TASK_STATE_KEY)
|
||||
if data is None:
|
||||
return None
|
||||
|
||||
# Convert datetime strings back to datetime objects
|
||||
if isinstance(data.get("started_at"), str):
|
||||
data["started_at"] = datetime.fromisoformat(data["started_at"])
|
||||
if isinstance(data.get("updated_at"), str):
|
||||
data["updated_at"] = datetime.fromisoformat(data["updated_at"])
|
||||
|
||||
return TaskState(**data)
|
||||
|
||||
async def update_task_progress(
|
||||
self,
|
||||
current_step: int | None = None,
|
||||
progress_percent: float | None = None,
|
||||
status: str | None = None,
|
||||
) -> TaskState | None:
|
||||
"""
|
||||
Update task progress fields.
|
||||
|
||||
Args:
|
||||
current_step: New current step number
|
||||
progress_percent: New progress percentage (0.0 to 100.0)
|
||||
status: New status string
|
||||
|
||||
Returns:
|
||||
Updated TaskState or None if no task state exists
|
||||
"""
|
||||
state = await self.get_task_state()
|
||||
if state is None:
|
||||
return None
|
||||
|
||||
if current_step is not None:
|
||||
state.current_step = current_step
|
||||
if progress_percent is not None:
|
||||
state.progress_percent = min(100.0, max(0.0, progress_percent))
|
||||
if status is not None:
|
||||
state.status = status
|
||||
|
||||
await self.set_task_state(state)
|
||||
return state
|
||||
|
||||
# =========================================================================
|
||||
# Scratchpad Operations
|
||||
# =========================================================================
|
||||
|
||||
async def append_scratchpad(self, content: str) -> None:
|
||||
"""
|
||||
Append content to the scratchpad.
|
||||
|
||||
Args:
|
||||
content: Text to append
|
||||
"""
|
||||
settings = get_memory_settings()
|
||||
entries = await self._storage.get(_SCRATCHPAD_KEY) or []
|
||||
|
||||
# Check capacity
|
||||
if len(entries) >= settings.working_memory_max_items_per_session:
|
||||
# Remove oldest entries
|
||||
entries = entries[-(settings.working_memory_max_items_per_session - 1) :]
|
||||
|
||||
entry = {
|
||||
"content": content,
|
||||
"timestamp": datetime.now(UTC).isoformat(),
|
||||
}
|
||||
entries.append(entry)
|
||||
await self._storage.set(_SCRATCHPAD_KEY, entries)
|
||||
|
||||
async def get_scratchpad(self) -> list[str]:
|
||||
"""
|
||||
Get all scratchpad entries.
|
||||
|
||||
Returns:
|
||||
List of scratchpad content strings (ordered by time)
|
||||
"""
|
||||
entries = await self._storage.get(_SCRATCHPAD_KEY) or []
|
||||
return [e["content"] for e in entries]
|
||||
|
||||
async def get_scratchpad_with_timestamps(self) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Get all scratchpad entries with timestamps.
|
||||
|
||||
Returns:
|
||||
List of dicts with 'content' and 'timestamp' keys
|
||||
"""
|
||||
return await self._storage.get(_SCRATCHPAD_KEY) or []
|
||||
|
||||
async def clear_scratchpad(self) -> int:
|
||||
"""
|
||||
Clear the scratchpad.
|
||||
|
||||
Returns:
|
||||
Number of entries cleared
|
||||
"""
|
||||
entries = await self._storage.get(_SCRATCHPAD_KEY) or []
|
||||
count = len(entries)
|
||||
await self._storage.set(_SCRATCHPAD_KEY, [])
|
||||
return count
|
||||
|
||||
# =========================================================================
|
||||
# Checkpoint Operations
|
||||
# =========================================================================
|
||||
|
||||
async def create_checkpoint(self, description: str = "") -> str:
|
||||
"""
|
||||
Create a checkpoint of current state.
|
||||
|
||||
Args:
|
||||
description: Optional description of the checkpoint
|
||||
|
||||
Returns:
|
||||
Checkpoint ID for later restoration
|
||||
"""
|
||||
# Use full UUID to avoid collision risk (8 chars has ~50k collision at birthday paradox)
|
||||
checkpoint_id = str(uuid.uuid4())
|
||||
checkpoint_key = f"{_CHECKPOINT_PREFIX}{checkpoint_id}"
|
||||
|
||||
# Capture all current state
|
||||
all_data = await self._storage.get_all()
|
||||
|
||||
checkpoint = {
|
||||
"id": checkpoint_id,
|
||||
"description": description,
|
||||
"created_at": datetime.now(UTC).isoformat(),
|
||||
"data": all_data,
|
||||
}
|
||||
|
||||
await self._storage.set(checkpoint_key, checkpoint)
|
||||
logger.debug(f"Created checkpoint {checkpoint_id}")
|
||||
return checkpoint_id
|
||||
|
||||
async def restore_checkpoint(self, checkpoint_id: str) -> None:
|
||||
"""
|
||||
Restore state from a checkpoint.
|
||||
|
||||
Args:
|
||||
checkpoint_id: ID of the checkpoint to restore
|
||||
|
||||
Raises:
|
||||
MemoryNotFoundError: If checkpoint not found
|
||||
"""
|
||||
checkpoint_key = f"{_CHECKPOINT_PREFIX}{checkpoint_id}"
|
||||
checkpoint = await self._storage.get(checkpoint_key)
|
||||
|
||||
if checkpoint is None:
|
||||
raise MemoryNotFoundError(f"Checkpoint {checkpoint_id} not found")
|
||||
|
||||
# Clear current state
|
||||
await self._storage.clear()
|
||||
|
||||
# Restore all data from checkpoint
|
||||
for key, value in checkpoint["data"].items():
|
||||
await self._storage.set(key, value)
|
||||
|
||||
# Keep the checkpoint itself
|
||||
await self._storage.set(checkpoint_key, checkpoint)
|
||||
|
||||
logger.debug(f"Restored checkpoint {checkpoint_id}")
|
||||
|
||||
async def list_checkpoints(self) -> list[dict[str, Any]]:
|
||||
"""
|
||||
List all available checkpoints.
|
||||
|
||||
Returns:
|
||||
List of checkpoint metadata (id, description, created_at)
|
||||
"""
|
||||
checkpoint_keys = await self._storage.list_keys(f"{_CHECKPOINT_PREFIX}*")
|
||||
checkpoints = []
|
||||
|
||||
for key in checkpoint_keys:
|
||||
data = await self._storage.get(key)
|
||||
if data:
|
||||
checkpoints.append(
|
||||
{
|
||||
"id": data["id"],
|
||||
"description": data["description"],
|
||||
"created_at": data["created_at"],
|
||||
}
|
||||
)
|
||||
|
||||
# Sort by creation time
|
||||
checkpoints.sort(key=lambda x: x["created_at"])
|
||||
return checkpoints
|
||||
|
||||
async def delete_checkpoint(self, checkpoint_id: str) -> bool:
|
||||
"""
|
||||
Delete a checkpoint.
|
||||
|
||||
Args:
|
||||
checkpoint_id: ID of the checkpoint to delete
|
||||
|
||||
Returns:
|
||||
True if checkpoint existed
|
||||
"""
|
||||
checkpoint_key = f"{_CHECKPOINT_PREFIX}{checkpoint_id}"
|
||||
return await self._storage.delete(checkpoint_key)
|
||||
|
||||
# =========================================================================
|
||||
# Health and Lifecycle
|
||||
# =========================================================================
|
||||
|
||||
async def is_healthy(self) -> bool:
|
||||
"""Check if the working memory storage is healthy."""
|
||||
return await self._storage.is_healthy()
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the working memory storage."""
|
||||
if self._storage:
|
||||
await self._storage.close()
|
||||
|
||||
async def get_stats(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get working memory statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with stats about current state
|
||||
"""
|
||||
all_keys = await self._storage.list_keys("*")
|
||||
user_keys = [k for k in all_keys if not k.startswith("_")]
|
||||
checkpoint_keys = [k for k in all_keys if k.startswith(_CHECKPOINT_PREFIX)]
|
||||
scratchpad = await self._storage.get(_SCRATCHPAD_KEY) or []
|
||||
|
||||
return {
|
||||
"scope_type": self._scope.scope_type.value,
|
||||
"scope_id": self._scope.scope_id,
|
||||
"using_fallback": self._using_fallback,
|
||||
"total_keys": len(all_keys),
|
||||
"user_keys": len(user_keys),
|
||||
"checkpoint_count": len(checkpoint_keys),
|
||||
"scratchpad_entries": len(scratchpad),
|
||||
"has_task_state": await self._storage.exists(_TASK_STATE_KEY),
|
||||
}
|
||||
406
backend/app/services/memory/working/storage.py
Normal file
406
backend/app/services/memory/working/storage.py
Normal file
@@ -0,0 +1,406 @@
|
||||
# app/services/memory/working/storage.py
|
||||
"""
|
||||
Working Memory Storage Backends.
|
||||
|
||||
Provides abstract storage interface and implementations:
|
||||
- RedisStorage: Primary storage using Redis with connection pooling
|
||||
- InMemoryStorage: Fallback storage when Redis is unavailable
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import fnmatch
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from app.services.memory.config import get_memory_settings
|
||||
from app.services.memory.exceptions import (
|
||||
MemoryConnectionError,
|
||||
MemoryStorageError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkingMemoryStorage(ABC):
|
||||
"""Abstract base class for working memory storage backends."""
|
||||
|
||||
@abstractmethod
|
||||
async def set(
|
||||
self,
|
||||
key: str,
|
||||
value: Any,
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
"""Store a value with optional TTL."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get(self, key: str) -> Any | None:
|
||||
"""Get a value by key, returns None if not found or expired."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def delete(self, key: str) -> bool:
|
||||
"""Delete a key, returns True if existed."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def exists(self, key: str) -> bool:
|
||||
"""Check if a key exists and is not expired."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def list_keys(self, pattern: str = "*") -> list[str]:
|
||||
"""List all keys matching a pattern."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_all(self) -> dict[str, Any]:
|
||||
"""Get all key-value pairs."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def clear(self) -> int:
|
||||
"""Clear all keys, returns count of deleted keys."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def is_healthy(self) -> bool:
|
||||
"""Check if the storage backend is healthy."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def close(self) -> None:
|
||||
"""Close the storage connection."""
|
||||
...
|
||||
|
||||
|
||||
class InMemoryStorage(WorkingMemoryStorage):
|
||||
"""
|
||||
In-memory storage backend for working memory.
|
||||
|
||||
Used as fallback when Redis is unavailable. Data is not persisted
|
||||
across restarts and is not shared between processes.
|
||||
"""
|
||||
|
||||
def __init__(self, max_keys: int = 10000) -> None:
|
||||
"""Initialize in-memory storage."""
|
||||
self._data: dict[str, Any] = {}
|
||||
self._expirations: dict[str, datetime] = {}
|
||||
self._max_keys = max_keys
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
def _is_expired(self, key: str) -> bool:
|
||||
"""Check if a key has expired."""
|
||||
if key not in self._expirations:
|
||||
return False
|
||||
return datetime.now(UTC) > self._expirations[key]
|
||||
|
||||
def _cleanup_expired(self) -> None:
|
||||
"""Remove all expired keys."""
|
||||
now = datetime.now(UTC)
|
||||
expired_keys = [
|
||||
key for key, exp_time in self._expirations.items() if now > exp_time
|
||||
]
|
||||
for key in expired_keys:
|
||||
self._data.pop(key, None)
|
||||
self._expirations.pop(key, None)
|
||||
|
||||
async def set(
|
||||
self,
|
||||
key: str,
|
||||
value: Any,
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
"""Store a value with optional TTL."""
|
||||
async with self._lock:
|
||||
# Cleanup expired keys periodically
|
||||
if len(self._data) % 100 == 0:
|
||||
self._cleanup_expired()
|
||||
|
||||
# Check capacity
|
||||
if key not in self._data and len(self._data) >= self._max_keys:
|
||||
# Evict expired keys first
|
||||
self._cleanup_expired()
|
||||
if len(self._data) >= self._max_keys:
|
||||
raise MemoryStorageError(
|
||||
f"Working memory capacity exceeded: {self._max_keys} keys"
|
||||
)
|
||||
|
||||
self._data[key] = value
|
||||
if ttl_seconds is not None:
|
||||
self._expirations[key] = datetime.now(UTC) + timedelta(
|
||||
seconds=ttl_seconds
|
||||
)
|
||||
elif key in self._expirations:
|
||||
# Remove existing expiration if no TTL specified
|
||||
del self._expirations[key]
|
||||
|
||||
async def get(self, key: str) -> Any | None:
|
||||
"""Get a value by key."""
|
||||
async with self._lock:
|
||||
if key not in self._data:
|
||||
return None
|
||||
if self._is_expired(key):
|
||||
del self._data[key]
|
||||
del self._expirations[key]
|
||||
return None
|
||||
return self._data[key]
|
||||
|
||||
async def delete(self, key: str) -> bool:
|
||||
"""Delete a key."""
|
||||
async with self._lock:
|
||||
existed = key in self._data
|
||||
self._data.pop(key, None)
|
||||
self._expirations.pop(key, None)
|
||||
return existed
|
||||
|
||||
async def exists(self, key: str) -> bool:
|
||||
"""Check if a key exists and is not expired."""
|
||||
async with self._lock:
|
||||
if key not in self._data:
|
||||
return False
|
||||
if self._is_expired(key):
|
||||
del self._data[key]
|
||||
del self._expirations[key]
|
||||
return False
|
||||
return True
|
||||
|
||||
async def list_keys(self, pattern: str = "*") -> list[str]:
|
||||
"""List all keys matching a pattern."""
|
||||
async with self._lock:
|
||||
self._cleanup_expired()
|
||||
if pattern == "*":
|
||||
return list(self._data.keys())
|
||||
return [key for key in self._data.keys() if fnmatch.fnmatch(key, pattern)]
|
||||
|
||||
async def get_all(self) -> dict[str, Any]:
|
||||
"""Get all key-value pairs."""
|
||||
async with self._lock:
|
||||
self._cleanup_expired()
|
||||
return dict(self._data)
|
||||
|
||||
async def clear(self) -> int:
|
||||
"""Clear all keys."""
|
||||
async with self._lock:
|
||||
count = len(self._data)
|
||||
self._data.clear()
|
||||
self._expirations.clear()
|
||||
return count
|
||||
|
||||
async def is_healthy(self) -> bool:
|
||||
"""In-memory storage is always healthy."""
|
||||
return True
|
||||
|
||||
async def close(self) -> None:
|
||||
"""No cleanup needed for in-memory storage."""
|
||||
|
||||
|
||||
class RedisStorage(WorkingMemoryStorage):
|
||||
"""
|
||||
Redis storage backend for working memory.
|
||||
|
||||
Primary storage with connection pooling, automatic reconnection,
|
||||
and proper serialization of Python objects.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
key_prefix: str = "",
|
||||
connection_timeout: float = 5.0,
|
||||
socket_timeout: float = 5.0,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize Redis storage.
|
||||
|
||||
Args:
|
||||
key_prefix: Prefix for all keys (e.g., "session:abc123:")
|
||||
connection_timeout: Timeout for establishing connections
|
||||
socket_timeout: Timeout for socket operations
|
||||
"""
|
||||
self._key_prefix = key_prefix
|
||||
self._connection_timeout = connection_timeout
|
||||
self._socket_timeout = socket_timeout
|
||||
self._redis: Any = None
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
def _make_key(self, key: str) -> str:
|
||||
"""Add prefix to key."""
|
||||
return f"{self._key_prefix}{key}"
|
||||
|
||||
def _strip_prefix(self, key: str) -> str:
|
||||
"""Remove prefix from key."""
|
||||
if key.startswith(self._key_prefix):
|
||||
return key[len(self._key_prefix) :]
|
||||
return key
|
||||
|
||||
def _serialize(self, value: Any) -> str:
|
||||
"""Serialize a Python value to JSON string."""
|
||||
return json.dumps(value, default=str)
|
||||
|
||||
def _deserialize(self, data: str | bytes | None) -> Any | None:
|
||||
"""Deserialize a JSON string to Python value."""
|
||||
if data is None:
|
||||
return None
|
||||
if isinstance(data, bytes):
|
||||
data = data.decode("utf-8")
|
||||
return json.loads(data)
|
||||
|
||||
async def _get_client(self) -> Any:
|
||||
"""Get or create Redis client."""
|
||||
if self._redis is not None:
|
||||
return self._redis
|
||||
|
||||
async with self._lock:
|
||||
if self._redis is not None:
|
||||
return self._redis
|
||||
|
||||
try:
|
||||
import redis.asyncio as aioredis
|
||||
except ImportError as e:
|
||||
raise MemoryConnectionError(
|
||||
"redis package not installed. Install with: pip install redis"
|
||||
) from e
|
||||
|
||||
settings = get_memory_settings()
|
||||
redis_url = settings.redis_url
|
||||
|
||||
try:
|
||||
self._redis = await aioredis.from_url(
|
||||
redis_url,
|
||||
encoding="utf-8",
|
||||
decode_responses=True,
|
||||
socket_connect_timeout=self._connection_timeout,
|
||||
socket_timeout=self._socket_timeout,
|
||||
)
|
||||
# Test connection
|
||||
await self._redis.ping()
|
||||
logger.info("Connected to Redis for working memory")
|
||||
return self._redis
|
||||
except Exception as e:
|
||||
self._redis = None
|
||||
raise MemoryConnectionError(f"Failed to connect to Redis: {e}") from e
|
||||
|
||||
async def set(
|
||||
self,
|
||||
key: str,
|
||||
value: Any,
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
"""Store a value with optional TTL."""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
full_key = self._make_key(key)
|
||||
serialized = self._serialize(value)
|
||||
|
||||
if ttl_seconds is not None:
|
||||
await client.setex(full_key, ttl_seconds, serialized)
|
||||
else:
|
||||
await client.set(full_key, serialized)
|
||||
except MemoryConnectionError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise MemoryStorageError(f"Failed to set key {key}: {e}") from e
|
||||
|
||||
async def get(self, key: str) -> Any | None:
|
||||
"""Get a value by key."""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
full_key = self._make_key(key)
|
||||
data = await client.get(full_key)
|
||||
return self._deserialize(data)
|
||||
except MemoryConnectionError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise MemoryStorageError(f"Failed to get key {key}: {e}") from e
|
||||
|
||||
async def delete(self, key: str) -> bool:
|
||||
"""Delete a key."""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
full_key = self._make_key(key)
|
||||
result = await client.delete(full_key)
|
||||
return bool(result)
|
||||
except MemoryConnectionError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise MemoryStorageError(f"Failed to delete key {key}: {e}") from e
|
||||
|
||||
async def exists(self, key: str) -> bool:
|
||||
"""Check if a key exists."""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
full_key = self._make_key(key)
|
||||
result = await client.exists(full_key)
|
||||
return bool(result)
|
||||
except MemoryConnectionError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise MemoryStorageError(f"Failed to check key {key}: {e}") from e
|
||||
|
||||
async def list_keys(self, pattern: str = "*") -> list[str]:
|
||||
"""List all keys matching a pattern."""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
full_pattern = self._make_key(pattern)
|
||||
keys = await client.keys(full_pattern)
|
||||
return [self._strip_prefix(key) for key in keys]
|
||||
except MemoryConnectionError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise MemoryStorageError(f"Failed to list keys: {e}") from e
|
||||
|
||||
async def get_all(self) -> dict[str, Any]:
|
||||
"""Get all key-value pairs."""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
full_pattern = self._make_key("*")
|
||||
keys = await client.keys(full_pattern)
|
||||
|
||||
if not keys:
|
||||
return {}
|
||||
|
||||
values = await client.mget(*keys)
|
||||
result = {}
|
||||
for key, value in zip(keys, values, strict=False):
|
||||
stripped_key = self._strip_prefix(key)
|
||||
result[stripped_key] = self._deserialize(value)
|
||||
return result
|
||||
except MemoryConnectionError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise MemoryStorageError(f"Failed to get all keys: {e}") from e
|
||||
|
||||
async def clear(self) -> int:
|
||||
"""Clear all keys with this prefix."""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
full_pattern = self._make_key("*")
|
||||
keys = await client.keys(full_pattern)
|
||||
|
||||
if not keys:
|
||||
return 0
|
||||
|
||||
return await client.delete(*keys)
|
||||
except MemoryConnectionError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise MemoryStorageError(f"Failed to clear keys: {e}") from e
|
||||
|
||||
async def is_healthy(self) -> bool:
|
||||
"""Check if Redis connection is healthy."""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
await client.ping()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the Redis connection."""
|
||||
if self._redis is not None:
|
||||
await self._redis.close()
|
||||
self._redis = None
|
||||
@@ -10,14 +10,16 @@ Modules:
|
||||
sync: Issue synchronization tasks (incremental/full sync, webhooks)
|
||||
workflow: Workflow state management tasks
|
||||
cost: Cost tracking and budget monitoring tasks
|
||||
memory_consolidation: Memory consolidation tasks
|
||||
"""
|
||||
|
||||
from app.tasks import agent, cost, git, sync, workflow
|
||||
from app.tasks import agent, cost, git, memory_consolidation, sync, workflow
|
||||
|
||||
__all__ = [
|
||||
"agent",
|
||||
"cost",
|
||||
"git",
|
||||
"memory_consolidation",
|
||||
"sync",
|
||||
"workflow",
|
||||
]
|
||||
|
||||
234
backend/app/tasks/memory_consolidation.py
Normal file
234
backend/app/tasks/memory_consolidation.py
Normal file
@@ -0,0 +1,234 @@
|
||||
# app/tasks/memory_consolidation.py
|
||||
"""
|
||||
Memory consolidation Celery tasks.
|
||||
|
||||
Handles scheduled and on-demand memory consolidation:
|
||||
- Session consolidation (on session end)
|
||||
- Nightly consolidation (scheduled)
|
||||
- On-demand project consolidation
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.celery_app import celery_app
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
bind=True,
|
||||
name="app.tasks.memory_consolidation.consolidate_session",
|
||||
autoretry_for=(Exception,),
|
||||
retry_backoff=True,
|
||||
retry_kwargs={"max_retries": 3},
|
||||
)
|
||||
def consolidate_session(
|
||||
self,
|
||||
project_id: str,
|
||||
session_id: str,
|
||||
task_type: str = "session_task",
|
||||
agent_instance_id: str | None = None,
|
||||
agent_type_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Consolidate a session's working memory to episodic memory.
|
||||
|
||||
This task is triggered when an agent session ends to transfer
|
||||
relevant session data into persistent episodic memory.
|
||||
|
||||
Args:
|
||||
project_id: UUID of the project
|
||||
session_id: Session identifier
|
||||
task_type: Type of task performed
|
||||
agent_instance_id: Optional agent instance UUID
|
||||
agent_type_id: Optional agent type UUID
|
||||
|
||||
Returns:
|
||||
dict with consolidation results
|
||||
"""
|
||||
logger.info(f"Consolidating session {session_id} for project {project_id}")
|
||||
|
||||
# TODO: Implement actual consolidation
|
||||
# This will involve:
|
||||
# 1. Getting database session from async context
|
||||
# 2. Loading working memory for session
|
||||
# 3. Calling consolidation service
|
||||
# 4. Returning results
|
||||
|
||||
# Placeholder implementation
|
||||
return {
|
||||
"status": "pending",
|
||||
"project_id": project_id,
|
||||
"session_id": session_id,
|
||||
"episode_created": False,
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
bind=True,
|
||||
name="app.tasks.memory_consolidation.run_nightly_consolidation",
|
||||
autoretry_for=(Exception,),
|
||||
retry_backoff=True,
|
||||
retry_kwargs={"max_retries": 3},
|
||||
)
|
||||
def run_nightly_consolidation(
|
||||
self,
|
||||
project_id: str,
|
||||
agent_type_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Run nightly memory consolidation for a project.
|
||||
|
||||
This task performs the full consolidation workflow:
|
||||
1. Extract facts from recent episodes to semantic memory
|
||||
2. Learn procedures from successful episode patterns
|
||||
3. Prune old, low-value memories
|
||||
|
||||
Args:
|
||||
project_id: UUID of the project to consolidate
|
||||
agent_type_id: Optional agent type to filter by
|
||||
|
||||
Returns:
|
||||
dict with consolidation results
|
||||
"""
|
||||
logger.info(f"Running nightly consolidation for project {project_id}")
|
||||
|
||||
# TODO: Implement actual consolidation
|
||||
# This will involve:
|
||||
# 1. Getting database session from async context
|
||||
# 2. Creating consolidation service instance
|
||||
# 3. Running run_nightly_consolidation
|
||||
# 4. Returning results
|
||||
|
||||
# Placeholder implementation
|
||||
return {
|
||||
"status": "pending",
|
||||
"project_id": project_id,
|
||||
"total_facts_created": 0,
|
||||
"total_procedures_created": 0,
|
||||
"total_pruned": 0,
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
bind=True,
|
||||
name="app.tasks.memory_consolidation.consolidate_episodes_to_facts",
|
||||
autoretry_for=(Exception,),
|
||||
retry_backoff=True,
|
||||
retry_kwargs={"max_retries": 3},
|
||||
)
|
||||
def consolidate_episodes_to_facts(
|
||||
self,
|
||||
project_id: str,
|
||||
since_hours: int = 24,
|
||||
limit: int | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Extract facts from episodic memories.
|
||||
|
||||
Args:
|
||||
project_id: UUID of the project
|
||||
since_hours: Process episodes from last N hours
|
||||
limit: Maximum episodes to process
|
||||
|
||||
Returns:
|
||||
dict with extraction results
|
||||
"""
|
||||
logger.info(f"Consolidating episodes to facts for project {project_id}")
|
||||
|
||||
# TODO: Implement actual consolidation
|
||||
# Placeholder implementation
|
||||
return {
|
||||
"status": "pending",
|
||||
"project_id": project_id,
|
||||
"items_processed": 0,
|
||||
"items_created": 0,
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
bind=True,
|
||||
name="app.tasks.memory_consolidation.consolidate_episodes_to_procedures",
|
||||
autoretry_for=(Exception,),
|
||||
retry_backoff=True,
|
||||
retry_kwargs={"max_retries": 3},
|
||||
)
|
||||
def consolidate_episodes_to_procedures(
|
||||
self,
|
||||
project_id: str,
|
||||
agent_type_id: str | None = None,
|
||||
since_days: int = 7,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Learn procedures from episodic patterns.
|
||||
|
||||
Args:
|
||||
project_id: UUID of the project
|
||||
agent_type_id: Optional agent type filter
|
||||
since_days: Process episodes from last N days
|
||||
|
||||
Returns:
|
||||
dict with procedure learning results
|
||||
"""
|
||||
logger.info(f"Consolidating episodes to procedures for project {project_id}")
|
||||
|
||||
# TODO: Implement actual consolidation
|
||||
# Placeholder implementation
|
||||
return {
|
||||
"status": "pending",
|
||||
"project_id": project_id,
|
||||
"items_processed": 0,
|
||||
"items_created": 0,
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
bind=True,
|
||||
name="app.tasks.memory_consolidation.prune_old_memories",
|
||||
autoretry_for=(Exception,),
|
||||
retry_backoff=True,
|
||||
retry_kwargs={"max_retries": 3},
|
||||
)
|
||||
def prune_old_memories(
|
||||
self,
|
||||
project_id: str,
|
||||
max_age_days: int = 90,
|
||||
min_importance: float = 0.2,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Prune old, low-value memories.
|
||||
|
||||
Args:
|
||||
project_id: UUID of the project
|
||||
max_age_days: Maximum age in days
|
||||
min_importance: Minimum importance to keep
|
||||
|
||||
Returns:
|
||||
dict with pruning results
|
||||
"""
|
||||
logger.info(f"Pruning old memories for project {project_id}")
|
||||
|
||||
# TODO: Implement actual pruning
|
||||
# Placeholder implementation
|
||||
return {
|
||||
"status": "pending",
|
||||
"project_id": project_id,
|
||||
"items_pruned": 0,
|
||||
}
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Celery Beat Schedule Configuration
|
||||
# =========================================================================
|
||||
|
||||
# This would typically be configured in celery_app.py or a separate config file
|
||||
# Example schedule for nightly consolidation:
|
||||
#
|
||||
# app.conf.beat_schedule = {
|
||||
# 'nightly-memory-consolidation': {
|
||||
# 'task': 'app.tasks.memory_consolidation.run_nightly_consolidation',
|
||||
# 'schedule': crontab(hour=2, minute=0), # 2 AM daily
|
||||
# 'args': (None,), # Will process all projects
|
||||
# },
|
||||
# }
|
||||
128
backend/data/default_agent_types.json
Normal file
128
backend/data/default_agent_types.json
Normal file
@@ -0,0 +1,128 @@
|
||||
[
|
||||
{
|
||||
"name": "Product Owner",
|
||||
"slug": "product-owner",
|
||||
"description": "Requirements discovery, stakeholder communication, and product vision. Leads the team in defining what to build and why.",
|
||||
"expertise": ["requirements", "stakeholder-management", "product-strategy", "user-stories", "acceptance-criteria", "prioritization"],
|
||||
"personality_prompt": "You are a skilled Product Owner focused on delivering maximum value to stakeholders. You excel at:\n- Understanding and articulating business needs\n- Writing clear user stories with acceptance criteria\n- Prioritizing features based on value and effort\n- Facilitating discussions between stakeholders and technical teams\n- Making trade-off decisions when scope conflicts arise\n\nYou communicate clearly and concisely, always keeping the end user and business goals in mind. You ask clarifying questions to ensure requirements are complete before passing them to the team.",
|
||||
"primary_model": "claude-sonnet-4-20250514",
|
||||
"fallback_models": ["claude-haiku-3-5-20241022"],
|
||||
"model_params": {
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 4096,
|
||||
"top_p": 0.95
|
||||
},
|
||||
"mcp_servers": ["gitea", "knowledge-base"],
|
||||
"tool_permissions": {
|
||||
"allowed": ["*"],
|
||||
"denied": [],
|
||||
"require_approval": ["gitea:delete_*"]
|
||||
},
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"name": "Business Analyst",
|
||||
"slug": "business-analyst",
|
||||
"description": "Analysis, documentation, and detailed specifications. Bridges the gap between business needs and technical implementation.",
|
||||
"expertise": ["analysis", "documentation", "specifications", "process-modeling", "data-analysis", "requirements-engineering"],
|
||||
"personality_prompt": "You are a meticulous Business Analyst who excels at turning vague requirements into precise specifications. You:\n- Create detailed functional and technical specifications\n- Model business processes and data flows\n- Identify edge cases and potential issues early\n- Document assumptions and dependencies clearly\n- Ensure traceability between requirements and implementation\n\nYou are thorough and detail-oriented, always considering the implications of decisions. You create documentation that developers can follow without ambiguity.",
|
||||
"primary_model": "claude-sonnet-4-20250514",
|
||||
"fallback_models": ["claude-haiku-3-5-20241022"],
|
||||
"model_params": {
|
||||
"temperature": 0.5,
|
||||
"max_tokens": 8192,
|
||||
"top_p": 0.95
|
||||
},
|
||||
"mcp_servers": ["gitea", "knowledge-base"],
|
||||
"tool_permissions": {
|
||||
"allowed": ["*"],
|
||||
"denied": [],
|
||||
"require_approval": []
|
||||
},
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"name": "Solutions Architect",
|
||||
"slug": "solutions-architect",
|
||||
"description": "System design, architecture decisions, and technical leadership. Defines the technical vision and ensures system coherence.",
|
||||
"expertise": ["system-design", "architecture", "adrs", "technical-decisions", "integration", "scalability", "security"],
|
||||
"personality_prompt": "You are an experienced Solutions Architect who designs robust, scalable systems. You:\n- Create architecture diagrams and technical documentation\n- Write Architecture Decision Records (ADRs) for key decisions\n- Evaluate technology choices based on requirements and constraints\n- Identify potential bottlenecks and security concerns\n- Ensure consistency across the system design\n\nYou think holistically about systems, considering maintainability, scalability, and operational concerns. You document your decisions with clear rationale and trade-off analysis.",
|
||||
"primary_model": "claude-sonnet-4-20250514",
|
||||
"fallback_models": ["claude-haiku-3-5-20241022"],
|
||||
"model_params": {
|
||||
"temperature": 0.6,
|
||||
"max_tokens": 8192,
|
||||
"top_p": 0.95
|
||||
},
|
||||
"mcp_servers": ["gitea", "knowledge-base", "filesystem"],
|
||||
"tool_permissions": {
|
||||
"allowed": ["*"],
|
||||
"denied": [],
|
||||
"require_approval": ["gitea:create_pull_request"]
|
||||
},
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"name": "Senior Engineer",
|
||||
"slug": "senior-engineer",
|
||||
"description": "Implementation, code review, and refactoring. Writes high-quality, maintainable code following best practices.",
|
||||
"expertise": ["implementation", "code-review", "refactoring", "testing", "debugging", "performance", "clean-code"],
|
||||
"personality_prompt": "You are a Senior Software Engineer who writes clean, maintainable code. You:\n- Implement features following established patterns and standards\n- Write comprehensive tests (unit, integration, e2e)\n- Review code for correctness, performance, and maintainability\n- Refactor code to improve quality without changing behavior\n- Debug complex issues systematically\n\nYou prioritize code quality and follow SOLID principles. You write code that other developers can easily understand and maintain. You always consider edge cases and error handling.",
|
||||
"primary_model": "claude-sonnet-4-20250514",
|
||||
"fallback_models": ["claude-haiku-3-5-20241022"],
|
||||
"model_params": {
|
||||
"temperature": 0.3,
|
||||
"max_tokens": 16384,
|
||||
"top_p": 0.95
|
||||
},
|
||||
"mcp_servers": ["gitea", "knowledge-base", "filesystem"],
|
||||
"tool_permissions": {
|
||||
"allowed": ["*"],
|
||||
"denied": [],
|
||||
"require_approval": ["gitea:create_pull_request", "gitea:delete_*"]
|
||||
},
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"name": "QA Engineer",
|
||||
"slug": "qa-engineer",
|
||||
"description": "Testing, quality assurance, and bug verification. Ensures the product meets quality standards before release.",
|
||||
"expertise": ["testing", "quality-assurance", "test-automation", "bug-verification", "test-planning", "regression-testing"],
|
||||
"personality_prompt": "You are a thorough QA Engineer who ensures product quality. You:\n- Design comprehensive test plans and test cases\n- Write automated tests (unit, integration, e2e)\n- Verify bug fixes and perform regression testing\n- Identify edge cases and boundary conditions\n- Document defects clearly with reproduction steps\n\nYou have a critical eye for quality and think like a user who might break things. You balance thoroughness with efficiency, focusing on high-risk areas while ensuring broad coverage.",
|
||||
"primary_model": "claude-sonnet-4-20250514",
|
||||
"fallback_models": ["claude-haiku-3-5-20241022"],
|
||||
"model_params": {
|
||||
"temperature": 0.4,
|
||||
"max_tokens": 8192,
|
||||
"top_p": 0.95
|
||||
},
|
||||
"mcp_servers": ["gitea", "knowledge-base", "filesystem"],
|
||||
"tool_permissions": {
|
||||
"allowed": ["*"],
|
||||
"denied": [],
|
||||
"require_approval": []
|
||||
},
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"name": "DevOps Engineer",
|
||||
"slug": "devops-engineer",
|
||||
"description": "CI/CD, deployment, and infrastructure. Ensures reliable, automated delivery pipelines and operational excellence.",
|
||||
"expertise": ["ci-cd", "deployment", "infrastructure", "docker", "kubernetes", "monitoring", "automation"],
|
||||
"personality_prompt": "You are a skilled DevOps Engineer who builds reliable delivery pipelines. You:\n- Design and maintain CI/CD pipelines\n- Configure infrastructure as code\n- Set up monitoring, logging, and alerting\n- Automate repetitive operational tasks\n- Ensure security and compliance in deployments\n\nYou think about reliability, observability, and automation. You design systems that fail gracefully and are easy to troubleshoot. You document runbooks and operational procedures clearly.",
|
||||
"primary_model": "claude-sonnet-4-20250514",
|
||||
"fallback_models": ["claude-haiku-3-5-20241022"],
|
||||
"model_params": {
|
||||
"temperature": 0.4,
|
||||
"max_tokens": 8192,
|
||||
"top_p": 0.95
|
||||
},
|
||||
"mcp_servers": ["gitea", "knowledge-base", "filesystem"],
|
||||
"tool_permissions": {
|
||||
"allowed": ["*"],
|
||||
"denied": [],
|
||||
"require_approval": ["gitea:create_release", "gitea:delete_*"]
|
||||
},
|
||||
"is_active": true
|
||||
}
|
||||
]
|
||||
879
backend/data/demo_data.json
Normal file
879
backend/data/demo_data.json
Normal file
@@ -0,0 +1,879 @@
|
||||
{
|
||||
"organizations": [
|
||||
{
|
||||
"name": "Acme Corp",
|
||||
"slug": "acme-corp",
|
||||
"description": "A leading provider of coyote-catching equipment."
|
||||
},
|
||||
{
|
||||
"name": "Globex Corporation",
|
||||
"slug": "globex",
|
||||
"description": "We own the East Coast."
|
||||
},
|
||||
{
|
||||
"name": "Soylent Corp",
|
||||
"slug": "soylent",
|
||||
"description": "Making food for the future."
|
||||
},
|
||||
{
|
||||
"name": "Initech",
|
||||
"slug": "initech",
|
||||
"description": "Software for the soul."
|
||||
},
|
||||
{
|
||||
"name": "Umbrella Corporation",
|
||||
"slug": "umbrella",
|
||||
"description": "Our business is life itself."
|
||||
},
|
||||
{
|
||||
"name": "Massive Dynamic",
|
||||
"slug": "massive-dynamic",
|
||||
"description": "What don't we do?"
|
||||
}
|
||||
],
|
||||
"users": [
|
||||
{
|
||||
"email": "demo@example.com",
|
||||
"password": "DemoPass1234!",
|
||||
"first_name": "Demo",
|
||||
"last_name": "User",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "acme-corp",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "alice@acme.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Alice",
|
||||
"last_name": "Smith",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "acme-corp",
|
||||
"role": "admin",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "bob@acme.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Bob",
|
||||
"last_name": "Jones",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "acme-corp",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "charlie@acme.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Charlie",
|
||||
"last_name": "Brown",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "acme-corp",
|
||||
"role": "member",
|
||||
"is_active": false
|
||||
},
|
||||
{
|
||||
"email": "diana@acme.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Diana",
|
||||
"last_name": "Prince",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "acme-corp",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "carol@globex.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Carol",
|
||||
"last_name": "Williams",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "globex",
|
||||
"role": "owner",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "dan@globex.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Dan",
|
||||
"last_name": "Miller",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "globex",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "ellen@globex.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Ellen",
|
||||
"last_name": "Ripley",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "globex",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "fred@globex.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Fred",
|
||||
"last_name": "Flintstone",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "globex",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "dave@soylent.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Dave",
|
||||
"last_name": "Brown",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "soylent",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "gina@soylent.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Gina",
|
||||
"last_name": "Torres",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "soylent",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "harry@soylent.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Harry",
|
||||
"last_name": "Potter",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "soylent",
|
||||
"role": "admin",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "eve@initech.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Eve",
|
||||
"last_name": "Davis",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "initech",
|
||||
"role": "admin",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "iris@initech.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Iris",
|
||||
"last_name": "West",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "initech",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "jack@initech.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Jack",
|
||||
"last_name": "Sparrow",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "initech",
|
||||
"role": "member",
|
||||
"is_active": false
|
||||
},
|
||||
{
|
||||
"email": "frank@umbrella.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Frank",
|
||||
"last_name": "Miller",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "umbrella",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "george@umbrella.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "George",
|
||||
"last_name": "Costanza",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "umbrella",
|
||||
"role": "member",
|
||||
"is_active": false
|
||||
},
|
||||
{
|
||||
"email": "kate@umbrella.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Kate",
|
||||
"last_name": "Bishop",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "umbrella",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "leo@massive.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Leo",
|
||||
"last_name": "Messi",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "massive-dynamic",
|
||||
"role": "owner",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "mary@massive.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Mary",
|
||||
"last_name": "Jane",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "massive-dynamic",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "nathan@massive.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Nathan",
|
||||
"last_name": "Drake",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "massive-dynamic",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "olivia@massive.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Olivia",
|
||||
"last_name": "Dunham",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "massive-dynamic",
|
||||
"role": "admin",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "peter@massive.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Peter",
|
||||
"last_name": "Parker",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "massive-dynamic",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "quinn@massive.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Quinn",
|
||||
"last_name": "Mallory",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "massive-dynamic",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "grace@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Grace",
|
||||
"last_name": "Hopper",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "heidi@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Heidi",
|
||||
"last_name": "Klum",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "ivan@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Ivan",
|
||||
"last_name": "Drago",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": false
|
||||
},
|
||||
{
|
||||
"email": "rachel@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Rachel",
|
||||
"last_name": "Green",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "sam@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Sam",
|
||||
"last_name": "Wilson",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "tony@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Tony",
|
||||
"last_name": "Stark",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "una@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Una",
|
||||
"last_name": "Chin-Riley",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": false
|
||||
},
|
||||
{
|
||||
"email": "victor@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Victor",
|
||||
"last_name": "Von Doom",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "wanda@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Wanda",
|
||||
"last_name": "Maximoff",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": true
|
||||
}
|
||||
],
|
||||
"projects": [
|
||||
{
|
||||
"name": "E-Commerce Platform Redesign",
|
||||
"slug": "ecommerce-redesign",
|
||||
"description": "Complete redesign of the e-commerce platform with modern UX, improved checkout flow, and mobile-first approach.",
|
||||
"owner_email": "__admin__",
|
||||
"autonomy_level": "milestone",
|
||||
"status": "active",
|
||||
"complexity": "complex",
|
||||
"client_mode": "technical",
|
||||
"settings": {
|
||||
"mcp_servers": ["gitea", "knowledge-base"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Mobile Banking App",
|
||||
"slug": "mobile-banking",
|
||||
"description": "Secure mobile banking application with biometric authentication, transaction history, and real-time notifications.",
|
||||
"owner_email": "__admin__",
|
||||
"autonomy_level": "full_control",
|
||||
"status": "active",
|
||||
"complexity": "complex",
|
||||
"client_mode": "technical",
|
||||
"settings": {
|
||||
"mcp_servers": ["gitea", "knowledge-base"],
|
||||
"security_level": "high"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Internal HR Portal",
|
||||
"slug": "hr-portal",
|
||||
"description": "Employee self-service portal for leave requests, performance reviews, and document management.",
|
||||
"owner_email": "__admin__",
|
||||
"autonomy_level": "autonomous",
|
||||
"status": "active",
|
||||
"complexity": "medium",
|
||||
"client_mode": "auto",
|
||||
"settings": {
|
||||
"mcp_servers": ["gitea", "knowledge-base"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "API Gateway Modernization",
|
||||
"slug": "api-gateway",
|
||||
"description": "Migrate legacy REST API gateway to modern GraphQL-based architecture with improved caching and rate limiting.",
|
||||
"owner_email": "__admin__",
|
||||
"autonomy_level": "milestone",
|
||||
"status": "active",
|
||||
"complexity": "complex",
|
||||
"client_mode": "technical",
|
||||
"settings": {
|
||||
"mcp_servers": ["gitea", "knowledge-base"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Customer Analytics Dashboard",
|
||||
"slug": "analytics-dashboard",
|
||||
"description": "Real-time analytics dashboard for customer behavior insights, cohort analysis, and predictive modeling.",
|
||||
"owner_email": "__admin__",
|
||||
"autonomy_level": "autonomous",
|
||||
"status": "completed",
|
||||
"complexity": "medium",
|
||||
"client_mode": "auto",
|
||||
"settings": {
|
||||
"mcp_servers": ["gitea", "knowledge-base"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "DevOps Pipeline Automation",
|
||||
"slug": "devops-automation",
|
||||
"description": "Automate CI/CD pipelines with AI-assisted deployments, rollback detection, and infrastructure as code.",
|
||||
"owner_email": "__admin__",
|
||||
"autonomy_level": "full_control",
|
||||
"status": "active",
|
||||
"complexity": "complex",
|
||||
"client_mode": "technical",
|
||||
"settings": {
|
||||
"mcp_servers": ["gitea", "knowledge-base"]
|
||||
}
|
||||
}
|
||||
],
|
||||
"sprints": [
|
||||
{
|
||||
"project_slug": "ecommerce-redesign",
|
||||
"name": "Sprint 1: Foundation",
|
||||
"number": 1,
|
||||
"goal": "Set up project infrastructure, design system, and core navigation components.",
|
||||
"start_date": "2026-01-06",
|
||||
"end_date": "2026-01-20",
|
||||
"status": "active",
|
||||
"planned_points": 21
|
||||
},
|
||||
{
|
||||
"project_slug": "ecommerce-redesign",
|
||||
"name": "Sprint 2: Product Catalog",
|
||||
"number": 2,
|
||||
"goal": "Implement product listing, filtering, search, and detail pages.",
|
||||
"start_date": "2026-01-20",
|
||||
"end_date": "2026-02-03",
|
||||
"status": "planned",
|
||||
"planned_points": 34
|
||||
},
|
||||
{
|
||||
"project_slug": "mobile-banking",
|
||||
"name": "Sprint 1: Authentication",
|
||||
"number": 1,
|
||||
"goal": "Implement secure login, biometric authentication, and session management.",
|
||||
"start_date": "2026-01-06",
|
||||
"end_date": "2026-01-20",
|
||||
"status": "active",
|
||||
"planned_points": 26
|
||||
},
|
||||
{
|
||||
"project_slug": "hr-portal",
|
||||
"name": "Sprint 1: Core Features",
|
||||
"number": 1,
|
||||
"goal": "Build employee dashboard, leave request system, and basic document management.",
|
||||
"start_date": "2026-01-06",
|
||||
"end_date": "2026-01-20",
|
||||
"status": "active",
|
||||
"planned_points": 18
|
||||
},
|
||||
{
|
||||
"project_slug": "api-gateway",
|
||||
"name": "Sprint 1: GraphQL Schema",
|
||||
"number": 1,
|
||||
"goal": "Define GraphQL schema and implement core resolvers for existing REST endpoints.",
|
||||
"start_date": "2025-12-23",
|
||||
"end_date": "2026-01-06",
|
||||
"status": "completed",
|
||||
"planned_points": 21
|
||||
},
|
||||
{
|
||||
"project_slug": "api-gateway",
|
||||
"name": "Sprint 2: Caching Layer",
|
||||
"number": 2,
|
||||
"goal": "Implement Redis-based caching layer and query batching.",
|
||||
"start_date": "2026-01-06",
|
||||
"end_date": "2026-01-20",
|
||||
"status": "active",
|
||||
"planned_points": 26
|
||||
},
|
||||
{
|
||||
"project_slug": "analytics-dashboard",
|
||||
"name": "Sprint 1: Data Pipeline",
|
||||
"number": 1,
|
||||
"goal": "Set up data ingestion pipeline and real-time event processing.",
|
||||
"start_date": "2025-11-15",
|
||||
"end_date": "2025-11-29",
|
||||
"status": "completed",
|
||||
"planned_points": 18
|
||||
},
|
||||
{
|
||||
"project_slug": "analytics-dashboard",
|
||||
"name": "Sprint 2: Dashboard UI",
|
||||
"number": 2,
|
||||
"goal": "Build interactive dashboard with charts and filtering capabilities.",
|
||||
"start_date": "2025-11-29",
|
||||
"end_date": "2025-12-13",
|
||||
"status": "completed",
|
||||
"planned_points": 21
|
||||
},
|
||||
{
|
||||
"project_slug": "devops-automation",
|
||||
"name": "Sprint 1: Pipeline Templates",
|
||||
"number": 1,
|
||||
"goal": "Create reusable CI/CD pipeline templates for common deployment patterns.",
|
||||
"start_date": "2026-01-06",
|
||||
"end_date": "2026-01-20",
|
||||
"status": "active",
|
||||
"planned_points": 24
|
||||
}
|
||||
],
|
||||
"agent_instances": [
|
||||
{
|
||||
"project_slug": "ecommerce-redesign",
|
||||
"agent_type_slug": "product-owner",
|
||||
"name": "Aria",
|
||||
"status": "idle"
|
||||
},
|
||||
{
|
||||
"project_slug": "ecommerce-redesign",
|
||||
"agent_type_slug": "solutions-architect",
|
||||
"name": "Marcus",
|
||||
"status": "idle"
|
||||
},
|
||||
{
|
||||
"project_slug": "ecommerce-redesign",
|
||||
"agent_type_slug": "senior-engineer",
|
||||
"name": "Zara",
|
||||
"status": "working",
|
||||
"current_task": "Implementing responsive navigation component"
|
||||
},
|
||||
{
|
||||
"project_slug": "mobile-banking",
|
||||
"agent_type_slug": "product-owner",
|
||||
"name": "Felix",
|
||||
"status": "waiting",
|
||||
"current_task": "Awaiting security requirements clarification"
|
||||
},
|
||||
{
|
||||
"project_slug": "mobile-banking",
|
||||
"agent_type_slug": "senior-engineer",
|
||||
"name": "Luna",
|
||||
"status": "working",
|
||||
"current_task": "Implementing biometric authentication flow"
|
||||
},
|
||||
{
|
||||
"project_slug": "mobile-banking",
|
||||
"agent_type_slug": "qa-engineer",
|
||||
"name": "Rex",
|
||||
"status": "idle"
|
||||
},
|
||||
{
|
||||
"project_slug": "hr-portal",
|
||||
"agent_type_slug": "business-analyst",
|
||||
"name": "Nova",
|
||||
"status": "working",
|
||||
"current_task": "Documenting leave request workflow"
|
||||
},
|
||||
{
|
||||
"project_slug": "hr-portal",
|
||||
"agent_type_slug": "senior-engineer",
|
||||
"name": "Atlas",
|
||||
"status": "working",
|
||||
"current_task": "Building employee dashboard API"
|
||||
},
|
||||
{
|
||||
"project_slug": "api-gateway",
|
||||
"agent_type_slug": "solutions-architect",
|
||||
"name": "Orion",
|
||||
"status": "working",
|
||||
"current_task": "Designing caching strategy for GraphQL queries"
|
||||
},
|
||||
{
|
||||
"project_slug": "api-gateway",
|
||||
"agent_type_slug": "senior-engineer",
|
||||
"name": "Cleo",
|
||||
"status": "working",
|
||||
"current_task": "Implementing Redis cache invalidation"
|
||||
},
|
||||
{
|
||||
"project_slug": "devops-automation",
|
||||
"agent_type_slug": "devops-engineer",
|
||||
"name": "Volt",
|
||||
"status": "working",
|
||||
"current_task": "Creating Terraform modules for AWS ECS"
|
||||
},
|
||||
{
|
||||
"project_slug": "devops-automation",
|
||||
"agent_type_slug": "senior-engineer",
|
||||
"name": "Sage",
|
||||
"status": "idle"
|
||||
},
|
||||
{
|
||||
"project_slug": "devops-automation",
|
||||
"agent_type_slug": "qa-engineer",
|
||||
"name": "Echo",
|
||||
"status": "waiting",
|
||||
"current_task": "Waiting for pipeline templates to test"
|
||||
}
|
||||
],
|
||||
"issues": [
|
||||
{
|
||||
"project_slug": "ecommerce-redesign",
|
||||
"sprint_number": 1,
|
||||
"type": "story",
|
||||
"title": "Design responsive navigation component",
|
||||
"body": "As a user, I want a navigation menu that works seamlessly on both desktop and mobile devices.\n\n## Acceptance Criteria\n- Hamburger menu on mobile viewports\n- Sticky header on scroll\n- Keyboard accessible\n- Screen reader compatible",
|
||||
"status": "in_progress",
|
||||
"priority": "high",
|
||||
"labels": ["frontend", "design-system"],
|
||||
"story_points": 5,
|
||||
"assigned_agent_name": "Zara"
|
||||
},
|
||||
{
|
||||
"project_slug": "ecommerce-redesign",
|
||||
"sprint_number": 1,
|
||||
"type": "task",
|
||||
"title": "Set up Tailwind CSS configuration",
|
||||
"body": "Configure Tailwind CSS with custom design tokens for the e-commerce platform.\n\n- Define color palette\n- Set up typography scale\n- Configure breakpoints\n- Add custom utilities",
|
||||
"status": "closed",
|
||||
"priority": "high",
|
||||
"labels": ["frontend", "infrastructure"],
|
||||
"story_points": 3
|
||||
},
|
||||
{
|
||||
"project_slug": "ecommerce-redesign",
|
||||
"sprint_number": 1,
|
||||
"type": "task",
|
||||
"title": "Create base component library structure",
|
||||
"body": "Set up the foundational component library with:\n- Button variants\n- Form inputs\n- Card component\n- Modal system",
|
||||
"status": "open",
|
||||
"priority": "medium",
|
||||
"labels": ["frontend", "design-system"],
|
||||
"story_points": 8
|
||||
},
|
||||
{
|
||||
"project_slug": "ecommerce-redesign",
|
||||
"sprint_number": 1,
|
||||
"type": "story",
|
||||
"title": "Implement user authentication flow",
|
||||
"body": "As a user, I want to sign up, log in, and manage my account.\n\n## Features\n- Email/password registration\n- Social login (Google, GitHub)\n- Password reset flow\n- Email verification",
|
||||
"status": "open",
|
||||
"priority": "critical",
|
||||
"labels": ["auth", "backend", "frontend"],
|
||||
"story_points": 13
|
||||
},
|
||||
{
|
||||
"project_slug": "ecommerce-redesign",
|
||||
"sprint_number": 2,
|
||||
"type": "epic",
|
||||
"title": "Product Catalog System",
|
||||
"body": "Complete product catalog implementation including:\n- Product listing with pagination\n- Advanced filtering and search\n- Product detail pages\n- Category navigation",
|
||||
"status": "open",
|
||||
"priority": "high",
|
||||
"labels": ["catalog", "backend", "frontend"],
|
||||
"story_points": null
|
||||
},
|
||||
{
|
||||
"project_slug": "mobile-banking",
|
||||
"sprint_number": 1,
|
||||
"type": "story",
|
||||
"title": "Implement biometric authentication",
|
||||
"body": "As a user, I want to log in using Face ID or Touch ID for quick and secure access.\n\n## Requirements\n- Support Face ID on iOS\n- Support fingerprint on Android\n- Fallback to PIN/password\n- Secure keychain storage",
|
||||
"status": "in_progress",
|
||||
"priority": "critical",
|
||||
"labels": ["auth", "security", "mobile"],
|
||||
"story_points": 8,
|
||||
"assigned_agent_name": "Luna"
|
||||
},
|
||||
{
|
||||
"project_slug": "mobile-banking",
|
||||
"sprint_number": 1,
|
||||
"type": "task",
|
||||
"title": "Set up secure session management",
|
||||
"body": "Implement secure session handling with:\n- JWT tokens with short expiry\n- Refresh token rotation\n- Session timeout handling\n- Multi-device session management",
|
||||
"status": "open",
|
||||
"priority": "critical",
|
||||
"labels": ["auth", "security", "backend"],
|
||||
"story_points": 5
|
||||
},
|
||||
{
|
||||
"project_slug": "mobile-banking",
|
||||
"sprint_number": 1,
|
||||
"type": "bug",
|
||||
"title": "Fix token refresh race condition",
|
||||
"body": "When multiple API calls happen simultaneously after token expiry, multiple refresh requests are made causing 401 errors.\n\n## Steps to Reproduce\n1. Wait for token to expire\n2. Trigger multiple API calls at once\n3. Observe multiple 401 errors",
|
||||
"status": "open",
|
||||
"priority": "high",
|
||||
"labels": ["bug", "auth", "backend"],
|
||||
"story_points": 3
|
||||
},
|
||||
{
|
||||
"project_slug": "mobile-banking",
|
||||
"sprint_number": 1,
|
||||
"type": "task",
|
||||
"title": "Implement PIN entry screen",
|
||||
"body": "Create secure PIN entry component with:\n- Masked input display\n- Haptic feedback\n- Brute force protection (lockout after 5 attempts)\n- Secure PIN storage",
|
||||
"status": "open",
|
||||
"priority": "high",
|
||||
"labels": ["auth", "mobile", "frontend"],
|
||||
"story_points": 5
|
||||
},
|
||||
{
|
||||
"project_slug": "hr-portal",
|
||||
"sprint_number": 1,
|
||||
"type": "story",
|
||||
"title": "Build employee dashboard",
|
||||
"body": "As an employee, I want a dashboard showing my key information at a glance.\n\n## Dashboard Widgets\n- Leave balance\n- Pending approvals\n- Upcoming holidays\n- Recent announcements",
|
||||
"status": "in_progress",
|
||||
"priority": "high",
|
||||
"labels": ["frontend", "dashboard"],
|
||||
"story_points": 5,
|
||||
"assigned_agent_name": "Atlas"
|
||||
},
|
||||
{
|
||||
"project_slug": "hr-portal",
|
||||
"sprint_number": 1,
|
||||
"type": "story",
|
||||
"title": "Implement leave request system",
|
||||
"body": "As an employee, I want to submit and track leave requests.\n\n## Features\n- Submit leave request with date range\n- View leave balance by type\n- Track request status\n- Manager approval workflow",
|
||||
"status": "in_progress",
|
||||
"priority": "high",
|
||||
"labels": ["backend", "frontend", "workflow"],
|
||||
"story_points": 8,
|
||||
"assigned_agent_name": "Nova"
|
||||
},
|
||||
{
|
||||
"project_slug": "hr-portal",
|
||||
"sprint_number": 1,
|
||||
"type": "task",
|
||||
"title": "Set up document storage integration",
|
||||
"body": "Integrate with S3-compatible storage for employee documents:\n- Secure upload/download\n- File type validation\n- Size limits\n- Virus scanning",
|
||||
"status": "open",
|
||||
"priority": "medium",
|
||||
"labels": ["backend", "infrastructure", "storage"],
|
||||
"story_points": 5
|
||||
},
|
||||
{
|
||||
"project_slug": "api-gateway",
|
||||
"sprint_number": 2,
|
||||
"type": "story",
|
||||
"title": "Implement Redis caching layer",
|
||||
"body": "As an API consumer, I want responses to be cached for improved performance.\n\n## Requirements\n- Cache GraphQL query results\n- Configurable TTL per query type\n- Cache invalidation on mutations\n- Cache hit/miss metrics",
|
||||
"status": "in_progress",
|
||||
"priority": "critical",
|
||||
"labels": ["backend", "performance", "redis"],
|
||||
"story_points": 8,
|
||||
"assigned_agent_name": "Cleo"
|
||||
},
|
||||
{
|
||||
"project_slug": "api-gateway",
|
||||
"sprint_number": 2,
|
||||
"type": "task",
|
||||
"title": "Set up query batching and deduplication",
|
||||
"body": "Implement DataLoader pattern for:\n- Batching multiple queries into single database calls\n- Deduplicating identical queries within request scope\n- N+1 query prevention",
|
||||
"status": "open",
|
||||
"priority": "high",
|
||||
"labels": ["backend", "performance", "graphql"],
|
||||
"story_points": 5
|
||||
},
|
||||
{
|
||||
"project_slug": "api-gateway",
|
||||
"sprint_number": 2,
|
||||
"type": "task",
|
||||
"title": "Implement rate limiting middleware",
|
||||
"body": "Add rate limiting to prevent API abuse:\n- Per-user rate limits\n- Per-IP fallback for anonymous requests\n- Sliding window algorithm\n- Custom limits per operation type",
|
||||
"status": "open",
|
||||
"priority": "high",
|
||||
"labels": ["backend", "security", "middleware"],
|
||||
"story_points": 5,
|
||||
"assigned_agent_name": "Orion"
|
||||
},
|
||||
{
|
||||
"project_slug": "api-gateway",
|
||||
"sprint_number": 2,
|
||||
"type": "bug",
|
||||
"title": "Fix N+1 query in user resolver",
|
||||
"body": "The user resolver is making separate database calls for each user's organization.\n\n## Steps to Reproduce\n1. Query users with organization field\n2. Check database logs\n3. Observe N+1 queries",
|
||||
"status": "open",
|
||||
"priority": "high",
|
||||
"labels": ["bug", "performance", "graphql"],
|
||||
"story_points": 3
|
||||
},
|
||||
{
|
||||
"project_slug": "analytics-dashboard",
|
||||
"sprint_number": 2,
|
||||
"type": "story",
|
||||
"title": "Build cohort analysis charts",
|
||||
"body": "As a product manager, I want to analyze user cohorts over time.\n\n## Features\n- Weekly/monthly cohort grouping\n- Retention curve visualization\n- Cohort comparison view",
|
||||
"status": "closed",
|
||||
"priority": "high",
|
||||
"labels": ["frontend", "charts", "analytics"],
|
||||
"story_points": 8
|
||||
},
|
||||
{
|
||||
"project_slug": "analytics-dashboard",
|
||||
"sprint_number": 2,
|
||||
"type": "task",
|
||||
"title": "Implement real-time event streaming",
|
||||
"body": "Set up WebSocket connection for live event updates:\n- Event type filtering\n- Buffering for high-volume periods\n- Reconnection handling",
|
||||
"status": "closed",
|
||||
"priority": "high",
|
||||
"labels": ["backend", "websocket", "realtime"],
|
||||
"story_points": 5
|
||||
},
|
||||
{
|
||||
"project_slug": "devops-automation",
|
||||
"sprint_number": 1,
|
||||
"type": "epic",
|
||||
"title": "CI/CD Pipeline Templates",
|
||||
"body": "Create reusable pipeline templates for common deployment patterns.\n\n## Templates Needed\n- Node.js applications\n- Python applications\n- Docker-based deployments\n- Kubernetes deployments",
|
||||
"status": "in_progress",
|
||||
"priority": "critical",
|
||||
"labels": ["infrastructure", "cicd", "templates"],
|
||||
"story_points": null
|
||||
},
|
||||
{
|
||||
"project_slug": "devops-automation",
|
||||
"sprint_number": 1,
|
||||
"type": "story",
|
||||
"title": "Create Terraform modules for AWS ECS",
|
||||
"body": "As a DevOps engineer, I want Terraform modules for ECS deployments.\n\n## Modules\n- ECS cluster configuration\n- Service and task definitions\n- Load balancer integration\n- Auto-scaling policies",
|
||||
"status": "in_progress",
|
||||
"priority": "high",
|
||||
"labels": ["terraform", "aws", "ecs"],
|
||||
"story_points": 8,
|
||||
"assigned_agent_name": "Volt"
|
||||
},
|
||||
{
|
||||
"project_slug": "devops-automation",
|
||||
"sprint_number": 1,
|
||||
"type": "task",
|
||||
"title": "Set up Gitea Actions runners",
|
||||
"body": "Configure self-hosted Gitea Actions runners:\n- Docker-in-Docker support\n- Caching for npm/pip\n- Secrets management\n- Resource limits",
|
||||
"status": "open",
|
||||
"priority": "high",
|
||||
"labels": ["infrastructure", "gitea", "cicd"],
|
||||
"story_points": 5
|
||||
},
|
||||
{
|
||||
"project_slug": "devops-automation",
|
||||
"sprint_number": 1,
|
||||
"type": "task",
|
||||
"title": "Implement rollback detection system",
|
||||
"body": "AI-assisted rollback detection:\n- Monitor deployment health metrics\n- Automatic rollback triggers\n- Notification system\n- Post-rollback analysis",
|
||||
"status": "open",
|
||||
"priority": "medium",
|
||||
"labels": ["ai", "monitoring", "automation"],
|
||||
"story_points": 8
|
||||
}
|
||||
]
|
||||
}
|
||||
507
backend/docs/MEMORY_SYSTEM.md
Normal file
507
backend/docs/MEMORY_SYSTEM.md
Normal file
@@ -0,0 +1,507 @@
|
||||
# Agent Memory System
|
||||
|
||||
Comprehensive multi-tier cognitive memory for AI agents, enabling state persistence, experiential learning, and context continuity across sessions.
|
||||
|
||||
## Overview
|
||||
|
||||
The Agent Memory System implements a cognitive architecture inspired by human memory:
|
||||
|
||||
```
|
||||
+------------------------------------------------------------------+
|
||||
| Agent Memory System |
|
||||
+------------------------------------------------------------------+
|
||||
| |
|
||||
| +------------------+ +------------------+ |
|
||||
| | Working Memory |----consolidate---->| Episodic Memory | |
|
||||
| | (Redis/In-Mem) | | (PostgreSQL) | |
|
||||
| | | | | |
|
||||
| | - Current task | | - Past sessions | |
|
||||
| | - Variables | | - Experiences | |
|
||||
| | - Scratchpad | | - Outcomes | |
|
||||
| +------------------+ +--------+---------+ |
|
||||
| | |
|
||||
| extract | |
|
||||
| v |
|
||||
| +------------------+ +------------------+ |
|
||||
| |Procedural Memory |<-----learn from----| Semantic Memory | |
|
||||
| | (PostgreSQL) | | (PostgreSQL + | |
|
||||
| | | | pgvector) | |
|
||||
| | - Procedures | | | |
|
||||
| | - Skills | | - Facts | |
|
||||
| | - Patterns | | - Entities | |
|
||||
| +------------------+ | - Relationships | |
|
||||
| +------------------+ |
|
||||
+------------------------------------------------------------------+
|
||||
```
|
||||
|
||||
## Memory Types
|
||||
|
||||
### Working Memory
|
||||
Short-term, session-scoped memory for current task state.
|
||||
|
||||
**Features:**
|
||||
- Key-value storage with TTL
|
||||
- Task state tracking
|
||||
- Scratchpad for reasoning
|
||||
- Checkpoint/restore support
|
||||
- Redis primary with in-memory fallback
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from app.services.memory.working import WorkingMemory
|
||||
|
||||
memory = WorkingMemory(scope_context)
|
||||
await memory.set("key", {"data": "value"}, ttl_seconds=3600)
|
||||
value = await memory.get("key")
|
||||
|
||||
# Task state
|
||||
await memory.set_task_state(TaskState(task_id="t1", status="running"))
|
||||
state = await memory.get_task_state()
|
||||
|
||||
# Checkpoints
|
||||
checkpoint_id = await memory.create_checkpoint()
|
||||
await memory.restore_checkpoint(checkpoint_id)
|
||||
```
|
||||
|
||||
### Episodic Memory
|
||||
Experiential records of past agent actions and outcomes.
|
||||
|
||||
**Features:**
|
||||
- Records task completions and failures
|
||||
- Semantic similarity search (pgvector)
|
||||
- Temporal and outcome-based retrieval
|
||||
- Importance scoring
|
||||
- Episode summarization
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from app.services.memory.episodic import EpisodicMemory
|
||||
|
||||
memory = EpisodicMemory(session, embedder)
|
||||
|
||||
# Record an episode
|
||||
episode = await memory.record_episode(
|
||||
project_id=project_id,
|
||||
episode=EpisodeCreate(
|
||||
task_type="code_review",
|
||||
task_description="Review PR #42",
|
||||
outcome=Outcome.SUCCESS,
|
||||
actions=[{"type": "analyze", "target": "src/"}],
|
||||
)
|
||||
)
|
||||
|
||||
# Search similar experiences
|
||||
similar = await memory.search_similar(
|
||||
project_id=project_id,
|
||||
query="debugging memory leak",
|
||||
limit=5
|
||||
)
|
||||
|
||||
# Get recent episodes
|
||||
recent = await memory.get_recent(project_id, limit=10)
|
||||
```
|
||||
|
||||
### Semantic Memory
|
||||
Learned facts and knowledge with confidence scoring.
|
||||
|
||||
**Features:**
|
||||
- Triple format (subject, predicate, object)
|
||||
- Confidence scoring with decay
|
||||
- Fact extraction from episodes
|
||||
- Conflict resolution
|
||||
- Entity-based retrieval
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from app.services.memory.semantic import SemanticMemory
|
||||
|
||||
memory = SemanticMemory(session, embedder)
|
||||
|
||||
# Store a fact
|
||||
fact = await memory.store_fact(
|
||||
project_id=project_id,
|
||||
fact=FactCreate(
|
||||
subject="UserService",
|
||||
predicate="handles",
|
||||
object="authentication",
|
||||
confidence=0.9,
|
||||
)
|
||||
)
|
||||
|
||||
# Search facts
|
||||
facts = await memory.search_facts(project_id, "authentication flow")
|
||||
|
||||
# Reinforce on repeated learning
|
||||
await memory.reinforce_fact(fact.id)
|
||||
```
|
||||
|
||||
### Procedural Memory
|
||||
Learned skills and procedures from successful patterns.
|
||||
|
||||
**Features:**
|
||||
- Procedure recording from task patterns
|
||||
- Trigger-based matching
|
||||
- Success rate tracking
|
||||
- Procedure suggestions
|
||||
- Step-by-step storage
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from app.services.memory.procedural import ProceduralMemory
|
||||
|
||||
memory = ProceduralMemory(session, embedder)
|
||||
|
||||
# Record a procedure
|
||||
procedure = await memory.record_procedure(
|
||||
project_id=project_id,
|
||||
procedure=ProcedureCreate(
|
||||
name="PR Review Process",
|
||||
trigger_pattern="code review requested",
|
||||
steps=[
|
||||
Step(action="fetch_diff"),
|
||||
Step(action="analyze_changes"),
|
||||
Step(action="check_tests"),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
# Find matching procedures
|
||||
matches = await memory.find_matching(project_id, "need to review code")
|
||||
|
||||
# Record outcomes
|
||||
await memory.record_outcome(procedure.id, success=True)
|
||||
```
|
||||
|
||||
## Memory Scoping
|
||||
|
||||
Memory is organized in a hierarchical scope structure:
|
||||
|
||||
```
|
||||
Global Memory (shared by all)
|
||||
└── Project Memory (per project)
|
||||
└── Agent Type Memory (per agent type)
|
||||
└── Agent Instance Memory (per instance)
|
||||
└── Session Memory (ephemeral)
|
||||
```
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from app.services.memory.scoping import ScopeManager, ScopeLevel
|
||||
|
||||
manager = ScopeManager(session)
|
||||
|
||||
# Get scoped memories with inheritance
|
||||
memories = await manager.get_scoped_memories(
|
||||
context=ScopeContext(
|
||||
project_id=project_id,
|
||||
agent_type_id=agent_type_id,
|
||||
agent_instance_id=agent_instance_id,
|
||||
session_id=session_id,
|
||||
),
|
||||
include_inherited=True, # Include parent scopes
|
||||
)
|
||||
```
|
||||
|
||||
## Memory Consolidation
|
||||
|
||||
Automatic background processes transfer and extract knowledge:
|
||||
|
||||
```
|
||||
Working Memory ──> Episodic Memory ──> Semantic Memory
|
||||
└──> Procedural Memory
|
||||
```
|
||||
|
||||
**Consolidation Types:**
|
||||
- `working_to_episodic`: Transfer session state to episodes (on session end)
|
||||
- `episodic_to_semantic`: Extract facts from experiences
|
||||
- `episodic_to_procedural`: Learn procedures from patterns
|
||||
- `prune`: Remove low-value memories
|
||||
|
||||
**Celery Tasks:**
|
||||
```python
|
||||
from app.tasks.memory_consolidation import (
|
||||
consolidate_session,
|
||||
run_nightly_consolidation,
|
||||
prune_old_memories,
|
||||
)
|
||||
|
||||
# Manual consolidation
|
||||
consolidate_session.delay(session_id)
|
||||
|
||||
# Scheduled nightly (3 AM by default)
|
||||
run_nightly_consolidation.delay()
|
||||
```
|
||||
|
||||
## Memory Retrieval
|
||||
|
||||
### Hybrid Retrieval
|
||||
Combine multiple retrieval strategies:
|
||||
|
||||
```python
|
||||
from app.services.memory.indexing import RetrievalEngine
|
||||
|
||||
engine = RetrievalEngine(session, embedder)
|
||||
|
||||
# Hybrid search across memory types
|
||||
results = await engine.retrieve_hybrid(
|
||||
project_id=project_id,
|
||||
query="authentication error handling",
|
||||
memory_types=["episodic", "semantic", "procedural"],
|
||||
filters={"outcome": "success"},
|
||||
limit=10,
|
||||
)
|
||||
```
|
||||
|
||||
### Index Types
|
||||
- **Vector Index**: Semantic similarity (HNSW/pgvector)
|
||||
- **Temporal Index**: Time-based retrieval
|
||||
- **Entity Index**: Entity mention lookup
|
||||
- **Outcome Index**: Success/failure filtering
|
||||
|
||||
## MCP Tools
|
||||
|
||||
The memory system exposes MCP tools for agent use:
|
||||
|
||||
### `remember`
|
||||
Store information in memory.
|
||||
```json
|
||||
{
|
||||
"memory_type": "working",
|
||||
"content": {"key": "value"},
|
||||
"importance": 0.8,
|
||||
"ttl_seconds": 3600
|
||||
}
|
||||
```
|
||||
|
||||
### `recall`
|
||||
Retrieve from memory.
|
||||
```json
|
||||
{
|
||||
"query": "authentication patterns",
|
||||
"memory_types": ["episodic", "semantic"],
|
||||
"limit": 10,
|
||||
"filters": {"outcome": "success"}
|
||||
}
|
||||
```
|
||||
|
||||
### `forget`
|
||||
Remove from memory.
|
||||
```json
|
||||
{
|
||||
"memory_type": "working",
|
||||
"key": "temp_data"
|
||||
}
|
||||
```
|
||||
|
||||
### `reflect`
|
||||
Analyze memory patterns.
|
||||
```json
|
||||
{
|
||||
"analysis_type": "success_factors",
|
||||
"task_type": "code_review",
|
||||
"time_range_days": 30
|
||||
}
|
||||
```
|
||||
|
||||
### `get_memory_stats`
|
||||
Get memory usage statistics.
|
||||
|
||||
### `record_outcome`
|
||||
Record task success/failure for learning.
|
||||
|
||||
## Memory Reflection
|
||||
|
||||
Analyze patterns and generate insights from memory:
|
||||
|
||||
```python
|
||||
from app.services.memory.reflection import MemoryReflection, TimeRange
|
||||
|
||||
reflection = MemoryReflection(session)
|
||||
|
||||
# Detect patterns
|
||||
patterns = await reflection.analyze_patterns(
|
||||
project_id=project_id,
|
||||
time_range=TimeRange.last_days(30),
|
||||
)
|
||||
|
||||
# Identify success factors
|
||||
factors = await reflection.identify_success_factors(
|
||||
project_id=project_id,
|
||||
task_type="code_review",
|
||||
)
|
||||
|
||||
# Detect anomalies
|
||||
anomalies = await reflection.detect_anomalies(
|
||||
project_id=project_id,
|
||||
baseline_days=30,
|
||||
)
|
||||
|
||||
# Generate insights
|
||||
insights = await reflection.generate_insights(project_id)
|
||||
|
||||
# Comprehensive reflection
|
||||
result = await reflection.reflect(project_id)
|
||||
print(result.summary)
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
All settings use the `MEM_` environment variable prefix:
|
||||
|
||||
| Variable | Default | Description |
|
||||
|----------|---------|-------------|
|
||||
| `MEM_WORKING_MEMORY_BACKEND` | `redis` | Backend: `redis` or `memory` |
|
||||
| `MEM_WORKING_MEMORY_DEFAULT_TTL_SECONDS` | `3600` | Default TTL (1 hour) |
|
||||
| `MEM_REDIS_URL` | `redis://localhost:6379/0` | Redis connection URL |
|
||||
| `MEM_EPISODIC_MAX_EPISODES_PER_PROJECT` | `10000` | Max episodes per project |
|
||||
| `MEM_EPISODIC_RETENTION_DAYS` | `365` | Episode retention period |
|
||||
| `MEM_SEMANTIC_MAX_FACTS_PER_PROJECT` | `50000` | Max facts per project |
|
||||
| `MEM_SEMANTIC_CONFIDENCE_DECAY_DAYS` | `90` | Confidence half-life |
|
||||
| `MEM_EMBEDDING_MODEL` | `text-embedding-3-small` | Embedding model |
|
||||
| `MEM_EMBEDDING_DIMENSIONS` | `1536` | Vector dimensions |
|
||||
| `MEM_RETRIEVAL_MIN_SIMILARITY` | `0.5` | Minimum similarity score |
|
||||
| `MEM_CONSOLIDATION_ENABLED` | `true` | Enable auto-consolidation |
|
||||
| `MEM_CONSOLIDATION_SCHEDULE_CRON` | `0 3 * * *` | Nightly schedule |
|
||||
| `MEM_CACHE_ENABLED` | `true` | Enable retrieval caching |
|
||||
| `MEM_CACHE_TTL_SECONDS` | `300` | Cache TTL (5 minutes) |
|
||||
|
||||
See `app/services/memory/config.py` for complete configuration options.
|
||||
|
||||
## Integration with Context Engine
|
||||
|
||||
Memory integrates with the Context Engine as a context source:
|
||||
|
||||
```python
|
||||
from app.services.memory.integration import MemoryContextSource
|
||||
|
||||
# Register as context source
|
||||
source = MemoryContextSource(memory_manager)
|
||||
context_engine.register_source(source)
|
||||
|
||||
# Memory is automatically included in context assembly
|
||||
context = await context_engine.assemble_context(
|
||||
project_id=project_id,
|
||||
session_id=session_id,
|
||||
current_task="Review authentication code",
|
||||
)
|
||||
```
|
||||
|
||||
## Caching
|
||||
|
||||
Multi-layer caching for performance:
|
||||
|
||||
- **Hot Cache**: Frequently accessed memories (LRU)
|
||||
- **Retrieval Cache**: Query result caching
|
||||
- **Embedding Cache**: Pre-computed embeddings
|
||||
|
||||
```python
|
||||
from app.services.memory.cache import CacheManager
|
||||
|
||||
cache = CacheManager(settings)
|
||||
await cache.warm_hot_cache(project_id) # Pre-warm common memories
|
||||
```
|
||||
|
||||
## Metrics
|
||||
|
||||
Prometheus-compatible metrics:
|
||||
|
||||
| Metric | Type | Labels |
|
||||
|--------|------|--------|
|
||||
| `memory_operations_total` | Counter | operation, memory_type, scope, success |
|
||||
| `memory_retrievals_total` | Counter | memory_type, strategy |
|
||||
| `memory_cache_hits_total` | Counter | cache_type |
|
||||
| `memory_retrieval_latency_seconds` | Histogram | - |
|
||||
| `memory_consolidation_duration_seconds` | Histogram | - |
|
||||
| `memory_items_count` | Gauge | memory_type, scope |
|
||||
|
||||
```python
|
||||
from app.services.memory.metrics import get_memory_metrics
|
||||
|
||||
metrics = await get_memory_metrics()
|
||||
summary = await metrics.get_summary()
|
||||
prometheus_output = await metrics.get_prometheus_format()
|
||||
```
|
||||
|
||||
## Performance Targets
|
||||
|
||||
| Operation | Target P95 |
|
||||
|-----------|------------|
|
||||
| Working memory get/set | < 5ms |
|
||||
| Episodic memory retrieval | < 100ms |
|
||||
| Semantic memory search | < 100ms |
|
||||
| Procedural memory matching | < 50ms |
|
||||
| Consolidation batch (1000 items) | < 30s |
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Redis Connection Issues
|
||||
```bash
|
||||
# Check Redis connectivity
|
||||
redis-cli ping
|
||||
|
||||
# Verify memory settings
|
||||
MEM_REDIS_URL=redis://localhost:6379/0
|
||||
```
|
||||
|
||||
### Slow Retrieval
|
||||
1. Check if caching is enabled: `MEM_CACHE_ENABLED=true`
|
||||
2. Verify HNSW indexes exist on vector columns
|
||||
3. Monitor `memory_retrieval_latency_seconds` metric
|
||||
|
||||
### High Memory Usage
|
||||
1. Review `MEM_EPISODIC_MAX_EPISODES_PER_PROJECT` limit
|
||||
2. Ensure pruning is enabled: `MEM_PRUNING_ENABLED=true`
|
||||
3. Check consolidation is running (cron schedule)
|
||||
|
||||
### Embedding Errors
|
||||
1. Verify LLM Gateway is accessible
|
||||
2. Check embedding model is valid
|
||||
3. Review batch size if hitting rate limits
|
||||
|
||||
## Directory Structure
|
||||
|
||||
```
|
||||
app/services/memory/
|
||||
├── __init__.py # Public exports
|
||||
├── config.py # MemorySettings
|
||||
├── exceptions.py # Memory-specific errors
|
||||
├── manager.py # MemoryManager facade
|
||||
├── types.py # Core types
|
||||
├── working/ # Working memory
|
||||
│ ├── memory.py
|
||||
│ └── storage.py
|
||||
├── episodic/ # Episodic memory
|
||||
│ ├── memory.py
|
||||
│ ├── recorder.py
|
||||
│ └── retrieval.py
|
||||
├── semantic/ # Semantic memory
|
||||
│ ├── memory.py
|
||||
│ ├── extraction.py
|
||||
│ └── verification.py
|
||||
├── procedural/ # Procedural memory
|
||||
│ ├── memory.py
|
||||
│ └── matching.py
|
||||
├── scoping/ # Memory scoping
|
||||
│ ├── scope.py
|
||||
│ └── resolver.py
|
||||
├── indexing/ # Indexing & retrieval
|
||||
│ ├── index.py
|
||||
│ └── retrieval.py
|
||||
├── consolidation/ # Memory consolidation
|
||||
│ └── service.py
|
||||
├── reflection/ # Memory reflection
|
||||
│ ├── service.py
|
||||
│ └── types.py
|
||||
├── integration/ # External integrations
|
||||
│ ├── context_source.py
|
||||
│ └── lifecycle.py
|
||||
├── cache/ # Caching layer
|
||||
│ ├── cache_manager.py
|
||||
│ ├── hot_cache.py
|
||||
│ └── embedding_cache.py
|
||||
├── mcp/ # MCP tools
|
||||
│ ├── service.py
|
||||
│ └── tools.py
|
||||
└── metrics/ # Observability
|
||||
└── collector.py
|
||||
```
|
||||
@@ -188,13 +188,14 @@ class TestPasswordResetConfirm:
|
||||
@pytest.mark.asyncio
|
||||
async def test_password_reset_confirm_expired_token(self, client, async_test_user):
|
||||
"""Test password reset confirmation with expired token."""
|
||||
import time as time_module
|
||||
import asyncio
|
||||
|
||||
# Create token that expires immediately
|
||||
token = create_password_reset_token(async_test_user.email, expires_in=1)
|
||||
# Create token that expires at current second (expires_in=0)
|
||||
# Token expires when exp < current_time, so we need to cross a second boundary
|
||||
token = create_password_reset_token(async_test_user.email, expires_in=0)
|
||||
|
||||
# Wait for token to expire
|
||||
time_module.sleep(2)
|
||||
# Wait for token to expire (need to cross second boundary)
|
||||
await asyncio.sleep(1.1)
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/auth/password-reset/confirm",
|
||||
|
||||
2
backend/tests/models/memory/__init__.py
Normal file
2
backend/tests/models/memory/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# tests/unit/models/memory/__init__.py
|
||||
"""Unit tests for memory database models."""
|
||||
71
backend/tests/models/memory/test_enums.py
Normal file
71
backend/tests/models/memory/test_enums.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# tests/unit/models/memory/test_enums.py
|
||||
"""Unit tests for memory model enums."""
|
||||
|
||||
from app.models.memory.enums import (
|
||||
ConsolidationStatus,
|
||||
ConsolidationType,
|
||||
EpisodeOutcome,
|
||||
ScopeType,
|
||||
)
|
||||
|
||||
|
||||
class TestScopeType:
|
||||
"""Tests for ScopeType enum."""
|
||||
|
||||
def test_all_values_exist(self) -> None:
|
||||
"""Test all expected scope types exist."""
|
||||
assert ScopeType.GLOBAL.value == "global"
|
||||
assert ScopeType.PROJECT.value == "project"
|
||||
assert ScopeType.AGENT_TYPE.value == "agent_type"
|
||||
assert ScopeType.AGENT_INSTANCE.value == "agent_instance"
|
||||
assert ScopeType.SESSION.value == "session"
|
||||
|
||||
def test_scope_count(self) -> None:
|
||||
"""Test we have exactly 5 scope types."""
|
||||
assert len(ScopeType) == 5
|
||||
|
||||
|
||||
class TestEpisodeOutcome:
|
||||
"""Tests for EpisodeOutcome enum."""
|
||||
|
||||
def test_all_values_exist(self) -> None:
|
||||
"""Test all expected outcome values exist."""
|
||||
assert EpisodeOutcome.SUCCESS.value == "success"
|
||||
assert EpisodeOutcome.FAILURE.value == "failure"
|
||||
assert EpisodeOutcome.PARTIAL.value == "partial"
|
||||
|
||||
def test_outcome_count(self) -> None:
|
||||
"""Test we have exactly 3 outcome types."""
|
||||
assert len(EpisodeOutcome) == 3
|
||||
|
||||
|
||||
class TestConsolidationType:
|
||||
"""Tests for ConsolidationType enum."""
|
||||
|
||||
def test_all_values_exist(self) -> None:
|
||||
"""Test all expected consolidation types exist."""
|
||||
assert ConsolidationType.WORKING_TO_EPISODIC.value == "working_to_episodic"
|
||||
assert ConsolidationType.EPISODIC_TO_SEMANTIC.value == "episodic_to_semantic"
|
||||
assert (
|
||||
ConsolidationType.EPISODIC_TO_PROCEDURAL.value == "episodic_to_procedural"
|
||||
)
|
||||
assert ConsolidationType.PRUNING.value == "pruning"
|
||||
|
||||
def test_consolidation_count(self) -> None:
|
||||
"""Test we have exactly 4 consolidation types."""
|
||||
assert len(ConsolidationType) == 4
|
||||
|
||||
|
||||
class TestConsolidationStatus:
|
||||
"""Tests for ConsolidationStatus enum."""
|
||||
|
||||
def test_all_values_exist(self) -> None:
|
||||
"""Test all expected status values exist."""
|
||||
assert ConsolidationStatus.PENDING.value == "pending"
|
||||
assert ConsolidationStatus.RUNNING.value == "running"
|
||||
assert ConsolidationStatus.COMPLETED.value == "completed"
|
||||
assert ConsolidationStatus.FAILED.value == "failed"
|
||||
|
||||
def test_status_count(self) -> None:
|
||||
"""Test we have exactly 4 status types."""
|
||||
assert len(ConsolidationStatus) == 4
|
||||
249
backend/tests/models/memory/test_models.py
Normal file
249
backend/tests/models/memory/test_models.py
Normal file
@@ -0,0 +1,249 @@
|
||||
# tests/unit/models/memory/test_models.py
|
||||
"""Unit tests for memory database models."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
from app.models.memory import (
|
||||
ConsolidationStatus,
|
||||
ConsolidationType,
|
||||
Episode,
|
||||
EpisodeOutcome,
|
||||
Fact,
|
||||
MemoryConsolidationLog,
|
||||
Procedure,
|
||||
ScopeType,
|
||||
WorkingMemory,
|
||||
)
|
||||
|
||||
|
||||
class TestWorkingMemoryModel:
|
||||
"""Tests for WorkingMemory model."""
|
||||
|
||||
def test_tablename(self) -> None:
|
||||
"""Test table name is correct."""
|
||||
assert WorkingMemory.__tablename__ == "working_memory"
|
||||
|
||||
def test_has_required_columns(self) -> None:
|
||||
"""Test all required columns exist."""
|
||||
columns = WorkingMemory.__table__.columns
|
||||
assert "id" in columns
|
||||
assert "scope_type" in columns
|
||||
assert "scope_id" in columns
|
||||
assert "key" in columns
|
||||
assert "value" in columns
|
||||
assert "expires_at" in columns
|
||||
assert "created_at" in columns
|
||||
assert "updated_at" in columns
|
||||
|
||||
def test_has_unique_constraint(self) -> None:
|
||||
"""Test unique constraint on scope+key."""
|
||||
indexes = {idx.name: idx for idx in WorkingMemory.__table__.indexes}
|
||||
assert "ix_working_memory_scope_key" in indexes
|
||||
assert indexes["ix_working_memory_scope_key"].unique
|
||||
|
||||
|
||||
class TestEpisodeModel:
|
||||
"""Tests for Episode model."""
|
||||
|
||||
def test_tablename(self) -> None:
|
||||
"""Test table name is correct."""
|
||||
assert Episode.__tablename__ == "episodes"
|
||||
|
||||
def test_has_required_columns(self) -> None:
|
||||
"""Test all required columns exist."""
|
||||
columns = Episode.__table__.columns
|
||||
required = [
|
||||
"id",
|
||||
"project_id",
|
||||
"agent_instance_id",
|
||||
"agent_type_id",
|
||||
"session_id",
|
||||
"task_type",
|
||||
"task_description",
|
||||
"actions",
|
||||
"context_summary",
|
||||
"outcome",
|
||||
"outcome_details",
|
||||
"duration_seconds",
|
||||
"tokens_used",
|
||||
"lessons_learned",
|
||||
"importance_score",
|
||||
"embedding",
|
||||
"occurred_at",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
]
|
||||
for col in required:
|
||||
assert col in columns, f"Missing column: {col}"
|
||||
|
||||
def test_has_foreign_keys(self) -> None:
|
||||
"""Test foreign key relationships exist."""
|
||||
columns = Episode.__table__.columns
|
||||
assert columns["project_id"].foreign_keys
|
||||
assert columns["agent_instance_id"].foreign_keys
|
||||
assert columns["agent_type_id"].foreign_keys
|
||||
|
||||
def test_has_relationships(self) -> None:
|
||||
"""Test ORM relationships exist."""
|
||||
mapper = Episode.__mapper__
|
||||
assert "project" in mapper.relationships
|
||||
assert "agent_instance" in mapper.relationships
|
||||
assert "agent_type" in mapper.relationships
|
||||
|
||||
|
||||
class TestFactModel:
|
||||
"""Tests for Fact model."""
|
||||
|
||||
def test_tablename(self) -> None:
|
||||
"""Test table name is correct."""
|
||||
assert Fact.__tablename__ == "facts"
|
||||
|
||||
def test_has_required_columns(self) -> None:
|
||||
"""Test all required columns exist."""
|
||||
columns = Fact.__table__.columns
|
||||
required = [
|
||||
"id",
|
||||
"project_id",
|
||||
"subject",
|
||||
"predicate",
|
||||
"object",
|
||||
"confidence",
|
||||
"source_episode_ids",
|
||||
"first_learned",
|
||||
"last_reinforced",
|
||||
"reinforcement_count",
|
||||
"embedding",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
]
|
||||
for col in required:
|
||||
assert col in columns, f"Missing column: {col}"
|
||||
|
||||
def test_project_id_nullable(self) -> None:
|
||||
"""Test project_id is nullable for global facts."""
|
||||
columns = Fact.__table__.columns
|
||||
assert columns["project_id"].nullable
|
||||
|
||||
|
||||
class TestProcedureModel:
|
||||
"""Tests for Procedure model."""
|
||||
|
||||
def test_tablename(self) -> None:
|
||||
"""Test table name is correct."""
|
||||
assert Procedure.__tablename__ == "procedures"
|
||||
|
||||
def test_has_required_columns(self) -> None:
|
||||
"""Test all required columns exist."""
|
||||
columns = Procedure.__table__.columns
|
||||
required = [
|
||||
"id",
|
||||
"project_id",
|
||||
"agent_type_id",
|
||||
"name",
|
||||
"trigger_pattern",
|
||||
"steps",
|
||||
"success_count",
|
||||
"failure_count",
|
||||
"last_used",
|
||||
"embedding",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
]
|
||||
for col in required:
|
||||
assert col in columns, f"Missing column: {col}"
|
||||
|
||||
def test_success_rate_property(self) -> None:
|
||||
"""Test success_rate calculated property."""
|
||||
proc = Procedure()
|
||||
proc.success_count = 8
|
||||
proc.failure_count = 2
|
||||
assert proc.success_rate == 0.8
|
||||
|
||||
def test_success_rate_zero_total(self) -> None:
|
||||
"""Test success_rate with zero total uses."""
|
||||
proc = Procedure()
|
||||
proc.success_count = 0
|
||||
proc.failure_count = 0
|
||||
assert proc.success_rate == 0.0
|
||||
|
||||
def test_total_uses_property(self) -> None:
|
||||
"""Test total_uses calculated property."""
|
||||
proc = Procedure()
|
||||
proc.success_count = 5
|
||||
proc.failure_count = 3
|
||||
assert proc.total_uses == 8
|
||||
|
||||
|
||||
class TestMemoryConsolidationLogModel:
|
||||
"""Tests for MemoryConsolidationLog model."""
|
||||
|
||||
def test_tablename(self) -> None:
|
||||
"""Test table name is correct."""
|
||||
assert MemoryConsolidationLog.__tablename__ == "memory_consolidation_log"
|
||||
|
||||
def test_has_required_columns(self) -> None:
|
||||
"""Test all required columns exist."""
|
||||
columns = MemoryConsolidationLog.__table__.columns
|
||||
required = [
|
||||
"id",
|
||||
"consolidation_type",
|
||||
"source_count",
|
||||
"result_count",
|
||||
"started_at",
|
||||
"completed_at",
|
||||
"status",
|
||||
"error",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
]
|
||||
for col in required:
|
||||
assert col in columns, f"Missing column: {col}"
|
||||
|
||||
def test_duration_seconds_property_completed(self) -> None:
|
||||
"""Test duration_seconds with completed job."""
|
||||
log = MemoryConsolidationLog()
|
||||
log.started_at = datetime.now(UTC)
|
||||
log.completed_at = log.started_at + timedelta(seconds=10)
|
||||
assert log.duration_seconds == pytest.approx(10.0)
|
||||
|
||||
def test_duration_seconds_property_incomplete(self) -> None:
|
||||
"""Test duration_seconds with incomplete job."""
|
||||
log = MemoryConsolidationLog()
|
||||
log.started_at = datetime.now(UTC)
|
||||
log.completed_at = None
|
||||
assert log.duration_seconds is None
|
||||
|
||||
def test_default_status(self) -> None:
|
||||
"""Test default status is PENDING."""
|
||||
columns = MemoryConsolidationLog.__table__.columns
|
||||
assert columns["status"].default.arg == ConsolidationStatus.PENDING
|
||||
|
||||
|
||||
class TestModelExports:
|
||||
"""Tests for model package exports."""
|
||||
|
||||
def test_all_models_exported(self) -> None:
|
||||
"""Test all models are exported from package."""
|
||||
from app.models.memory import (
|
||||
Episode,
|
||||
Fact,
|
||||
MemoryConsolidationLog,
|
||||
Procedure,
|
||||
WorkingMemory,
|
||||
)
|
||||
|
||||
# Verify these are the actual classes
|
||||
assert Episode.__tablename__ == "episodes"
|
||||
assert Fact.__tablename__ == "facts"
|
||||
assert Procedure.__tablename__ == "procedures"
|
||||
assert WorkingMemory.__tablename__ == "working_memory"
|
||||
assert MemoryConsolidationLog.__tablename__ == "memory_consolidation_log"
|
||||
|
||||
def test_enums_exported(self) -> None:
|
||||
"""Test all enums are exported."""
|
||||
assert ScopeType.GLOBAL.value == "global"
|
||||
assert EpisodeOutcome.SUCCESS.value == "success"
|
||||
assert ConsolidationType.WORKING_TO_EPISODIC.value == "working_to_episodic"
|
||||
assert ConsolidationStatus.PENDING.value == "pending"
|
||||
@@ -304,10 +304,18 @@ class TestTaskModuleExports:
|
||||
assert hasattr(tasks, "sync")
|
||||
assert hasattr(tasks, "workflow")
|
||||
assert hasattr(tasks, "cost")
|
||||
assert hasattr(tasks, "memory_consolidation")
|
||||
|
||||
def test_tasks_all_attribute_is_correct(self):
|
||||
"""Test that __all__ contains all expected module names."""
|
||||
from app import tasks
|
||||
|
||||
expected_modules = ["agent", "git", "sync", "workflow", "cost"]
|
||||
expected_modules = [
|
||||
"agent",
|
||||
"git",
|
||||
"sync",
|
||||
"workflow",
|
||||
"cost",
|
||||
"memory_consolidation",
|
||||
]
|
||||
assert set(tasks.__all__) == set(expected_modules)
|
||||
|
||||
2
backend/tests/unit/models/__init__.py
Normal file
2
backend/tests/unit/models/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# tests/unit/models/__init__.py
|
||||
"""Unit tests for database models."""
|
||||
260
backend/tests/unit/services/context/types/test_memory.py
Normal file
260
backend/tests/unit/services/context/types/test_memory.py
Normal file
@@ -0,0 +1,260 @@
|
||||
# tests/unit/services/context/types/test_memory.py
|
||||
"""Tests for MemoryContext type."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
from app.services.context.types import ContextType
|
||||
from app.services.context.types.memory import MemoryContext, MemorySubtype
|
||||
|
||||
|
||||
class TestMemorySubtype:
|
||||
"""Tests for MemorySubtype enum."""
|
||||
|
||||
def test_all_types_defined(self) -> None:
|
||||
"""All memory subtypes should be defined."""
|
||||
assert MemorySubtype.WORKING == "working"
|
||||
assert MemorySubtype.EPISODIC == "episodic"
|
||||
assert MemorySubtype.SEMANTIC == "semantic"
|
||||
assert MemorySubtype.PROCEDURAL == "procedural"
|
||||
|
||||
def test_enum_values(self) -> None:
|
||||
"""Enum values should match strings."""
|
||||
assert MemorySubtype.WORKING.value == "working"
|
||||
assert MemorySubtype("episodic") == MemorySubtype.EPISODIC
|
||||
|
||||
|
||||
class TestMemoryContext:
|
||||
"""Tests for MemoryContext class."""
|
||||
|
||||
def test_get_type_returns_memory(self) -> None:
|
||||
"""get_type should return MEMORY."""
|
||||
ctx = MemoryContext(content="test", source="test_source")
|
||||
assert ctx.get_type() == ContextType.MEMORY
|
||||
|
||||
def test_default_values(self) -> None:
|
||||
"""Default values should be set correctly."""
|
||||
ctx = MemoryContext(content="test", source="test_source")
|
||||
assert ctx.memory_subtype == MemorySubtype.EPISODIC
|
||||
assert ctx.memory_id is None
|
||||
assert ctx.relevance_score == 0.0
|
||||
assert ctx.importance == 0.5
|
||||
|
||||
def test_to_dict_includes_memory_fields(self) -> None:
|
||||
"""to_dict should include memory-specific fields."""
|
||||
ctx = MemoryContext(
|
||||
content="test content",
|
||||
source="test_source",
|
||||
memory_subtype=MemorySubtype.SEMANTIC,
|
||||
memory_id="mem-123",
|
||||
relevance_score=0.8,
|
||||
subject="User",
|
||||
predicate="prefers",
|
||||
object_value="dark mode",
|
||||
)
|
||||
|
||||
data = ctx.to_dict()
|
||||
|
||||
assert data["memory_subtype"] == "semantic"
|
||||
assert data["memory_id"] == "mem-123"
|
||||
assert data["relevance_score"] == 0.8
|
||||
assert data["subject"] == "User"
|
||||
assert data["predicate"] == "prefers"
|
||||
assert data["object_value"] == "dark mode"
|
||||
|
||||
def test_from_dict(self) -> None:
|
||||
"""from_dict should create correct MemoryContext."""
|
||||
data = {
|
||||
"content": "test content",
|
||||
"source": "test_source",
|
||||
"timestamp": "2024-01-01T00:00:00+00:00",
|
||||
"memory_subtype": "semantic",
|
||||
"memory_id": "mem-123",
|
||||
"relevance_score": 0.8,
|
||||
"subject": "Test",
|
||||
}
|
||||
|
||||
ctx = MemoryContext.from_dict(data)
|
||||
|
||||
assert ctx.content == "test content"
|
||||
assert ctx.memory_subtype == MemorySubtype.SEMANTIC
|
||||
assert ctx.memory_id == "mem-123"
|
||||
assert ctx.subject == "Test"
|
||||
|
||||
|
||||
class TestMemoryContextFromWorkingMemory:
|
||||
"""Tests for MemoryContext.from_working_memory."""
|
||||
|
||||
def test_creates_working_memory_context(self) -> None:
|
||||
"""Should create working memory context from key/value."""
|
||||
ctx = MemoryContext.from_working_memory(
|
||||
key="user_preferences",
|
||||
value={"theme": "dark"},
|
||||
source="working:sess-123",
|
||||
query="preferences",
|
||||
)
|
||||
|
||||
assert ctx.memory_subtype == MemorySubtype.WORKING
|
||||
assert ctx.key == "user_preferences"
|
||||
assert "{'theme': 'dark'}" in ctx.content
|
||||
assert ctx.relevance_score == 1.0 # Working memory is always relevant
|
||||
assert ctx.importance == 0.8 # Higher importance
|
||||
|
||||
def test_string_value(self) -> None:
|
||||
"""Should handle string values."""
|
||||
ctx = MemoryContext.from_working_memory(
|
||||
key="current_task",
|
||||
value="Build authentication",
|
||||
)
|
||||
|
||||
assert ctx.content == "Build authentication"
|
||||
|
||||
|
||||
class TestMemoryContextFromEpisodicMemory:
|
||||
"""Tests for MemoryContext.from_episodic_memory."""
|
||||
|
||||
def test_creates_episodic_memory_context(self) -> None:
|
||||
"""Should create episodic memory context from episode."""
|
||||
episode = MagicMock()
|
||||
episode.id = uuid4()
|
||||
episode.task_description = "Implemented login feature"
|
||||
episode.task_type = "feature_implementation"
|
||||
episode.outcome = MagicMock(value="success")
|
||||
episode.importance_score = 0.9
|
||||
episode.session_id = "sess-123"
|
||||
episode.occurred_at = datetime.now(UTC)
|
||||
episode.lessons_learned = ["Use proper validation"]
|
||||
|
||||
ctx = MemoryContext.from_episodic_memory(episode, query="login")
|
||||
|
||||
assert ctx.memory_subtype == MemorySubtype.EPISODIC
|
||||
assert ctx.memory_id == str(episode.id)
|
||||
assert ctx.content == "Implemented login feature"
|
||||
assert ctx.task_type == "feature_implementation"
|
||||
assert ctx.outcome == "success"
|
||||
assert ctx.importance == 0.9
|
||||
|
||||
def test_handles_missing_outcome(self) -> None:
|
||||
"""Should handle episodes with no outcome."""
|
||||
episode = MagicMock()
|
||||
episode.id = uuid4()
|
||||
episode.task_description = "WIP task"
|
||||
episode.outcome = None
|
||||
episode.importance_score = 0.5
|
||||
episode.occurred_at = None
|
||||
|
||||
ctx = MemoryContext.from_episodic_memory(episode)
|
||||
|
||||
assert ctx.outcome is None
|
||||
|
||||
|
||||
class TestMemoryContextFromSemanticMemory:
|
||||
"""Tests for MemoryContext.from_semantic_memory."""
|
||||
|
||||
def test_creates_semantic_memory_context(self) -> None:
|
||||
"""Should create semantic memory context from fact."""
|
||||
fact = MagicMock()
|
||||
fact.id = uuid4()
|
||||
fact.subject = "User"
|
||||
fact.predicate = "prefers"
|
||||
fact.object = "dark mode"
|
||||
fact.confidence = 0.95
|
||||
|
||||
ctx = MemoryContext.from_semantic_memory(fact, query="user preferences")
|
||||
|
||||
assert ctx.memory_subtype == MemorySubtype.SEMANTIC
|
||||
assert ctx.memory_id == str(fact.id)
|
||||
assert ctx.content == "User prefers dark mode"
|
||||
assert ctx.subject == "User"
|
||||
assert ctx.predicate == "prefers"
|
||||
assert ctx.object_value == "dark mode"
|
||||
assert ctx.relevance_score == 0.95
|
||||
|
||||
|
||||
class TestMemoryContextFromProceduralMemory:
|
||||
"""Tests for MemoryContext.from_procedural_memory."""
|
||||
|
||||
def test_creates_procedural_memory_context(self) -> None:
|
||||
"""Should create procedural memory context from procedure."""
|
||||
procedure = MagicMock()
|
||||
procedure.id = uuid4()
|
||||
procedure.name = "Deploy to Production"
|
||||
procedure.trigger_pattern = "When deploying to production"
|
||||
procedure.steps = [
|
||||
{"action": "run_tests"},
|
||||
{"action": "build_docker"},
|
||||
{"action": "deploy"},
|
||||
]
|
||||
procedure.success_rate = 0.85
|
||||
procedure.success_count = 10
|
||||
procedure.failure_count = 2
|
||||
|
||||
ctx = MemoryContext.from_procedural_memory(procedure, query="deploy")
|
||||
|
||||
assert ctx.memory_subtype == MemorySubtype.PROCEDURAL
|
||||
assert ctx.memory_id == str(procedure.id)
|
||||
assert "Deploy to Production" in ctx.content
|
||||
assert "When deploying to production" in ctx.content
|
||||
assert ctx.trigger == "When deploying to production"
|
||||
assert ctx.success_rate == 0.85
|
||||
assert ctx.metadata["steps_count"] == 3
|
||||
assert ctx.metadata["execution_count"] == 12
|
||||
|
||||
|
||||
class TestMemoryContextHelpers:
|
||||
"""Tests for MemoryContext helper methods."""
|
||||
|
||||
def test_is_working_memory(self) -> None:
|
||||
"""is_working_memory should return True for working memory."""
|
||||
ctx = MemoryContext(
|
||||
content="test",
|
||||
source="test",
|
||||
memory_subtype=MemorySubtype.WORKING,
|
||||
)
|
||||
assert ctx.is_working_memory() is True
|
||||
assert ctx.is_episodic_memory() is False
|
||||
|
||||
def test_is_episodic_memory(self) -> None:
|
||||
"""is_episodic_memory should return True for episodic memory."""
|
||||
ctx = MemoryContext(
|
||||
content="test",
|
||||
source="test",
|
||||
memory_subtype=MemorySubtype.EPISODIC,
|
||||
)
|
||||
assert ctx.is_episodic_memory() is True
|
||||
assert ctx.is_semantic_memory() is False
|
||||
|
||||
def test_is_semantic_memory(self) -> None:
|
||||
"""is_semantic_memory should return True for semantic memory."""
|
||||
ctx = MemoryContext(
|
||||
content="test",
|
||||
source="test",
|
||||
memory_subtype=MemorySubtype.SEMANTIC,
|
||||
)
|
||||
assert ctx.is_semantic_memory() is True
|
||||
assert ctx.is_procedural_memory() is False
|
||||
|
||||
def test_is_procedural_memory(self) -> None:
|
||||
"""is_procedural_memory should return True for procedural memory."""
|
||||
ctx = MemoryContext(
|
||||
content="test",
|
||||
source="test",
|
||||
memory_subtype=MemorySubtype.PROCEDURAL,
|
||||
)
|
||||
assert ctx.is_procedural_memory() is True
|
||||
assert ctx.is_working_memory() is False
|
||||
|
||||
def test_get_formatted_source(self) -> None:
|
||||
"""get_formatted_source should return formatted string."""
|
||||
ctx = MemoryContext(
|
||||
content="test",
|
||||
source="episodic:12345678-1234-1234-1234-123456789012",
|
||||
memory_subtype=MemorySubtype.EPISODIC,
|
||||
memory_id="12345678-1234-1234-1234-123456789012",
|
||||
)
|
||||
|
||||
formatted = ctx.get_formatted_source()
|
||||
|
||||
assert "[episodic]" in formatted
|
||||
assert "12345678..." in formatted
|
||||
1
backend/tests/unit/services/memory/__init__.py
Normal file
1
backend/tests/unit/services/memory/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for the Agent Memory System."""
|
||||
2
backend/tests/unit/services/memory/cache/__init__.py
vendored
Normal file
2
backend/tests/unit/services/memory/cache/__init__.py
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
# tests/unit/services/memory/cache/__init__.py
|
||||
"""Tests for memory caching layer."""
|
||||
331
backend/tests/unit/services/memory/cache/test_cache_manager.py
vendored
Normal file
331
backend/tests/unit/services/memory/cache/test_cache_manager.py
vendored
Normal file
@@ -0,0 +1,331 @@
|
||||
# tests/unit/services/memory/cache/test_cache_manager.py
|
||||
"""Tests for CacheManager."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.memory.cache.cache_manager import (
|
||||
CacheManager,
|
||||
CacheStats,
|
||||
get_cache_manager,
|
||||
reset_cache_manager,
|
||||
)
|
||||
from app.services.memory.cache.embedding_cache import EmbeddingCache
|
||||
from app.services.memory.cache.hot_cache import HotMemoryCache
|
||||
|
||||
pytestmark = pytest.mark.asyncio(loop_scope="function")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_singleton() -> None:
|
||||
"""Reset singleton before each test."""
|
||||
reset_cache_manager()
|
||||
|
||||
|
||||
class TestCacheStats:
|
||||
"""Tests for CacheStats."""
|
||||
|
||||
def test_to_dict(self) -> None:
|
||||
"""Should convert to dictionary."""
|
||||
from datetime import UTC, datetime
|
||||
|
||||
stats = CacheStats(
|
||||
hot_cache={"hits": 10},
|
||||
embedding_cache={"hits": 20},
|
||||
overall_hit_rate=0.75,
|
||||
last_cleanup=datetime.now(UTC),
|
||||
cleanup_count=5,
|
||||
)
|
||||
|
||||
result = stats.to_dict()
|
||||
|
||||
assert result["hot_cache"] == {"hits": 10}
|
||||
assert result["overall_hit_rate"] == 0.75
|
||||
assert result["cleanup_count"] == 5
|
||||
assert result["last_cleanup"] is not None
|
||||
|
||||
|
||||
class TestCacheManager:
|
||||
"""Tests for CacheManager."""
|
||||
|
||||
@pytest.fixture
|
||||
def manager(self) -> CacheManager:
|
||||
"""Create a cache manager."""
|
||||
return CacheManager()
|
||||
|
||||
def test_is_enabled(self, manager: CacheManager) -> None:
|
||||
"""Should check if caching is enabled."""
|
||||
# Default is enabled from settings
|
||||
assert manager.is_enabled is True
|
||||
|
||||
def test_has_hot_cache(self, manager: CacheManager) -> None:
|
||||
"""Should have hot memory cache."""
|
||||
assert manager.hot_cache is not None
|
||||
assert isinstance(manager.hot_cache, HotMemoryCache)
|
||||
|
||||
def test_has_embedding_cache(self, manager: CacheManager) -> None:
|
||||
"""Should have embedding cache."""
|
||||
assert manager.embedding_cache is not None
|
||||
assert isinstance(manager.embedding_cache, EmbeddingCache)
|
||||
|
||||
def test_cache_memory(self, manager: CacheManager) -> None:
|
||||
"""Should cache memory in hot cache."""
|
||||
memory_id = uuid4()
|
||||
memory = {"task": "test", "data": "value"}
|
||||
|
||||
manager.cache_memory("episodic", memory_id, memory)
|
||||
result = manager.get_memory("episodic", memory_id)
|
||||
|
||||
assert result == memory
|
||||
|
||||
def test_cache_memory_with_scope(self, manager: CacheManager) -> None:
|
||||
"""Should cache memory with scope."""
|
||||
memory_id = uuid4()
|
||||
memory = {"task": "test"}
|
||||
|
||||
manager.cache_memory("semantic", memory_id, memory, scope="proj-123")
|
||||
result = manager.get_memory("semantic", memory_id, scope="proj-123")
|
||||
|
||||
assert result == memory
|
||||
|
||||
async def test_cache_embedding(self, manager: CacheManager) -> None:
|
||||
"""Should cache embedding."""
|
||||
content = "test content"
|
||||
embedding = [0.1, 0.2, 0.3]
|
||||
|
||||
content_hash = await manager.cache_embedding(content, embedding)
|
||||
result = await manager.get_embedding(content)
|
||||
|
||||
assert result == embedding
|
||||
assert len(content_hash) == 32
|
||||
|
||||
async def test_invalidate_memory(self, manager: CacheManager) -> None:
|
||||
"""Should invalidate memory from hot cache."""
|
||||
memory_id = uuid4()
|
||||
manager.cache_memory("episodic", memory_id, {"data": "test"})
|
||||
|
||||
count = await manager.invalidate_memory("episodic", memory_id)
|
||||
|
||||
assert count >= 1
|
||||
assert manager.get_memory("episodic", memory_id) is None
|
||||
|
||||
async def test_invalidate_by_type(self, manager: CacheManager) -> None:
|
||||
"""Should invalidate all entries of a type."""
|
||||
manager.cache_memory("episodic", uuid4(), {"data": "1"})
|
||||
manager.cache_memory("episodic", uuid4(), {"data": "2"})
|
||||
manager.cache_memory("semantic", uuid4(), {"data": "3"})
|
||||
|
||||
count = await manager.invalidate_by_type("episodic")
|
||||
|
||||
assert count >= 2
|
||||
|
||||
async def test_invalidate_by_scope(self, manager: CacheManager) -> None:
|
||||
"""Should invalidate all entries in a scope."""
|
||||
manager.cache_memory("episodic", uuid4(), {"data": "1"}, scope="proj-1")
|
||||
manager.cache_memory("semantic", uuid4(), {"data": "2"}, scope="proj-1")
|
||||
manager.cache_memory("episodic", uuid4(), {"data": "3"}, scope="proj-2")
|
||||
|
||||
count = await manager.invalidate_by_scope("proj-1")
|
||||
|
||||
assert count >= 2
|
||||
|
||||
async def test_invalidate_embedding(self, manager: CacheManager) -> None:
|
||||
"""Should invalidate cached embedding."""
|
||||
content = "test content"
|
||||
await manager.cache_embedding(content, [0.1, 0.2])
|
||||
|
||||
result = await manager.invalidate_embedding(content)
|
||||
|
||||
assert result is True
|
||||
assert await manager.get_embedding(content) is None
|
||||
|
||||
async def test_clear_all(self, manager: CacheManager) -> None:
|
||||
"""Should clear all caches."""
|
||||
manager.cache_memory("episodic", uuid4(), {"data": "test"})
|
||||
await manager.cache_embedding("content", [0.1])
|
||||
|
||||
count = await manager.clear_all()
|
||||
|
||||
assert count >= 2
|
||||
|
||||
async def test_cleanup_expired(self, manager: CacheManager) -> None:
|
||||
"""Should clean up expired entries."""
|
||||
count = await manager.cleanup_expired()
|
||||
|
||||
# May be 0 if no expired entries
|
||||
assert count >= 0
|
||||
assert manager._cleanup_count == 1
|
||||
assert manager._last_cleanup is not None
|
||||
|
||||
def test_get_stats(self, manager: CacheManager) -> None:
|
||||
"""Should return aggregated statistics."""
|
||||
manager.cache_memory("episodic", uuid4(), {"data": "test"})
|
||||
|
||||
stats = manager.get_stats()
|
||||
|
||||
assert "hot_cache" in stats.to_dict()
|
||||
assert "embedding_cache" in stats.to_dict()
|
||||
assert "overall_hit_rate" in stats.to_dict()
|
||||
|
||||
def test_get_hot_memories(self, manager: CacheManager) -> None:
|
||||
"""Should return most accessed memories."""
|
||||
id1 = uuid4()
|
||||
id2 = uuid4()
|
||||
|
||||
manager.cache_memory("episodic", id1, {"data": "1"})
|
||||
manager.cache_memory("episodic", id2, {"data": "2"})
|
||||
|
||||
# Access first multiple times
|
||||
for _ in range(5):
|
||||
manager.get_memory("episodic", id1)
|
||||
|
||||
hot = manager.get_hot_memories(limit=2)
|
||||
|
||||
assert len(hot) == 2
|
||||
|
||||
def test_reset_stats(self, manager: CacheManager) -> None:
|
||||
"""Should reset all statistics."""
|
||||
manager.cache_memory("episodic", uuid4(), {"data": "test"})
|
||||
manager.get_memory("episodic", uuid4()) # Miss
|
||||
|
||||
manager.reset_stats()
|
||||
|
||||
stats = manager.get_stats()
|
||||
assert stats.hot_cache.get("hits", 0) == 0
|
||||
|
||||
async def test_warmup(self, manager: CacheManager) -> None:
|
||||
"""Should warm up cache with memories."""
|
||||
memories = [
|
||||
("episodic", uuid4(), {"data": "1"}),
|
||||
("episodic", uuid4(), {"data": "2"}),
|
||||
("semantic", uuid4(), {"data": "3"}),
|
||||
]
|
||||
|
||||
count = await manager.warmup(memories)
|
||||
|
||||
assert count == 3
|
||||
|
||||
|
||||
class TestCacheManagerWithRetrieval:
|
||||
"""Tests for CacheManager with retrieval cache."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_retrieval_cache(self) -> MagicMock:
|
||||
"""Create mock retrieval cache."""
|
||||
cache = MagicMock()
|
||||
cache.invalidate_by_memory = MagicMock(return_value=1)
|
||||
cache.clear = MagicMock(return_value=5)
|
||||
cache.get_stats = MagicMock(return_value={"entries": 10})
|
||||
return cache
|
||||
|
||||
@pytest.fixture
|
||||
def manager_with_retrieval(
|
||||
self,
|
||||
mock_retrieval_cache: MagicMock,
|
||||
) -> CacheManager:
|
||||
"""Create manager with retrieval cache."""
|
||||
manager = CacheManager()
|
||||
manager.set_retrieval_cache(mock_retrieval_cache)
|
||||
return manager
|
||||
|
||||
async def test_invalidate_clears_retrieval(
|
||||
self,
|
||||
manager_with_retrieval: CacheManager,
|
||||
mock_retrieval_cache: MagicMock,
|
||||
) -> None:
|
||||
"""Should invalidate retrieval cache entries."""
|
||||
memory_id = uuid4()
|
||||
|
||||
await manager_with_retrieval.invalidate_memory("episodic", memory_id)
|
||||
|
||||
mock_retrieval_cache.invalidate_by_memory.assert_called_once_with(memory_id)
|
||||
|
||||
def test_stats_includes_retrieval(
|
||||
self,
|
||||
manager_with_retrieval: CacheManager,
|
||||
) -> None:
|
||||
"""Should include retrieval cache stats."""
|
||||
stats = manager_with_retrieval.get_stats()
|
||||
|
||||
assert "retrieval_cache" in stats.to_dict()
|
||||
|
||||
|
||||
class TestCacheManagerDisabled:
|
||||
"""Tests for CacheManager when disabled."""
|
||||
|
||||
@pytest.fixture
|
||||
def disabled_manager(self) -> CacheManager:
|
||||
"""Create a disabled cache manager."""
|
||||
with patch(
|
||||
"app.services.memory.cache.cache_manager.get_memory_settings"
|
||||
) as mock_settings:
|
||||
settings = MagicMock()
|
||||
settings.cache_enabled = False
|
||||
settings.cache_max_items = 1000
|
||||
settings.cache_ttl_seconds = 300
|
||||
mock_settings.return_value = settings
|
||||
|
||||
return CacheManager()
|
||||
|
||||
def test_get_memory_returns_none(self, disabled_manager: CacheManager) -> None:
|
||||
"""Should return None when disabled."""
|
||||
disabled_manager.cache_memory("episodic", uuid4(), {"data": "test"})
|
||||
result = disabled_manager.get_memory("episodic", uuid4())
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_get_embedding_returns_none(
|
||||
self,
|
||||
disabled_manager: CacheManager,
|
||||
) -> None:
|
||||
"""Should return None for embeddings when disabled."""
|
||||
result = await disabled_manager.get_embedding("content")
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_warmup_returns_zero(self, disabled_manager: CacheManager) -> None:
|
||||
"""Should return 0 from warmup when disabled."""
|
||||
count = await disabled_manager.warmup([("episodic", uuid4(), {})])
|
||||
|
||||
assert count == 0
|
||||
|
||||
|
||||
class TestGetCacheManager:
|
||||
"""Tests for get_cache_manager factory."""
|
||||
|
||||
def test_returns_singleton(self) -> None:
|
||||
"""Should return same instance."""
|
||||
manager1 = get_cache_manager()
|
||||
manager2 = get_cache_manager()
|
||||
|
||||
assert manager1 is manager2
|
||||
|
||||
def test_reset_creates_new(self) -> None:
|
||||
"""Should create new instance after reset."""
|
||||
manager1 = get_cache_manager()
|
||||
reset_cache_manager()
|
||||
manager2 = get_cache_manager()
|
||||
|
||||
assert manager1 is not manager2
|
||||
|
||||
def test_reset_parameter(self) -> None:
|
||||
"""Should create new instance with reset=True."""
|
||||
manager1 = get_cache_manager()
|
||||
manager2 = get_cache_manager(reset=True)
|
||||
|
||||
assert manager1 is not manager2
|
||||
|
||||
|
||||
class TestResetCacheManager:
|
||||
"""Tests for reset_cache_manager."""
|
||||
|
||||
def test_resets_singleton(self) -> None:
|
||||
"""Should reset the singleton."""
|
||||
get_cache_manager()
|
||||
reset_cache_manager()
|
||||
|
||||
# Next call should create new instance
|
||||
manager = get_cache_manager()
|
||||
assert manager is not None
|
||||
391
backend/tests/unit/services/memory/cache/test_embedding_cache.py
vendored
Normal file
391
backend/tests/unit/services/memory/cache/test_embedding_cache.py
vendored
Normal file
@@ -0,0 +1,391 @@
|
||||
# tests/unit/services/memory/cache/test_embedding_cache.py
|
||||
"""Tests for EmbeddingCache."""
|
||||
|
||||
import time
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.memory.cache.embedding_cache import (
|
||||
CachedEmbeddingGenerator,
|
||||
EmbeddingCache,
|
||||
EmbeddingCacheStats,
|
||||
EmbeddingEntry,
|
||||
create_embedding_cache,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.asyncio(loop_scope="function")
|
||||
|
||||
|
||||
class TestEmbeddingEntry:
|
||||
"""Tests for EmbeddingEntry."""
|
||||
|
||||
def test_creates_entry(self) -> None:
|
||||
"""Should create entry with embedding."""
|
||||
from datetime import UTC, datetime
|
||||
|
||||
entry = EmbeddingEntry(
|
||||
embedding=[0.1, 0.2, 0.3],
|
||||
content_hash="abc123",
|
||||
model="text-embedding-3-small",
|
||||
created_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
assert entry.embedding == [0.1, 0.2, 0.3]
|
||||
assert entry.content_hash == "abc123"
|
||||
assert entry.ttl_seconds == 3600.0
|
||||
|
||||
def test_is_expired(self) -> None:
|
||||
"""Should detect expired entries."""
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
old_time = datetime.now(UTC) - timedelta(seconds=4000)
|
||||
entry = EmbeddingEntry(
|
||||
embedding=[0.1],
|
||||
content_hash="abc",
|
||||
model="default",
|
||||
created_at=old_time,
|
||||
ttl_seconds=3600.0,
|
||||
)
|
||||
|
||||
assert entry.is_expired() is True
|
||||
|
||||
def test_not_expired(self) -> None:
|
||||
"""Should detect non-expired entries."""
|
||||
from datetime import UTC, datetime
|
||||
|
||||
entry = EmbeddingEntry(
|
||||
embedding=[0.1],
|
||||
content_hash="abc",
|
||||
model="default",
|
||||
created_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
assert entry.is_expired() is False
|
||||
|
||||
|
||||
class TestEmbeddingCacheStats:
|
||||
"""Tests for EmbeddingCacheStats."""
|
||||
|
||||
def test_hit_rate_calculation(self) -> None:
|
||||
"""Should calculate hit rate correctly."""
|
||||
stats = EmbeddingCacheStats(hits=90, misses=10)
|
||||
|
||||
assert stats.hit_rate == 0.9
|
||||
|
||||
def test_hit_rate_zero_requests(self) -> None:
|
||||
"""Should return 0 for no requests."""
|
||||
stats = EmbeddingCacheStats()
|
||||
|
||||
assert stats.hit_rate == 0.0
|
||||
|
||||
def test_to_dict(self) -> None:
|
||||
"""Should convert to dictionary."""
|
||||
stats = EmbeddingCacheStats(hits=10, misses=5, bytes_saved=1000)
|
||||
|
||||
result = stats.to_dict()
|
||||
|
||||
assert result["hits"] == 10
|
||||
assert result["bytes_saved"] == 1000
|
||||
|
||||
|
||||
class TestEmbeddingCache:
|
||||
"""Tests for EmbeddingCache."""
|
||||
|
||||
@pytest.fixture
|
||||
def cache(self) -> EmbeddingCache:
|
||||
"""Create an embedding cache."""
|
||||
return EmbeddingCache(max_size=100, default_ttl_seconds=300.0)
|
||||
|
||||
async def test_put_and_get(self, cache: EmbeddingCache) -> None:
|
||||
"""Should store and retrieve embeddings."""
|
||||
content = "Hello world"
|
||||
embedding = [0.1, 0.2, 0.3, 0.4]
|
||||
|
||||
content_hash = await cache.put(content, embedding)
|
||||
result = await cache.get(content)
|
||||
|
||||
assert result == embedding
|
||||
assert len(content_hash) == 32
|
||||
|
||||
async def test_get_missing(self, cache: EmbeddingCache) -> None:
|
||||
"""Should return None for missing content."""
|
||||
result = await cache.get("nonexistent content")
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_get_by_hash(self, cache: EmbeddingCache) -> None:
|
||||
"""Should get by content hash."""
|
||||
content = "Test content"
|
||||
embedding = [0.1, 0.2]
|
||||
|
||||
content_hash = await cache.put(content, embedding)
|
||||
result = await cache.get_by_hash(content_hash)
|
||||
|
||||
assert result == embedding
|
||||
|
||||
async def test_model_separation(self, cache: EmbeddingCache) -> None:
|
||||
"""Should separate embeddings by model."""
|
||||
content = "Same content"
|
||||
emb1 = [0.1, 0.2]
|
||||
emb2 = [0.3, 0.4]
|
||||
|
||||
await cache.put(content, emb1, model="model-a")
|
||||
await cache.put(content, emb2, model="model-b")
|
||||
|
||||
result1 = await cache.get(content, model="model-a")
|
||||
result2 = await cache.get(content, model="model-b")
|
||||
|
||||
assert result1 == emb1
|
||||
assert result2 == emb2
|
||||
|
||||
async def test_lru_eviction(self) -> None:
|
||||
"""Should evict LRU entries when at capacity."""
|
||||
cache = EmbeddingCache(max_size=3)
|
||||
|
||||
await cache.put("content1", [0.1])
|
||||
await cache.put("content2", [0.2])
|
||||
await cache.put("content3", [0.3])
|
||||
|
||||
# Access first to make it recent
|
||||
await cache.get("content1")
|
||||
|
||||
# Add fourth, should evict second (LRU)
|
||||
await cache.put("content4", [0.4])
|
||||
|
||||
assert await cache.get("content1") is not None
|
||||
assert await cache.get("content2") is None # Evicted
|
||||
assert await cache.get("content3") is not None
|
||||
assert await cache.get("content4") is not None
|
||||
|
||||
async def test_ttl_expiration(self) -> None:
|
||||
"""Should expire entries after TTL."""
|
||||
cache = EmbeddingCache(max_size=100, default_ttl_seconds=0.05)
|
||||
|
||||
await cache.put("content", [0.1, 0.2])
|
||||
|
||||
time.sleep(0.06)
|
||||
|
||||
result = await cache.get("content")
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_put_batch(self, cache: EmbeddingCache) -> None:
|
||||
"""Should cache multiple embeddings."""
|
||||
items = [
|
||||
("content1", [0.1]),
|
||||
("content2", [0.2]),
|
||||
("content3", [0.3]),
|
||||
]
|
||||
|
||||
hashes = await cache.put_batch(items)
|
||||
|
||||
assert len(hashes) == 3
|
||||
assert await cache.get("content1") == [0.1]
|
||||
assert await cache.get("content2") == [0.2]
|
||||
|
||||
async def test_invalidate(self, cache: EmbeddingCache) -> None:
|
||||
"""Should invalidate cached embedding."""
|
||||
await cache.put("content", [0.1, 0.2])
|
||||
|
||||
result = await cache.invalidate("content")
|
||||
|
||||
assert result is True
|
||||
assert await cache.get("content") is None
|
||||
|
||||
async def test_invalidate_by_hash(self, cache: EmbeddingCache) -> None:
|
||||
"""Should invalidate by hash."""
|
||||
content_hash = await cache.put("content", [0.1, 0.2])
|
||||
|
||||
result = await cache.invalidate_by_hash(content_hash)
|
||||
|
||||
assert result is True
|
||||
assert await cache.get("content") is None
|
||||
|
||||
async def test_invalidate_by_model(self, cache: EmbeddingCache) -> None:
|
||||
"""Should invalidate all embeddings for a model."""
|
||||
await cache.put("content1", [0.1], model="model-a")
|
||||
await cache.put("content2", [0.2], model="model-a")
|
||||
await cache.put("content3", [0.3], model="model-b")
|
||||
|
||||
count = await cache.invalidate_by_model("model-a")
|
||||
|
||||
assert count == 2
|
||||
assert await cache.get("content1", model="model-a") is None
|
||||
assert await cache.get("content3", model="model-b") is not None
|
||||
|
||||
async def test_clear(self, cache: EmbeddingCache) -> None:
|
||||
"""Should clear all entries."""
|
||||
await cache.put("content1", [0.1])
|
||||
await cache.put("content2", [0.2])
|
||||
|
||||
count = await cache.clear()
|
||||
|
||||
assert count == 2
|
||||
assert cache.size == 0
|
||||
|
||||
def test_cleanup_expired(self) -> None:
|
||||
"""Should remove expired entries."""
|
||||
cache = EmbeddingCache(max_size=100, default_ttl_seconds=0.05)
|
||||
|
||||
# Use synchronous put for setup
|
||||
cache._put_memory("hash1", "default", [0.1])
|
||||
cache._put_memory("hash2", "default", [0.2], ttl_seconds=10)
|
||||
|
||||
time.sleep(0.06)
|
||||
|
||||
count = cache.cleanup_expired()
|
||||
|
||||
assert count == 1
|
||||
|
||||
def test_get_stats(self, cache: EmbeddingCache) -> None:
|
||||
"""Should return accurate statistics."""
|
||||
# Put synchronously for setup
|
||||
cache._put_memory("hash1", "default", [0.1])
|
||||
|
||||
stats = cache.get_stats()
|
||||
|
||||
assert stats.current_size == 1
|
||||
|
||||
def test_hash_content(self) -> None:
|
||||
"""Should produce consistent hashes."""
|
||||
hash1 = EmbeddingCache.hash_content("test content")
|
||||
hash2 = EmbeddingCache.hash_content("test content")
|
||||
hash3 = EmbeddingCache.hash_content("different content")
|
||||
|
||||
assert hash1 == hash2
|
||||
assert hash1 != hash3
|
||||
assert len(hash1) == 32
|
||||
|
||||
|
||||
class TestEmbeddingCacheWithRedis:
|
||||
"""Tests for EmbeddingCache with Redis."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis(self) -> MagicMock:
|
||||
"""Create mock Redis."""
|
||||
redis = MagicMock()
|
||||
redis.get = AsyncMock(return_value=None)
|
||||
redis.setex = AsyncMock()
|
||||
redis.delete = AsyncMock()
|
||||
redis.scan_iter = MagicMock(return_value=iter([]))
|
||||
return redis
|
||||
|
||||
@pytest.fixture
|
||||
def cache_with_redis(self, mock_redis: MagicMock) -> EmbeddingCache:
|
||||
"""Create cache with mock Redis."""
|
||||
return EmbeddingCache(
|
||||
max_size=100,
|
||||
default_ttl_seconds=300.0,
|
||||
redis=mock_redis,
|
||||
)
|
||||
|
||||
async def test_put_stores_in_redis(
|
||||
self,
|
||||
cache_with_redis: EmbeddingCache,
|
||||
mock_redis: MagicMock,
|
||||
) -> None:
|
||||
"""Should store in Redis when available."""
|
||||
await cache_with_redis.put("content", [0.1, 0.2])
|
||||
|
||||
mock_redis.setex.assert_called_once()
|
||||
|
||||
async def test_get_checks_redis_on_miss(
|
||||
self,
|
||||
cache_with_redis: EmbeddingCache,
|
||||
mock_redis: MagicMock,
|
||||
) -> None:
|
||||
"""Should check Redis when memory cache misses."""
|
||||
import json
|
||||
|
||||
mock_redis.get.return_value = json.dumps([0.1, 0.2])
|
||||
|
||||
result = await cache_with_redis.get("content")
|
||||
|
||||
assert result == [0.1, 0.2]
|
||||
mock_redis.get.assert_called_once()
|
||||
|
||||
|
||||
class TestCachedEmbeddingGenerator:
|
||||
"""Tests for CachedEmbeddingGenerator."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_generator(self) -> MagicMock:
|
||||
"""Create mock embedding generator."""
|
||||
gen = MagicMock()
|
||||
gen.generate = AsyncMock(return_value=[0.1, 0.2, 0.3])
|
||||
gen.generate_batch = AsyncMock(return_value=[[0.1], [0.2], [0.3]])
|
||||
return gen
|
||||
|
||||
@pytest.fixture
|
||||
def cache(self) -> EmbeddingCache:
|
||||
"""Create embedding cache."""
|
||||
return EmbeddingCache(max_size=100)
|
||||
|
||||
@pytest.fixture
|
||||
def cached_gen(
|
||||
self,
|
||||
mock_generator: MagicMock,
|
||||
cache: EmbeddingCache,
|
||||
) -> CachedEmbeddingGenerator:
|
||||
"""Create cached generator."""
|
||||
return CachedEmbeddingGenerator(mock_generator, cache)
|
||||
|
||||
async def test_generate_caches_result(
|
||||
self,
|
||||
cached_gen: CachedEmbeddingGenerator,
|
||||
mock_generator: MagicMock,
|
||||
) -> None:
|
||||
"""Should cache generated embedding."""
|
||||
result1 = await cached_gen.generate("test text")
|
||||
result2 = await cached_gen.generate("test text")
|
||||
|
||||
assert result1 == [0.1, 0.2, 0.3]
|
||||
assert result2 == [0.1, 0.2, 0.3]
|
||||
mock_generator.generate.assert_called_once() # Only called once
|
||||
|
||||
async def test_generate_batch_uses_cache(
|
||||
self,
|
||||
cached_gen: CachedEmbeddingGenerator,
|
||||
mock_generator: MagicMock,
|
||||
cache: EmbeddingCache,
|
||||
) -> None:
|
||||
"""Should use cache for batch generation."""
|
||||
# Pre-cache one embedding
|
||||
await cache.put("text1", [0.5])
|
||||
|
||||
# Mock returns 2 embeddings for the 2 uncached texts
|
||||
mock_generator.generate_batch = AsyncMock(return_value=[[0.2], [0.3]])
|
||||
|
||||
results = await cached_gen.generate_batch(["text1", "text2", "text3"])
|
||||
|
||||
assert len(results) == 3
|
||||
assert results[0] == [0.5] # From cache
|
||||
assert results[1] == [0.2] # Generated
|
||||
assert results[2] == [0.3] # Generated
|
||||
|
||||
async def test_get_stats(self, cached_gen: CachedEmbeddingGenerator) -> None:
|
||||
"""Should return generator statistics."""
|
||||
await cached_gen.generate("text1")
|
||||
await cached_gen.generate("text1") # Cache hit
|
||||
|
||||
stats = cached_gen.get_stats()
|
||||
|
||||
assert stats["call_count"] == 2
|
||||
assert stats["cache_hit_count"] == 1
|
||||
|
||||
|
||||
class TestCreateEmbeddingCache:
|
||||
"""Tests for factory function."""
|
||||
|
||||
def test_creates_cache(self) -> None:
|
||||
"""Should create cache with defaults."""
|
||||
cache = create_embedding_cache()
|
||||
|
||||
assert cache.max_size == 50000
|
||||
|
||||
def test_creates_cache_with_options(self) -> None:
|
||||
"""Should create cache with custom options."""
|
||||
cache = create_embedding_cache(max_size=1000, default_ttl_seconds=600.0)
|
||||
|
||||
assert cache.max_size == 1000
|
||||
355
backend/tests/unit/services/memory/cache/test_hot_cache.py
vendored
Normal file
355
backend/tests/unit/services/memory/cache/test_hot_cache.py
vendored
Normal file
@@ -0,0 +1,355 @@
|
||||
# tests/unit/services/memory/cache/test_hot_cache.py
|
||||
"""Tests for HotMemoryCache."""
|
||||
|
||||
import time
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.memory.cache.hot_cache import (
|
||||
CacheEntry,
|
||||
CacheKey,
|
||||
HotCacheStats,
|
||||
HotMemoryCache,
|
||||
create_hot_cache,
|
||||
)
|
||||
|
||||
|
||||
class TestCacheKey:
|
||||
"""Tests for CacheKey."""
|
||||
|
||||
def test_creates_key(self) -> None:
|
||||
"""Should create key with required fields."""
|
||||
key = CacheKey(memory_type="episodic", memory_id="123")
|
||||
|
||||
assert key.memory_type == "episodic"
|
||||
assert key.memory_id == "123"
|
||||
assert key.scope is None
|
||||
|
||||
def test_creates_key_with_scope(self) -> None:
|
||||
"""Should create key with scope."""
|
||||
key = CacheKey(memory_type="semantic", memory_id="456", scope="proj-123")
|
||||
|
||||
assert key.scope == "proj-123"
|
||||
|
||||
def test_hash_and_equality(self) -> None:
|
||||
"""Keys with same values should be equal and have same hash."""
|
||||
key1 = CacheKey(memory_type="episodic", memory_id="123", scope="proj-1")
|
||||
key2 = CacheKey(memory_type="episodic", memory_id="123", scope="proj-1")
|
||||
|
||||
assert key1 == key2
|
||||
assert hash(key1) == hash(key2)
|
||||
|
||||
def test_str_representation(self) -> None:
|
||||
"""Should produce readable string."""
|
||||
key = CacheKey(memory_type="episodic", memory_id="123", scope="proj-1")
|
||||
|
||||
assert str(key) == "episodic:proj-1:123"
|
||||
|
||||
def test_str_without_scope(self) -> None:
|
||||
"""Should produce string without scope."""
|
||||
key = CacheKey(memory_type="episodic", memory_id="123")
|
||||
|
||||
assert str(key) == "episodic:123"
|
||||
|
||||
|
||||
class TestCacheEntry:
|
||||
"""Tests for CacheEntry."""
|
||||
|
||||
def test_creates_entry(self) -> None:
|
||||
"""Should create entry with value."""
|
||||
entry = CacheEntry(
|
||||
value={"data": "test"},
|
||||
created_at=pytest.importorskip("datetime").datetime.now(
|
||||
pytest.importorskip("datetime").UTC
|
||||
),
|
||||
last_accessed_at=pytest.importorskip("datetime").datetime.now(
|
||||
pytest.importorskip("datetime").UTC
|
||||
),
|
||||
)
|
||||
|
||||
assert entry.value == {"data": "test"}
|
||||
assert entry.access_count == 1
|
||||
assert entry.ttl_seconds == 300.0
|
||||
|
||||
def test_is_expired(self) -> None:
|
||||
"""Should detect expired entries."""
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
old_time = datetime.now(UTC) - timedelta(seconds=400)
|
||||
entry = CacheEntry(
|
||||
value="test",
|
||||
created_at=old_time,
|
||||
last_accessed_at=old_time,
|
||||
ttl_seconds=300.0,
|
||||
)
|
||||
|
||||
assert entry.is_expired() is True
|
||||
|
||||
def test_not_expired(self) -> None:
|
||||
"""Should detect non-expired entries."""
|
||||
from datetime import UTC, datetime
|
||||
|
||||
entry = CacheEntry(
|
||||
value="test",
|
||||
created_at=datetime.now(UTC),
|
||||
last_accessed_at=datetime.now(UTC),
|
||||
ttl_seconds=300.0,
|
||||
)
|
||||
|
||||
assert entry.is_expired() is False
|
||||
|
||||
def test_touch_updates_access(self) -> None:
|
||||
"""Touch should update access time and count."""
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
old_time = datetime.now(UTC) - timedelta(seconds=10)
|
||||
entry = CacheEntry(
|
||||
value="test",
|
||||
created_at=old_time,
|
||||
last_accessed_at=old_time,
|
||||
access_count=5,
|
||||
)
|
||||
|
||||
entry.touch()
|
||||
|
||||
assert entry.access_count == 6
|
||||
assert entry.last_accessed_at > old_time
|
||||
|
||||
|
||||
class TestHotCacheStats:
|
||||
"""Tests for HotCacheStats."""
|
||||
|
||||
def test_hit_rate_calculation(self) -> None:
|
||||
"""Should calculate hit rate correctly."""
|
||||
stats = HotCacheStats(hits=80, misses=20)
|
||||
|
||||
assert stats.hit_rate == 0.8
|
||||
|
||||
def test_hit_rate_zero_requests(self) -> None:
|
||||
"""Should return 0 for no requests."""
|
||||
stats = HotCacheStats()
|
||||
|
||||
assert stats.hit_rate == 0.0
|
||||
|
||||
def test_to_dict(self) -> None:
|
||||
"""Should convert to dictionary."""
|
||||
stats = HotCacheStats(hits=10, misses=5, evictions=2)
|
||||
|
||||
result = stats.to_dict()
|
||||
|
||||
assert result["hits"] == 10
|
||||
assert result["misses"] == 5
|
||||
assert result["evictions"] == 2
|
||||
assert "hit_rate" in result
|
||||
|
||||
|
||||
class TestHotMemoryCache:
|
||||
"""Tests for HotMemoryCache."""
|
||||
|
||||
@pytest.fixture
|
||||
def cache(self) -> HotMemoryCache[dict]:
|
||||
"""Create a hot memory cache."""
|
||||
return HotMemoryCache[dict](max_size=100, default_ttl_seconds=300.0)
|
||||
|
||||
def test_put_and_get(self, cache: HotMemoryCache[dict]) -> None:
|
||||
"""Should store and retrieve values."""
|
||||
key = CacheKey(memory_type="episodic", memory_id="123")
|
||||
value = {"data": "test"}
|
||||
|
||||
cache.put(key, value)
|
||||
result = cache.get(key)
|
||||
|
||||
assert result == value
|
||||
|
||||
def test_get_missing_key(self, cache: HotMemoryCache[dict]) -> None:
|
||||
"""Should return None for missing keys."""
|
||||
key = CacheKey(memory_type="episodic", memory_id="nonexistent")
|
||||
|
||||
result = cache.get(key)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_put_by_id(self, cache: HotMemoryCache[dict]) -> None:
|
||||
"""Should store by type and ID."""
|
||||
memory_id = uuid4()
|
||||
value = {"data": "test"}
|
||||
|
||||
cache.put_by_id("episodic", memory_id, value)
|
||||
result = cache.get_by_id("episodic", memory_id)
|
||||
|
||||
assert result == value
|
||||
|
||||
def test_put_by_id_with_scope(self, cache: HotMemoryCache[dict]) -> None:
|
||||
"""Should store with scope."""
|
||||
memory_id = uuid4()
|
||||
value = {"data": "test"}
|
||||
|
||||
cache.put_by_id("semantic", memory_id, value, scope="proj-123")
|
||||
result = cache.get_by_id("semantic", memory_id, scope="proj-123")
|
||||
|
||||
assert result == value
|
||||
|
||||
def test_lru_eviction(self) -> None:
|
||||
"""Should evict LRU entries when at capacity."""
|
||||
cache = HotMemoryCache[str](max_size=3)
|
||||
|
||||
# Fill cache
|
||||
cache.put_by_id("test", "1", "first")
|
||||
cache.put_by_id("test", "2", "second")
|
||||
cache.put_by_id("test", "3", "third")
|
||||
|
||||
# Access first to make it recent
|
||||
cache.get_by_id("test", "1")
|
||||
|
||||
# Add fourth, should evict second (LRU)
|
||||
cache.put_by_id("test", "4", "fourth")
|
||||
|
||||
assert cache.get_by_id("test", "1") is not None # Accessed, kept
|
||||
assert cache.get_by_id("test", "2") is None # Evicted (LRU)
|
||||
assert cache.get_by_id("test", "3") is not None
|
||||
assert cache.get_by_id("test", "4") is not None
|
||||
|
||||
def test_ttl_expiration(self) -> None:
|
||||
"""Should expire entries after TTL."""
|
||||
cache = HotMemoryCache[str](max_size=100, default_ttl_seconds=0.05)
|
||||
|
||||
cache.put_by_id("test", "1", "value")
|
||||
|
||||
# Wait for expiration
|
||||
time.sleep(0.06)
|
||||
|
||||
result = cache.get_by_id("test", "1")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_invalidate(self, cache: HotMemoryCache[dict]) -> None:
|
||||
"""Should invalidate specific entry."""
|
||||
key = CacheKey(memory_type="episodic", memory_id="123")
|
||||
cache.put(key, {"data": "test"})
|
||||
|
||||
result = cache.invalidate(key)
|
||||
|
||||
assert result is True
|
||||
assert cache.get(key) is None
|
||||
|
||||
def test_invalidate_by_id(self, cache: HotMemoryCache[dict]) -> None:
|
||||
"""Should invalidate by ID."""
|
||||
memory_id = uuid4()
|
||||
cache.put_by_id("episodic", memory_id, {"data": "test"})
|
||||
|
||||
result = cache.invalidate_by_id("episodic", memory_id)
|
||||
|
||||
assert result is True
|
||||
assert cache.get_by_id("episodic", memory_id) is None
|
||||
|
||||
def test_invalidate_by_type(self, cache: HotMemoryCache[dict]) -> None:
|
||||
"""Should invalidate all entries of a type."""
|
||||
cache.put_by_id("episodic", "1", {"data": "1"})
|
||||
cache.put_by_id("episodic", "2", {"data": "2"})
|
||||
cache.put_by_id("semantic", "3", {"data": "3"})
|
||||
|
||||
count = cache.invalidate_by_type("episodic")
|
||||
|
||||
assert count == 2
|
||||
assert cache.get_by_id("episodic", "1") is None
|
||||
assert cache.get_by_id("episodic", "2") is None
|
||||
assert cache.get_by_id("semantic", "3") is not None
|
||||
|
||||
def test_invalidate_by_scope(self, cache: HotMemoryCache[dict]) -> None:
|
||||
"""Should invalidate all entries in a scope."""
|
||||
cache.put_by_id("episodic", "1", {"data": "1"}, scope="proj-1")
|
||||
cache.put_by_id("semantic", "2", {"data": "2"}, scope="proj-1")
|
||||
cache.put_by_id("episodic", "3", {"data": "3"}, scope="proj-2")
|
||||
|
||||
count = cache.invalidate_by_scope("proj-1")
|
||||
|
||||
assert count == 2
|
||||
assert cache.get_by_id("episodic", "3", scope="proj-2") is not None
|
||||
|
||||
def test_invalidate_pattern(self, cache: HotMemoryCache[dict]) -> None:
|
||||
"""Should invalidate entries matching pattern."""
|
||||
cache.put_by_id("episodic", "123", {"data": "1"})
|
||||
cache.put_by_id("episodic", "124", {"data": "2"})
|
||||
cache.put_by_id("semantic", "125", {"data": "3"})
|
||||
|
||||
count = cache.invalidate_pattern("episodic:*")
|
||||
|
||||
assert count == 2
|
||||
|
||||
def test_clear(self, cache: HotMemoryCache[dict]) -> None:
|
||||
"""Should clear all entries."""
|
||||
cache.put_by_id("episodic", "1", {"data": "1"})
|
||||
cache.put_by_id("semantic", "2", {"data": "2"})
|
||||
|
||||
count = cache.clear()
|
||||
|
||||
assert count == 2
|
||||
assert cache.size == 0
|
||||
|
||||
def test_cleanup_expired(self) -> None:
|
||||
"""Should remove expired entries."""
|
||||
cache = HotMemoryCache[str](max_size=100, default_ttl_seconds=0.05)
|
||||
|
||||
cache.put_by_id("test", "1", "value1")
|
||||
cache.put_by_id("test", "2", "value2", ttl_seconds=10)
|
||||
|
||||
time.sleep(0.06)
|
||||
|
||||
count = cache.cleanup_expired()
|
||||
|
||||
assert count == 1 # Only the first one expired
|
||||
assert cache.size == 1
|
||||
|
||||
def test_get_hot_memories(self, cache: HotMemoryCache[dict]) -> None:
|
||||
"""Should return most accessed memories."""
|
||||
cache.put_by_id("episodic", "1", {"data": "1"})
|
||||
cache.put_by_id("episodic", "2", {"data": "2"})
|
||||
|
||||
# Access first one multiple times
|
||||
for _ in range(5):
|
||||
cache.get_by_id("episodic", "1")
|
||||
|
||||
hot = cache.get_hot_memories(limit=2)
|
||||
|
||||
assert len(hot) == 2
|
||||
assert hot[0][1] >= hot[1][1] # Sorted by access count
|
||||
|
||||
def test_get_stats(self, cache: HotMemoryCache[dict]) -> None:
|
||||
"""Should return accurate statistics."""
|
||||
cache.put_by_id("episodic", "1", {"data": "1"})
|
||||
cache.get_by_id("episodic", "1") # Hit
|
||||
cache.get_by_id("episodic", "2") # Miss
|
||||
|
||||
stats = cache.get_stats()
|
||||
|
||||
assert stats.hits == 1
|
||||
assert stats.misses == 1
|
||||
assert stats.current_size == 1
|
||||
|
||||
def test_reset_stats(self, cache: HotMemoryCache[dict]) -> None:
|
||||
"""Should reset statistics."""
|
||||
cache.put_by_id("episodic", "1", {"data": "1"})
|
||||
cache.get_by_id("episodic", "1")
|
||||
|
||||
cache.reset_stats()
|
||||
stats = cache.get_stats()
|
||||
|
||||
assert stats.hits == 0
|
||||
assert stats.misses == 0
|
||||
|
||||
|
||||
class TestCreateHotCache:
|
||||
"""Tests for factory function."""
|
||||
|
||||
def test_creates_cache(self) -> None:
|
||||
"""Should create cache with defaults."""
|
||||
cache = create_hot_cache()
|
||||
|
||||
assert cache.max_size == 10000
|
||||
|
||||
def test_creates_cache_with_options(self) -> None:
|
||||
"""Should create cache with custom options."""
|
||||
cache = create_hot_cache(max_size=500, default_ttl_seconds=60.0)
|
||||
|
||||
assert cache.max_size == 500
|
||||
@@ -0,0 +1,2 @@
|
||||
# tests/unit/services/memory/consolidation/__init__.py
|
||||
"""Tests for memory consolidation."""
|
||||
736
backend/tests/unit/services/memory/consolidation/test_service.py
Normal file
736
backend/tests/unit/services/memory/consolidation/test_service.py
Normal file
@@ -0,0 +1,736 @@
|
||||
# tests/unit/services/memory/consolidation/test_service.py
|
||||
"""Unit tests for memory consolidation service."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.memory.consolidation.service import (
|
||||
ConsolidationConfig,
|
||||
ConsolidationResult,
|
||||
MemoryConsolidationService,
|
||||
NightlyConsolidationResult,
|
||||
SessionConsolidationResult,
|
||||
)
|
||||
from app.services.memory.types import Episode, Outcome, TaskState
|
||||
|
||||
|
||||
def _utcnow() -> datetime:
|
||||
"""Get current UTC time."""
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
def make_episode(
|
||||
outcome: Outcome = Outcome.SUCCESS,
|
||||
occurred_at: datetime | None = None,
|
||||
task_type: str = "test_task",
|
||||
lessons_learned: list[str] | None = None,
|
||||
importance_score: float = 0.5,
|
||||
actions: list[dict] | None = None,
|
||||
) -> Episode:
|
||||
"""Create a test episode."""
|
||||
return Episode(
|
||||
id=uuid4(),
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=uuid4(),
|
||||
agent_type_id=uuid4(),
|
||||
session_id="test-session",
|
||||
task_type=task_type,
|
||||
task_description="Test task description",
|
||||
actions=actions or [{"action": "test"}],
|
||||
context_summary="Test context",
|
||||
outcome=outcome,
|
||||
outcome_details="Test outcome",
|
||||
duration_seconds=10.0,
|
||||
tokens_used=100,
|
||||
lessons_learned=lessons_learned or [],
|
||||
importance_score=importance_score,
|
||||
embedding=None,
|
||||
occurred_at=occurred_at or _utcnow(),
|
||||
created_at=_utcnow(),
|
||||
updated_at=_utcnow(),
|
||||
)
|
||||
|
||||
|
||||
def make_task_state(
|
||||
current_step: int = 5,
|
||||
total_steps: int = 10,
|
||||
progress_percent: float = 50.0,
|
||||
status: str = "in_progress",
|
||||
description: str = "Test Task",
|
||||
) -> TaskState:
|
||||
"""Create a test task state."""
|
||||
now = _utcnow()
|
||||
return TaskState(
|
||||
task_id="test-task-id",
|
||||
task_type="test_task",
|
||||
description=description,
|
||||
current_step=current_step,
|
||||
total_steps=total_steps,
|
||||
status=status,
|
||||
progress_percent=progress_percent,
|
||||
started_at=now - timedelta(hours=1),
|
||||
updated_at=now,
|
||||
)
|
||||
|
||||
|
||||
class TestConsolidationConfig:
|
||||
"""Tests for ConsolidationConfig."""
|
||||
|
||||
def test_default_values(self) -> None:
|
||||
"""Test default configuration values."""
|
||||
config = ConsolidationConfig()
|
||||
|
||||
assert config.min_steps_for_episode == 2
|
||||
assert config.min_duration_seconds == 5.0
|
||||
assert config.min_confidence_for_fact == 0.6
|
||||
assert config.max_facts_per_episode == 10
|
||||
assert config.min_episodes_for_procedure == 3
|
||||
assert config.max_episode_age_days == 90
|
||||
assert config.batch_size == 100
|
||||
|
||||
def test_custom_values(self) -> None:
|
||||
"""Test custom configuration values."""
|
||||
config = ConsolidationConfig(
|
||||
min_steps_for_episode=5,
|
||||
batch_size=50,
|
||||
)
|
||||
|
||||
assert config.min_steps_for_episode == 5
|
||||
assert config.batch_size == 50
|
||||
|
||||
|
||||
class TestConsolidationResult:
|
||||
"""Tests for ConsolidationResult."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test creating a consolidation result."""
|
||||
result = ConsolidationResult(
|
||||
source_type="episodic",
|
||||
target_type="semantic",
|
||||
items_processed=10,
|
||||
items_created=5,
|
||||
)
|
||||
|
||||
assert result.source_type == "episodic"
|
||||
assert result.target_type == "semantic"
|
||||
assert result.items_processed == 10
|
||||
assert result.items_created == 5
|
||||
assert result.items_skipped == 0
|
||||
assert result.errors == []
|
||||
|
||||
def test_to_dict(self) -> None:
|
||||
"""Test converting to dictionary."""
|
||||
result = ConsolidationResult(
|
||||
source_type="episodic",
|
||||
target_type="semantic",
|
||||
items_processed=10,
|
||||
items_created=5,
|
||||
errors=["test error"],
|
||||
)
|
||||
|
||||
d = result.to_dict()
|
||||
|
||||
assert d["source_type"] == "episodic"
|
||||
assert d["target_type"] == "semantic"
|
||||
assert d["items_processed"] == 10
|
||||
assert d["items_created"] == 5
|
||||
assert "test error" in d["errors"]
|
||||
|
||||
|
||||
class TestSessionConsolidationResult:
|
||||
"""Tests for SessionConsolidationResult."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test creating a session consolidation result."""
|
||||
result = SessionConsolidationResult(
|
||||
session_id="test-session",
|
||||
episode_created=True,
|
||||
episode_id=uuid4(),
|
||||
scratchpad_entries=5,
|
||||
)
|
||||
|
||||
assert result.session_id == "test-session"
|
||||
assert result.episode_created is True
|
||||
assert result.episode_id is not None
|
||||
|
||||
|
||||
class TestNightlyConsolidationResult:
|
||||
"""Tests for NightlyConsolidationResult."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test creating a nightly consolidation result."""
|
||||
result = NightlyConsolidationResult(
|
||||
started_at=_utcnow(),
|
||||
)
|
||||
|
||||
assert result.started_at is not None
|
||||
assert result.completed_at is None
|
||||
assert result.total_episodes_processed == 0
|
||||
|
||||
def test_to_dict(self) -> None:
|
||||
"""Test converting to dictionary."""
|
||||
result = NightlyConsolidationResult(
|
||||
started_at=_utcnow(),
|
||||
completed_at=_utcnow(),
|
||||
total_facts_created=5,
|
||||
total_procedures_created=2,
|
||||
)
|
||||
|
||||
d = result.to_dict()
|
||||
|
||||
assert "started_at" in d
|
||||
assert "completed_at" in d
|
||||
assert d["total_facts_created"] == 5
|
||||
assert d["total_procedures_created"] == 2
|
||||
|
||||
|
||||
class TestMemoryConsolidationService:
|
||||
"""Tests for MemoryConsolidationService."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
return AsyncMock()
|
||||
|
||||
@pytest.fixture
|
||||
def service(self, mock_session: AsyncMock) -> MemoryConsolidationService:
|
||||
"""Create a consolidation service with mocked dependencies."""
|
||||
return MemoryConsolidationService(
|
||||
session=mock_session,
|
||||
config=ConsolidationConfig(),
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Session Consolidation Tests
|
||||
# =========================================================================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consolidate_session_insufficient_steps(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test session not consolidated when insufficient steps."""
|
||||
mock_working_memory = AsyncMock()
|
||||
task_state = make_task_state(current_step=1) # Less than min_steps_for_episode
|
||||
mock_working_memory.get_task_state.return_value = task_state
|
||||
|
||||
result = await service.consolidate_session(
|
||||
working_memory=mock_working_memory,
|
||||
project_id=uuid4(),
|
||||
session_id="test-session",
|
||||
)
|
||||
|
||||
assert result.episode_created is False
|
||||
assert result.episode_id is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consolidate_session_no_task_state(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test session not consolidated when no task state."""
|
||||
mock_working_memory = AsyncMock()
|
||||
mock_working_memory.get_task_state.return_value = None
|
||||
|
||||
result = await service.consolidate_session(
|
||||
working_memory=mock_working_memory,
|
||||
project_id=uuid4(),
|
||||
session_id="test-session",
|
||||
)
|
||||
|
||||
assert result.episode_created is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consolidate_session_success(
|
||||
self, service: MemoryConsolidationService, mock_session: AsyncMock
|
||||
) -> None:
|
||||
"""Test successful session consolidation."""
|
||||
mock_working_memory = AsyncMock()
|
||||
task_state = make_task_state(
|
||||
current_step=5,
|
||||
progress_percent=100.0,
|
||||
status="complete",
|
||||
)
|
||||
mock_working_memory.get_task_state.return_value = task_state
|
||||
mock_working_memory.get_scratchpad.return_value = ["step1", "step2"]
|
||||
mock_working_memory.get_all.return_value = {"key1": "value1"}
|
||||
|
||||
# Mock episodic memory
|
||||
mock_episode = make_episode()
|
||||
with patch.object(
|
||||
service, "_get_episodic", new_callable=AsyncMock
|
||||
) as mock_get_episodic:
|
||||
mock_episodic = AsyncMock()
|
||||
mock_episodic.record_episode.return_value = mock_episode
|
||||
mock_get_episodic.return_value = mock_episodic
|
||||
|
||||
result = await service.consolidate_session(
|
||||
working_memory=mock_working_memory,
|
||||
project_id=uuid4(),
|
||||
session_id="test-session",
|
||||
)
|
||||
|
||||
assert result.episode_created is True
|
||||
assert result.episode_id == mock_episode.id
|
||||
assert result.scratchpad_entries == 2
|
||||
|
||||
# =========================================================================
|
||||
# Outcome Determination Tests
|
||||
# =========================================================================
|
||||
|
||||
def test_determine_session_outcome_success(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test outcome determination for successful session."""
|
||||
task_state = make_task_state(status="complete", progress_percent=100.0)
|
||||
outcome = service._determine_session_outcome(task_state)
|
||||
assert outcome == Outcome.SUCCESS
|
||||
|
||||
def test_determine_session_outcome_failure(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test outcome determination for failed session."""
|
||||
task_state = make_task_state(status="error", progress_percent=25.0)
|
||||
outcome = service._determine_session_outcome(task_state)
|
||||
assert outcome == Outcome.FAILURE
|
||||
|
||||
def test_determine_session_outcome_partial(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test outcome determination for partial session."""
|
||||
task_state = make_task_state(status="stopped", progress_percent=60.0)
|
||||
outcome = service._determine_session_outcome(task_state)
|
||||
assert outcome == Outcome.PARTIAL
|
||||
|
||||
def test_determine_session_outcome_none(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test outcome determination with no task state."""
|
||||
outcome = service._determine_session_outcome(None)
|
||||
assert outcome == Outcome.PARTIAL
|
||||
|
||||
# =========================================================================
|
||||
# Action Building Tests
|
||||
# =========================================================================
|
||||
|
||||
def test_build_actions_from_session(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test building actions from session data."""
|
||||
scratchpad = ["thought 1", "thought 2"]
|
||||
variables = {"var1": "value1"}
|
||||
task_state = make_task_state()
|
||||
|
||||
actions = service._build_actions_from_session(scratchpad, variables, task_state)
|
||||
|
||||
assert len(actions) == 3 # 2 scratchpad + 1 final state
|
||||
assert actions[0]["type"] == "reasoning"
|
||||
assert actions[2]["type"] == "final_state"
|
||||
|
||||
def test_build_context_summary(self, service: MemoryConsolidationService) -> None:
|
||||
"""Test building context summary."""
|
||||
task_state = make_task_state(
|
||||
description="Test Task",
|
||||
progress_percent=75.0,
|
||||
)
|
||||
variables = {"key": "value"}
|
||||
|
||||
summary = service._build_context_summary(task_state, variables)
|
||||
|
||||
assert "Test Task" in summary
|
||||
assert "75.0%" in summary
|
||||
|
||||
# =========================================================================
|
||||
# Importance Calculation Tests
|
||||
# =========================================================================
|
||||
|
||||
def test_calculate_session_importance_base(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test base importance calculation."""
|
||||
task_state = make_task_state(total_steps=3) # Below threshold
|
||||
importance = service._calculate_session_importance(
|
||||
task_state, Outcome.SUCCESS, []
|
||||
)
|
||||
|
||||
assert importance == 0.5 # Base score
|
||||
|
||||
def test_calculate_session_importance_failure(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test importance boost for failures."""
|
||||
task_state = make_task_state(total_steps=3) # Below threshold
|
||||
importance = service._calculate_session_importance(
|
||||
task_state, Outcome.FAILURE, []
|
||||
)
|
||||
|
||||
assert importance == 0.8 # Base (0.5) + failure boost (0.3)
|
||||
|
||||
def test_calculate_session_importance_complex(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test importance for complex session."""
|
||||
task_state = make_task_state(total_steps=10)
|
||||
actions = [{"step": i} for i in range(6)]
|
||||
importance = service._calculate_session_importance(
|
||||
task_state, Outcome.SUCCESS, actions
|
||||
)
|
||||
|
||||
# Base (0.5) + many steps (0.1) + many actions (0.1)
|
||||
assert importance == 0.7
|
||||
|
||||
# =========================================================================
|
||||
# Episode to Fact Consolidation Tests
|
||||
# =========================================================================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consolidate_episodes_to_facts_empty(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test consolidation with no episodes."""
|
||||
with patch.object(
|
||||
service, "_get_episodic", new_callable=AsyncMock
|
||||
) as mock_get_episodic:
|
||||
mock_episodic = AsyncMock()
|
||||
mock_episodic.get_recent.return_value = []
|
||||
mock_get_episodic.return_value = mock_episodic
|
||||
|
||||
result = await service.consolidate_episodes_to_facts(
|
||||
project_id=uuid4(),
|
||||
)
|
||||
|
||||
assert result.items_processed == 0
|
||||
assert result.items_created == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consolidate_episodes_to_facts_success(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test successful fact extraction."""
|
||||
episode = make_episode(
|
||||
lessons_learned=["Always check return values"],
|
||||
)
|
||||
|
||||
mock_fact = MagicMock()
|
||||
mock_fact.reinforcement_count = 1 # New fact
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
service, "_get_episodic", new_callable=AsyncMock
|
||||
) as mock_get_episodic,
|
||||
patch.object(
|
||||
service, "_get_semantic", new_callable=AsyncMock
|
||||
) as mock_get_semantic,
|
||||
):
|
||||
mock_episodic = AsyncMock()
|
||||
mock_episodic.get_recent.return_value = [episode]
|
||||
mock_get_episodic.return_value = mock_episodic
|
||||
|
||||
mock_semantic = AsyncMock()
|
||||
mock_semantic.store_fact.return_value = mock_fact
|
||||
mock_get_semantic.return_value = mock_semantic
|
||||
|
||||
result = await service.consolidate_episodes_to_facts(
|
||||
project_id=uuid4(),
|
||||
)
|
||||
|
||||
assert result.items_processed == 1
|
||||
# At least one fact should be created from lesson
|
||||
assert result.items_created >= 0
|
||||
|
||||
# =========================================================================
|
||||
# Episode to Procedure Consolidation Tests
|
||||
# =========================================================================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consolidate_episodes_to_procedures_insufficient(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test consolidation with insufficient episodes."""
|
||||
# Only 1 episode - less than min_episodes_for_procedure (3)
|
||||
episode = make_episode()
|
||||
|
||||
with patch.object(
|
||||
service, "_get_episodic", new_callable=AsyncMock
|
||||
) as mock_get_episodic:
|
||||
mock_episodic = AsyncMock()
|
||||
mock_episodic.get_by_outcome.return_value = [episode]
|
||||
mock_get_episodic.return_value = mock_episodic
|
||||
|
||||
result = await service.consolidate_episodes_to_procedures(
|
||||
project_id=uuid4(),
|
||||
)
|
||||
|
||||
assert result.items_processed == 1
|
||||
assert result.items_created == 0
|
||||
assert result.items_skipped == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consolidate_episodes_to_procedures_success(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test successful procedure creation."""
|
||||
# Create enough episodes for a procedure
|
||||
episodes = [
|
||||
make_episode(
|
||||
task_type="deploy",
|
||||
actions=[{"type": "step1"}, {"type": "step2"}, {"type": "step3"}],
|
||||
)
|
||||
for _ in range(5)
|
||||
]
|
||||
|
||||
mock_procedure = MagicMock()
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
service, "_get_episodic", new_callable=AsyncMock
|
||||
) as mock_get_episodic,
|
||||
patch.object(
|
||||
service, "_get_procedural", new_callable=AsyncMock
|
||||
) as mock_get_procedural,
|
||||
):
|
||||
mock_episodic = AsyncMock()
|
||||
mock_episodic.get_by_outcome.return_value = episodes
|
||||
mock_get_episodic.return_value = mock_episodic
|
||||
|
||||
mock_procedural = AsyncMock()
|
||||
mock_procedural.find_matching.return_value = [] # No existing procedure
|
||||
mock_procedural.record_procedure.return_value = mock_procedure
|
||||
mock_get_procedural.return_value = mock_procedural
|
||||
|
||||
result = await service.consolidate_episodes_to_procedures(
|
||||
project_id=uuid4(),
|
||||
)
|
||||
|
||||
assert result.items_processed == 5
|
||||
assert result.items_created == 1
|
||||
|
||||
# =========================================================================
|
||||
# Common Steps Extraction Tests
|
||||
# =========================================================================
|
||||
|
||||
def test_extract_common_steps(self, service: MemoryConsolidationService) -> None:
|
||||
"""Test extracting steps from episodes."""
|
||||
episodes = [
|
||||
make_episode(
|
||||
outcome=Outcome.SUCCESS,
|
||||
importance_score=0.8,
|
||||
actions=[
|
||||
{"type": "step1", "content": "First step"},
|
||||
{"type": "step2", "content": "Second step"},
|
||||
],
|
||||
),
|
||||
make_episode(
|
||||
outcome=Outcome.SUCCESS,
|
||||
importance_score=0.5,
|
||||
actions=[{"type": "simple"}],
|
||||
),
|
||||
]
|
||||
|
||||
steps = service._extract_common_steps(episodes)
|
||||
|
||||
assert len(steps) == 2
|
||||
assert steps[0]["order"] == 1
|
||||
assert steps[0]["action"] == "step1"
|
||||
|
||||
# =========================================================================
|
||||
# Pruning Tests
|
||||
# =========================================================================
|
||||
|
||||
def test_should_prune_episode_old_low_importance(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test pruning old, low-importance episode."""
|
||||
old_date = _utcnow() - timedelta(days=100)
|
||||
episode = make_episode(
|
||||
occurred_at=old_date,
|
||||
importance_score=0.1,
|
||||
outcome=Outcome.SUCCESS,
|
||||
)
|
||||
cutoff = _utcnow() - timedelta(days=90)
|
||||
|
||||
should_prune = service._should_prune_episode(episode, cutoff, 0.2)
|
||||
|
||||
assert should_prune is True
|
||||
|
||||
def test_should_prune_episode_recent(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test not pruning recent episode."""
|
||||
recent_date = _utcnow() - timedelta(days=30)
|
||||
episode = make_episode(
|
||||
occurred_at=recent_date,
|
||||
importance_score=0.1,
|
||||
)
|
||||
cutoff = _utcnow() - timedelta(days=90)
|
||||
|
||||
should_prune = service._should_prune_episode(episode, cutoff, 0.2)
|
||||
|
||||
assert should_prune is False
|
||||
|
||||
def test_should_prune_episode_failure_protected(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test not pruning failure (with keep_all_failures=True)."""
|
||||
old_date = _utcnow() - timedelta(days=100)
|
||||
episode = make_episode(
|
||||
occurred_at=old_date,
|
||||
importance_score=0.1,
|
||||
outcome=Outcome.FAILURE,
|
||||
)
|
||||
cutoff = _utcnow() - timedelta(days=90)
|
||||
|
||||
should_prune = service._should_prune_episode(episode, cutoff, 0.2)
|
||||
|
||||
# Config has keep_all_failures=True by default
|
||||
assert should_prune is False
|
||||
|
||||
def test_should_prune_episode_with_lessons_protected(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test not pruning episode with lessons."""
|
||||
old_date = _utcnow() - timedelta(days=100)
|
||||
episode = make_episode(
|
||||
occurred_at=old_date,
|
||||
importance_score=0.1,
|
||||
lessons_learned=["Important lesson"],
|
||||
)
|
||||
cutoff = _utcnow() - timedelta(days=90)
|
||||
|
||||
should_prune = service._should_prune_episode(episode, cutoff, 0.2)
|
||||
|
||||
# Config has keep_all_with_lessons=True by default
|
||||
assert should_prune is False
|
||||
|
||||
def test_should_prune_episode_high_importance_protected(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test not pruning high importance episode."""
|
||||
old_date = _utcnow() - timedelta(days=100)
|
||||
episode = make_episode(
|
||||
occurred_at=old_date,
|
||||
importance_score=0.8,
|
||||
)
|
||||
cutoff = _utcnow() - timedelta(days=90)
|
||||
|
||||
should_prune = service._should_prune_episode(episode, cutoff, 0.2)
|
||||
|
||||
assert should_prune is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prune_old_episodes(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test episode pruning."""
|
||||
old_episode = make_episode(
|
||||
occurred_at=_utcnow() - timedelta(days=100),
|
||||
importance_score=0.1,
|
||||
outcome=Outcome.SUCCESS,
|
||||
lessons_learned=[],
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
service, "_get_episodic", new_callable=AsyncMock
|
||||
) as mock_get_episodic:
|
||||
mock_episodic = AsyncMock()
|
||||
mock_episodic.get_recent.return_value = [old_episode]
|
||||
mock_episodic.delete.return_value = True
|
||||
mock_get_episodic.return_value = mock_episodic
|
||||
|
||||
result = await service.prune_old_episodes(project_id=uuid4())
|
||||
|
||||
assert result.items_processed == 1
|
||||
assert result.items_pruned == 1
|
||||
|
||||
# =========================================================================
|
||||
# Nightly Consolidation Tests
|
||||
# =========================================================================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_nightly_consolidation(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test nightly consolidation workflow."""
|
||||
with (
|
||||
patch.object(
|
||||
service,
|
||||
"consolidate_episodes_to_facts",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_facts,
|
||||
patch.object(
|
||||
service,
|
||||
"consolidate_episodes_to_procedures",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_procedures,
|
||||
patch.object(
|
||||
service,
|
||||
"prune_old_episodes",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_prune,
|
||||
):
|
||||
mock_facts.return_value = ConsolidationResult(
|
||||
source_type="episodic",
|
||||
target_type="semantic",
|
||||
items_processed=10,
|
||||
items_created=5,
|
||||
)
|
||||
mock_procedures.return_value = ConsolidationResult(
|
||||
source_type="episodic",
|
||||
target_type="procedural",
|
||||
items_processed=10,
|
||||
items_created=2,
|
||||
)
|
||||
mock_prune.return_value = ConsolidationResult(
|
||||
source_type="episodic",
|
||||
target_type="pruned",
|
||||
items_pruned=3,
|
||||
)
|
||||
|
||||
result = await service.run_nightly_consolidation(project_id=uuid4())
|
||||
|
||||
assert result.completed_at is not None
|
||||
assert result.total_facts_created == 5
|
||||
assert result.total_procedures_created == 2
|
||||
assert result.total_pruned == 3
|
||||
assert result.total_episodes_processed == 20
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_nightly_consolidation_with_errors(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test nightly consolidation handles errors."""
|
||||
with (
|
||||
patch.object(
|
||||
service,
|
||||
"consolidate_episodes_to_facts",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_facts,
|
||||
patch.object(
|
||||
service,
|
||||
"consolidate_episodes_to_procedures",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_procedures,
|
||||
patch.object(
|
||||
service,
|
||||
"prune_old_episodes",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_prune,
|
||||
):
|
||||
mock_facts.return_value = ConsolidationResult(
|
||||
source_type="episodic",
|
||||
target_type="semantic",
|
||||
errors=["fact error"],
|
||||
)
|
||||
mock_procedures.return_value = ConsolidationResult(
|
||||
source_type="episodic",
|
||||
target_type="procedural",
|
||||
)
|
||||
mock_prune.return_value = ConsolidationResult(
|
||||
source_type="episodic",
|
||||
target_type="pruned",
|
||||
)
|
||||
|
||||
result = await service.run_nightly_consolidation(project_id=uuid4())
|
||||
|
||||
assert "fact error" in result.errors
|
||||
2
backend/tests/unit/services/memory/episodic/__init__.py
Normal file
2
backend/tests/unit/services/memory/episodic/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# tests/unit/services/memory/episodic/__init__.py
|
||||
"""Unit tests for episodic memory service."""
|
||||
359
backend/tests/unit/services/memory/episodic/test_memory.py
Normal file
359
backend/tests/unit/services/memory/episodic/test_memory.py
Normal file
@@ -0,0 +1,359 @@
|
||||
# tests/unit/services/memory/episodic/test_memory.py
|
||||
"""Unit tests for EpisodicMemory class."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.memory.episodic.memory import EpisodicMemory
|
||||
from app.services.memory.episodic.retrieval import RetrievalStrategy
|
||||
from app.services.memory.types import EpisodeCreate, Outcome, RetrievalResult
|
||||
|
||||
|
||||
class TestEpisodicMemoryInit:
|
||||
"""Tests for EpisodicMemory initialization."""
|
||||
|
||||
def test_init_creates_recorder_and_retriever(self) -> None:
|
||||
"""Test that init creates recorder and retriever."""
|
||||
mock_session = AsyncMock()
|
||||
memory = EpisodicMemory(session=mock_session)
|
||||
|
||||
assert memory._recorder is not None
|
||||
assert memory._retriever is not None
|
||||
assert memory._session is mock_session
|
||||
|
||||
def test_init_with_embedding_generator(self) -> None:
|
||||
"""Test init with embedding generator."""
|
||||
mock_session = AsyncMock()
|
||||
mock_embedding_gen = AsyncMock()
|
||||
memory = EpisodicMemory(
|
||||
session=mock_session, embedding_generator=mock_embedding_gen
|
||||
)
|
||||
|
||||
assert memory._embedding_generator is mock_embedding_gen
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_factory_method(self) -> None:
|
||||
"""Test create factory method."""
|
||||
mock_session = AsyncMock()
|
||||
memory = await EpisodicMemory.create(session=mock_session)
|
||||
|
||||
assert memory is not None
|
||||
assert memory._session is mock_session
|
||||
|
||||
|
||||
class TestEpisodicMemoryRecording:
|
||||
"""Tests for episode recording methods."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
session = AsyncMock()
|
||||
session.add = MagicMock()
|
||||
session.flush = AsyncMock()
|
||||
session.refresh = AsyncMock()
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def memory(self, mock_session: AsyncMock) -> EpisodicMemory:
|
||||
"""Create an EpisodicMemory instance."""
|
||||
return EpisodicMemory(session=mock_session)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_episode(
|
||||
self,
|
||||
memory: EpisodicMemory,
|
||||
) -> None:
|
||||
"""Test recording an episode."""
|
||||
episode_data = EpisodeCreate(
|
||||
project_id=uuid4(),
|
||||
session_id="test-session",
|
||||
task_type="test_task",
|
||||
task_description="Test description",
|
||||
actions=[{"action": "test"}],
|
||||
context_summary="Test context",
|
||||
outcome=Outcome.SUCCESS,
|
||||
outcome_details="Success",
|
||||
duration_seconds=30.0,
|
||||
tokens_used=100,
|
||||
)
|
||||
|
||||
result = await memory.record_episode(episode_data)
|
||||
|
||||
assert result.project_id == episode_data.project_id
|
||||
assert result.task_type == "test_task"
|
||||
assert result.outcome == Outcome.SUCCESS
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_success(
|
||||
self,
|
||||
memory: EpisodicMemory,
|
||||
) -> None:
|
||||
"""Test convenience method for recording success."""
|
||||
project_id = uuid4()
|
||||
result = await memory.record_success(
|
||||
project_id=project_id,
|
||||
session_id="test-session",
|
||||
task_type="deployment",
|
||||
task_description="Deploy to production",
|
||||
actions=[{"step": "deploy"}],
|
||||
context_summary="Deploying v1.0",
|
||||
outcome_details="Deployed successfully",
|
||||
duration_seconds=60.0,
|
||||
tokens_used=200,
|
||||
)
|
||||
|
||||
assert result.outcome == Outcome.SUCCESS
|
||||
assert result.task_type == "deployment"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_failure(
|
||||
self,
|
||||
memory: EpisodicMemory,
|
||||
) -> None:
|
||||
"""Test convenience method for recording failure."""
|
||||
project_id = uuid4()
|
||||
result = await memory.record_failure(
|
||||
project_id=project_id,
|
||||
session_id="test-session",
|
||||
task_type="deployment",
|
||||
task_description="Deploy to production",
|
||||
actions=[{"step": "deploy"}],
|
||||
context_summary="Deploying v1.0",
|
||||
error_details="Connection timeout",
|
||||
duration_seconds=30.0,
|
||||
tokens_used=100,
|
||||
)
|
||||
|
||||
assert result.outcome == Outcome.FAILURE
|
||||
assert result.outcome_details == "Connection timeout"
|
||||
|
||||
|
||||
class TestEpisodicMemoryRetrieval:
|
||||
"""Tests for episode retrieval methods."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
session = AsyncMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = []
|
||||
session.execute.return_value = mock_result
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def memory(self, mock_session: AsyncMock) -> EpisodicMemory:
|
||||
"""Create an EpisodicMemory instance."""
|
||||
return EpisodicMemory(session=mock_session)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_similar(
|
||||
self,
|
||||
memory: EpisodicMemory,
|
||||
) -> None:
|
||||
"""Test semantic search."""
|
||||
project_id = uuid4()
|
||||
results = await memory.search_similar(project_id, "authentication bug")
|
||||
|
||||
assert isinstance(results, list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_recent(
|
||||
self,
|
||||
memory: EpisodicMemory,
|
||||
) -> None:
|
||||
"""Test getting recent episodes."""
|
||||
project_id = uuid4()
|
||||
results = await memory.get_recent(project_id, limit=5)
|
||||
|
||||
assert isinstance(results, list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_outcome(
|
||||
self,
|
||||
memory: EpisodicMemory,
|
||||
) -> None:
|
||||
"""Test getting episodes by outcome."""
|
||||
project_id = uuid4()
|
||||
results = await memory.get_by_outcome(project_id, Outcome.FAILURE, limit=5)
|
||||
|
||||
assert isinstance(results, list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_task_type(
|
||||
self,
|
||||
memory: EpisodicMemory,
|
||||
) -> None:
|
||||
"""Test getting episodes by task type."""
|
||||
project_id = uuid4()
|
||||
results = await memory.get_by_task_type(project_id, "code_review", limit=5)
|
||||
|
||||
assert isinstance(results, list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_important(
|
||||
self,
|
||||
memory: EpisodicMemory,
|
||||
) -> None:
|
||||
"""Test getting important episodes."""
|
||||
project_id = uuid4()
|
||||
results = await memory.get_important(project_id, limit=5, min_importance=0.8)
|
||||
|
||||
assert isinstance(results, list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_with_full_result(
|
||||
self,
|
||||
memory: EpisodicMemory,
|
||||
) -> None:
|
||||
"""Test retrieve with full result metadata."""
|
||||
project_id = uuid4()
|
||||
result = await memory.retrieve(project_id, RetrievalStrategy.RECENCY, limit=10)
|
||||
|
||||
assert isinstance(result, RetrievalResult)
|
||||
assert result.retrieval_type == "recency"
|
||||
|
||||
|
||||
class TestEpisodicMemorySummarization:
|
||||
"""Tests for episode summarization."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
session = AsyncMock()
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def memory(self, mock_session: AsyncMock) -> EpisodicMemory:
|
||||
"""Create an EpisodicMemory instance."""
|
||||
return EpisodicMemory(session=mock_session)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summarize_empty_list(
|
||||
self,
|
||||
memory: EpisodicMemory,
|
||||
) -> None:
|
||||
"""Test summarizing empty episode list."""
|
||||
summary = await memory.summarize_episodes([])
|
||||
assert "No episodes to summarize" in summary
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summarize_not_found(
|
||||
self,
|
||||
memory: EpisodicMemory,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test summarizing when episodes not found."""
|
||||
# Mock get_by_id to return None
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = None
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
summary = await memory.summarize_episodes([uuid4(), uuid4()])
|
||||
assert "No episodes found" in summary
|
||||
|
||||
|
||||
class TestEpisodicMemoryStats:
|
||||
"""Tests for episode statistics."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
session = AsyncMock()
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def memory(self, mock_session: AsyncMock) -> EpisodicMemory:
|
||||
"""Create an EpisodicMemory instance."""
|
||||
return EpisodicMemory(session=mock_session)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_stats(
|
||||
self,
|
||||
memory: EpisodicMemory,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test getting episode statistics."""
|
||||
# Mock empty result
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = []
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
stats = await memory.get_stats(uuid4())
|
||||
|
||||
assert "total_count" in stats
|
||||
assert "success_count" in stats
|
||||
assert "failure_count" in stats
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_count(
|
||||
self,
|
||||
memory: EpisodicMemory,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test counting episodes."""
|
||||
# Mock result with 3 episodes
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = [1, 2, 3]
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
count = await memory.count(uuid4())
|
||||
assert count == 3
|
||||
|
||||
|
||||
class TestEpisodicMemoryModification:
|
||||
"""Tests for episode modification methods."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
session = AsyncMock()
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def memory(self, mock_session: AsyncMock) -> EpisodicMemory:
|
||||
"""Create an EpisodicMemory instance."""
|
||||
return EpisodicMemory(session=mock_session)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_id_not_found(
|
||||
self,
|
||||
memory: EpisodicMemory,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test get_by_id returns None when not found."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = None
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = await memory.get_by_id(uuid4())
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_importance_not_found(
|
||||
self,
|
||||
memory: EpisodicMemory,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test update_importance returns None when not found."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = None
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = await memory.update_importance(uuid4(), 0.9)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_not_found(
|
||||
self,
|
||||
memory: EpisodicMemory,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test delete returns False when not found."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = None
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = await memory.delete(uuid4())
|
||||
assert result is False
|
||||
348
backend/tests/unit/services/memory/episodic/test_recorder.py
Normal file
348
backend/tests/unit/services/memory/episodic/test_recorder.py
Normal file
@@ -0,0 +1,348 @@
|
||||
# tests/unit/services/memory/episodic/test_recorder.py
|
||||
"""Unit tests for EpisodeRecorder."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.models.memory.enums import EpisodeOutcome
|
||||
from app.services.memory.episodic.recorder import EpisodeRecorder, _outcome_to_db
|
||||
from app.services.memory.types import EpisodeCreate, Outcome
|
||||
|
||||
|
||||
class TestOutcomeConversion:
|
||||
"""Tests for outcome conversion functions."""
|
||||
|
||||
def test_outcome_to_db_success(self) -> None:
|
||||
"""Test converting success outcome."""
|
||||
result = _outcome_to_db(Outcome.SUCCESS)
|
||||
assert result == EpisodeOutcome.SUCCESS
|
||||
|
||||
def test_outcome_to_db_failure(self) -> None:
|
||||
"""Test converting failure outcome."""
|
||||
result = _outcome_to_db(Outcome.FAILURE)
|
||||
assert result == EpisodeOutcome.FAILURE
|
||||
|
||||
def test_outcome_to_db_partial(self) -> None:
|
||||
"""Test converting partial outcome."""
|
||||
result = _outcome_to_db(Outcome.PARTIAL)
|
||||
assert result == EpisodeOutcome.PARTIAL
|
||||
|
||||
|
||||
class TestEpisodeRecorderImportanceCalculation:
|
||||
"""Tests for importance score calculation."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
session = AsyncMock()
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def recorder(self, mock_session: AsyncMock) -> EpisodeRecorder:
|
||||
"""Create a recorder with mocked session."""
|
||||
return EpisodeRecorder(session=mock_session)
|
||||
|
||||
def test_calculate_importance_success_default(
|
||||
self, recorder: EpisodeRecorder
|
||||
) -> None:
|
||||
"""Test importance for successful episode (default)."""
|
||||
episode = EpisodeCreate(
|
||||
project_id=uuid4(),
|
||||
session_id="test-session",
|
||||
task_type="test",
|
||||
task_description="Test task",
|
||||
actions=[],
|
||||
context_summary="Context",
|
||||
outcome=Outcome.SUCCESS,
|
||||
outcome_details="",
|
||||
duration_seconds=10.0,
|
||||
tokens_used=100,
|
||||
)
|
||||
score = recorder._calculate_importance(episode)
|
||||
assert 0.0 <= score <= 1.0
|
||||
assert score == 0.5 # Base score for success
|
||||
|
||||
def test_calculate_importance_failure_higher(
|
||||
self, recorder: EpisodeRecorder
|
||||
) -> None:
|
||||
"""Test that failures get higher importance."""
|
||||
episode = EpisodeCreate(
|
||||
project_id=uuid4(),
|
||||
session_id="test-session",
|
||||
task_type="test",
|
||||
task_description="Test task",
|
||||
actions=[],
|
||||
context_summary="Context",
|
||||
outcome=Outcome.FAILURE,
|
||||
outcome_details="Error occurred",
|
||||
duration_seconds=10.0,
|
||||
tokens_used=100,
|
||||
)
|
||||
score = recorder._calculate_importance(episode)
|
||||
assert score >= 0.7 # Failure adds 0.2 to base 0.5
|
||||
|
||||
def test_calculate_importance_with_lessons(self, recorder: EpisodeRecorder) -> None:
|
||||
"""Test that lessons increase importance."""
|
||||
episode = EpisodeCreate(
|
||||
project_id=uuid4(),
|
||||
session_id="test-session",
|
||||
task_type="test",
|
||||
task_description="Test task",
|
||||
actions=[],
|
||||
context_summary="Context",
|
||||
outcome=Outcome.SUCCESS,
|
||||
outcome_details="",
|
||||
duration_seconds=10.0,
|
||||
tokens_used=100,
|
||||
lessons_learned=["Lesson 1", "Lesson 2"],
|
||||
)
|
||||
score = recorder._calculate_importance(episode)
|
||||
assert score > 0.5 # Lessons add to importance
|
||||
|
||||
def test_calculate_importance_long_duration(
|
||||
self, recorder: EpisodeRecorder
|
||||
) -> None:
|
||||
"""Test that longer tasks get higher importance."""
|
||||
episode = EpisodeCreate(
|
||||
project_id=uuid4(),
|
||||
session_id="test-session",
|
||||
task_type="test",
|
||||
task_description="Test task",
|
||||
actions=[],
|
||||
context_summary="Context",
|
||||
outcome=Outcome.SUCCESS,
|
||||
outcome_details="",
|
||||
duration_seconds=400.0, # > 300 seconds
|
||||
tokens_used=100,
|
||||
)
|
||||
score = recorder._calculate_importance(episode)
|
||||
assert score > 0.5 # Long duration adds to importance
|
||||
|
||||
def test_calculate_importance_clamped_to_max(
|
||||
self, recorder: EpisodeRecorder
|
||||
) -> None:
|
||||
"""Test that importance is clamped to 1.0 max."""
|
||||
episode = EpisodeCreate(
|
||||
project_id=uuid4(),
|
||||
session_id="test-session",
|
||||
task_type="test",
|
||||
task_description="Test task",
|
||||
actions=[],
|
||||
context_summary="Context",
|
||||
outcome=Outcome.FAILURE, # +0.2
|
||||
outcome_details="Error",
|
||||
duration_seconds=400.0, # +0.1
|
||||
tokens_used=2000, # +0.05
|
||||
lessons_learned=["L1", "L2", "L3", "L4", "L5"], # +0.15
|
||||
)
|
||||
score = recorder._calculate_importance(episode)
|
||||
assert score <= 1.0
|
||||
|
||||
|
||||
class TestEpisodeRecorderEmbeddingText:
|
||||
"""Tests for embedding text generation."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
return AsyncMock()
|
||||
|
||||
@pytest.fixture
|
||||
def recorder(self, mock_session: AsyncMock) -> EpisodeRecorder:
|
||||
"""Create a recorder with mocked session."""
|
||||
return EpisodeRecorder(session=mock_session)
|
||||
|
||||
def test_create_embedding_text_basic(self, recorder: EpisodeRecorder) -> None:
|
||||
"""Test basic embedding text creation."""
|
||||
episode = EpisodeCreate(
|
||||
project_id=uuid4(),
|
||||
session_id="test-session",
|
||||
task_type="code_review",
|
||||
task_description="Review PR #123",
|
||||
actions=[],
|
||||
context_summary="Reviewing authentication changes",
|
||||
outcome=Outcome.SUCCESS,
|
||||
outcome_details="",
|
||||
duration_seconds=60.0,
|
||||
tokens_used=500,
|
||||
)
|
||||
text = recorder._create_embedding_text(episode)
|
||||
assert "code_review" in text
|
||||
assert "Review PR #123" in text
|
||||
assert "authentication" in text
|
||||
assert "success" in text
|
||||
|
||||
def test_create_embedding_text_with_details(
|
||||
self, recorder: EpisodeRecorder
|
||||
) -> None:
|
||||
"""Test embedding text includes outcome details."""
|
||||
episode = EpisodeCreate(
|
||||
project_id=uuid4(),
|
||||
session_id="test-session",
|
||||
task_type="deployment",
|
||||
task_description="Deploy to production",
|
||||
actions=[],
|
||||
context_summary="Deploying v1.0.0",
|
||||
outcome=Outcome.FAILURE,
|
||||
outcome_details="Connection timeout to server",
|
||||
duration_seconds=30.0,
|
||||
tokens_used=200,
|
||||
)
|
||||
text = recorder._create_embedding_text(episode)
|
||||
assert "Connection timeout" in text
|
||||
|
||||
def test_create_embedding_text_with_lessons(
|
||||
self, recorder: EpisodeRecorder
|
||||
) -> None:
|
||||
"""Test embedding text includes lessons learned."""
|
||||
episode = EpisodeCreate(
|
||||
project_id=uuid4(),
|
||||
session_id="test-session",
|
||||
task_type="debugging",
|
||||
task_description="Fix memory leak",
|
||||
actions=[],
|
||||
context_summary="Debugging memory issues",
|
||||
outcome=Outcome.SUCCESS,
|
||||
outcome_details="",
|
||||
duration_seconds=120.0,
|
||||
tokens_used=800,
|
||||
lessons_learned=["Always close file handles", "Use context managers"],
|
||||
)
|
||||
text = recorder._create_embedding_text(episode)
|
||||
assert "Always close file handles" in text
|
||||
assert "context managers" in text
|
||||
|
||||
|
||||
class TestEpisodeRecorderRecord:
|
||||
"""Tests for episode recording."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
session = AsyncMock()
|
||||
session.add = MagicMock()
|
||||
session.flush = AsyncMock()
|
||||
session.refresh = AsyncMock()
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def recorder(self, mock_session: AsyncMock) -> EpisodeRecorder:
|
||||
"""Create a recorder with mocked session."""
|
||||
return EpisodeRecorder(session=mock_session)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_creates_episode(
|
||||
self,
|
||||
recorder: EpisodeRecorder,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test that record creates an episode."""
|
||||
episode_data = EpisodeCreate(
|
||||
project_id=uuid4(),
|
||||
session_id="test-session",
|
||||
task_type="test_task",
|
||||
task_description="Test description",
|
||||
actions=[{"action": "test"}],
|
||||
context_summary="Test context",
|
||||
outcome=Outcome.SUCCESS,
|
||||
outcome_details="Success",
|
||||
duration_seconds=30.0,
|
||||
tokens_used=100,
|
||||
)
|
||||
|
||||
result = await recorder.record(episode_data)
|
||||
|
||||
# Verify session methods were called
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.flush.assert_called_once()
|
||||
mock_session.refresh.assert_called_once()
|
||||
|
||||
# Verify result
|
||||
assert result.project_id == episode_data.project_id
|
||||
assert result.session_id == episode_data.session_id
|
||||
assert result.task_type == episode_data.task_type
|
||||
assert result.outcome == Outcome.SUCCESS
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_with_embedding_generator(
|
||||
self,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test recording with embedding generator."""
|
||||
mock_embedding_gen = AsyncMock()
|
||||
mock_embedding_gen.generate = AsyncMock(return_value=[0.1] * 1536)
|
||||
|
||||
recorder = EpisodeRecorder(
|
||||
session=mock_session, embedding_generator=mock_embedding_gen
|
||||
)
|
||||
|
||||
episode_data = EpisodeCreate(
|
||||
project_id=uuid4(),
|
||||
session_id="test-session",
|
||||
task_type="test_task",
|
||||
task_description="Test description",
|
||||
actions=[],
|
||||
context_summary="Test context",
|
||||
outcome=Outcome.SUCCESS,
|
||||
outcome_details="",
|
||||
duration_seconds=10.0,
|
||||
tokens_used=50,
|
||||
)
|
||||
|
||||
await recorder.record(episode_data)
|
||||
|
||||
# Verify embedding generator was called
|
||||
mock_embedding_gen.generate.assert_called_once()
|
||||
|
||||
|
||||
class TestEpisodeRecorderStats:
|
||||
"""Tests for episode statistics."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
session = AsyncMock()
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def recorder(self, mock_session: AsyncMock) -> EpisodeRecorder:
|
||||
"""Create a recorder with mocked session."""
|
||||
return EpisodeRecorder(session=mock_session)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_stats_empty(
|
||||
self,
|
||||
recorder: EpisodeRecorder,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test stats for project with no episodes."""
|
||||
# Mock empty result
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = []
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
stats = await recorder.get_stats(uuid4())
|
||||
|
||||
assert stats["total_count"] == 0
|
||||
assert stats["success_count"] == 0
|
||||
assert stats["failure_count"] == 0
|
||||
assert stats["partial_count"] == 0
|
||||
assert stats["avg_importance"] == 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_count_by_project(
|
||||
self,
|
||||
recorder: EpisodeRecorder,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test counting episodes by project."""
|
||||
# Mock result with 5 episodes
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = [1, 2, 3, 4, 5]
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
count = await recorder.count_by_project(uuid4())
|
||||
|
||||
assert count == 5
|
||||
400
backend/tests/unit/services/memory/episodic/test_retrieval.py
Normal file
400
backend/tests/unit/services/memory/episodic/test_retrieval.py
Normal file
@@ -0,0 +1,400 @@
|
||||
# tests/unit/services/memory/episodic/test_retrieval.py
|
||||
"""Unit tests for episode retrieval strategies."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.models.memory.enums import EpisodeOutcome
|
||||
from app.services.memory.episodic.retrieval import (
|
||||
EpisodeRetriever,
|
||||
ImportanceRetriever,
|
||||
OutcomeRetriever,
|
||||
RecencyRetriever,
|
||||
RetrievalStrategy,
|
||||
SemanticRetriever,
|
||||
TaskTypeRetriever,
|
||||
)
|
||||
from app.services.memory.types import Outcome
|
||||
|
||||
|
||||
def create_mock_episode_model(
|
||||
project_id=None,
|
||||
outcome=EpisodeOutcome.SUCCESS,
|
||||
task_type="test_task",
|
||||
importance_score=0.5,
|
||||
occurred_at=None,
|
||||
):
|
||||
"""Create a mock episode model for testing."""
|
||||
mock = MagicMock()
|
||||
mock.id = uuid4()
|
||||
mock.project_id = project_id or uuid4()
|
||||
mock.agent_instance_id = None
|
||||
mock.agent_type_id = None
|
||||
mock.session_id = "test-session"
|
||||
mock.task_type = task_type
|
||||
mock.task_description = "Test description"
|
||||
mock.actions = []
|
||||
mock.context_summary = "Test context"
|
||||
mock.outcome = outcome
|
||||
mock.outcome_details = ""
|
||||
mock.duration_seconds = 30.0
|
||||
mock.tokens_used = 100
|
||||
mock.lessons_learned = []
|
||||
mock.importance_score = importance_score
|
||||
mock.embedding = None
|
||||
mock.occurred_at = occurred_at or datetime.now(UTC)
|
||||
mock.created_at = datetime.now(UTC)
|
||||
mock.updated_at = datetime.now(UTC)
|
||||
return mock
|
||||
|
||||
|
||||
class TestRetrievalStrategy:
|
||||
"""Tests for RetrievalStrategy enum."""
|
||||
|
||||
def test_strategy_values(self) -> None:
|
||||
"""Test that strategy enum has expected values."""
|
||||
assert RetrievalStrategy.SEMANTIC == "semantic"
|
||||
assert RetrievalStrategy.RECENCY == "recency"
|
||||
assert RetrievalStrategy.OUTCOME == "outcome"
|
||||
assert RetrievalStrategy.IMPORTANCE == "importance"
|
||||
assert RetrievalStrategy.HYBRID == "hybrid"
|
||||
|
||||
|
||||
class TestRecencyRetriever:
|
||||
"""Tests for RecencyRetriever."""
|
||||
|
||||
@pytest.fixture
|
||||
def retriever(self) -> RecencyRetriever:
|
||||
"""Create a recency retriever."""
|
||||
return RecencyRetriever()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
return AsyncMock()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_returns_episodes(
|
||||
self,
|
||||
retriever: RecencyRetriever,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test that retrieve returns episodes."""
|
||||
project_id = uuid4()
|
||||
mock_episodes = [
|
||||
create_mock_episode_model(project_id=project_id),
|
||||
create_mock_episode_model(project_id=project_id),
|
||||
]
|
||||
|
||||
# Mock query result
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = mock_episodes
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = await retriever.retrieve(mock_session, project_id, limit=10)
|
||||
|
||||
assert len(result.items) == 2
|
||||
assert result.retrieval_type == "recency"
|
||||
assert result.latency_ms >= 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_with_since_filter(
|
||||
self,
|
||||
retriever: RecencyRetriever,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test retrieve with since time filter."""
|
||||
project_id = uuid4()
|
||||
since = datetime.now(UTC) - timedelta(hours=1)
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = []
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = await retriever.retrieve(
|
||||
mock_session, project_id, limit=10, since=since
|
||||
)
|
||||
|
||||
assert result.metadata["since"] == since.isoformat()
|
||||
|
||||
|
||||
class TestOutcomeRetriever:
|
||||
"""Tests for OutcomeRetriever."""
|
||||
|
||||
@pytest.fixture
|
||||
def retriever(self) -> OutcomeRetriever:
|
||||
"""Create an outcome retriever."""
|
||||
return OutcomeRetriever()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
return AsyncMock()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_by_success(
|
||||
self,
|
||||
retriever: OutcomeRetriever,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test retrieving successful episodes."""
|
||||
project_id = uuid4()
|
||||
mock_episodes = [
|
||||
create_mock_episode_model(
|
||||
project_id=project_id, outcome=EpisodeOutcome.SUCCESS
|
||||
),
|
||||
]
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = mock_episodes
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = await retriever.retrieve(
|
||||
mock_session, project_id, limit=10, outcome=Outcome.SUCCESS
|
||||
)
|
||||
|
||||
assert result.retrieval_type == "outcome"
|
||||
assert result.metadata["outcome"] == "success"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_by_failure(
|
||||
self,
|
||||
retriever: OutcomeRetriever,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test retrieving failed episodes."""
|
||||
project_id = uuid4()
|
||||
mock_episodes = [
|
||||
create_mock_episode_model(
|
||||
project_id=project_id, outcome=EpisodeOutcome.FAILURE
|
||||
),
|
||||
]
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = mock_episodes
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = await retriever.retrieve(
|
||||
mock_session, project_id, limit=10, outcome=Outcome.FAILURE
|
||||
)
|
||||
|
||||
assert result.metadata["outcome"] == "failure"
|
||||
|
||||
|
||||
class TestImportanceRetriever:
|
||||
"""Tests for ImportanceRetriever."""
|
||||
|
||||
@pytest.fixture
|
||||
def retriever(self) -> ImportanceRetriever:
|
||||
"""Create an importance retriever."""
|
||||
return ImportanceRetriever()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
return AsyncMock()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_by_importance(
|
||||
self,
|
||||
retriever: ImportanceRetriever,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test retrieving by importance score."""
|
||||
project_id = uuid4()
|
||||
mock_episodes = [
|
||||
create_mock_episode_model(project_id=project_id, importance_score=0.9),
|
||||
create_mock_episode_model(project_id=project_id, importance_score=0.8),
|
||||
]
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = mock_episodes
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = await retriever.retrieve(
|
||||
mock_session, project_id, limit=10, min_importance=0.7
|
||||
)
|
||||
|
||||
assert result.retrieval_type == "importance"
|
||||
assert result.metadata["min_importance"] == 0.7
|
||||
|
||||
|
||||
class TestTaskTypeRetriever:
|
||||
"""Tests for TaskTypeRetriever."""
|
||||
|
||||
@pytest.fixture
|
||||
def retriever(self) -> TaskTypeRetriever:
|
||||
"""Create a task type retriever."""
|
||||
return TaskTypeRetriever()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
return AsyncMock()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_by_task_type(
|
||||
self,
|
||||
retriever: TaskTypeRetriever,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test retrieving by task type."""
|
||||
project_id = uuid4()
|
||||
mock_episodes = [
|
||||
create_mock_episode_model(project_id=project_id, task_type="code_review"),
|
||||
]
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = mock_episodes
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = await retriever.retrieve(
|
||||
mock_session, project_id, limit=10, task_type="code_review"
|
||||
)
|
||||
|
||||
assert result.metadata["task_type"] == "code_review"
|
||||
|
||||
|
||||
class TestSemanticRetriever:
|
||||
"""Tests for SemanticRetriever."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
return AsyncMock()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_falls_back_without_query(
|
||||
self,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test that semantic search falls back to recency without query."""
|
||||
retriever = SemanticRetriever()
|
||||
project_id = uuid4()
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = []
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = await retriever.retrieve(mock_session, project_id, limit=10)
|
||||
|
||||
# Should fall back to recency
|
||||
assert result.retrieval_type == "semantic"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_with_embedding_generator(
|
||||
self,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test semantic retrieval with embedding generator."""
|
||||
mock_embedding_gen = AsyncMock()
|
||||
mock_embedding_gen.generate = AsyncMock(return_value=[0.1] * 1536)
|
||||
|
||||
retriever = SemanticRetriever(embedding_generator=mock_embedding_gen)
|
||||
project_id = uuid4()
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = []
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = await retriever.retrieve(
|
||||
mock_session, project_id, limit=10, query_text="test query"
|
||||
)
|
||||
|
||||
assert result.retrieval_type == "semantic"
|
||||
|
||||
|
||||
class TestEpisodeRetriever:
|
||||
"""Tests for unified EpisodeRetriever."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
session = AsyncMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = []
|
||||
session.execute.return_value = mock_result
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def retriever(self, mock_session: AsyncMock) -> EpisodeRetriever:
|
||||
"""Create an episode retriever."""
|
||||
return EpisodeRetriever(session=mock_session)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_with_recency_strategy(
|
||||
self,
|
||||
retriever: EpisodeRetriever,
|
||||
) -> None:
|
||||
"""Test retrieve with recency strategy."""
|
||||
project_id = uuid4()
|
||||
result = await retriever.retrieve(
|
||||
project_id, RetrievalStrategy.RECENCY, limit=10
|
||||
)
|
||||
assert result.retrieval_type == "recency"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_with_outcome_strategy(
|
||||
self,
|
||||
retriever: EpisodeRetriever,
|
||||
) -> None:
|
||||
"""Test retrieve with outcome strategy."""
|
||||
project_id = uuid4()
|
||||
result = await retriever.retrieve(
|
||||
project_id, RetrievalStrategy.OUTCOME, limit=10
|
||||
)
|
||||
assert result.retrieval_type == "outcome"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_recent_convenience_method(
|
||||
self,
|
||||
retriever: EpisodeRetriever,
|
||||
) -> None:
|
||||
"""Test get_recent convenience method."""
|
||||
project_id = uuid4()
|
||||
result = await retriever.get_recent(project_id, limit=5)
|
||||
assert result.retrieval_type == "recency"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_outcome_convenience_method(
|
||||
self,
|
||||
retriever: EpisodeRetriever,
|
||||
) -> None:
|
||||
"""Test get_by_outcome convenience method."""
|
||||
project_id = uuid4()
|
||||
result = await retriever.get_by_outcome(project_id, Outcome.SUCCESS, limit=5)
|
||||
assert result.retrieval_type == "outcome"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_important_convenience_method(
|
||||
self,
|
||||
retriever: EpisodeRetriever,
|
||||
) -> None:
|
||||
"""Test get_important convenience method."""
|
||||
project_id = uuid4()
|
||||
result = await retriever.get_important(project_id, limit=5, min_importance=0.8)
|
||||
assert result.retrieval_type == "importance"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_similar_convenience_method(
|
||||
self,
|
||||
retriever: EpisodeRetriever,
|
||||
) -> None:
|
||||
"""Test search_similar convenience method."""
|
||||
project_id = uuid4()
|
||||
result = await retriever.search_similar(project_id, "test query", limit=5)
|
||||
assert result.retrieval_type == "semantic"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_strategy_raises_error(
|
||||
self,
|
||||
retriever: EpisodeRetriever,
|
||||
) -> None:
|
||||
"""Test that unknown strategy raises ValueError."""
|
||||
project_id = uuid4()
|
||||
|
||||
with pytest.raises(ValueError, match="Unknown retrieval strategy"):
|
||||
await retriever.retrieve(project_id, "invalid_strategy", limit=10) # type: ignore
|
||||
2
backend/tests/unit/services/memory/indexing/__init__.py
Normal file
2
backend/tests/unit/services/memory/indexing/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# tests/unit/services/memory/indexing/__init__.py
|
||||
"""Unit tests for memory indexing."""
|
||||
497
backend/tests/unit/services/memory/indexing/test_index.py
Normal file
497
backend/tests/unit/services/memory/indexing/test_index.py
Normal file
@@ -0,0 +1,497 @@
|
||||
# tests/unit/services/memory/indexing/test_index.py
|
||||
"""Unit tests for memory indexing."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.memory.indexing.index import (
|
||||
EntityIndex,
|
||||
MemoryIndexer,
|
||||
OutcomeIndex,
|
||||
TemporalIndex,
|
||||
VectorIndex,
|
||||
get_memory_indexer,
|
||||
)
|
||||
from app.services.memory.types import Episode, Fact, MemoryType, Outcome, Procedure
|
||||
|
||||
|
||||
def _utcnow() -> datetime:
|
||||
"""Get current UTC time."""
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
def make_episode(
|
||||
embedding: list[float] | None = None,
|
||||
outcome: Outcome = Outcome.SUCCESS,
|
||||
occurred_at: datetime | None = None,
|
||||
) -> Episode:
|
||||
"""Create a test episode."""
|
||||
return Episode(
|
||||
id=uuid4(),
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=uuid4(),
|
||||
agent_type_id=uuid4(),
|
||||
session_id="test-session",
|
||||
task_type="test_task",
|
||||
task_description="Test task description",
|
||||
actions=[{"action": "test"}],
|
||||
context_summary="Test context",
|
||||
outcome=outcome,
|
||||
outcome_details="Test outcome",
|
||||
duration_seconds=10.0,
|
||||
tokens_used=100,
|
||||
lessons_learned=["lesson1"],
|
||||
importance_score=0.8,
|
||||
embedding=embedding,
|
||||
occurred_at=occurred_at or _utcnow(),
|
||||
created_at=_utcnow(),
|
||||
updated_at=_utcnow(),
|
||||
)
|
||||
|
||||
|
||||
def make_fact(
|
||||
embedding: list[float] | None = None,
|
||||
subject: str = "test_subject",
|
||||
predicate: str = "has_property",
|
||||
obj: str = "test_value",
|
||||
) -> Fact:
|
||||
"""Create a test fact."""
|
||||
return Fact(
|
||||
id=uuid4(),
|
||||
project_id=uuid4(),
|
||||
subject=subject,
|
||||
predicate=predicate,
|
||||
object=obj,
|
||||
confidence=0.9,
|
||||
source_episode_ids=[uuid4()],
|
||||
first_learned=_utcnow(),
|
||||
last_reinforced=_utcnow(),
|
||||
reinforcement_count=1,
|
||||
embedding=embedding,
|
||||
created_at=_utcnow(),
|
||||
updated_at=_utcnow(),
|
||||
)
|
||||
|
||||
|
||||
def make_procedure(
|
||||
embedding: list[float] | None = None,
|
||||
success_count: int = 8,
|
||||
failure_count: int = 2,
|
||||
) -> Procedure:
|
||||
"""Create a test procedure."""
|
||||
return Procedure(
|
||||
id=uuid4(),
|
||||
project_id=uuid4(),
|
||||
agent_type_id=uuid4(),
|
||||
name="test_procedure",
|
||||
trigger_pattern="test.*",
|
||||
steps=[{"step": 1, "action": "test"}],
|
||||
success_count=success_count,
|
||||
failure_count=failure_count,
|
||||
last_used=_utcnow(),
|
||||
embedding=embedding,
|
||||
created_at=_utcnow(),
|
||||
updated_at=_utcnow(),
|
||||
)
|
||||
|
||||
|
||||
class TestVectorIndex:
|
||||
"""Tests for VectorIndex."""
|
||||
|
||||
@pytest.fixture
|
||||
def index(self) -> VectorIndex[Episode]:
|
||||
"""Create a vector index."""
|
||||
return VectorIndex[Episode](dimension=4)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_item(self, index: VectorIndex[Episode]) -> None:
|
||||
"""Test adding an item to the index."""
|
||||
episode = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
|
||||
|
||||
entry = await index.add(episode)
|
||||
|
||||
assert entry.memory_id == episode.id
|
||||
assert entry.memory_type == MemoryType.EPISODIC
|
||||
assert entry.dimension == 4
|
||||
assert await index.count() == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_item(self, index: VectorIndex[Episode]) -> None:
|
||||
"""Test removing an item from the index."""
|
||||
episode = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
|
||||
await index.add(episode)
|
||||
|
||||
result = await index.remove(episode.id)
|
||||
|
||||
assert result is True
|
||||
assert await index.count() == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_nonexistent(self, index: VectorIndex[Episode]) -> None:
|
||||
"""Test removing a nonexistent item."""
|
||||
result = await index.remove(uuid4())
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_similar(self, index: VectorIndex[Episode]) -> None:
|
||||
"""Test searching for similar items."""
|
||||
# Add items with different embeddings
|
||||
e1 = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
|
||||
e2 = make_episode(embedding=[0.9, 0.1, 0.0, 0.0])
|
||||
e3 = make_episode(embedding=[0.0, 1.0, 0.0, 0.0])
|
||||
|
||||
await index.add(e1)
|
||||
await index.add(e2)
|
||||
await index.add(e3)
|
||||
|
||||
# Search for similar to first
|
||||
results = await index.search([1.0, 0.0, 0.0, 0.0], limit=2)
|
||||
|
||||
assert len(results) == 2
|
||||
# First result should be most similar
|
||||
assert results[0].memory_id == e1.id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_min_similarity(self, index: VectorIndex[Episode]) -> None:
|
||||
"""Test minimum similarity threshold."""
|
||||
e1 = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
|
||||
e2 = make_episode(embedding=[0.0, 1.0, 0.0, 0.0]) # Orthogonal
|
||||
|
||||
await index.add(e1)
|
||||
await index.add(e2)
|
||||
|
||||
# Search with high threshold
|
||||
results = await index.search([1.0, 0.0, 0.0, 0.0], min_similarity=0.9)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].memory_id == e1.id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_empty_query(self, index: VectorIndex[Episode]) -> None:
|
||||
"""Test search with empty query."""
|
||||
e1 = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
|
||||
await index.add(e1)
|
||||
|
||||
results = await index.search([], limit=10)
|
||||
assert len(results) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear(self, index: VectorIndex[Episode]) -> None:
|
||||
"""Test clearing the index."""
|
||||
await index.add(make_episode(embedding=[1.0, 0.0, 0.0, 0.0]))
|
||||
await index.add(make_episode(embedding=[0.0, 1.0, 0.0, 0.0]))
|
||||
|
||||
count = await index.clear()
|
||||
|
||||
assert count == 2
|
||||
assert await index.count() == 0
|
||||
|
||||
|
||||
class TestTemporalIndex:
|
||||
"""Tests for TemporalIndex."""
|
||||
|
||||
@pytest.fixture
|
||||
def index(self) -> TemporalIndex[Episode]:
|
||||
"""Create a temporal index."""
|
||||
return TemporalIndex[Episode]()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_item(self, index: TemporalIndex[Episode]) -> None:
|
||||
"""Test adding an item."""
|
||||
episode = make_episode()
|
||||
entry = await index.add(episode)
|
||||
|
||||
assert entry.memory_id == episode.id
|
||||
assert await index.count() == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_by_time_range(self, index: TemporalIndex[Episode]) -> None:
|
||||
"""Test searching by time range."""
|
||||
now = _utcnow()
|
||||
old = make_episode(occurred_at=now - timedelta(hours=2))
|
||||
recent = make_episode(occurred_at=now - timedelta(hours=1))
|
||||
newest = make_episode(occurred_at=now)
|
||||
|
||||
await index.add(old)
|
||||
await index.add(recent)
|
||||
await index.add(newest)
|
||||
|
||||
# Search last hour
|
||||
results = await index.search(
|
||||
query=None,
|
||||
start_time=now - timedelta(hours=1, minutes=30),
|
||||
end_time=now,
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_recent(self, index: TemporalIndex[Episode]) -> None:
|
||||
"""Test searching for recent items."""
|
||||
now = _utcnow()
|
||||
old = make_episode(occurred_at=now - timedelta(hours=2))
|
||||
recent = make_episode(occurred_at=now - timedelta(minutes=30))
|
||||
|
||||
await index.add(old)
|
||||
await index.add(recent)
|
||||
|
||||
# Search last hour (3600 seconds)
|
||||
results = await index.search(query=None, recent_seconds=3600)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].memory_id == recent.id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_order(self, index: TemporalIndex[Episode]) -> None:
|
||||
"""Test result ordering."""
|
||||
now = _utcnow()
|
||||
e1 = make_episode(occurred_at=now - timedelta(hours=2))
|
||||
e2 = make_episode(occurred_at=now - timedelta(hours=1))
|
||||
e3 = make_episode(occurred_at=now)
|
||||
|
||||
await index.add(e1)
|
||||
await index.add(e2)
|
||||
await index.add(e3)
|
||||
|
||||
# Descending order (newest first)
|
||||
results_desc = await index.search(query=None, order="desc", limit=10)
|
||||
assert results_desc[0].memory_id == e3.id
|
||||
|
||||
# Ascending order (oldest first)
|
||||
results_asc = await index.search(query=None, order="asc", limit=10)
|
||||
assert results_asc[0].memory_id == e1.id
|
||||
|
||||
|
||||
class TestEntityIndex:
|
||||
"""Tests for EntityIndex."""
|
||||
|
||||
@pytest.fixture
|
||||
def index(self) -> EntityIndex[Fact]:
|
||||
"""Create an entity index."""
|
||||
return EntityIndex[Fact]()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_item(self, index: EntityIndex[Fact]) -> None:
|
||||
"""Test adding an item."""
|
||||
fact = make_fact(subject="user", obj="admin")
|
||||
entry = await index.add(fact)
|
||||
|
||||
assert entry.memory_id == fact.id
|
||||
assert await index.count() == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_by_entity(self, index: EntityIndex[Fact]) -> None:
|
||||
"""Test searching by entity."""
|
||||
f1 = make_fact(subject="user", obj="admin")
|
||||
f2 = make_fact(subject="system", obj="config")
|
||||
|
||||
await index.add(f1)
|
||||
await index.add(f2)
|
||||
|
||||
results = await index.search(
|
||||
query=None,
|
||||
entity_type="subject",
|
||||
entity_value="user",
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].memory_id == f1.id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_multiple_entities(self, index: EntityIndex[Fact]) -> None:
|
||||
"""Test searching with multiple entities."""
|
||||
f1 = make_fact(subject="user", obj="admin")
|
||||
f2 = make_fact(subject="user", obj="guest")
|
||||
|
||||
await index.add(f1)
|
||||
await index.add(f2)
|
||||
|
||||
# Search for facts about "user" subject
|
||||
results = await index.search(
|
||||
query=None,
|
||||
entities=[("subject", "user")],
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_match_all(self, index: EntityIndex[Fact]) -> None:
|
||||
"""Test matching all entities."""
|
||||
f1 = make_fact(subject="user", obj="admin")
|
||||
f2 = make_fact(subject="user", obj="guest")
|
||||
|
||||
await index.add(f1)
|
||||
await index.add(f2)
|
||||
|
||||
# Search for user+admin (match all)
|
||||
results = await index.search(
|
||||
query=None,
|
||||
entities=[("subject", "user"), ("object", "admin")],
|
||||
match_all=True,
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].memory_id == f1.id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_entities(self, index: EntityIndex[Fact]) -> None:
|
||||
"""Test getting entities for a memory."""
|
||||
fact = make_fact(subject="user", obj="admin")
|
||||
await index.add(fact)
|
||||
|
||||
entities = await index.get_entities(fact.id)
|
||||
|
||||
assert ("subject", "user") in entities
|
||||
assert ("object", "admin") in entities
|
||||
|
||||
|
||||
class TestOutcomeIndex:
|
||||
"""Tests for OutcomeIndex."""
|
||||
|
||||
@pytest.fixture
|
||||
def index(self) -> OutcomeIndex[Episode]:
|
||||
"""Create an outcome index."""
|
||||
return OutcomeIndex[Episode]()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_item(self, index: OutcomeIndex[Episode]) -> None:
|
||||
"""Test adding an item."""
|
||||
episode = make_episode(outcome=Outcome.SUCCESS)
|
||||
entry = await index.add(episode)
|
||||
|
||||
assert entry.memory_id == episode.id
|
||||
assert entry.outcome == Outcome.SUCCESS
|
||||
assert await index.count() == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_by_outcome(self, index: OutcomeIndex[Episode]) -> None:
|
||||
"""Test searching by outcome."""
|
||||
success = make_episode(outcome=Outcome.SUCCESS)
|
||||
failure = make_episode(outcome=Outcome.FAILURE)
|
||||
|
||||
await index.add(success)
|
||||
await index.add(failure)
|
||||
|
||||
results = await index.search(query=None, outcome=Outcome.SUCCESS)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].memory_id == success.id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_multiple_outcomes(self, index: OutcomeIndex[Episode]) -> None:
|
||||
"""Test searching with multiple outcomes."""
|
||||
success = make_episode(outcome=Outcome.SUCCESS)
|
||||
partial = make_episode(outcome=Outcome.PARTIAL)
|
||||
failure = make_episode(outcome=Outcome.FAILURE)
|
||||
|
||||
await index.add(success)
|
||||
await index.add(partial)
|
||||
await index.add(failure)
|
||||
|
||||
results = await index.search(
|
||||
query=None,
|
||||
outcomes=[Outcome.SUCCESS, Outcome.PARTIAL],
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_outcome_stats(self, index: OutcomeIndex[Episode]) -> None:
|
||||
"""Test getting outcome statistics."""
|
||||
await index.add(make_episode(outcome=Outcome.SUCCESS))
|
||||
await index.add(make_episode(outcome=Outcome.SUCCESS))
|
||||
await index.add(make_episode(outcome=Outcome.FAILURE))
|
||||
|
||||
stats = await index.get_outcome_stats()
|
||||
|
||||
assert stats[Outcome.SUCCESS] == 2
|
||||
assert stats[Outcome.FAILURE] == 1
|
||||
assert stats[Outcome.PARTIAL] == 0
|
||||
|
||||
|
||||
class TestMemoryIndexer:
|
||||
"""Tests for MemoryIndexer."""
|
||||
|
||||
@pytest.fixture
|
||||
def indexer(self) -> MemoryIndexer:
|
||||
"""Create a memory indexer."""
|
||||
return MemoryIndexer()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_index_episode(self, indexer: MemoryIndexer) -> None:
|
||||
"""Test indexing an episode."""
|
||||
episode = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
|
||||
|
||||
results = await indexer.index(episode)
|
||||
|
||||
assert "vector" in results
|
||||
assert "temporal" in results
|
||||
assert "entity" in results
|
||||
assert "outcome" in results
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_index_fact(self, indexer: MemoryIndexer) -> None:
|
||||
"""Test indexing a fact."""
|
||||
fact = make_fact(embedding=[1.0, 0.0, 0.0, 0.0])
|
||||
|
||||
results = await indexer.index(fact)
|
||||
|
||||
# Facts don't have outcomes
|
||||
assert "vector" in results
|
||||
assert "temporal" in results
|
||||
assert "entity" in results
|
||||
assert "outcome" not in results
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_from_all(self, indexer: MemoryIndexer) -> None:
|
||||
"""Test removing from all indices."""
|
||||
episode = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
|
||||
await indexer.index(episode)
|
||||
|
||||
results = await indexer.remove(episode.id)
|
||||
|
||||
assert results["vector"] is True
|
||||
assert results["temporal"] is True
|
||||
assert results["entity"] is True
|
||||
assert results["outcome"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_all(self, indexer: MemoryIndexer) -> None:
|
||||
"""Test clearing all indices."""
|
||||
await indexer.index(make_episode(embedding=[1.0, 0.0, 0.0, 0.0]))
|
||||
await indexer.index(make_episode(embedding=[0.0, 1.0, 0.0, 0.0]))
|
||||
|
||||
counts = await indexer.clear_all()
|
||||
|
||||
assert counts["vector"] == 2
|
||||
assert counts["temporal"] == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_stats(self, indexer: MemoryIndexer) -> None:
|
||||
"""Test getting index statistics."""
|
||||
await indexer.index(make_episode(embedding=[1.0, 0.0, 0.0, 0.0]))
|
||||
|
||||
stats = await indexer.get_stats()
|
||||
|
||||
assert stats["vector"] == 1
|
||||
assert stats["temporal"] == 1
|
||||
assert stats["entity"] == 1
|
||||
assert stats["outcome"] == 1
|
||||
|
||||
|
||||
class TestGetMemoryIndexer:
|
||||
"""Tests for singleton getter."""
|
||||
|
||||
def test_returns_instance(self) -> None:
|
||||
"""Test that getter returns instance."""
|
||||
indexer = get_memory_indexer()
|
||||
assert indexer is not None
|
||||
assert isinstance(indexer, MemoryIndexer)
|
||||
|
||||
def test_returns_same_instance(self) -> None:
|
||||
"""Test that getter returns same instance."""
|
||||
indexer1 = get_memory_indexer()
|
||||
indexer2 = get_memory_indexer()
|
||||
assert indexer1 is indexer2
|
||||
450
backend/tests/unit/services/memory/indexing/test_retrieval.py
Normal file
450
backend/tests/unit/services/memory/indexing/test_retrieval.py
Normal file
@@ -0,0 +1,450 @@
|
||||
# tests/unit/services/memory/indexing/test_retrieval.py
|
||||
"""Unit tests for memory retrieval."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.memory.indexing.index import MemoryIndexer
|
||||
from app.services.memory.indexing.retrieval import (
|
||||
RelevanceScorer,
|
||||
RetrievalCache,
|
||||
RetrievalEngine,
|
||||
RetrievalQuery,
|
||||
ScoredResult,
|
||||
get_retrieval_engine,
|
||||
)
|
||||
from app.services.memory.types import Episode, MemoryType, Outcome
|
||||
|
||||
|
||||
def _utcnow() -> datetime:
|
||||
"""Get current UTC time."""
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
def make_episode(
|
||||
embedding: list[float] | None = None,
|
||||
outcome: Outcome = Outcome.SUCCESS,
|
||||
occurred_at: datetime | None = None,
|
||||
task_type: str = "test_task",
|
||||
) -> Episode:
|
||||
"""Create a test episode."""
|
||||
return Episode(
|
||||
id=uuid4(),
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=uuid4(),
|
||||
agent_type_id=uuid4(),
|
||||
session_id="test-session",
|
||||
task_type=task_type,
|
||||
task_description="Test task description",
|
||||
actions=[{"action": "test"}],
|
||||
context_summary="Test context",
|
||||
outcome=outcome,
|
||||
outcome_details="Test outcome",
|
||||
duration_seconds=10.0,
|
||||
tokens_used=100,
|
||||
lessons_learned=["lesson1"],
|
||||
importance_score=0.8,
|
||||
embedding=embedding,
|
||||
occurred_at=occurred_at or _utcnow(),
|
||||
created_at=_utcnow(),
|
||||
updated_at=_utcnow(),
|
||||
)
|
||||
|
||||
|
||||
class TestRetrievalQuery:
|
||||
"""Tests for RetrievalQuery."""
|
||||
|
||||
def test_default_values(self) -> None:
|
||||
"""Test default query values."""
|
||||
query = RetrievalQuery()
|
||||
|
||||
assert query.query_text is None
|
||||
assert query.limit == 10
|
||||
assert query.min_relevance == 0.0
|
||||
assert query.use_vector is True
|
||||
assert query.use_temporal is True
|
||||
|
||||
def test_cache_key_generation(self) -> None:
|
||||
"""Test cache key generation."""
|
||||
query1 = RetrievalQuery(query_text="test", limit=10)
|
||||
query2 = RetrievalQuery(query_text="test", limit=10)
|
||||
query3 = RetrievalQuery(query_text="different", limit=10)
|
||||
|
||||
# Same queries should have same key
|
||||
assert query1.to_cache_key() == query2.to_cache_key()
|
||||
# Different queries should have different keys
|
||||
assert query1.to_cache_key() != query3.to_cache_key()
|
||||
|
||||
|
||||
class TestScoredResult:
|
||||
"""Tests for ScoredResult."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test creating a scored result."""
|
||||
result = ScoredResult(
|
||||
memory_id=uuid4(),
|
||||
memory_type=MemoryType.EPISODIC,
|
||||
relevance_score=0.85,
|
||||
score_breakdown={"vector": 0.9, "recency": 0.8},
|
||||
)
|
||||
|
||||
assert result.relevance_score == 0.85
|
||||
assert result.score_breakdown["vector"] == 0.9
|
||||
|
||||
|
||||
class TestRelevanceScorer:
|
||||
"""Tests for RelevanceScorer."""
|
||||
|
||||
@pytest.fixture
|
||||
def scorer(self) -> RelevanceScorer:
|
||||
"""Create a relevance scorer."""
|
||||
return RelevanceScorer()
|
||||
|
||||
def test_score_with_vector(self, scorer: RelevanceScorer) -> None:
|
||||
"""Test scoring with vector similarity."""
|
||||
result = scorer.score(
|
||||
memory_id=uuid4(),
|
||||
memory_type=MemoryType.EPISODIC,
|
||||
vector_similarity=0.9,
|
||||
)
|
||||
|
||||
assert result.relevance_score > 0
|
||||
assert result.score_breakdown["vector"] == 0.9
|
||||
|
||||
def test_score_with_recency(self, scorer: RelevanceScorer) -> None:
|
||||
"""Test scoring with recency."""
|
||||
recent_result = scorer.score(
|
||||
memory_id=uuid4(),
|
||||
memory_type=MemoryType.EPISODIC,
|
||||
timestamp=_utcnow(),
|
||||
)
|
||||
|
||||
old_result = scorer.score(
|
||||
memory_id=uuid4(),
|
||||
memory_type=MemoryType.EPISODIC,
|
||||
timestamp=_utcnow() - timedelta(days=7),
|
||||
)
|
||||
|
||||
# Recent should have higher recency score
|
||||
assert (
|
||||
recent_result.score_breakdown["recency"]
|
||||
> old_result.score_breakdown["recency"]
|
||||
)
|
||||
|
||||
def test_score_with_outcome_preference(self, scorer: RelevanceScorer) -> None:
|
||||
"""Test scoring with outcome preference."""
|
||||
success_result = scorer.score(
|
||||
memory_id=uuid4(),
|
||||
memory_type=MemoryType.EPISODIC,
|
||||
outcome=Outcome.SUCCESS,
|
||||
preferred_outcomes=[Outcome.SUCCESS],
|
||||
)
|
||||
|
||||
failure_result = scorer.score(
|
||||
memory_id=uuid4(),
|
||||
memory_type=MemoryType.EPISODIC,
|
||||
outcome=Outcome.FAILURE,
|
||||
preferred_outcomes=[Outcome.SUCCESS],
|
||||
)
|
||||
|
||||
assert success_result.score_breakdown["outcome"] == 1.0
|
||||
assert failure_result.score_breakdown["outcome"] == 0.0
|
||||
|
||||
def test_score_with_entity_match(self, scorer: RelevanceScorer) -> None:
|
||||
"""Test scoring with entity matches."""
|
||||
full_match = scorer.score(
|
||||
memory_id=uuid4(),
|
||||
memory_type=MemoryType.EPISODIC,
|
||||
entity_match_count=3,
|
||||
entity_total=3,
|
||||
)
|
||||
|
||||
partial_match = scorer.score(
|
||||
memory_id=uuid4(),
|
||||
memory_type=MemoryType.EPISODIC,
|
||||
entity_match_count=1,
|
||||
entity_total=3,
|
||||
)
|
||||
|
||||
assert (
|
||||
full_match.score_breakdown["entity"]
|
||||
> partial_match.score_breakdown["entity"]
|
||||
)
|
||||
|
||||
|
||||
class TestRetrievalCache:
|
||||
"""Tests for RetrievalCache."""
|
||||
|
||||
@pytest.fixture
|
||||
def cache(self) -> RetrievalCache:
|
||||
"""Create a retrieval cache."""
|
||||
return RetrievalCache(max_entries=10, default_ttl_seconds=60)
|
||||
|
||||
def test_put_and_get(self, cache: RetrievalCache) -> None:
|
||||
"""Test putting and getting from cache."""
|
||||
results = [
|
||||
ScoredResult(
|
||||
memory_id=uuid4(),
|
||||
memory_type=MemoryType.EPISODIC,
|
||||
relevance_score=0.8,
|
||||
)
|
||||
]
|
||||
|
||||
cache.put("test_key", results)
|
||||
cached = cache.get("test_key")
|
||||
|
||||
assert cached is not None
|
||||
assert len(cached) == 1
|
||||
|
||||
def test_get_nonexistent(self, cache: RetrievalCache) -> None:
|
||||
"""Test getting nonexistent entry."""
|
||||
result = cache.get("nonexistent")
|
||||
assert result is None
|
||||
|
||||
def test_lru_eviction(self) -> None:
|
||||
"""Test LRU eviction when at capacity."""
|
||||
cache = RetrievalCache(max_entries=2, default_ttl_seconds=60)
|
||||
|
||||
results = [
|
||||
ScoredResult(
|
||||
memory_id=uuid4(),
|
||||
memory_type=MemoryType.EPISODIC,
|
||||
relevance_score=0.8,
|
||||
)
|
||||
]
|
||||
|
||||
cache.put("key1", results)
|
||||
cache.put("key2", results)
|
||||
cache.put("key3", results) # Should evict key1
|
||||
|
||||
assert cache.get("key1") is None
|
||||
assert cache.get("key2") is not None
|
||||
assert cache.get("key3") is not None
|
||||
|
||||
def test_invalidate(self, cache: RetrievalCache) -> None:
|
||||
"""Test invalidating a cache entry."""
|
||||
results = [
|
||||
ScoredResult(
|
||||
memory_id=uuid4(),
|
||||
memory_type=MemoryType.EPISODIC,
|
||||
relevance_score=0.8,
|
||||
)
|
||||
]
|
||||
|
||||
cache.put("test_key", results)
|
||||
removed = cache.invalidate("test_key")
|
||||
|
||||
assert removed is True
|
||||
assert cache.get("test_key") is None
|
||||
|
||||
def test_invalidate_by_memory(self, cache: RetrievalCache) -> None:
|
||||
"""Test invalidating by memory ID."""
|
||||
memory_id = uuid4()
|
||||
results = [
|
||||
ScoredResult(
|
||||
memory_id=memory_id,
|
||||
memory_type=MemoryType.EPISODIC,
|
||||
relevance_score=0.8,
|
||||
)
|
||||
]
|
||||
|
||||
cache.put("key1", results)
|
||||
cache.put("key2", results)
|
||||
|
||||
count = cache.invalidate_by_memory(memory_id)
|
||||
|
||||
assert count == 2
|
||||
assert cache.get("key1") is None
|
||||
assert cache.get("key2") is None
|
||||
|
||||
def test_clear(self, cache: RetrievalCache) -> None:
|
||||
"""Test clearing the cache."""
|
||||
results = [
|
||||
ScoredResult(
|
||||
memory_id=uuid4(),
|
||||
memory_type=MemoryType.EPISODIC,
|
||||
relevance_score=0.8,
|
||||
)
|
||||
]
|
||||
|
||||
cache.put("key1", results)
|
||||
cache.put("key2", results)
|
||||
|
||||
count = cache.clear()
|
||||
|
||||
assert count == 2
|
||||
assert cache.get("key1") is None
|
||||
|
||||
def test_get_stats(self, cache: RetrievalCache) -> None:
|
||||
"""Test getting cache statistics."""
|
||||
stats = cache.get_stats()
|
||||
|
||||
assert "total_entries" in stats
|
||||
assert "max_entries" in stats
|
||||
assert stats["max_entries"] == 10
|
||||
|
||||
|
||||
class TestRetrievalEngine:
|
||||
"""Tests for RetrievalEngine."""
|
||||
|
||||
@pytest.fixture
|
||||
def indexer(self) -> MemoryIndexer:
|
||||
"""Create a memory indexer."""
|
||||
return MemoryIndexer()
|
||||
|
||||
@pytest.fixture
|
||||
def engine(self, indexer: MemoryIndexer) -> RetrievalEngine:
|
||||
"""Create a retrieval engine."""
|
||||
return RetrievalEngine(indexer=indexer, enable_cache=True)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_by_vector(
|
||||
self, engine: RetrievalEngine, indexer: MemoryIndexer
|
||||
) -> None:
|
||||
"""Test retrieval by vector similarity."""
|
||||
e1 = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
|
||||
e2 = make_episode(embedding=[0.9, 0.1, 0.0, 0.0])
|
||||
e3 = make_episode(embedding=[0.0, 1.0, 0.0, 0.0])
|
||||
|
||||
await indexer.index(e1)
|
||||
await indexer.index(e2)
|
||||
await indexer.index(e3)
|
||||
|
||||
query = RetrievalQuery(
|
||||
query_embedding=[1.0, 0.0, 0.0, 0.0],
|
||||
limit=2,
|
||||
use_temporal=False,
|
||||
use_entity=False,
|
||||
use_outcome=False,
|
||||
)
|
||||
|
||||
result = await engine.retrieve(query)
|
||||
|
||||
assert len(result.items) > 0
|
||||
assert result.retrieval_type == "hybrid"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_recent(
|
||||
self, engine: RetrievalEngine, indexer: MemoryIndexer
|
||||
) -> None:
|
||||
"""Test retrieval of recent items."""
|
||||
now = _utcnow()
|
||||
old = make_episode(occurred_at=now - timedelta(hours=2))
|
||||
recent = make_episode(occurred_at=now - timedelta(minutes=30))
|
||||
|
||||
await indexer.index(old)
|
||||
await indexer.index(recent)
|
||||
|
||||
result = await engine.retrieve_recent(hours=1)
|
||||
|
||||
assert len(result.items) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_by_entity(
|
||||
self, engine: RetrievalEngine, indexer: MemoryIndexer
|
||||
) -> None:
|
||||
"""Test retrieval by entity."""
|
||||
e1 = make_episode(task_type="deploy")
|
||||
e2 = make_episode(task_type="test")
|
||||
|
||||
await indexer.index(e1)
|
||||
await indexer.index(e2)
|
||||
|
||||
result = await engine.retrieve_by_entity("task_type", "deploy")
|
||||
|
||||
assert len(result.items) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_successful(
|
||||
self, engine: RetrievalEngine, indexer: MemoryIndexer
|
||||
) -> None:
|
||||
"""Test retrieval of successful items."""
|
||||
success = make_episode(outcome=Outcome.SUCCESS)
|
||||
failure = make_episode(outcome=Outcome.FAILURE)
|
||||
|
||||
await indexer.index(success)
|
||||
await indexer.index(failure)
|
||||
|
||||
result = await engine.retrieve_successful()
|
||||
|
||||
assert len(result.items) == 1
|
||||
# Check outcome index was used
|
||||
assert result.items[0].memory_id == success.id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_with_cache(
|
||||
self, engine: RetrievalEngine, indexer: MemoryIndexer
|
||||
) -> None:
|
||||
"""Test that retrieval uses cache."""
|
||||
episode = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
|
||||
await indexer.index(episode)
|
||||
|
||||
query = RetrievalQuery(
|
||||
query_embedding=[1.0, 0.0, 0.0, 0.0],
|
||||
limit=10,
|
||||
)
|
||||
|
||||
# First retrieval
|
||||
result1 = await engine.retrieve(query)
|
||||
assert result1.metadata.get("cache_hit") is False
|
||||
|
||||
# Second retrieval should be cached
|
||||
result2 = await engine.retrieve(query)
|
||||
assert result2.metadata.get("cache_hit") is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalidate_cache(
|
||||
self, engine: RetrievalEngine, indexer: MemoryIndexer
|
||||
) -> None:
|
||||
"""Test cache invalidation."""
|
||||
episode = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
|
||||
await indexer.index(episode)
|
||||
|
||||
query = RetrievalQuery(query_embedding=[1.0, 0.0, 0.0, 0.0])
|
||||
await engine.retrieve(query)
|
||||
|
||||
count = engine.invalidate_cache()
|
||||
|
||||
assert count > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_similar(
|
||||
self, engine: RetrievalEngine, indexer: MemoryIndexer
|
||||
) -> None:
|
||||
"""Test retrieve_similar convenience method."""
|
||||
e1 = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
|
||||
e2 = make_episode(embedding=[0.0, 1.0, 0.0, 0.0])
|
||||
|
||||
await indexer.index(e1)
|
||||
await indexer.index(e2)
|
||||
|
||||
result = await engine.retrieve_similar(
|
||||
embedding=[1.0, 0.0, 0.0, 0.0],
|
||||
limit=1,
|
||||
)
|
||||
|
||||
assert len(result.items) == 1
|
||||
|
||||
def test_get_cache_stats(self, engine: RetrievalEngine) -> None:
|
||||
"""Test getting cache statistics."""
|
||||
stats = engine.get_cache_stats()
|
||||
|
||||
assert "total_entries" in stats
|
||||
|
||||
|
||||
class TestGetRetrievalEngine:
|
||||
"""Tests for singleton getter."""
|
||||
|
||||
def test_returns_instance(self) -> None:
|
||||
"""Test that getter returns instance."""
|
||||
engine = get_retrieval_engine()
|
||||
assert engine is not None
|
||||
assert isinstance(engine, RetrievalEngine)
|
||||
|
||||
def test_returns_same_instance(self) -> None:
|
||||
"""Test that getter returns same instance."""
|
||||
engine1 = get_retrieval_engine()
|
||||
engine2 = get_retrieval_engine()
|
||||
assert engine1 is engine2
|
||||
@@ -0,0 +1,2 @@
|
||||
# tests/unit/services/memory/integration/__init__.py
|
||||
"""Tests for memory integration module."""
|
||||
@@ -0,0 +1,319 @@
|
||||
# tests/unit/services/memory/integration/test_context_source.py
|
||||
"""Tests for MemoryContextSource service."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.context.types.memory import MemorySubtype
|
||||
from app.services.memory.integration.context_source import (
|
||||
MemoryContextSource,
|
||||
MemoryFetchConfig,
|
||||
MemoryFetchResult,
|
||||
get_memory_context_source,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.asyncio(loop_scope="function")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session() -> MagicMock:
|
||||
"""Create mock database session."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def context_source(mock_session: MagicMock) -> MemoryContextSource:
|
||||
"""Create MemoryContextSource instance."""
|
||||
return MemoryContextSource(session=mock_session)
|
||||
|
||||
|
||||
class TestMemoryFetchConfig:
|
||||
"""Tests for MemoryFetchConfig."""
|
||||
|
||||
def test_default_values(self) -> None:
|
||||
"""Default config values should be set correctly."""
|
||||
config = MemoryFetchConfig()
|
||||
|
||||
assert config.working_limit == 10
|
||||
assert config.episodic_limit == 10
|
||||
assert config.semantic_limit == 15
|
||||
assert config.procedural_limit == 5
|
||||
assert config.episodic_days_back == 30
|
||||
assert config.min_relevance == 0.3
|
||||
assert config.include_working is True
|
||||
assert config.include_episodic is True
|
||||
assert config.include_semantic is True
|
||||
assert config.include_procedural is True
|
||||
|
||||
def test_custom_values(self) -> None:
|
||||
"""Custom config values should be respected."""
|
||||
config = MemoryFetchConfig(
|
||||
working_limit=5,
|
||||
include_working=False,
|
||||
)
|
||||
|
||||
assert config.working_limit == 5
|
||||
assert config.include_working is False
|
||||
|
||||
|
||||
class TestMemoryFetchResult:
|
||||
"""Tests for MemoryFetchResult."""
|
||||
|
||||
def test_stores_results(self) -> None:
|
||||
"""Result should store contexts and metadata."""
|
||||
result = MemoryFetchResult(
|
||||
contexts=[],
|
||||
by_type={"working": 0, "episodic": 5, "semantic": 3, "procedural": 0},
|
||||
fetch_time_ms=15.5,
|
||||
query="test query",
|
||||
)
|
||||
|
||||
assert result.contexts == []
|
||||
assert result.by_type["episodic"] == 5
|
||||
assert result.fetch_time_ms == 15.5
|
||||
assert result.query == "test query"
|
||||
|
||||
|
||||
class TestMemoryContextSource:
|
||||
"""Tests for MemoryContextSource service."""
|
||||
|
||||
async def test_fetch_context_empty_when_no_sources(
|
||||
self,
|
||||
context_source: MemoryContextSource,
|
||||
) -> None:
|
||||
"""fetch_context should return empty when all sources fail."""
|
||||
config = MemoryFetchConfig(
|
||||
include_working=False,
|
||||
include_episodic=False,
|
||||
include_semantic=False,
|
||||
include_procedural=False,
|
||||
)
|
||||
|
||||
result = await context_source.fetch_context(
|
||||
query="test",
|
||||
project_id=uuid4(),
|
||||
config=config,
|
||||
)
|
||||
|
||||
assert len(result.contexts) == 0
|
||||
assert result.by_type == {
|
||||
"working": 0,
|
||||
"episodic": 0,
|
||||
"semantic": 0,
|
||||
"procedural": 0,
|
||||
}
|
||||
|
||||
@patch("app.services.memory.integration.context_source.WorkingMemory")
|
||||
async def test_fetch_working_memory(
|
||||
self,
|
||||
mock_working_cls: MagicMock,
|
||||
context_source: MemoryContextSource,
|
||||
) -> None:
|
||||
"""Should fetch working memory when session_id provided."""
|
||||
# Setup mock - both keys should match the query "task"
|
||||
mock_working = AsyncMock()
|
||||
mock_working.list_keys = AsyncMock(return_value=["current_task", "task_state"])
|
||||
mock_working.get = AsyncMock(side_effect=lambda k: {"key": k, "value": "test"})
|
||||
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
||||
|
||||
config = MemoryFetchConfig(
|
||||
include_episodic=False,
|
||||
include_semantic=False,
|
||||
include_procedural=False,
|
||||
)
|
||||
|
||||
result = await context_source.fetch_context(
|
||||
query="task", # Both keys contain "task"
|
||||
project_id=uuid4(),
|
||||
session_id="sess-123",
|
||||
config=config,
|
||||
)
|
||||
|
||||
assert result.by_type["working"] == 2
|
||||
assert all(c.memory_subtype == MemorySubtype.WORKING for c in result.contexts)
|
||||
|
||||
@patch("app.services.memory.integration.context_source.EpisodicMemory")
|
||||
async def test_fetch_episodic_memory(
|
||||
self,
|
||||
mock_episodic_cls: MagicMock,
|
||||
context_source: MemoryContextSource,
|
||||
) -> None:
|
||||
"""Should fetch episodic memory."""
|
||||
# Setup mock episode
|
||||
mock_episode = MagicMock()
|
||||
mock_episode.id = uuid4()
|
||||
mock_episode.task_description = "Completed login feature"
|
||||
mock_episode.task_type = "feature"
|
||||
mock_episode.outcome = MagicMock(value="success")
|
||||
mock_episode.importance_score = 0.8
|
||||
mock_episode.occurred_at = datetime.now(UTC)
|
||||
mock_episode.lessons_learned = []
|
||||
|
||||
mock_episodic = AsyncMock()
|
||||
mock_episodic.search_similar = AsyncMock(return_value=[mock_episode])
|
||||
mock_episodic.get_recent = AsyncMock(return_value=[])
|
||||
mock_episodic_cls.create = AsyncMock(return_value=mock_episodic)
|
||||
|
||||
config = MemoryFetchConfig(
|
||||
include_working=False,
|
||||
include_semantic=False,
|
||||
include_procedural=False,
|
||||
)
|
||||
|
||||
result = await context_source.fetch_context(
|
||||
query="login",
|
||||
project_id=uuid4(),
|
||||
config=config,
|
||||
)
|
||||
|
||||
assert result.by_type["episodic"] == 1
|
||||
assert result.contexts[0].memory_subtype == MemorySubtype.EPISODIC
|
||||
assert "Completed login feature" in result.contexts[0].content
|
||||
|
||||
@patch("app.services.memory.integration.context_source.SemanticMemory")
|
||||
async def test_fetch_semantic_memory(
|
||||
self,
|
||||
mock_semantic_cls: MagicMock,
|
||||
context_source: MemoryContextSource,
|
||||
) -> None:
|
||||
"""Should fetch semantic memory."""
|
||||
# Setup mock fact
|
||||
mock_fact = MagicMock()
|
||||
mock_fact.id = uuid4()
|
||||
mock_fact.subject = "User"
|
||||
mock_fact.predicate = "prefers"
|
||||
mock_fact.object = "dark mode"
|
||||
mock_fact.confidence = 0.9
|
||||
|
||||
mock_semantic = AsyncMock()
|
||||
mock_semantic.search_facts = AsyncMock(return_value=[mock_fact])
|
||||
mock_semantic_cls.create = AsyncMock(return_value=mock_semantic)
|
||||
|
||||
config = MemoryFetchConfig(
|
||||
include_working=False,
|
||||
include_episodic=False,
|
||||
include_procedural=False,
|
||||
)
|
||||
|
||||
result = await context_source.fetch_context(
|
||||
query="preferences",
|
||||
project_id=uuid4(),
|
||||
config=config,
|
||||
)
|
||||
|
||||
assert result.by_type["semantic"] == 1
|
||||
assert result.contexts[0].memory_subtype == MemorySubtype.SEMANTIC
|
||||
assert "User prefers dark mode" in result.contexts[0].content
|
||||
|
||||
@patch("app.services.memory.integration.context_source.ProceduralMemory")
|
||||
async def test_fetch_procedural_memory(
|
||||
self,
|
||||
mock_procedural_cls: MagicMock,
|
||||
context_source: MemoryContextSource,
|
||||
) -> None:
|
||||
"""Should fetch procedural memory."""
|
||||
# Setup mock procedure
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.id = uuid4()
|
||||
mock_proc.name = "Deploy"
|
||||
mock_proc.trigger_pattern = "When deploying"
|
||||
mock_proc.steps = [{"action": "build"}, {"action": "test"}]
|
||||
mock_proc.success_rate = 0.9
|
||||
mock_proc.success_count = 9
|
||||
mock_proc.failure_count = 1
|
||||
|
||||
mock_procedural = AsyncMock()
|
||||
mock_procedural.find_matching = AsyncMock(return_value=[mock_proc])
|
||||
mock_procedural_cls.create = AsyncMock(return_value=mock_procedural)
|
||||
|
||||
config = MemoryFetchConfig(
|
||||
include_working=False,
|
||||
include_episodic=False,
|
||||
include_semantic=False,
|
||||
)
|
||||
|
||||
result = await context_source.fetch_context(
|
||||
query="deploy",
|
||||
project_id=uuid4(),
|
||||
config=config,
|
||||
)
|
||||
|
||||
assert result.by_type["procedural"] == 1
|
||||
assert result.contexts[0].memory_subtype == MemorySubtype.PROCEDURAL
|
||||
assert "Deploy" in result.contexts[0].content
|
||||
|
||||
async def test_results_sorted_by_relevance(
|
||||
self,
|
||||
context_source: MemoryContextSource,
|
||||
) -> None:
|
||||
"""Results should be sorted by relevance score."""
|
||||
with (
|
||||
patch.object(context_source, "_fetch_episodic") as mock_ep,
|
||||
patch.object(context_source, "_fetch_semantic") as mock_sem,
|
||||
):
|
||||
# Create contexts with different relevance scores
|
||||
from app.services.context.types.memory import MemoryContext
|
||||
|
||||
ctx_low = MemoryContext(
|
||||
content="low relevance",
|
||||
source="test",
|
||||
relevance_score=0.3,
|
||||
)
|
||||
ctx_high = MemoryContext(
|
||||
content="high relevance",
|
||||
source="test",
|
||||
relevance_score=0.9,
|
||||
)
|
||||
|
||||
mock_ep.return_value = [ctx_low]
|
||||
mock_sem.return_value = [ctx_high]
|
||||
|
||||
config = MemoryFetchConfig(
|
||||
include_working=False,
|
||||
include_procedural=False,
|
||||
)
|
||||
|
||||
result = await context_source.fetch_context(
|
||||
query="test",
|
||||
project_id=uuid4(),
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Higher relevance should come first
|
||||
assert result.contexts[0].relevance_score == 0.9
|
||||
assert result.contexts[1].relevance_score == 0.3
|
||||
|
||||
@patch("app.services.memory.integration.context_source.WorkingMemory")
|
||||
async def test_fetch_all_working(
|
||||
self,
|
||||
mock_working_cls: MagicMock,
|
||||
context_source: MemoryContextSource,
|
||||
) -> None:
|
||||
"""fetch_all_working should return all working memory items."""
|
||||
mock_working = AsyncMock()
|
||||
mock_working.list_keys = AsyncMock(return_value=["key1", "key2", "key3"])
|
||||
mock_working.get = AsyncMock(return_value="value")
|
||||
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
||||
|
||||
contexts = await context_source.fetch_all_working(
|
||||
session_id="sess-123",
|
||||
project_id=uuid4(),
|
||||
)
|
||||
|
||||
assert len(contexts) == 3
|
||||
assert all(c.memory_subtype == MemorySubtype.WORKING for c in contexts)
|
||||
|
||||
|
||||
class TestGetMemoryContextSource:
|
||||
"""Tests for factory function."""
|
||||
|
||||
async def test_creates_instance(self) -> None:
|
||||
"""Factory should create MemoryContextSource instance."""
|
||||
mock_session = MagicMock()
|
||||
|
||||
source = await get_memory_context_source(mock_session)
|
||||
|
||||
assert isinstance(source, MemoryContextSource)
|
||||
472
backend/tests/unit/services/memory/integration/test_lifecycle.py
Normal file
472
backend/tests/unit/services/memory/integration/test_lifecycle.py
Normal file
@@ -0,0 +1,472 @@
|
||||
# tests/unit/services/memory/integration/test_lifecycle.py
|
||||
"""Tests for Agent Lifecycle Hooks."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.memory.integration.lifecycle import (
|
||||
AgentLifecycleManager,
|
||||
LifecycleEvent,
|
||||
LifecycleHooks,
|
||||
LifecycleResult,
|
||||
get_lifecycle_manager,
|
||||
)
|
||||
from app.services.memory.types import Outcome
|
||||
|
||||
pytestmark = pytest.mark.asyncio(loop_scope="function")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session() -> MagicMock:
|
||||
"""Create mock database session."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def lifecycle_hooks() -> LifecycleHooks:
|
||||
"""Create lifecycle hooks instance."""
|
||||
return LifecycleHooks()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def lifecycle_manager(mock_session: MagicMock) -> AgentLifecycleManager:
|
||||
"""Create lifecycle manager instance."""
|
||||
return AgentLifecycleManager(session=mock_session)
|
||||
|
||||
|
||||
class TestLifecycleEvent:
|
||||
"""Tests for LifecycleEvent dataclass."""
|
||||
|
||||
def test_creates_event(self) -> None:
|
||||
"""Should create event with required fields."""
|
||||
project_id = uuid4()
|
||||
agent_id = uuid4()
|
||||
|
||||
event = LifecycleEvent(
|
||||
event_type="spawn",
|
||||
project_id=project_id,
|
||||
agent_instance_id=agent_id,
|
||||
)
|
||||
|
||||
assert event.event_type == "spawn"
|
||||
assert event.project_id == project_id
|
||||
assert event.agent_instance_id == agent_id
|
||||
assert event.timestamp is not None
|
||||
assert event.metadata == {}
|
||||
|
||||
def test_with_optional_fields(self) -> None:
|
||||
"""Should include optional fields."""
|
||||
event = LifecycleEvent(
|
||||
event_type="terminate",
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=uuid4(),
|
||||
session_id="sess-123",
|
||||
metadata={"reason": "completed"},
|
||||
)
|
||||
|
||||
assert event.session_id == "sess-123"
|
||||
assert event.metadata["reason"] == "completed"
|
||||
|
||||
|
||||
class TestLifecycleResult:
|
||||
"""Tests for LifecycleResult dataclass."""
|
||||
|
||||
def test_success_result(self) -> None:
|
||||
"""Should create success result."""
|
||||
result = LifecycleResult(
|
||||
success=True,
|
||||
event_type="spawn",
|
||||
message="Agent spawned",
|
||||
data={"session_id": "sess-123"},
|
||||
duration_ms=10.5,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.event_type == "spawn"
|
||||
assert result.data["session_id"] == "sess-123"
|
||||
|
||||
def test_failure_result(self) -> None:
|
||||
"""Should create failure result."""
|
||||
result = LifecycleResult(
|
||||
success=False,
|
||||
event_type="resume",
|
||||
message="Checkpoint not found",
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert result.message == "Checkpoint not found"
|
||||
|
||||
|
||||
class TestLifecycleHooks:
|
||||
"""Tests for LifecycleHooks class."""
|
||||
|
||||
def test_register_spawn_hook(self, lifecycle_hooks: LifecycleHooks) -> None:
|
||||
"""Should register spawn hook."""
|
||||
|
||||
async def my_hook(event: LifecycleEvent) -> None:
|
||||
pass
|
||||
|
||||
result = lifecycle_hooks.on_spawn(my_hook)
|
||||
|
||||
assert result is my_hook
|
||||
assert my_hook in lifecycle_hooks._spawn_hooks
|
||||
|
||||
def test_register_all_hooks(self, lifecycle_hooks: LifecycleHooks) -> None:
|
||||
"""Should register hooks for all event types."""
|
||||
[
|
||||
lifecycle_hooks.on_spawn(AsyncMock()),
|
||||
lifecycle_hooks.on_pause(AsyncMock()),
|
||||
lifecycle_hooks.on_resume(AsyncMock()),
|
||||
lifecycle_hooks.on_terminate(AsyncMock()),
|
||||
]
|
||||
|
||||
assert len(lifecycle_hooks._spawn_hooks) == 1
|
||||
assert len(lifecycle_hooks._pause_hooks) == 1
|
||||
assert len(lifecycle_hooks._resume_hooks) == 1
|
||||
assert len(lifecycle_hooks._terminate_hooks) == 1
|
||||
|
||||
async def test_run_spawn_hooks(self, lifecycle_hooks: LifecycleHooks) -> None:
|
||||
"""Should run all spawn hooks."""
|
||||
hook1 = AsyncMock()
|
||||
hook2 = AsyncMock()
|
||||
lifecycle_hooks.on_spawn(hook1)
|
||||
lifecycle_hooks.on_spawn(hook2)
|
||||
|
||||
event = LifecycleEvent(
|
||||
event_type="spawn",
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=uuid4(),
|
||||
)
|
||||
|
||||
await lifecycle_hooks.run_spawn_hooks(event)
|
||||
|
||||
hook1.assert_called_once_with(event)
|
||||
hook2.assert_called_once_with(event)
|
||||
|
||||
async def test_hook_failure_doesnt_stop_others(
|
||||
self, lifecycle_hooks: LifecycleHooks
|
||||
) -> None:
|
||||
"""Hook failure should not stop other hooks from running."""
|
||||
hook1 = AsyncMock(side_effect=ValueError("Oops"))
|
||||
hook2 = AsyncMock()
|
||||
lifecycle_hooks.on_pause(hook1)
|
||||
lifecycle_hooks.on_pause(hook2)
|
||||
|
||||
event = LifecycleEvent(
|
||||
event_type="pause",
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=uuid4(),
|
||||
)
|
||||
|
||||
await lifecycle_hooks.run_pause_hooks(event)
|
||||
|
||||
# hook2 should still be called even though hook1 failed
|
||||
hook2.assert_called_once()
|
||||
|
||||
|
||||
class TestAgentLifecycleManagerSpawn:
|
||||
"""Tests for AgentLifecycleManager.spawn."""
|
||||
|
||||
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
|
||||
async def test_spawn_creates_working_memory(
|
||||
self,
|
||||
mock_working_cls: MagicMock,
|
||||
lifecycle_manager: AgentLifecycleManager,
|
||||
) -> None:
|
||||
"""Spawn should create working memory for session."""
|
||||
mock_working = AsyncMock()
|
||||
mock_working.set = AsyncMock()
|
||||
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
||||
|
||||
result = await lifecycle_manager.spawn(
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=uuid4(),
|
||||
session_id="sess-123",
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.event_type == "spawn"
|
||||
mock_working_cls.for_session.assert_called_once()
|
||||
|
||||
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
|
||||
async def test_spawn_with_initial_state(
|
||||
self,
|
||||
mock_working_cls: MagicMock,
|
||||
lifecycle_manager: AgentLifecycleManager,
|
||||
) -> None:
|
||||
"""Spawn should populate initial state."""
|
||||
mock_working = AsyncMock()
|
||||
mock_working.set = AsyncMock()
|
||||
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
||||
|
||||
result = await lifecycle_manager.spawn(
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=uuid4(),
|
||||
session_id="sess-123",
|
||||
initial_state={"key1": "value1", "key2": "value2"},
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.data["initial_items"] == 2
|
||||
assert mock_working.set.call_count == 2
|
||||
|
||||
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
|
||||
async def test_spawn_runs_hooks(
|
||||
self,
|
||||
mock_working_cls: MagicMock,
|
||||
lifecycle_manager: AgentLifecycleManager,
|
||||
) -> None:
|
||||
"""Spawn should run registered hooks."""
|
||||
mock_working = AsyncMock()
|
||||
mock_working.set = AsyncMock()
|
||||
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
||||
|
||||
hook = AsyncMock()
|
||||
lifecycle_manager.hooks.on_spawn(hook)
|
||||
|
||||
await lifecycle_manager.spawn(
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=uuid4(),
|
||||
session_id="sess-123",
|
||||
)
|
||||
|
||||
hook.assert_called_once()
|
||||
|
||||
|
||||
class TestAgentLifecycleManagerPause:
|
||||
"""Tests for AgentLifecycleManager.pause."""
|
||||
|
||||
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
|
||||
async def test_pause_creates_checkpoint(
|
||||
self,
|
||||
mock_working_cls: MagicMock,
|
||||
lifecycle_manager: AgentLifecycleManager,
|
||||
) -> None:
|
||||
"""Pause should create checkpoint of working memory."""
|
||||
mock_working = AsyncMock()
|
||||
mock_working.list_keys = AsyncMock(return_value=["key1", "key2"])
|
||||
mock_working.get = AsyncMock(return_value={"data": "test"})
|
||||
mock_working.set = AsyncMock()
|
||||
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
||||
|
||||
result = await lifecycle_manager.pause(
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=uuid4(),
|
||||
session_id="sess-123",
|
||||
checkpoint_id="ckpt-001",
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.event_type == "pause"
|
||||
assert result.data["checkpoint_id"] == "ckpt-001"
|
||||
assert result.data["items_saved"] == 2
|
||||
|
||||
# Should save checkpoint with state
|
||||
mock_working.set.assert_called_once()
|
||||
call_args = mock_working.set.call_args
|
||||
# Check positional arg (first arg is key)
|
||||
assert "__checkpoint__ckpt-001" in call_args[0][0]
|
||||
|
||||
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
|
||||
async def test_pause_generates_checkpoint_id(
|
||||
self,
|
||||
mock_working_cls: MagicMock,
|
||||
lifecycle_manager: AgentLifecycleManager,
|
||||
) -> None:
|
||||
"""Pause should generate checkpoint ID if not provided."""
|
||||
mock_working = AsyncMock()
|
||||
mock_working.list_keys = AsyncMock(return_value=[])
|
||||
mock_working.set = AsyncMock()
|
||||
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
||||
|
||||
result = await lifecycle_manager.pause(
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=uuid4(),
|
||||
session_id="sess-123",
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert "checkpoint_id" in result.data
|
||||
assert result.data["checkpoint_id"].startswith("checkpoint_")
|
||||
|
||||
|
||||
class TestAgentLifecycleManagerResume:
|
||||
"""Tests for AgentLifecycleManager.resume."""
|
||||
|
||||
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
|
||||
async def test_resume_restores_checkpoint(
|
||||
self,
|
||||
mock_working_cls: MagicMock,
|
||||
lifecycle_manager: AgentLifecycleManager,
|
||||
) -> None:
|
||||
"""Resume should restore working memory from checkpoint."""
|
||||
checkpoint_data = {
|
||||
"state": {"key1": "value1", "key2": "value2"},
|
||||
"timestamp": datetime.now(UTC).isoformat(),
|
||||
"keys_count": 2,
|
||||
}
|
||||
|
||||
mock_working = AsyncMock()
|
||||
mock_working.list_keys = AsyncMock(return_value=[])
|
||||
mock_working.get = AsyncMock(return_value=checkpoint_data)
|
||||
mock_working.set = AsyncMock()
|
||||
mock_working.delete = AsyncMock()
|
||||
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
||||
|
||||
result = await lifecycle_manager.resume(
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=uuid4(),
|
||||
session_id="sess-123",
|
||||
checkpoint_id="ckpt-001",
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.event_type == "resume"
|
||||
assert result.data["items_restored"] == 2
|
||||
|
||||
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
|
||||
async def test_resume_checkpoint_not_found(
|
||||
self,
|
||||
mock_working_cls: MagicMock,
|
||||
lifecycle_manager: AgentLifecycleManager,
|
||||
) -> None:
|
||||
"""Resume should fail if checkpoint not found."""
|
||||
mock_working = AsyncMock()
|
||||
mock_working.get = AsyncMock(return_value=None)
|
||||
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
||||
|
||||
result = await lifecycle_manager.resume(
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=uuid4(),
|
||||
session_id="sess-123",
|
||||
checkpoint_id="nonexistent",
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert "not found" in result.message.lower()
|
||||
|
||||
|
||||
class TestAgentLifecycleManagerTerminate:
|
||||
"""Tests for AgentLifecycleManager.terminate."""
|
||||
|
||||
@patch("app.services.memory.integration.lifecycle.EpisodicMemory")
|
||||
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
|
||||
async def test_terminate_consolidates_to_episodic(
|
||||
self,
|
||||
mock_working_cls: MagicMock,
|
||||
mock_episodic_cls: MagicMock,
|
||||
lifecycle_manager: AgentLifecycleManager,
|
||||
) -> None:
|
||||
"""Terminate should consolidate working memory to episodic."""
|
||||
mock_working = AsyncMock()
|
||||
mock_working.list_keys = AsyncMock(return_value=["key1", "key2"])
|
||||
mock_working.get = AsyncMock(return_value="value")
|
||||
mock_working.delete = AsyncMock()
|
||||
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
||||
|
||||
mock_episode = MagicMock()
|
||||
mock_episode.id = uuid4()
|
||||
|
||||
mock_episodic = AsyncMock()
|
||||
mock_episodic.record_episode = AsyncMock(return_value=mock_episode)
|
||||
mock_episodic_cls.create = AsyncMock(return_value=mock_episodic)
|
||||
|
||||
result = await lifecycle_manager.terminate(
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=uuid4(),
|
||||
session_id="sess-123",
|
||||
task_description="Completed task",
|
||||
outcome=Outcome.SUCCESS,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.event_type == "terminate"
|
||||
assert result.data["episode_id"] == str(mock_episode.id)
|
||||
assert result.data["state_items_consolidated"] == 2
|
||||
mock_episodic.record_episode.assert_called_once()
|
||||
|
||||
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
|
||||
async def test_terminate_cleans_up_working(
|
||||
self,
|
||||
mock_working_cls: MagicMock,
|
||||
lifecycle_manager: AgentLifecycleManager,
|
||||
) -> None:
|
||||
"""Terminate should clean up working memory."""
|
||||
mock_working = AsyncMock()
|
||||
mock_working.list_keys = AsyncMock(return_value=["key1", "key2"])
|
||||
mock_working.get = AsyncMock(return_value="value")
|
||||
mock_working.delete = AsyncMock()
|
||||
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
||||
|
||||
result = await lifecycle_manager.terminate(
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=uuid4(),
|
||||
session_id="sess-123",
|
||||
consolidate_to_episodic=False,
|
||||
cleanup_working=True,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.data["items_cleared"] == 2
|
||||
assert mock_working.delete.call_count == 2
|
||||
|
||||
|
||||
class TestAgentLifecycleManagerListCheckpoints:
|
||||
"""Tests for AgentLifecycleManager.list_checkpoints."""
|
||||
|
||||
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
|
||||
async def test_list_checkpoints(
|
||||
self,
|
||||
mock_working_cls: MagicMock,
|
||||
lifecycle_manager: AgentLifecycleManager,
|
||||
) -> None:
|
||||
"""Should list available checkpoints."""
|
||||
mock_working = AsyncMock()
|
||||
mock_working.list_keys = AsyncMock(
|
||||
return_value=[
|
||||
"__checkpoint__ckpt-001",
|
||||
"__checkpoint__ckpt-002",
|
||||
"regular_key",
|
||||
]
|
||||
)
|
||||
mock_working.get = AsyncMock(
|
||||
return_value={
|
||||
"timestamp": "2024-01-01T00:00:00Z",
|
||||
"keys_count": 5,
|
||||
}
|
||||
)
|
||||
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
||||
|
||||
checkpoints = await lifecycle_manager.list_checkpoints(
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=uuid4(),
|
||||
session_id="sess-123",
|
||||
)
|
||||
|
||||
assert len(checkpoints) == 2
|
||||
assert checkpoints[0]["checkpoint_id"] == "ckpt-001"
|
||||
assert checkpoints[0]["keys_count"] == 5
|
||||
|
||||
|
||||
class TestGetLifecycleManager:
|
||||
"""Tests for factory function."""
|
||||
|
||||
async def test_creates_instance(self) -> None:
|
||||
"""Factory should create AgentLifecycleManager instance."""
|
||||
mock_session = MagicMock()
|
||||
|
||||
manager = await get_lifecycle_manager(mock_session)
|
||||
|
||||
assert isinstance(manager, AgentLifecycleManager)
|
||||
|
||||
async def test_with_custom_hooks(self) -> None:
|
||||
"""Factory should accept custom hooks."""
|
||||
mock_session = MagicMock()
|
||||
custom_hooks = LifecycleHooks()
|
||||
|
||||
manager = await get_lifecycle_manager(mock_session, hooks=custom_hooks)
|
||||
|
||||
assert manager.hooks is custom_hooks
|
||||
2
backend/tests/unit/services/memory/mcp/__init__.py
Normal file
2
backend/tests/unit/services/memory/mcp/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# tests/unit/services/memory/mcp/__init__.py
|
||||
"""Tests for memory MCP tools."""
|
||||
661
backend/tests/unit/services/memory/mcp/test_service.py
Normal file
661
backend/tests/unit/services/memory/mcp/test_service.py
Normal file
@@ -0,0 +1,661 @@
|
||||
# tests/unit/services/memory/mcp/test_service.py
|
||||
"""Tests for MemoryToolService."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.memory.mcp.service import (
|
||||
MemoryToolService,
|
||||
ToolContext,
|
||||
ToolResult,
|
||||
get_memory_tool_service,
|
||||
)
|
||||
from app.services.memory.types import Outcome
|
||||
|
||||
pytestmark = pytest.mark.asyncio(loop_scope="function")
|
||||
|
||||
|
||||
def make_context(
|
||||
project_id: UUID | None = None,
|
||||
agent_instance_id: UUID | None = None,
|
||||
session_id: str | None = None,
|
||||
) -> ToolContext:
|
||||
"""Create a test context."""
|
||||
return ToolContext(
|
||||
project_id=project_id or uuid4(),
|
||||
agent_instance_id=agent_instance_id or uuid4(),
|
||||
session_id=session_id or "test-session",
|
||||
)
|
||||
|
||||
|
||||
def make_mock_session() -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
session = AsyncMock()
|
||||
session.execute = AsyncMock()
|
||||
session.commit = AsyncMock()
|
||||
session.flush = AsyncMock()
|
||||
return session
|
||||
|
||||
|
||||
class TestToolContext:
|
||||
"""Tests for ToolContext dataclass."""
|
||||
|
||||
def test_context_creation(self) -> None:
|
||||
"""Context should be creatable with required fields."""
|
||||
project_id = uuid4()
|
||||
ctx = ToolContext(project_id=project_id)
|
||||
assert ctx.project_id == project_id
|
||||
assert ctx.agent_instance_id is None
|
||||
assert ctx.session_id is None
|
||||
|
||||
def test_context_with_all_fields(self) -> None:
|
||||
"""Context should accept all optional fields."""
|
||||
project_id = uuid4()
|
||||
agent_id = uuid4()
|
||||
ctx = ToolContext(
|
||||
project_id=project_id,
|
||||
agent_instance_id=agent_id,
|
||||
agent_type_id=uuid4(),
|
||||
session_id="session-123",
|
||||
user_id=uuid4(),
|
||||
)
|
||||
assert ctx.project_id == project_id
|
||||
assert ctx.agent_instance_id == agent_id
|
||||
assert ctx.session_id == "session-123"
|
||||
|
||||
|
||||
class TestToolResult:
|
||||
"""Tests for ToolResult dataclass."""
|
||||
|
||||
def test_success_result(self) -> None:
|
||||
"""Success result should have correct fields."""
|
||||
result = ToolResult(
|
||||
success=True,
|
||||
data={"key": "value"},
|
||||
execution_time_ms=10.5,
|
||||
)
|
||||
assert result.success is True
|
||||
assert result.data == {"key": "value"}
|
||||
assert result.error is None
|
||||
|
||||
def test_error_result(self) -> None:
|
||||
"""Error result should have correct fields."""
|
||||
result = ToolResult(
|
||||
success=False,
|
||||
error="Something went wrong",
|
||||
error_code="VALIDATION_ERROR",
|
||||
)
|
||||
assert result.success is False
|
||||
assert result.error == "Something went wrong"
|
||||
assert result.error_code == "VALIDATION_ERROR"
|
||||
|
||||
def test_to_dict(self) -> None:
|
||||
"""Result should convert to dict correctly."""
|
||||
result = ToolResult(
|
||||
success=True,
|
||||
data={"test": 1},
|
||||
execution_time_ms=5.0,
|
||||
)
|
||||
result_dict = result.to_dict()
|
||||
assert result_dict["success"] is True
|
||||
assert result_dict["data"] == {"test": 1}
|
||||
assert result_dict["execution_time_ms"] == 5.0
|
||||
|
||||
|
||||
class TestMemoryToolService:
|
||||
"""Tests for MemoryToolService."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock session."""
|
||||
return make_mock_session()
|
||||
|
||||
@pytest.fixture
|
||||
def service(self, mock_session: AsyncMock) -> MemoryToolService:
|
||||
"""Create a service with mock session."""
|
||||
return MemoryToolService(session=mock_session)
|
||||
|
||||
@pytest.fixture
|
||||
def context(self) -> ToolContext:
|
||||
"""Create a test context."""
|
||||
return make_context()
|
||||
|
||||
async def test_execute_unknown_tool(
|
||||
self,
|
||||
service: MemoryToolService,
|
||||
context: ToolContext,
|
||||
) -> None:
|
||||
"""Unknown tool should return error."""
|
||||
result = await service.execute_tool(
|
||||
tool_name="unknown_tool",
|
||||
arguments={},
|
||||
context=context,
|
||||
)
|
||||
assert result.success is False
|
||||
assert result.error_code == "UNKNOWN_TOOL"
|
||||
|
||||
async def test_execute_with_invalid_args(
|
||||
self,
|
||||
service: MemoryToolService,
|
||||
context: ToolContext,
|
||||
) -> None:
|
||||
"""Invalid arguments should return validation error."""
|
||||
result = await service.execute_tool(
|
||||
tool_name="remember",
|
||||
arguments={"memory_type": "invalid_type"},
|
||||
context=context,
|
||||
)
|
||||
assert result.success is False
|
||||
assert result.error_code == "VALIDATION_ERROR"
|
||||
|
||||
@patch("app.services.memory.mcp.service.WorkingMemory")
|
||||
async def test_remember_working_memory(
|
||||
self,
|
||||
mock_working_cls: MagicMock,
|
||||
service: MemoryToolService,
|
||||
context: ToolContext,
|
||||
) -> None:
|
||||
"""Remember should store in working memory."""
|
||||
# Setup mock
|
||||
mock_working = AsyncMock()
|
||||
mock_working.set = AsyncMock()
|
||||
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
||||
|
||||
result = await service.execute_tool(
|
||||
tool_name="remember",
|
||||
arguments={
|
||||
"memory_type": "working",
|
||||
"content": "Test content",
|
||||
"key": "test_key",
|
||||
"ttl_seconds": 3600,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.data["stored"] is True
|
||||
assert result.data["memory_type"] == "working"
|
||||
assert result.data["key"] == "test_key"
|
||||
|
||||
async def test_remember_episodic_memory(
|
||||
self,
|
||||
service: MemoryToolService,
|
||||
context: ToolContext,
|
||||
) -> None:
|
||||
"""Remember should store in episodic memory."""
|
||||
with patch(
|
||||
"app.services.memory.mcp.service.EpisodicMemory"
|
||||
) as mock_episodic_cls:
|
||||
# Setup mock
|
||||
mock_episode = MagicMock()
|
||||
mock_episode.id = uuid4()
|
||||
|
||||
mock_episodic = AsyncMock()
|
||||
mock_episodic.record_episode = AsyncMock(return_value=mock_episode)
|
||||
mock_episodic_cls.create = AsyncMock(return_value=mock_episodic)
|
||||
|
||||
result = await service.execute_tool(
|
||||
tool_name="remember",
|
||||
arguments={
|
||||
"memory_type": "episodic",
|
||||
"content": "Important event happened",
|
||||
"importance": 0.8,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.data["stored"] is True
|
||||
assert result.data["memory_type"] == "episodic"
|
||||
assert "episode_id" in result.data
|
||||
|
||||
async def test_remember_working_without_key(
|
||||
self,
|
||||
service: MemoryToolService,
|
||||
context: ToolContext,
|
||||
) -> None:
|
||||
"""Working memory without key should fail."""
|
||||
result = await service.execute_tool(
|
||||
tool_name="remember",
|
||||
arguments={
|
||||
"memory_type": "working",
|
||||
"content": "Test content",
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert "key is required" in result.error.lower()
|
||||
|
||||
async def test_remember_working_without_session(
|
||||
self,
|
||||
service: MemoryToolService,
|
||||
) -> None:
|
||||
"""Working memory without session should fail."""
|
||||
context = ToolContext(project_id=uuid4(), session_id=None)
|
||||
|
||||
result = await service.execute_tool(
|
||||
tool_name="remember",
|
||||
arguments={
|
||||
"memory_type": "working",
|
||||
"content": "Test content",
|
||||
"key": "test_key",
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert "session id is required" in result.error.lower()
|
||||
|
||||
async def test_remember_semantic_memory(
|
||||
self,
|
||||
service: MemoryToolService,
|
||||
context: ToolContext,
|
||||
) -> None:
|
||||
"""Remember should store facts in semantic memory."""
|
||||
with patch(
|
||||
"app.services.memory.mcp.service.SemanticMemory"
|
||||
) as mock_semantic_cls:
|
||||
mock_fact = MagicMock()
|
||||
mock_fact.id = uuid4()
|
||||
|
||||
mock_semantic = AsyncMock()
|
||||
mock_semantic.store_fact = AsyncMock(return_value=mock_fact)
|
||||
mock_semantic_cls.create = AsyncMock(return_value=mock_semantic)
|
||||
|
||||
result = await service.execute_tool(
|
||||
tool_name="remember",
|
||||
arguments={
|
||||
"memory_type": "semantic",
|
||||
"content": "User prefers dark mode",
|
||||
"subject": "User",
|
||||
"predicate": "prefers",
|
||||
"object_value": "dark mode",
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.data["memory_type"] == "semantic"
|
||||
assert "fact_id" in result.data
|
||||
assert "triple" in result.data
|
||||
|
||||
async def test_remember_semantic_without_fields(
|
||||
self,
|
||||
service: MemoryToolService,
|
||||
context: ToolContext,
|
||||
) -> None:
|
||||
"""Semantic memory without subject/predicate/object should fail."""
|
||||
result = await service.execute_tool(
|
||||
tool_name="remember",
|
||||
arguments={
|
||||
"memory_type": "semantic",
|
||||
"content": "Some content",
|
||||
"subject": "User",
|
||||
# Missing predicate and object_value
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert "required" in result.error.lower()
|
||||
|
||||
async def test_remember_procedural_memory(
|
||||
self,
|
||||
service: MemoryToolService,
|
||||
context: ToolContext,
|
||||
) -> None:
|
||||
"""Remember should store procedures in procedural memory."""
|
||||
with patch(
|
||||
"app.services.memory.mcp.service.ProceduralMemory"
|
||||
) as mock_procedural_cls:
|
||||
mock_procedure = MagicMock()
|
||||
mock_procedure.id = uuid4()
|
||||
|
||||
mock_procedural = AsyncMock()
|
||||
mock_procedural.record_procedure = AsyncMock(return_value=mock_procedure)
|
||||
mock_procedural_cls.create = AsyncMock(return_value=mock_procedural)
|
||||
|
||||
result = await service.execute_tool(
|
||||
tool_name="remember",
|
||||
arguments={
|
||||
"memory_type": "procedural",
|
||||
"content": "File creation procedure",
|
||||
"trigger": "When creating a new file",
|
||||
"steps": [
|
||||
{"action": "check_exists"},
|
||||
{"action": "create"},
|
||||
],
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.data["memory_type"] == "procedural"
|
||||
assert "procedure_id" in result.data
|
||||
assert result.data["steps_count"] == 2
|
||||
|
||||
@patch("app.services.memory.mcp.service.EpisodicMemory")
|
||||
@patch("app.services.memory.mcp.service.SemanticMemory")
|
||||
async def test_recall_from_multiple_types(
|
||||
self,
|
||||
mock_semantic_cls: MagicMock,
|
||||
mock_episodic_cls: MagicMock,
|
||||
service: MemoryToolService,
|
||||
context: ToolContext,
|
||||
) -> None:
|
||||
"""Recall should search across multiple memory types."""
|
||||
# Mock episodic
|
||||
mock_episode = MagicMock()
|
||||
mock_episode.id = uuid4()
|
||||
mock_episode.task_description = "Test episode"
|
||||
mock_episode.outcome = Outcome.SUCCESS
|
||||
mock_episode.occurred_at = datetime.now(UTC)
|
||||
mock_episode.importance_score = 0.9
|
||||
|
||||
mock_episodic = AsyncMock()
|
||||
mock_episodic.search_similar = AsyncMock(return_value=[mock_episode])
|
||||
mock_episodic_cls.create = AsyncMock(return_value=mock_episodic)
|
||||
|
||||
# Mock semantic
|
||||
mock_fact = MagicMock()
|
||||
mock_fact.id = uuid4()
|
||||
mock_fact.subject = "User"
|
||||
mock_fact.predicate = "prefers"
|
||||
mock_fact.object = "dark mode"
|
||||
mock_fact.confidence = 0.8
|
||||
|
||||
mock_semantic = AsyncMock()
|
||||
mock_semantic.search_facts = AsyncMock(return_value=[mock_fact])
|
||||
mock_semantic_cls.create = AsyncMock(return_value=mock_semantic)
|
||||
|
||||
result = await service.execute_tool(
|
||||
tool_name="recall",
|
||||
arguments={
|
||||
"query": "user preferences",
|
||||
"memory_types": ["episodic", "semantic"],
|
||||
"limit": 10,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.data["total_results"] == 2
|
||||
assert len(result.data["results"]) == 2
|
||||
|
||||
@patch("app.services.memory.mcp.service.WorkingMemory")
|
||||
async def test_forget_working_memory(
|
||||
self,
|
||||
mock_working_cls: MagicMock,
|
||||
service: MemoryToolService,
|
||||
context: ToolContext,
|
||||
) -> None:
|
||||
"""Forget should delete from working memory."""
|
||||
mock_working = AsyncMock()
|
||||
mock_working.delete = AsyncMock(return_value=True)
|
||||
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
||||
|
||||
result = await service.execute_tool(
|
||||
tool_name="forget",
|
||||
arguments={
|
||||
"memory_type": "working",
|
||||
"key": "temp_key",
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.data["deleted"] is True
|
||||
assert result.data["deleted_count"] == 1
|
||||
|
||||
async def test_forget_pattern_requires_confirm(
|
||||
self,
|
||||
service: MemoryToolService,
|
||||
context: ToolContext,
|
||||
) -> None:
|
||||
"""Pattern deletion should require confirmation."""
|
||||
with patch("app.services.memory.mcp.service.WorkingMemory") as mock_working_cls:
|
||||
mock_working = AsyncMock()
|
||||
mock_working.list_keys = AsyncMock(return_value=["cache_1", "cache_2"])
|
||||
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
||||
|
||||
result = await service.execute_tool(
|
||||
tool_name="forget",
|
||||
arguments={
|
||||
"memory_type": "working",
|
||||
"pattern": "cache_*",
|
||||
"confirm_bulk": False,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert "confirm_bulk" in result.error.lower()
|
||||
|
||||
@patch("app.services.memory.mcp.service.EpisodicMemory")
|
||||
async def test_reflect_recent_patterns(
|
||||
self,
|
||||
mock_episodic_cls: MagicMock,
|
||||
service: MemoryToolService,
|
||||
context: ToolContext,
|
||||
) -> None:
|
||||
"""Reflect should analyze recent patterns."""
|
||||
# Create mock episodes
|
||||
mock_episodes = []
|
||||
for i in range(5):
|
||||
ep = MagicMock()
|
||||
ep.id = uuid4()
|
||||
ep.task_type = "code_review" if i % 2 == 0 else "deployment"
|
||||
ep.outcome = Outcome.SUCCESS if i < 3 else Outcome.FAILURE
|
||||
ep.task_description = f"Episode {i}"
|
||||
ep.lessons_learned = None
|
||||
ep.occurred_at = datetime.now(UTC)
|
||||
mock_episodes.append(ep)
|
||||
|
||||
mock_episodic = AsyncMock()
|
||||
mock_episodic.get_recent = AsyncMock(return_value=mock_episodes)
|
||||
mock_episodic_cls.create = AsyncMock(return_value=mock_episodic)
|
||||
|
||||
result = await service.execute_tool(
|
||||
tool_name="reflect",
|
||||
arguments={
|
||||
"analysis_type": "recent_patterns",
|
||||
"depth": 3,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.data["analysis_type"] == "recent_patterns"
|
||||
assert result.data["total_episodes"] == 5
|
||||
assert "top_task_types" in result.data
|
||||
assert "outcome_distribution" in result.data
|
||||
|
||||
@patch("app.services.memory.mcp.service.EpisodicMemory")
|
||||
async def test_reflect_success_factors(
|
||||
self,
|
||||
mock_episodic_cls: MagicMock,
|
||||
service: MemoryToolService,
|
||||
context: ToolContext,
|
||||
) -> None:
|
||||
"""Reflect should analyze success factors."""
|
||||
mock_episodes = []
|
||||
for i in range(10):
|
||||
ep = MagicMock()
|
||||
ep.id = uuid4()
|
||||
ep.task_type = "code_review"
|
||||
ep.outcome = Outcome.SUCCESS if i < 8 else Outcome.FAILURE
|
||||
ep.task_description = f"Episode {i}"
|
||||
ep.lessons_learned = "Learned something" if i < 3 else None
|
||||
ep.occurred_at = datetime.now(UTC)
|
||||
mock_episodes.append(ep)
|
||||
|
||||
mock_episodic = AsyncMock()
|
||||
mock_episodic.get_recent = AsyncMock(return_value=mock_episodes)
|
||||
mock_episodic_cls.create = AsyncMock(return_value=mock_episodic)
|
||||
|
||||
result = await service.execute_tool(
|
||||
tool_name="reflect",
|
||||
arguments={
|
||||
"analysis_type": "success_factors",
|
||||
"include_examples": True,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.data["analysis_type"] == "success_factors"
|
||||
assert result.data["overall_success_rate"] == 0.8
|
||||
|
||||
@patch("app.services.memory.mcp.service.EpisodicMemory")
|
||||
@patch("app.services.memory.mcp.service.SemanticMemory")
|
||||
@patch("app.services.memory.mcp.service.ProceduralMemory")
|
||||
@patch("app.services.memory.mcp.service.WorkingMemory")
|
||||
async def test_get_memory_stats(
|
||||
self,
|
||||
mock_working_cls: MagicMock,
|
||||
mock_procedural_cls: MagicMock,
|
||||
mock_semantic_cls: MagicMock,
|
||||
mock_episodic_cls: MagicMock,
|
||||
service: MemoryToolService,
|
||||
context: ToolContext,
|
||||
) -> None:
|
||||
"""Get memory stats should return statistics."""
|
||||
# Setup mocks
|
||||
mock_working = AsyncMock()
|
||||
mock_working.list_keys = AsyncMock(return_value=["key1", "key2"])
|
||||
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
||||
|
||||
mock_episodic = AsyncMock()
|
||||
mock_episodic.get_recent = AsyncMock(
|
||||
return_value=[MagicMock() for _ in range(10)]
|
||||
)
|
||||
mock_episodic_cls.create = AsyncMock(return_value=mock_episodic)
|
||||
|
||||
mock_semantic = AsyncMock()
|
||||
mock_semantic.search_facts = AsyncMock(
|
||||
return_value=[MagicMock() for _ in range(5)]
|
||||
)
|
||||
mock_semantic_cls.create = AsyncMock(return_value=mock_semantic)
|
||||
|
||||
mock_procedural = AsyncMock()
|
||||
mock_procedural.find_matching = AsyncMock(
|
||||
return_value=[MagicMock() for _ in range(3)]
|
||||
)
|
||||
mock_procedural_cls.create = AsyncMock(return_value=mock_procedural)
|
||||
|
||||
result = await service.execute_tool(
|
||||
tool_name="get_memory_stats",
|
||||
arguments={
|
||||
"include_breakdown": True,
|
||||
"include_recent_activity": False,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert "breakdown" in result.data
|
||||
breakdown = result.data["breakdown"]
|
||||
assert breakdown["working"] == 2
|
||||
assert breakdown["episodic"] == 10
|
||||
assert breakdown["semantic"] == 5
|
||||
assert breakdown["procedural"] == 3
|
||||
assert breakdown["total"] == 20
|
||||
|
||||
@patch("app.services.memory.mcp.service.ProceduralMemory")
|
||||
async def test_search_procedures(
|
||||
self,
|
||||
mock_procedural_cls: MagicMock,
|
||||
service: MemoryToolService,
|
||||
context: ToolContext,
|
||||
) -> None:
|
||||
"""Search procedures should return matching procedures."""
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.id = uuid4()
|
||||
mock_proc.name = "Deployment procedure"
|
||||
mock_proc.description = "How to deploy"
|
||||
mock_proc.trigger = "When deploying"
|
||||
mock_proc.success_rate = 0.9
|
||||
mock_proc.execution_count = 10
|
||||
mock_proc.steps = [{"action": "deploy"}]
|
||||
|
||||
mock_procedural = AsyncMock()
|
||||
mock_procedural.find_matching = AsyncMock(return_value=[mock_proc])
|
||||
mock_procedural_cls.create = AsyncMock(return_value=mock_procedural)
|
||||
|
||||
result = await service.execute_tool(
|
||||
tool_name="search_procedures",
|
||||
arguments={
|
||||
"trigger": "Deploying to production",
|
||||
"min_success_rate": 0.8,
|
||||
"include_steps": True,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.data["procedures_found"] == 1
|
||||
proc = result.data["procedures"][0]
|
||||
assert proc["name"] == "Deployment procedure"
|
||||
assert "steps" in proc
|
||||
|
||||
async def test_record_outcome(
|
||||
self,
|
||||
service: MemoryToolService,
|
||||
context: ToolContext,
|
||||
) -> None:
|
||||
"""Record outcome should store outcome and update procedure."""
|
||||
with (
|
||||
patch(
|
||||
"app.services.memory.mcp.service.EpisodicMemory"
|
||||
) as mock_episodic_cls,
|
||||
patch(
|
||||
"app.services.memory.mcp.service.ProceduralMemory"
|
||||
) as mock_procedural_cls,
|
||||
):
|
||||
mock_episode = MagicMock()
|
||||
mock_episode.id = uuid4()
|
||||
|
||||
mock_episodic = AsyncMock()
|
||||
mock_episodic.record_episode = AsyncMock(return_value=mock_episode)
|
||||
mock_episodic_cls.create = AsyncMock(return_value=mock_episodic)
|
||||
|
||||
mock_procedural = AsyncMock()
|
||||
mock_procedural.record_outcome = AsyncMock()
|
||||
mock_procedural_cls.create = AsyncMock(return_value=mock_procedural)
|
||||
|
||||
result = await service.execute_tool(
|
||||
tool_name="record_outcome",
|
||||
arguments={
|
||||
"task_type": "code_review",
|
||||
"outcome": "success",
|
||||
"lessons_learned": "Breaking changes caught early",
|
||||
"duration_seconds": 120.5,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.data["recorded"] is True
|
||||
assert result.data["outcome"] == "success"
|
||||
assert "episode_id" in result.data
|
||||
|
||||
|
||||
class TestGetMemoryToolService:
|
||||
"""Tests for get_memory_tool_service factory."""
|
||||
|
||||
async def test_creates_service(self) -> None:
|
||||
"""Factory should create a service."""
|
||||
mock_session = make_mock_session()
|
||||
service = await get_memory_tool_service(mock_session)
|
||||
assert isinstance(service, MemoryToolService)
|
||||
|
||||
async def test_accepts_embedding_generator(self) -> None:
|
||||
"""Factory should accept embedding generator."""
|
||||
mock_session = make_mock_session()
|
||||
mock_generator = MagicMock()
|
||||
service = await get_memory_tool_service(mock_session, mock_generator)
|
||||
assert service._embedding_generator is mock_generator
|
||||
424
backend/tests/unit/services/memory/mcp/test_tools.py
Normal file
424
backend/tests/unit/services/memory/mcp/test_tools.py
Normal file
@@ -0,0 +1,424 @@
|
||||
# tests/unit/services/memory/mcp/test_tools.py
|
||||
"""Tests for MCP tool definitions."""
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.services.memory.mcp.tools import (
|
||||
MEMORY_TOOL_DEFINITIONS,
|
||||
AnalysisType,
|
||||
ForgetArgs,
|
||||
GetMemoryStatsArgs,
|
||||
MemoryToolDefinition,
|
||||
MemoryType,
|
||||
OutcomeType,
|
||||
RecallArgs,
|
||||
RecordOutcomeArgs,
|
||||
ReflectArgs,
|
||||
RememberArgs,
|
||||
SearchProceduresArgs,
|
||||
get_all_tool_schemas,
|
||||
get_tool_definition,
|
||||
)
|
||||
|
||||
|
||||
class TestMemoryType:
|
||||
"""Tests for MemoryType enum."""
|
||||
|
||||
def test_all_types_defined(self) -> None:
|
||||
"""All memory types should be defined."""
|
||||
assert MemoryType.WORKING == "working"
|
||||
assert MemoryType.EPISODIC == "episodic"
|
||||
assert MemoryType.SEMANTIC == "semantic"
|
||||
assert MemoryType.PROCEDURAL == "procedural"
|
||||
|
||||
def test_enum_values(self) -> None:
|
||||
"""Enum values should match strings."""
|
||||
assert MemoryType.WORKING.value == "working"
|
||||
assert MemoryType("episodic") == MemoryType.EPISODIC
|
||||
|
||||
|
||||
class TestAnalysisType:
|
||||
"""Tests for AnalysisType enum."""
|
||||
|
||||
def test_all_types_defined(self) -> None:
|
||||
"""All analysis types should be defined."""
|
||||
assert AnalysisType.RECENT_PATTERNS == "recent_patterns"
|
||||
assert AnalysisType.SUCCESS_FACTORS == "success_factors"
|
||||
assert AnalysisType.FAILURE_PATTERNS == "failure_patterns"
|
||||
assert AnalysisType.COMMON_PROCEDURES == "common_procedures"
|
||||
assert AnalysisType.LEARNING_PROGRESS == "learning_progress"
|
||||
|
||||
|
||||
class TestOutcomeType:
|
||||
"""Tests for OutcomeType enum."""
|
||||
|
||||
def test_all_outcomes_defined(self) -> None:
|
||||
"""All outcome types should be defined."""
|
||||
assert OutcomeType.SUCCESS == "success"
|
||||
assert OutcomeType.PARTIAL == "partial"
|
||||
assert OutcomeType.FAILURE == "failure"
|
||||
assert OutcomeType.ABANDONED == "abandoned"
|
||||
|
||||
|
||||
class TestRememberArgs:
|
||||
"""Tests for RememberArgs validation."""
|
||||
|
||||
def test_valid_working_memory_args(self) -> None:
|
||||
"""Valid working memory args should parse."""
|
||||
args = RememberArgs(
|
||||
memory_type=MemoryType.WORKING,
|
||||
content="Test content",
|
||||
key="test_key",
|
||||
ttl_seconds=3600,
|
||||
)
|
||||
assert args.memory_type == MemoryType.WORKING
|
||||
assert args.key == "test_key"
|
||||
assert args.ttl_seconds == 3600
|
||||
|
||||
def test_valid_semantic_args(self) -> None:
|
||||
"""Valid semantic memory args should parse."""
|
||||
args = RememberArgs(
|
||||
memory_type=MemoryType.SEMANTIC,
|
||||
content="User prefers dark mode",
|
||||
subject="User",
|
||||
predicate="prefers",
|
||||
object_value="dark mode",
|
||||
)
|
||||
assert args.subject == "User"
|
||||
assert args.predicate == "prefers"
|
||||
assert args.object_value == "dark mode"
|
||||
|
||||
def test_valid_procedural_args(self) -> None:
|
||||
"""Valid procedural memory args should parse."""
|
||||
args = RememberArgs(
|
||||
memory_type=MemoryType.PROCEDURAL,
|
||||
content="File creation procedure",
|
||||
trigger="When creating a new file",
|
||||
steps=[{"action": "check_exists"}, {"action": "create"}],
|
||||
)
|
||||
assert args.trigger == "When creating a new file"
|
||||
assert len(args.steps) == 2
|
||||
|
||||
def test_importance_validation(self) -> None:
|
||||
"""Importance must be between 0 and 1."""
|
||||
args = RememberArgs(
|
||||
memory_type=MemoryType.WORKING,
|
||||
content="Test",
|
||||
importance=0.8,
|
||||
)
|
||||
assert args.importance == 0.8
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
RememberArgs(
|
||||
memory_type=MemoryType.WORKING,
|
||||
content="Test",
|
||||
importance=1.5, # Invalid
|
||||
)
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
RememberArgs(
|
||||
memory_type=MemoryType.WORKING,
|
||||
content="Test",
|
||||
importance=-0.1, # Invalid
|
||||
)
|
||||
|
||||
def test_content_required(self) -> None:
|
||||
"""Content is required."""
|
||||
with pytest.raises(ValidationError):
|
||||
RememberArgs(
|
||||
memory_type=MemoryType.WORKING,
|
||||
content="", # Empty not allowed
|
||||
)
|
||||
|
||||
def test_ttl_validation(self) -> None:
|
||||
"""TTL must be within bounds."""
|
||||
with pytest.raises(ValidationError):
|
||||
RememberArgs(
|
||||
memory_type=MemoryType.WORKING,
|
||||
content="Test",
|
||||
ttl_seconds=0, # Too low
|
||||
)
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
RememberArgs(
|
||||
memory_type=MemoryType.WORKING,
|
||||
content="Test",
|
||||
ttl_seconds=86400 * 31, # Over 30 days
|
||||
)
|
||||
|
||||
def test_default_values(self) -> None:
|
||||
"""Default values should be set correctly."""
|
||||
args = RememberArgs(
|
||||
memory_type=MemoryType.WORKING,
|
||||
content="Test",
|
||||
)
|
||||
assert args.importance == 0.5
|
||||
assert args.ttl_seconds is None
|
||||
assert args.metadata == {}
|
||||
assert args.key is None
|
||||
|
||||
|
||||
class TestRecallArgs:
|
||||
"""Tests for RecallArgs validation."""
|
||||
|
||||
def test_valid_args(self) -> None:
|
||||
"""Valid recall args should parse."""
|
||||
args = RecallArgs(
|
||||
query="authentication errors",
|
||||
memory_types=[MemoryType.EPISODIC, MemoryType.SEMANTIC],
|
||||
limit=10,
|
||||
)
|
||||
assert args.query == "authentication errors"
|
||||
assert len(args.memory_types) == 2
|
||||
assert args.limit == 10
|
||||
|
||||
def test_default_memory_types(self) -> None:
|
||||
"""Default memory types should be episodic and semantic."""
|
||||
args = RecallArgs(query="test query")
|
||||
assert MemoryType.EPISODIC in args.memory_types
|
||||
assert MemoryType.SEMANTIC in args.memory_types
|
||||
|
||||
def test_limit_validation(self) -> None:
|
||||
"""Limit must be between 1 and 100."""
|
||||
with pytest.raises(ValidationError):
|
||||
RecallArgs(query="test", limit=0)
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
RecallArgs(query="test", limit=101)
|
||||
|
||||
def test_min_relevance_validation(self) -> None:
|
||||
"""Min relevance must be between 0 and 1."""
|
||||
args = RecallArgs(query="test", min_relevance=0.5)
|
||||
assert args.min_relevance == 0.5
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
RecallArgs(query="test", min_relevance=1.5)
|
||||
|
||||
|
||||
class TestForgetArgs:
|
||||
"""Tests for ForgetArgs validation."""
|
||||
|
||||
def test_valid_key_deletion(self) -> None:
|
||||
"""Valid key deletion args should parse."""
|
||||
args = ForgetArgs(
|
||||
memory_type=MemoryType.WORKING,
|
||||
key="temp_key",
|
||||
)
|
||||
assert args.memory_type == MemoryType.WORKING
|
||||
assert args.key == "temp_key"
|
||||
|
||||
def test_valid_id_deletion(self) -> None:
|
||||
"""Valid ID deletion args should parse."""
|
||||
args = ForgetArgs(
|
||||
memory_type=MemoryType.EPISODIC,
|
||||
memory_id="12345678-1234-1234-1234-123456789012",
|
||||
)
|
||||
assert args.memory_id is not None
|
||||
|
||||
def test_pattern_deletion_requires_confirm(self) -> None:
|
||||
"""Pattern deletion should parse but service should validate confirm."""
|
||||
args = ForgetArgs(
|
||||
memory_type=MemoryType.WORKING,
|
||||
pattern="cache_*",
|
||||
confirm_bulk=False,
|
||||
)
|
||||
assert args.pattern == "cache_*"
|
||||
assert args.confirm_bulk is False
|
||||
|
||||
|
||||
class TestReflectArgs:
|
||||
"""Tests for ReflectArgs validation."""
|
||||
|
||||
def test_valid_args(self) -> None:
|
||||
"""Valid reflect args should parse."""
|
||||
args = ReflectArgs(
|
||||
analysis_type=AnalysisType.SUCCESS_FACTORS,
|
||||
depth=3,
|
||||
)
|
||||
assert args.analysis_type == AnalysisType.SUCCESS_FACTORS
|
||||
assert args.depth == 3
|
||||
|
||||
def test_depth_validation(self) -> None:
|
||||
"""Depth must be between 1 and 5."""
|
||||
with pytest.raises(ValidationError):
|
||||
ReflectArgs(analysis_type=AnalysisType.SUCCESS_FACTORS, depth=0)
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
ReflectArgs(analysis_type=AnalysisType.SUCCESS_FACTORS, depth=6)
|
||||
|
||||
def test_default_values(self) -> None:
|
||||
"""Default values should be set correctly."""
|
||||
args = ReflectArgs(analysis_type=AnalysisType.RECENT_PATTERNS)
|
||||
assert args.depth == 3
|
||||
assert args.include_examples is True
|
||||
assert args.max_items == 10
|
||||
|
||||
|
||||
class TestGetMemoryStatsArgs:
|
||||
"""Tests for GetMemoryStatsArgs validation."""
|
||||
|
||||
def test_valid_args(self) -> None:
|
||||
"""Valid args should parse."""
|
||||
args = GetMemoryStatsArgs(
|
||||
include_breakdown=True,
|
||||
include_recent_activity=True,
|
||||
time_range_days=30,
|
||||
)
|
||||
assert args.include_breakdown is True
|
||||
assert args.time_range_days == 30
|
||||
|
||||
def test_time_range_validation(self) -> None:
|
||||
"""Time range must be between 1 and 90."""
|
||||
with pytest.raises(ValidationError):
|
||||
GetMemoryStatsArgs(time_range_days=0)
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
GetMemoryStatsArgs(time_range_days=91)
|
||||
|
||||
|
||||
class TestSearchProceduresArgs:
|
||||
"""Tests for SearchProceduresArgs validation."""
|
||||
|
||||
def test_valid_args(self) -> None:
|
||||
"""Valid args should parse."""
|
||||
args = SearchProceduresArgs(
|
||||
trigger="Deploying to production",
|
||||
min_success_rate=0.8,
|
||||
limit=5,
|
||||
)
|
||||
assert args.trigger == "Deploying to production"
|
||||
assert args.min_success_rate == 0.8
|
||||
|
||||
def test_trigger_required(self) -> None:
|
||||
"""Trigger is required."""
|
||||
with pytest.raises(ValidationError):
|
||||
SearchProceduresArgs(trigger="")
|
||||
|
||||
def test_success_rate_validation(self) -> None:
|
||||
"""Success rate must be between 0 and 1."""
|
||||
with pytest.raises(ValidationError):
|
||||
SearchProceduresArgs(trigger="test", min_success_rate=1.5)
|
||||
|
||||
|
||||
class TestRecordOutcomeArgs:
|
||||
"""Tests for RecordOutcomeArgs validation."""
|
||||
|
||||
def test_valid_success_args(self) -> None:
|
||||
"""Valid success args should parse."""
|
||||
args = RecordOutcomeArgs(
|
||||
task_type="code_review",
|
||||
outcome=OutcomeType.SUCCESS,
|
||||
lessons_learned="Breaking changes caught early",
|
||||
)
|
||||
assert args.task_type == "code_review"
|
||||
assert args.outcome == OutcomeType.SUCCESS
|
||||
|
||||
def test_valid_failure_args(self) -> None:
|
||||
"""Valid failure args should parse."""
|
||||
args = RecordOutcomeArgs(
|
||||
task_type="deployment",
|
||||
outcome=OutcomeType.FAILURE,
|
||||
error_details="Database migration timeout",
|
||||
duration_seconds=120.5,
|
||||
)
|
||||
assert args.outcome == OutcomeType.FAILURE
|
||||
assert args.error_details is not None
|
||||
|
||||
def test_task_type_required(self) -> None:
|
||||
"""Task type is required."""
|
||||
with pytest.raises(ValidationError):
|
||||
RecordOutcomeArgs(task_type="", outcome=OutcomeType.SUCCESS)
|
||||
|
||||
|
||||
class TestMemoryToolDefinition:
|
||||
"""Tests for MemoryToolDefinition class."""
|
||||
|
||||
def test_to_mcp_format(self) -> None:
|
||||
"""Tool should convert to MCP format."""
|
||||
tool = MemoryToolDefinition(
|
||||
name="test_tool",
|
||||
description="A test tool",
|
||||
args_schema=RememberArgs,
|
||||
)
|
||||
|
||||
mcp_format = tool.to_mcp_format()
|
||||
|
||||
assert mcp_format["name"] == "test_tool"
|
||||
assert mcp_format["description"] == "A test tool"
|
||||
assert "inputSchema" in mcp_format
|
||||
assert "properties" in mcp_format["inputSchema"]
|
||||
|
||||
def test_validate_args(self) -> None:
|
||||
"""Tool should validate args using schema."""
|
||||
tool = MemoryToolDefinition(
|
||||
name="remember",
|
||||
description="Store in memory",
|
||||
args_schema=RememberArgs,
|
||||
)
|
||||
|
||||
# Valid args
|
||||
validated = tool.validate_args(
|
||||
{
|
||||
"memory_type": "working",
|
||||
"content": "Test content",
|
||||
}
|
||||
)
|
||||
assert isinstance(validated, RememberArgs)
|
||||
|
||||
# Invalid args
|
||||
with pytest.raises(ValidationError):
|
||||
tool.validate_args({"memory_type": "invalid"})
|
||||
|
||||
|
||||
class TestToolDefinitions:
|
||||
"""Tests for the tool definitions dictionary."""
|
||||
|
||||
def test_all_tools_defined(self) -> None:
|
||||
"""All expected tools should be defined."""
|
||||
expected_tools = [
|
||||
"remember",
|
||||
"recall",
|
||||
"forget",
|
||||
"reflect",
|
||||
"get_memory_stats",
|
||||
"search_procedures",
|
||||
"record_outcome",
|
||||
]
|
||||
|
||||
for tool_name in expected_tools:
|
||||
assert tool_name in MEMORY_TOOL_DEFINITIONS
|
||||
assert isinstance(MEMORY_TOOL_DEFINITIONS[tool_name], MemoryToolDefinition)
|
||||
|
||||
def test_get_tool_definition(self) -> None:
|
||||
"""get_tool_definition should return correct tool."""
|
||||
tool = get_tool_definition("remember")
|
||||
assert tool is not None
|
||||
assert tool.name == "remember"
|
||||
|
||||
unknown = get_tool_definition("unknown_tool")
|
||||
assert unknown is None
|
||||
|
||||
def test_get_all_tool_schemas(self) -> None:
|
||||
"""get_all_tool_schemas should return MCP-formatted schemas."""
|
||||
schemas = get_all_tool_schemas()
|
||||
|
||||
assert len(schemas) == 7
|
||||
for schema in schemas:
|
||||
assert "name" in schema
|
||||
assert "description" in schema
|
||||
assert "inputSchema" in schema
|
||||
|
||||
def test_tool_descriptions_not_empty(self) -> None:
|
||||
"""All tools should have descriptions."""
|
||||
for name, tool in MEMORY_TOOL_DEFINITIONS.items():
|
||||
assert tool.description, f"Tool {name} has empty description"
|
||||
assert len(tool.description) > 50, f"Tool {name} description too short"
|
||||
|
||||
def test_input_schemas_have_properties(self) -> None:
|
||||
"""All tool schemas should have properties defined."""
|
||||
for name, tool in MEMORY_TOOL_DEFINITIONS.items():
|
||||
schema = tool.to_mcp_format()
|
||||
assert "properties" in schema["inputSchema"], (
|
||||
f"Tool {name} missing properties"
|
||||
)
|
||||
2
backend/tests/unit/services/memory/metrics/__init__.py
Normal file
2
backend/tests/unit/services/memory/metrics/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# tests/unit/services/memory/metrics/__init__.py
|
||||
"""Tests for Memory Metrics."""
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user